In [1]:
import os
import sys
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"

sys.path.insert(0,'..')
import argparse
import json
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from copy import deepcopy
from dataloader import PretrainDataset
from models.MolHF import MolHF
from torch.utils.data import DataLoader
from multiprocessing import Pool
from time import time, ctime
import optimize_property as op

In [2]:
# Предобученная модель для оптимизации

parser = argparse.ArgumentParser()

#общие параметры
parser.dataset = 'zinc250k' 
parser.device = 'cuda' 
parser.seed = 42
parser.save = True
parser.model = 'MolHF'
parser.order = 'bfs'
parser.property_name = 'qed'

parser.init_checkpoint = './save_pretrain/zinc250k_model/checkpoint.pth'
parser.model_dir = './save_optimization'
parser.property_model_path = 'qed_moflow_zinc250k_10.pth'

# параметры модели
parser.deq_scale = 0.6 
parser.batch_size = 256
parser.lr = 1e-3 
parser.squeeze_fold = 2 
parser.n_block = 4 
parser.a_num_flows = 6 
parser.num_layers = 2 
parser.hid_dim = 256 
parser.b_num_flows = 3 
parser.filter_size = 256 
parser.temperature = 0.6 
parser.learn_prior = True 
parser.inv_conv = True 
parser.inv_rotate = True 
parser.condition = True
parser.hidden = '32'

parser.num_data = None
parser.is_test_idx = False
parser.num_workers = 0
parser.deq_type = 'random'
parser.debug = 'true'
parser.lr_decay = 1

# опциональные параметры для оптимизации
parser.split = 'moflow'
parser.topk = 1
parser.num_iter = 100
parser.opt_lr = 0.8
parser.topscore = False
parser.consopt = True

In [3]:
start = time()
print("Start at Time: {}".format(ctime()))
args = parser
# configuration
if args.dataset == 'polymer':
    # polymer
    num2atom = {0: 6, 1: 7, 2: 8, 3: 9, 4: 14, 5: 15, 6: 16}
    atom_valency = {6: 4, 7: 3, 8: 2, 9: 1, 14: 4, 15: 3, 16: 2}
else:
    # zinc250k
    num2atom = {0: 6, 1: 7, 2: 8, 3: 9, 4: 15, 5: 16, 6: 17, 7: 35, 8: 53}
    atom_valency = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1}

args.strides = [2, 2, 2]
data_path = os.path.join('./data_preprocessed', args.dataset)
with open(os.path.join(data_path, 'config.txt'), 'r') as f:
    data_config = eval(f.read())

with open("./dataset/zinc250k/{}_idx.json".format(args.split), "r") as f:
    train_idx, valid_idx = json.load(f)
dataset = PretrainDataset("./data_preprocessed/{}".format(args.dataset), data_config, args)
train_dataset = deepcopy(dataset)
train_dataset._indices = train_idx
valid_dataset = deepcopy(dataset)
valid_dataset._indices = valid_idx

if not os.path.exists(os.path.join("./data_preprocessed/{}".format(args.dataset), 'zinc250k_property.csv')):
    smiles_list = dataset.all_smiles
    property_list = []
    print(torch.multiprocessing.cpu_count())
    with Pool(processes=torch.multiprocessing.cpu_count()) as pool:
        iter = pool.imap(op.get_mol_property, smiles_list)
        for idx, data in tqdm(enumerate(iter), total=len(smiles_list)):
            property_list.append(data)
    mol_property = np.array(property_list)
    table = pd.DataFrame(mol_property, columns=['qed', 'plogp'])
    table['smile'] = smiles_list
    table.to_csv(os.path.join("./data_preprocessed/{}".format(args.dataset), 'zinc250k_property.csv'), index=False)

if args.hidden in ('', ','):
    hidden = []
else:
    hidden = [int(d) for d in args.hidden.strip(',').split(',')]
print('Hidden dim for output regression: ', hidden)

if args.property_model_path is None:
    property_list = op.load_property_csv(args.dataset, normalize=True)
    mol_property = np.array(property_list) 
    train_dataset.is_mol_property = True
    train_dataset.mol_property = mol_property
    valid_dataset.is_mol_property = True
    valid_dataset.mol_property = mol_property
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size,collate_fn=PretrainDataset.collate_fn, num_workers=args.num_workers, drop_last=True)
    valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size,collate_fn=PretrainDataset.collate_fn, num_workers=args.num_workers, drop_last=True)
    print('Prepare data done! Time {:.2f} seconds'.format(time() - start))
    property_model_path = os.path.join(args.model_dir, '{}_{}_{}_{}.pth'.format(args.property_name, args.split, args.dataset, args.ratio))
    
    model = MolHF(data_config, args).to(args.device)
    op.initialize_from_checkpoint(model, args)
    property_model = op.FlowProp(model, hidden).to(args.device)
    property_model = op.fit_model(property_model, train_loader, valid_loader, args, property_model_path)   
else:
    print("Loading trained regression model for optimization")
    print('Prepare data done! Time {:.2f} seconds'.format(time() - start))
    prop_list = op.load_property_csv(args.dataset, normalize=False)
    train_prop = [prop_list[i] for i in train_idx]

    # DMNP
    dmnp_smiles = 'CC1CCC(C2=C1C=CC(=C2)C)C(C)CCC(=O)O'
    train_prop = [tuple(op.get_mol_property(dmnp_smiles) + [dmnp_smiles])]
    
    test_prop = [prop_list[i] for i in valid_idx]
    property_model_path = os.path.join(args.model_dir, args.property_model_path)
    print("loading {} regression model from: {}".format(args.property_name, property_model_path))
    model = MolHF(data_config, args).to(args.device)
    op.initialize_from_checkpoint(model, args)
    property_model = op.FlowProp(model, hidden).to(args.device)
    property_model.load_state_dict(torch.load(property_model_path, map_location=args.device))
    print('Load model done! Time {:.2f} seconds'.format(time() - start))

    property_model.eval()

    if args.topscore:
        print('Finding top score:')
        op.find_top_score_smiles(property_model, test_prop if args.is_test_idx else train_prop, data_config, args)

    if args.consopt:
        print('Constrained optimization:')
        op.constrain_optimization_smiles(property_model, test_prop if args.is_test_idx else train_prop, data_config, args)
        
    print('Total Time {:.2f} seconds'.format(time() - start))



Start at Time: Wed Apr  2 23:13:26 2025
reading data from ./data_preprocessed/zinc250k
Atom order: bfs
Hidden dim for output regression:  [32]
Loading trained regression model for optimization
Prepare data done! Time 0.86 seconds
Load ./data_preprocessed/zinc250k/zinc250k_property.csv done, length: 249456
loading qed regression model from: ./save_optimization\qed_moflow_zinc250k_10.pth
initialize from ./save_pretrain/zinc250k_model/checkpoint.pth Done!
Load model done! Time 3.87 seconds
Constrained optimization:
Constrained optimization of qed score
tp:  [(0.866434891722508, 1.2139587974886887, 'CC1CCC(C2=C1C=CC(=C2)C)C(C)CCC(=O)O')]
tps:  [(0.866434891722508, 1.2139587974886887, 'CC1CCC(C2=C1C=CC(=C2)C)C(C)CCC(=O)O')]
the number of molecue is 0
Optimization 0/1, time: 0.00 seconds
one
CC1CCC(C2=C1C=CC(=C2)C)C(C)CCC(=O)O ['CC1=CC2=C(C=C1)C(C)CCC2C(C)CCC(=O)O']
two
['CC1=CC2=C(C=C1)C(C)CCC2C(C)CCC(=O)O'] ['CC1=CC2=C(C=C1)C(C)CCC2C(C)CCC(=O)O']
nan
nan
nan
valid molecules: 100/100
[0] Cc

In [25]:
def dock(sml): return -10

def MFp_dist(x, y): return 1

def Fr_dist(x, y): return 1

In [22]:
def optimize_mol(optimize_model:op.FlowProp, smiles, data_config, args, random=False):
    
    lr = args.opt_lr
    num_iter = args.num_iter
    
    optimize_model.eval()

    with torch.no_grad():
        atoms, bond = op.smiles_to_adj(smiles, args.dataset)
        atoms, bond = op.get_mol_data(atoms, bond, data_config)
        atoms, bond = torch.from_numpy(atoms).unsqueeze(0), torch.from_numpy(bond).unsqueeze(0)
        atoms, bond = atoms.to(args.device), bond.to(args.device)
        mol_z, _ = optimize_model.encode(atoms, bond)
        # if args.debug:
        #     h = property_model.model.to_latent_format(mol_z)
        #     x_rev, adj_rev = property_model.reverse(h)
        #     reverse_smiles = adj_to_smiles(x_rev.cpu(), adj_rev.cpu(), num2atom, atom_valency)
        #     print('one')
        #     print(smiles, reverse_smiles)
        #     z, _, _,  = property_model.model(atoms, bond)
        #     x_rev, adj_rev = property_model.model.reverse(z)
        #     reverse_smiles2 = adj_to_smiles(x_rev.cpu(), adj_rev.cpu(), num2atom, atom_valency)
        #     train_smiles2 = adj_to_smiles(atoms.cpu(), bond.cpu(), num2atom, atom_valency)

        #     print('two')
        #     print(train_smiles2, reverse_smiles2)
            
    mol = op.Chem.MolFromSmiles(smiles)
    fp1 = op.AllChem.GetMorganFingerprint(mol, 2)
    start = (smiles, dock(mol), None) 
    mol_x, mol_adj = mol_z
    
    cur_xs = [x.clone().detach().requires_grad_(True).to(args.device) for x in mol_x]
    cur_adjs = [adj.clone().detach().requires_grad_(True).to(args.device) for adj in mol_adj]
    cur_vec = property_model.model.to_latent_format([cur_xs, cur_adjs])
    
    start_xs = [x.clone().detach().requires_grad_(True).to(args.device) for x in mol_x]
    start_adjs = [adj.clone().detach().requires_grad_(True).to(args.device) for adj in mol_adj]
    start_vec = property_model.model.to_latent_format([start_xs, start_adjs])

    visited = []
    # проход по шагам оптимизации
    for step in range(num_iter):
        prop_val = property_model.propNN(cur_vec).squeeze()
        grad = torch.autograd.grad(prop_val, cur_vec)[0]
        # cur_vec = cur_vec.data + lr * grad.data
        if random:
            rad = torch.randn_like(cur_vec.data)
            cur_vec = start_vec.data + lr * rad / torch.sqrt(rad * rad)
        else:
            cur_vec = cur_vec.data + lr * grad.data / torch.norm(grad.data, dim=-1)
        
        lr = lr*args.lr_decay
        cur_vec = cur_vec.clone().detach().requires_grad_(True).to(args.device)  # torch.tensor(cur_vec, requires_grad=True).to(mol_vec)
        visited.append(cur_vec)

    hidden_z = torch.cat(visited, dim=0).to(args.device) # собираем все
    x, adj = property_model.reverse(hidden_z) # делаем набор матриц
    
    val_res = op.check_validity(x, adj, num2atom, atom_valency, debug=args.debug)
    valid_mols = val_res['valid_mols']
    valid_smiles = val_res['valid_smiles']
    results = [[], [], [], []]
    sm_set = set()
    sm_set.add(smiles)
    for m, s in zip(valid_mols, valid_smiles):
        if s in sm_set or s == "":
            continue
        sm_set.add(s)
        p = dock(m)
        fp2 = op.AllChem.GetMorganFingerprint(m, 2) # здесь возможно возникает то самое предупреждение с MorganFingerprint
        sim = op.DataStructs.TanimotoSimilarity(fp1, fp2) # здесь вставить свой dist
        if sim >= 0:
            results[0].append((s, p, sim, smiles))
        if sim >= 0.2:
            results[1].append((s, p, sim, smiles))
        if sim >= 0.4:
            results[2].append((s, p, sim, smiles))
        if sim >= 0.6:
            results[3].append((s, p, sim, smiles))
    # smile, property, similarity, mol
    results[0].sort(key=lambda tup: tup[1], reverse=True)
    results[1].sort(key=lambda tup: tup[1], reverse=True)
    results[2].sort(key=lambda tup: tup[1], reverse=True)
    results[3].sort(key=lambda tup: tup[1], reverse=True)
    return results, start


def optimize_selected_mol(optimize_model, start_smiles, data_config, args, optim_mode='grad'):
    start_time = time()

    # if args.property_name == 'qed':
    #     col = 0
    # elif args.property_name == 'plogp':
    #     col = 1

    print('Optiimization {} for better {} score'.format(start_smiles[3], 'docking'))
    
    result_list = [[],[],[],[]]
    nfail = [0, 0, 0, 0]

    qed, plogp, d, smile = start_smiles
    property1 = d
    results, _ = optimize_mol(optimize_model, smile,  data_config, args, random=False)

    for t in range(len(results)):
        if len(results[t]) > 0:
            smile2, property2, sim, _ = results[t][0]
            prop_delta = property2 - property1
            if prop_delta <= 0:
                result_list[t].append((smile2, property2, sim, smile, qed, plogp, prop_delta))
            else:
                nfail[t] += 1
                print('Failure: for dist {} the best values less than initial'.format(t))
        else:
            nfail[t] += 1
            print('Failure: there is no moleculars with dist'.format())
    
    # for i, r in enumerate(train_prop_sorted):
    #     print("the number of molecue is {}".format(i))
    #     if i >= args.topk:
    #         break
    #     if i % 50 == 0:
    #         print('Optimization {}/{}, time: {:.2f} seconds'.format(i, args.topk, time() - start_time))
    #     qed, plogp, smile = r
    #     results, ori = optimize_mol(property_model, smile,  data_config, args, random=False)
    #     print('res: ', results)
    #     for t in range(len(results)):
    #         if len(results[t]) > 0:
    #             smile2, property2, sim, _ = results[t][0]
    #             plogp_delta = property2 - qed
    #             if plogp_delta >= 0:
    #                 result_list[t].append((smile2, property2, sim, smile, qed, plogp, plogp_delta))
    #             else:
    #                 nfail[t] += 1
    #                 print('Failure:{}:{}'.format(i, smile))
    #         else:
    #             nfail[t] += 1
    #             print('Failure:{}:{}'.format(i, smile))
                
    for i in range(len(result_list)):
        df = pd.DataFrame(result_list[i],
                        columns=['smile_new', 'prop_new', 'sim', 'smile_old', 'qed_old', 'plogp_old', 'plogp_delta'])

        print(df.describe())
        print("For sim > {}:".format(0.2*i))
        print('nfail:{} in total:{}'.format(nfail[i], args.topk))
        print('success rate: {}'.format((args.topk-nfail[i])*1.0/args.topk))


In [23]:
dmnp = tuple(op.get_mol_property('CC1CCC(C2=C1C=CC(=C2)C)C(C)CCC(=O)O') + [-10] + ['CC1CCC(C2=C1C=CC(=C2)C)C(C)CCC(=O)O'])

In [26]:
optimize_selected_mol(property_model, dmnp, data_config, args)

Optiimization CC1CCC(C2=C1C=CC(=C2)C)C(C)CCC(=O)O for better docking score
nan
nan
nan
valid molecules: 100/100
[0] Cc1ccc2c(c1)C(C(C)CCC(=O)O)CCC2C
[1] CC(CCC(=O)O)C1CCC(C)C2(C)CCC=CC12
[2] CC(CCC(=O)O)C1CCC(C)C2(C)CCC=CC12
[3] CC1=C(C2=C(F)CC2)C(C(C)CCC(=O)O)CCC1C
[4] CC1=C(C2=C(F)CC2)C(C(C)CCC(=O)O)CCC1C
[5] CC1=C(C2=C(F)CC2)C(C(C)CCC(=O)O)CNC1=O
[6] CC1=C(C2=C(F)CC2)C(C(C)CCC(=O)O)CNC1=O
[7] CC1=C(C2=C(F)CC2)C(C(C)CCC(=O)O)CNC1=O
[8] CC(=O)CCC(C)C1CNC(=O)C(C)=C1C1=C(F)CC1
[9] CC(=O)CCC(C)C1CNC(=O)C(C)=C1C1=C(F)CC1
[10] CC(=O)CCC(C)C(C)C1(C2=C(F)CC2)NC(=O)C1C
[11] CC(=O)CCC(C)C(C)C1(C2=C(F)CC2)NC(=O)C1C
[12] C=CC(C)C(C)(C1=C(F)CC1)C(C)C(C)CCC(C)=O
[13] C=CC(C)C(C)(C1CCC1=C)C(C)C(C)CCC(=C)C
[14] C=CC(C)C(C)(C1CCC1=C)C(C)C(C)CCC(=C)C
[15] C=CC(C)C(C)(C1CCC1=C)C(C)C(C)CCC(=C)C
[16] C=CC(C)C(C)(C1CCC1=C)C(C)C(C)CCC(=C)C
[17] C=CC(C)C(C)(C1CCC1=C)C(C)C(C)CCC(=C)C
[18] C=CC(C)C(C)(C1CCC1=C)C(C)C(C)CCC(=C)C
[19] C=C1C2CN1C1(O)C(C)N21
[20] C=C1C2CN1C1(O)C(C)N21
[21] C=C=C1CCC2(C)C(C)N12
[22