In [None]:
import argparse
import torch
import sys
sys.path.append('..')


from models.ECloudDecipher.models.io.coati import load_e3gnn_smiles_clip_e2e
from models.ECloudDecipher.models.regression.basic_due import basic_due
from models.ECloudDecipher.utils.chem import read_sdf, write_sdf, rm_radical, sa, qed, logp
from rdkit import Chem
import random
from models.ECloudDecipher.generative.molopt import gradient_opt
from models.ECloudDecipher.generative.coati_purifications import embed_smiles
from functools import partial
from torch.nn.functional import sigmoid
import torch.nn.functional as F
import numpy as np
from models.ECloudDecipher.generative.coati_purifications import force_decode_valid_batch, embed_smiles, force_decode_valid
import os.path as osp
from models.ECloudDecipher.optimize.scoring import ScoringFunction
from models.ECloudDecipher.optimize.mol_functions import qed_score, substructure_match_score, penalize_macrocycles, heavy_atom_count, penalized_logp_score
from models.ECloudDecipher.optimize.pso_optimizer import BasePSOptimizer
from models.ECloudDecipher.optimize.swarm import Swarm
from models.ECloudDecipher.optimize.rules.qsar_score import qsar_model


arg_parser = argparse.ArgumentParser(description='molecular optimization on the chemical space')
arg_parser.add_argument('--device', choices=['cuda:0', 'cpu'], \
    default='cuda:0',help='Device')
arg_parser.add_argument('--seed', type=int, default=2024) 
arg_parser.add_argument('--ecloudgen_ckpt', type=str, default = '../model_ckpts/ecloud_smiles_67.pkl')
args = arg_parser.parse_args([])

In [None]:
# model loading
DEVICE = torch.device(args.device)
DEVICE = 'cuda:0'
encoder, tokenizer = load_e3gnn_smiles_clip_e2e(
    freeze=True,
    device=DEVICE,
    # model parameters to load.
    doc_url=args.ecloudgen_ckpt,
)

In [None]:
class EPSO_format_model():
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device

    def seq_to_emb(self, smiles):
        if isinstance(smiles, str):
            smi_emb = embed_smiles(smiles, self.model, self.tokenizer).to(DEVICE)
            return smi_emb
        else:
            emb_list = []
            for smi in smiles:
                smi_emb = embed_smiles(smi, self.model, self.tokenizer).to(DEVICE)
                emb_list.append(smi_emb)
            return torch.stack(emb_list).reshape(-1, 256)

    
    def emb_to_seq(self, embs):

        seq_list = []
        for emb in embs:
            seq = force_decode_valid_batch(emb, self.model, self.tokenizer)
            seq_list.append(seq)
        return seq_list

In [None]:
ecloud_latent = EPSO_format_model(encoder, tokenizer, DEVICE)
init_mol = read_sdf('example/3uw9_starting.sdf')[0]
init_smiles = Chem.MolToSmiles(init_mol)
init_emb = ecloud_latent.seq_to_emb([init_smiles, init_smiles])

In [None]:
from models.ECloudDecipher.optimize.mol_functions import obey_lipinski, get_HAcceptors, get_weight


linpinski_desirability = [{"x": 0, "y": 0}, {"x": 1, "y": 0}, {"x": 2, "y": 0}, {"x": 3, "y": 0.0}, {"x": 4, "y": 0.0},{"x": 5, "y": 1.0}]
haccept_desirability = [{"x": 0, "y": 0}, {"x": 1, "y": 1}, {"x": 2, "y": 1}]
scoring_functions = [
    ScoringFunction(func=penalized_logp_score, name="plogp", is_mol_func=True, truncate_left=False, truncate_right=False),
    ScoringFunction(func=obey_lipinski, name="lipinski", desirability=linpinski_desirability, is_mol_func=True, truncate_left=False, truncate_right=False),
    ScoringFunction(func=get_HAcceptors, name="HAcceptor", desirability=haccept_desirability, is_mol_func=True, truncate_left=False, truncate_right=False),
    
]

pso_opt = BasePSOptimizer.from_query(
    init_smiles=init_smiles,
    num_part=200,
    num_swarms=1,
    inference_model=ecloud_latent,
    scoring_functions=scoring_functions,
    x_min=-10., 
    x_max=10.,
    device=DEVICE)

In [None]:
pso_opt.run(20)