In [1]:
import os
import sys
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"

# for linux env.
sys.path.insert(0,'..')
import argparse
# from distutils.util import strtobool
import json
import torch
# import torch.nn as nn
# import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm
# import networkx as nx
# from rdkit import Chem, DataStructs
# from rdkit.Chem import AllChem
from copy import deepcopy
# from mflow.generate import generate_mols_along_axis
from dataloader import PretrainDataset
from models.MolHF import MolHF
from torch.utils.data import DataLoader
# from envs import environment as env
# from envs.timereport import TimeReport
# from envs.environment import penalized_logp, qed 
# from utils import check_validity, adj_to_smiles, smiles_to_adj, construct_mol
from multiprocessing import Pool
# from sklearn.metrics import r2_score, mean_absolute_error
# from dataloader import get_mol_data
from time import time, ctime
# import functools
import optimize_property as op



In [2]:
# Предобученная модель для оптимизации

parser = argparse.ArgumentParser()

#общие параметры
parser.dataset = 'zinc250k' 
parser.device = 'cpu' 
parser.seed = 42
parser.save = True
parser.model = 'MolHF'
parser.order = 'bfs'
parser.property_name = 'plogp'

parser.init_checkpoint = './save_pretrain/zinc250k_model/checkpoint.pth'
parser.model_dir = './save_optimization'
parser.property_model_path = 'plogp_moflow_zinc250k_10.pth'

# параметры модели
parser.deq_scale = 0.6 
parser.batch_size = 256
parser.lr = 1e-3 
parser.squeeze_fold = 2 
parser.n_block = 4 
parser.a_num_flows = 6 
parser.num_layers = 2 
parser.hid_dim = 256 
parser.b_num_flows = 3 
parser.filter_size = 256 
parser.temperature = 0.6 
parser.learn_prior = True 
parser.inv_conv = True 
parser.inv_rotate = True 
parser.condition = True
parser.hidden = '32'

parser.num_data = None
parser.is_test_idx = False
parser.num_workers = 0
parser.deq_type = 'random'
parser.debug = 'true'

# опциональные параметры для оптимизации
parser.split = 'moflow'
parser.topk = 30 
parser.num_iter = 10 
parser.opt_lr = 0.5
parser.topscore = False
parser.consopt = True

In [3]:
start = time()
print("Start at Time: {}".format(ctime()))
args = parser
# configuration
if args.dataset == 'polymer':
    # polymer
    num2atom = {0: 6, 1: 7, 2: 8, 3: 9, 4: 14, 5: 15, 6: 16}
    atom_valency = {6: 4, 7: 3, 8: 2, 9: 1, 14: 4, 15: 3, 16: 2}
else:
    # zinc250k
    num2atom = {0: 6, 1: 7, 2: 8, 3: 9, 4: 15, 5: 16, 6: 17, 7: 35, 8: 53}
    atom_valency = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1}

args.strides = [2, 2, 2]
data_path = os.path.join('./data_preprocessed', args.dataset)
with open(os.path.join(data_path, 'config.txt'), 'r') as f:
    data_config = eval(f.read())

with open("./dataset/zinc250k/{}_idx.json".format(args.split), "r") as f:
    train_idx, valid_idx = json.load(f)
dataset = PretrainDataset("./data_preprocessed/{}".format(args.dataset), data_config, args)
train_dataset = deepcopy(dataset)
train_dataset._indices = train_idx
valid_dataset = deepcopy(dataset)
valid_dataset._indices = valid_idx

if not os.path.exists(os.path.join("./data_preprocessed/{}".format(args.dataset), 'zinc250k_property.csv')):
    smiles_list = dataset.all_smiles
    property_list = []
    print(torch.multiprocessing.cpu_count())
    with Pool(processes=torch.multiprocessing.cpu_count()) as pool:
        iter = pool.imap(op.get_mol_property, smiles_list)
        for idx, data in tqdm(enumerate(iter), total=len(smiles_list)):
            property_list.append(data)
    mol_property = np.array(property_list)
    table = pd.DataFrame(mol_property, columns=['qed', 'plogp'])
    table['smile'] = smiles_list
    table.to_csv(os.path.join("./data_preprocessed/{}".format(args.dataset), 'zinc250k_property.csv'), index=False)

if args.hidden in ('', ','):
    hidden = []
else:
    hidden = [int(d) for d in args.hidden.strip(',').split(',')]
print('Hidden dim for output regression: ', hidden)

if args.property_model_path is None:
    property_list = op.load_property_csv(args.dataset, normalize=True)
    mol_property = np.array(property_list) 
    train_dataset.is_mol_property = True
    train_dataset.mol_property = mol_property
    valid_dataset.is_mol_property = True
    valid_dataset.mol_property = mol_property
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size,collate_fn=PretrainDataset.collate_fn, num_workers=args.num_workers, drop_last=True)
    valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size,collate_fn=PretrainDataset.collate_fn, num_workers=args.num_workers, drop_last=True)
    print('Prepare data done! Time {:.2f} seconds'.format(time() - start))
    property_model_path = os.path.join(args.model_dir, '{}_{}_{}_{}.pth'.format(args.property_name, args.split, args.dataset, args.ratio))
    
    model = MolHF(data_config, args).to(args.device)
    op.initialize_from_checkpoint(model, args)
    property_model = op.FlowProp(model, hidden).to(args.device)
    property_model = op.fit_model(property_model, train_loader, valid_loader, args, property_model_path)   
else:
    print("Loading trained regression model for optimization")
    print('Prepare data done! Time {:.2f} seconds'.format(time() - start))
    prop_list = op.load_property_csv(args.dataset, normalize=False)
    train_prop = [prop_list[i] for i in train_idx]
    test_prop = [prop_list[i] for i in valid_idx]
    property_model_path = os.path.join(args.model_dir, args.property_model_path)
    print("loading {} regression model from: {}".format(args.property_name, property_model_path))
    model = MolHF(data_config, args).to(args.device)
    op.initialize_from_checkpoint(model, args)
    property_model = op.FlowProp(model, hidden).to(args.device)
    property_model.load_state_dict(torch.load(property_model_path, map_location=args.device))
    print('Load model done! Time {:.2f} seconds'.format(time() - start))

    property_model.eval()

    if args.topscore:
        print('Finding top score:')
        op.find_top_score_smiles(property_model, test_prop if args.is_test_idx else train_prop, data_config, args)

    if args.consopt:
        print('Constrained optimization:')
        op.constrain_optimization_smiles(property_model, test_prop if args.is_test_idx else train_prop, data_config, args)
        
    print('Total Time {:.2f} seconds'.format(time() - start))



Start at Time: Tue Mar 25 14:23:03 2025
reading data from ./data_preprocessed/zinc250k
Atom order: bfs
Hidden dim for output regression:  [32]
Loading trained regression model for optimization
Prepare data done! Time 5.57 seconds
Load ./data_preprocessed/zinc250k/zinc250k_property.csv done, length: 249456
loading plogp regression model from: ./save_optimization\plogp_moflow_zinc250k_10.pth


  checkpoint = torch.load(args.init_checkpoint, map_location=args.device)


initialize from ./save_pretrain/zinc250k_model/checkpoint.pth Done!


  property_model.load_state_dict(torch.load(property_model_path, map_location=args.device))


Load model done! Time 9.66 seconds
Constrained optimization:
Constrained optimization of plogp score
the number of molecue is 0
Optimization 0/30, time: 0.16 seconds


KeyError: 9