In [2]:
import models.basic as bs

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

import torch
import numpy as np
import json
import networkx as nx
from tqdm import tqdm
from time import time, ctime
from datetime import datetime
from utils import construct_mol, correct_mol, set_random_seed
from envs import environment as env
from models.graphflow import squeeze_adj
from models.MolHF import MolHF
from rdkit import Chem
from torch.utils.data import DataLoader
from dataloader import PretrainDataset
import argparse

In [None]:
# Предобученная модель для генерации

parser = argparse.ArgumentParser()

parser.dataset = 'zinc250k' 
parser.device = 'cuda' 
parser.deq_scale = 0.6 
parser.batch_size = 256
parser.lr = 1e-3 
parser.squeeze_fold = 2 
parser.n_block = 4 
parser.a_num_flows = 6 
parser.num_layers = 2 
parser.hid_dim = 256 
parser.b_num_flows = 3 
parser.filter_size = 256 
parser.temperature = 0.6 
parser.learn_prior = True 
parser.inv_conv = True 
parser.inv_rotate = True 
parser.condition = True 
parser.init_checkpoint = './save_pretrain/zinc250k_model/checkpoint.pth' 
parser.gen_num = 100


In [None]:
# Предобученная модель для оптимизации

parser = argparse.ArgumentParser()

parser.dataset = 'zinc250k' 
parser.device = 'cuda' 
parser.deq_scale = 0.6 
parser.batch_size = 256
parser.lr = 1e-3 
parser.squeeze_fold = 2 
parser.n_block = 4 
parser.a_num_flows = 6 
parser.num_layers = 2 
parser.hid_dim = 256 
parser.b_num_flows = 3 
parser.filter_size = 256 
parser.temperature = 0.6 
parser.learn_prior = True 
parser.inv_conv = True 
parser.inv_rotate = True 
parser.condition = True 
parser.init_checkpoint = './save_pretrain/zinc250k_model/checkpoint.pth' 
parser.topk = 30 
parser.num_iter = 10 
parser.opt_lr = 0.5 
parser.consopt = True

10

In [None]:
args = parser
set_random_seed(args.seed)
if args.save:
    dt = datetime.now()
    # TODO: Add more information.
    log_dir = os.path.join('./save_pretrain', args.model, args.order, '{}_{:02d}-{:02d}-{:02d}'.format(
        dt.date(), dt.hour, dt.minute, dt.second))
    args.save_path = log_dir

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

if args.dataset == 'polymer':
    # polymer
    num2atom = {0: 6, 1: 7, 2: 8, 3: 9, 4: 14, 5: 15, 6: 16}
    atom_valency = {6: 4, 7: 3, 8: 2, 9: 1, 14: 4, 15: 3, 16: 2}
else:
    # zinc250k
    num2atom = {0: 6, 1: 7, 2: 8, 3: 9, 4: 15, 5: 16, 6: 17, 7: 35, 8: 53}
    atom_valency = {6: 4, 7: 3, 8: 2, 9: 1,
                    15: 3, 16: 2, 17: 1, 35: 1, 53: 1}

# load data
data_path = os.path.join('./data_preprocessed', args.dataset)
with open(os.path.join(data_path, 'config.txt'), 'r') as f:
    data_config = eval(f.read())
dataset = PretrainDataset(
    data_path, data_config, args)
# print(list(dataset))
train_loader = DataLoader(dataset, batch_size=args.batch_size,
                            collate_fn=PretrainDataset.collate_fn, shuffle=True, num_workers=args.num_workers, drop_last=True)
with open('train_loader.pickle', 'wb') as file:
    # Сериализация и сохранение объекта в файл
    print('try')
    pickle.dump(train_loader, file)

trainer = Trainer(train_loader, None, args)
if args.init_checkpoint is not None:
    trainer.initialize_from_checkpoint(train=args.train)
if args.train:
    if args.save:
        mol_out_dir = os.path.join(log_dir, 'mols')

        if not os.path.exists(mol_out_dir):
            os.makedirs(mol_out_dir)
    else:
        mol_out_dir = None
    start = time()
    trainer.fit(mol_out_dir=mol_out_dir)
    print('Task model fitting done! Time {:.2f} seconds, Data: {}'.format(
        time() - start, ctime()))

elif args.resample:
    trainer.resampling_molecules(resample_mode=0)
else:
    print('Start generating!')
    start = time()
    valid_ratio = []
    unique_ratio = []
    novel_ratio = []
    valid_5atom_ratio = []
    valid_39atom_ratio = []
    for i in range(5):
        _, Validity, Validity_without_check, Uniqueness, Novelty, _, mol_atom_size = trainer.generate_molecule(
            args.gen_num)
        valid_ratio.append(Validity)
        unique_ratio.append(Uniqueness)
        novel_ratio.append(Novelty)
        valid_5atom_ratio.append(
            np.sum(np.array(mol_atom_size) >= 5) / args.gen_num * 100)
        valid_39atom_ratio.append(
            np.sum(np.array(mol_atom_size) >= 39) / args.gen_num * 100)

    print("validity: mean={:.2f}%, sd={:.2f}%, vals={}".format(
        np.mean(valid_ratio), np.std(valid_ratio), valid_ratio))
    print("validity if atom >= 5: mean={:.2f}%, sd={:.2f}%, vals={}".format(
        np.mean(valid_5atom_ratio), np.std(valid_5atom_ratio), valid_5atom_ratio))
    print("validity if atom >= 39: mean={:.2f}%, sd={:.2f}%, vals={}".format(
        np.mean(valid_39atom_ratio), np.std(valid_39atom_ratio), valid_39atom_ratio))
    print("novelty: mean={:.2f}%, sd={:.2f}%, vals={}".format(
        np.mean(novel_ratio), np.std(novel_ratio), novel_ratio))
    print("uniqueness: mean={:.2f}%, sd={:.2f}%, vals={}".format(
        np.mean(unique_ratio), np.std(unique_ratio), unique_ratio))
    print('Task random generation done! Time {:.2f} seconds, Data: {}'.format(
        time() - start, ctime()))


In [24]:
prior_dist = torch.distributions.normal.Normal(torch.zeros([3]), 2*torch.ones([3]))
z = prior_dist.sample((20000000,))
z.mean(axis=0), z.std(axis=0)

(tensor([ 0.0004, -0.0004,  0.0002]), tensor([2.0000, 2.0000, 1.9999]))

In [3]:
bs.test_ZeroConv2d()

x.shape: torch.Size([2, 1, 5, 5])
tensor([[[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]]])
y.shape torch.Size([2, 2, 5, 5])
tensor([[[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0.,