In [3]:
import argparse
import torch
import sys
sys.path.append('..')


from models.ECloudDecipher.models.io.coati import load_e3gnn_smiles_clip_e2e
from models.ECloudDecipher.models.regression.basic_due import basic_due
from models.ECloudDecipher.utils.chem import read_sdf, write_sdf, rm_radical, sa, qed, logp
from rdkit import Chem
import random
from models.ECloudDecipher.generative.molopt import gradient_opt
from models.ECloudDecipher.generative.coati_purifications import embed_smiles
from functools import partial
from torch.nn.functional import sigmoid
import torch.nn.functional as F
import numpy as np
from models.ECloudDecipher.generative.coati_purifications import force_decode_valid_batch, embed_smiles, force_decode_valid
import os.path as osp
from models.ECloudDecipher.optimize.scoring import ScoringFunction
from models.ECloudDecipher.optimize.mol_functions import qed_score, substructure_match_score, penalize_macrocycles, heavy_atom_count, penalized_logp_score
from models.ECloudDecipher.optimize.pso_optimizer import BasePSOptimizer
from models.ECloudDecipher.optimize.swarm import Swarm
from models.ECloudDecipher.optimize.rules.qsar_score import qsar_model


arg_parser = argparse.ArgumentParser(description='molecular optimization on the chemical space')
arg_parser.add_argument('--device', choices=['cuda:0', 'cpu'], \
    default='cuda:0',help='Device')
arg_parser.add_argument('--seed', type=int, default=2024) 
arg_parser.add_argument('--ecloudgen_ckpt', type=str, default = '../model_ckpts/ecloud_smiles_67.pkl')
args = arg_parser.parse_args([])

In [4]:
# model loading
DEVICE = torch.device(args.device)
DEVICE = 'cuda:0'
encoder, tokenizer = load_e3gnn_smiles_clip_e2e(
    freeze=True,
    device=DEVICE,
    # model parameters to load.
    doc_url=args.ecloudgen_ckpt,
)

Loading model from ../model_ckpts/ecloud_smiles_67.pkl
Loading tokenizer mar from ../model_ckpts/ecloud_smiles_67.pkl
number of parameters: 12.64M
number of parameters Total: 2.44M xformer: 19.60M Total: 22.04M 
Freezing encoder
44882816 params frozen!


In [5]:
class EPSO_format_model():
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device

    def seq_to_emb(self, smiles):
        if isinstance(smiles, str):
            smi_emb = embed_smiles(smiles, self.model, self.tokenizer).to(DEVICE)
            return smi_emb
        else:
            emb_list = []
            for smi in smiles:
                smi_emb = embed_smiles(smi, self.model, self.tokenizer).to(DEVICE)
                emb_list.append(smi_emb)
            return torch.stack(emb_list).reshape(-1, 256)

    
    def emb_to_seq(self, embs):

        seq_list = []
        for emb in embs:
            seq = force_decode_valid_batch(emb, self.model, self.tokenizer)
            seq_list.append(seq)
        return seq_list

In [8]:
from models.ECloudDecipher.optimize.mol_functions import obey_lipinski, get_HAcceptors, get_weight, sa_score, sim_score

# heavy atom count 
heavy_atom_count_desirability = [{"x": 0, "y": 0}, {"x": 5, "y": 0.1}, {"x": 15, "y": 0.9}, {"x": 20, "y": 1.0}, {"x": 25, "y": 1.0}, {"x": 30, "y": 0.9,}, {"x": 40, "y": 0.1}, {"x": 45, "y": 0.0}]
heavy_atom_count_scoring = ScoringFunction(heavy_atom_count, "hac", desirability=heavy_atom_count_desirability, is_mol_func=True)

# weight 
weight_desirability = [{"x": 100, "y": 0.5}, {"x": 200, "y": 0.7}, {"x": 300, "y":1.0}, {"x": 400, "y": 1.0},{ "x": 500, "y": 1.0}]
weight_scoring = ScoringFunction(get_weight, "weight", desirability=weight_desirability, is_mol_func=True)

# subsructure_match
substructure_mol = Chem.MolFromSmiles("c1ccccc1")
substructure_match_score = partial(substructure_match_score, query=substructure_mol) # use partial to define the additional argument (the substructure) 
miss_match_desirability = [{"x": 0, "y": 1}, {"x": 1, "y": 0}] # invert the resulting score to penalize for a match.

# qed_score 
qed_desirability = [{"x": 0.6, "y": 1.0}, {"x": 1, "y": 1}]
qed_scoring = ScoringFunction(qed_score, "qed", is_mol_func=True, truncate_left=True, truncate_right=False, desirability=qed_desirability)

# sa_score 
sa_desirability = [{"x": 0.6, "y": 1.0}, {"x": 1, "y": 1}]
sa_scoring = ScoringFunction(sa_score, 'sa', is_mol_func=True, truncate_left=False, truncate_right=False, desirability=sa_desirability)

# lipinski
linpinski_desirability = [{"x": 0, "y": 0}, {"x": 1, "y": 0}, {"x": 2, "y": 0}, {"x": 3, "y": 0.0}, {"x": 4, "y": 0.0},{"x": 5, "y": 1.0}]
lipinski_scoring = ScoringFunction(func=obey_lipinski, name="lipinski", desirability=linpinski_desirability, is_mol_func=True, truncate_left=False, truncate_right=False)

# hydrogen bond acceptor
haccept_desirability = [{"x": 0, "y": 0}, {"x": 1, "y": 1}, {"x": 2, "y": 1}]
haccept_scoring = ScoringFunction(func=get_HAcceptors, name="HAcceptor", desirability=haccept_desirability, is_mol_func=True, truncate_left=False, truncate_right=False)

# similarity score
similarity_desirability = [{"x": 0.6, "y": 1.0}, {"x": 1, "y": 1}]
ref_mol = Chem.MolFromSmiles("c1ccccc1")
sim_to_ref_score = partial(sim_score, ref_mol=ref_mol)
similarity_scoring = ScoringFunction(func=sim_to_ref_score, name="similarity", desirability=similarity_desirability, is_mol_func=True, truncate_left=False, truncate_right=False)

In [None]:
ecloud_latent = EPSO_format_model(encoder, tokenizer, DEVICE)
init_mol = read_sdf('example/3uw9_starting.sdf')[0]
init_smiles = Chem.MolToSmiles(init_mol)
init_emb = ecloud_latent.seq_to_emb([init_smiles, init_smiles])

scoring_functions = [
    weight_scoring, 
    qed_scoring,
    sa_scoring,
    lipinski_scoring,
    haccept_scoring,
]

pso_opt = BasePSOptimizer.from_query(
    init_smiles=init_smiles,
    num_part=200,
    num_swarms=1,
    inference_model=ecloud_latent,
    scoring_functions=scoring_functions,
    x_min=-10., 
    x_max=10.,
    device=DEVICE)
    
pso_opt.run(20)

Particle Swarm Optimization...


Step 0, max: 0.957, min: 0.957, mean: 0.957
Step 1, max: 1.000, min: 1.000, mean: 1.000
Step 2, max: 1.000, min: 1.000, mean: 1.000
Step 3, max: 1.000, min: 1.000, mean: 1.000


In [None]:
opt_smiles = pso_opt.best_solutions['smiles'].to_list()
opt_mols = [Chem.MolFromSmiles(smi) for smi in opt_smiles]

results_stats = {}
for scoring_function in scoring_functions:
    results_stats[scoring_function.name] = scoring_function(opt_mols)[0]
    mean_score = np.mean(results_stats[scoring_function.name])
    std_score = np.std(results_stats[scoring_function.name])
    print(f"{scoring_function.name}: {mean_score:.2f}, std:{std_score:.2f}")