In [1]:
import os
os.chdir(os.environ["MAIN_DIR"])

In [2]:
from rdkit import Chem
from IPython.display import display
import tqdm
from rewards.properties import logP, qed, drd2

mfs = Chem.MolFromSmiles
mts = Chem.MolToSmiles

### Get start molecules

In [3]:
# All start molecules
%time start_mols = list(Chem.SDMolSupplier('datasets/offlineRL/Enamine_Building_Blocks_Stock_266483cmpd_20230901.sdf'))

CPU times: user 27.6 s, sys: 2.11 s, total: 29.7 s
Wall time: 29.3 s


In [None]:
from action_utils import get_applicable_actions
from multiprocessing import Pool
import pandas as pd
import numpy as np
from rewards.properties import similarity

temp_mols = []

def error_handled_get_applicable_actions(mol):
    try:
        return get_applicable_actions(mol)
    except:
        return np.zeros(shape=(0,0))

with Pool(14) as P:
    for mol, df in tqdm.tqdm(zip(start_mols, P.imap(error_handled_get_applicable_actions, start_mols, chunksize=100)), total=len(start_mols)):
        if df.shape[0] > 0:
            temp_mols.append(mol)

temp_smiles = list(map(mts, temp_mols))

start_mols_df = pd.DataFrame(temp_smiles, columns=["SMILES"])
start_mols_df

 22%|████████████▏                                           | 58001/266483 [01:05<04:20, 801.00it/s]

### Load dataset(s)

In [None]:
# PROPERTY_NAME = "drd2"
# SCORING_FT = drd2

# PROPERTY_NAME = "qed"
# SCORING_FT = qed

# PROPERTY_NAME = "logp04"
# SCORING_FT = logP

PROPERTY_NAME = "logp06"
SCORING_FT = logP

In [None]:
# Input file
input_data_dir = os.path.join("datasets/coma", PROPERTY_NAME)
filepath_test = os.path.join(input_data_dir, "rdkit_test.txt")
print(filepath_test)

In [None]:
# Output file
output_dir = f"results/eval_on_coma_{PROPERTY_NAME}"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
filepath_output = os.path.join(output_dir, f"{PROPERTY_NAME}.csv")
print(filepath_output)

In [None]:
# Load data
import pandas as pd
test_data = pd.read_csv(filepath_test, header=None)
test_data.columns = ["smiles"]
test_data

### Functions to get some start molecules and for generating molecules with path length 'n'

In [None]:
from rdkit.Chem.Scaffolds.MurckoScaffold import GetScaffoldForMol
import tqdm
from rdkit.Chem import rdFMCS
from functools import partial

def my_conv_sim_fn(mol, sm):
    return similarity(mfs(sm), mol)

def my_conv_scaf_sim_fn(scaf, sm):
    return similarity(GetScaffoldForMol(mfs(sm)), scaf)

def my_conv_mcs_fn(mol, sm):
    return rdFMCS.FindMCS([mol, GetScaffoldForMol(mfs(sm))]).numAtoms


def get_start_mols(smiles, n):
    mol = mfs(smiles)
    start_mol_list = []

    
    # Similarity
    mol_sim_df = pd.Series(Pool(14).map(partial(my_conv_sim_fn, mol), start_mols_df["SMILES"], chunksize=100))
    start_mol_list.extend(start_mols_df.loc[mol_sim_df.sort_values(ascending=False).index[:3*n]]["SMILES"].values.tolist())
    
    # Scaffold by similarity
    scaf = GetScaffoldForMol(mol)
    scaf_sim_df = pd.Series(Pool(14).map(partial(my_conv_scaf_sim_fn, scaf), start_mols_df["SMILES"], chunksize=100))
    start_mol_list.extend(start_mols_df.loc[scaf_sim_df.sort_values(ascending=False).index[:3*n]]["SMILES"].values.tolist())

    # MCS
    scaf_mcs_df = pd.Series(Pool(14).map(partial(my_conv_mcs_fn, mol), start_mols_df["SMILES"], chunksize=100))
    val = scaf_mcs_df.sort_values(ascending=False).values[n-1]
    start_mol_list.extend(start_mols_df.loc[scaf_sim_df[scaf_mcs_df >= val].sort_values(ascending=False).index[:3*n]]["SMILES"].values.tolist())

    start_mol_list = np.random.choice(np.unique(start_mol_list), size=(3*n,), replace=False)
    return start_mol_list

# %time get_start_mols(test_data["smiles"][0], 10)

In [None]:
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
%run lead_optimization/experimental\ notebooks/supervised_functions.ipynb

In [None]:
action_dataset = pd.read_csv("datasets/my_uspto/action_dataset-filtered.csv", index_col=0)
action_dataset = action_dataset.loc[action_dataset["action_tested"] & action_dataset["action_works"]]
action_dataset = action_dataset[["rsub", "rcen", "rsig", "rsig_cs_indices", "psub", "pcen", "psig", "psig_cs_indices"]]
print(action_dataset.shape)

action_rsigs = data.Molecule.pack(list(map(molecule_from_smile, action_dataset["rsig"])))
action_psigs = data.Molecule.pack(list(map(molecule_from_smile, action_dataset["psig"])))

In [None]:
import time
def get_topk_predictions(model, source_list, target_list, topk=10):
    # Convert to mols
    if isinstance(source_list, pd.Series):
        source_list = source_list.tolist()
    if isinstance(target_list, pd.Series):
        target_list = target_list.tolist()
    tt = time.time()
    sources = data.Molecule.pack(list(map(molecule_from_smile, source_list)))
    targets = data.Molecule.pack(list(map(molecule_from_smile, target_list)))
    print(f"Took {time.time() - tt}s to pack molecules.")

    # Predictions
    batch_size = 1024
#     pred = model(sources, targets, None, None, "actor").detach()
    pred = torch.concatenate([model(sources[i:min(i+batch_size, sources.batch_size)].to(device), 
                                 targets[i:min(i+batch_size, sources.batch_size)].to(device), None, None, "actor").detach() for i in range(0, sources.batch_size, batch_size)], axis=0)

    action_embeddings = get_action_dataset_embeddings(model.GIN)

    # Get applicable actions for source(s)
    applicable_action_indices_list = []
    
    with Pool(30) as p:
        for idxes in tqdm.tqdm(p.imap(functools.partial(get_emb_indices_and_correct_idx, no_correct_idx=True), 
                                      [{"reactant": source_list[i]} for i in range(pred.shape[0])], chunksize=10),
                              total=pred.shape[0]):
            applicable_action_indices_list.append(idxes)

    # Sort by critic's Q
    dict_of_list_of_indices = {}
    
    for i in tqdm.tqdm(range(pred.shape[0])):
        pred_for_i = pred[i]
        adi = applicable_action_indices_list[i]
        if len(adi) == 0:
            dict_of_list_of_indices[i] = np.array([])
            continue

        # Get top 50 for actor
        dist = torch.linalg.norm(action_embeddings[adi] - pred[i], axis=1)
        dict_of_list_of_indices[i] = adi[torch.argsort(dist)[:50].cpu().numpy().astype(np.int64)]

    # Sort with critic's Q
    i_sorted = list(range(pred.shape[0]))
    action_indices = np.concatenate([dict_of_list_of_indices[i] for i in i_sorted])
    state_indices = np.concatenate([np.full_like(dict_of_list_of_indices[i], i) for i in i_sorted])
    critic_qs = []
    for i in tqdm.tqdm(range(0, action_indices.shape[0], batch_size)):
        batch_reactants = sources[state_indices[i:i+batch_size]]
        batch_products = targets[state_indices[i:i+batch_size]]
        batch_rsigs = action_rsigs[action_indices[i:i+batch_size]]
        batch_psigs = action_psigs[action_indices[i:i+batch_size]]
        critic_qs.append(ac(batch_reactants.to(device), batch_products.to(device), batch_rsigs.to(device), batch_psigs.to(device), "critic").detach().cpu().numpy())

    critic_qs = np.concatenate(critic_qs)

    # Get action predictions
    action_pred_indices = []
    start = 0
    for i in tqdm.tqdm(i_sorted):
        end = start + dict_of_list_of_indices[i].shape[0]
        i_critic_qs = critic_qs[start:end]

        action_pred_indices.append(dict_of_list_of_indices[i][i_critic_qs.reshape(-1).argsort()[::-1]][:topk])
        start = end

    return action_pred_indices
    
# %time pred = get_topk_predictions(ac, source_list[:100], target_list[target_list_idx][:100])
# print(len(pred))

In [None]:
def apply_actions_on_reactant(args):
    reactant, action_dataset_idx = args
    listy = []
    for idx in action_dataset_idx:
        try:
            listy.append(Chem.MolToSmiles(apply_action(Chem.MolFromSmiles(reactant), *action_dataset.iloc[idx])))
        except Exception as e:
            pass
    return listy

In [None]:
def get_similarity_from_smiles(s1, s2):
    return similarity(Chem.MolFromSmiles(s1), Chem.MolFromSmiles(s2))

In [None]:
def run_predictions(source_list, target_list, steps=5, topk=5, limit=1000):    
    target_list_idx = np.arange(target_list.shape[0])
    sim_dict = {}
    trajectory_dict = {str(i): source_list[i] for i in range(len(source_list))} # Keeps track of trajectory in dict format (need hash keeys for quick access)
    source_keys = list(map(str, np.arange(len(source_list)))) # Map for index to keys of previous step in trajectory
    
    # RUN -----------------------
    for i_step in range(1, steps+1): 
        print("Running prediction for step", i_step)
        # Get action predictions
        pred = get_topk_predictions(ac, source_list, target_list[target_list_idx], topk=topk)
    
        # get products
        temp_source_keys = []
        temp_source_list = []
        with Pool(14) as p:
            print("Applying actions for step", i_step)
            for i, product_list in tqdm.tqdm(enumerate(p.imap(apply_actions_on_reactant, zip(source_list, pred), chunksize=10)), total=len(pred)):
                for _i, product in enumerate(product_list):
                    key = f"{source_keys[i]}_{_i}"
                    trajectory_dict[key] = product
                    sim_dict[key] = get_similarity_from_smiles(product, target_list[int(key.split("_")[0])])
                    temp_source_keys.append(key)
                    temp_source_list.append(product)
    
        print("Getting top some sim products for each s-t pair")
        temp_source_keys = np.array(temp_source_keys)
        temp_source_list = np.array(temp_source_list)
        temp_source_sim = np.vectorize(sim_dict.get)(temp_source_keys)
        temp_source_argsort = np.argsort(temp_source_sim)
        temp_source_st_idx = np.array(list(map(lambda x: int(x.split("_")[0]), temp_source_keys)))
        temp_source_indices = []
        for t_i in range(target_list.shape[0]):
            temp_source_indices.append((temp_source_argsort[temp_source_st_idx == t_i])[:limit])
    
        temp_source_indices = np.concatenate(temp_source_indices)
    
        # update source list and source_keys for next step
        source_keys = temp_source_keys[temp_source_indices]
        source_list = temp_source_list[temp_source_indices]
        target_list_idx = list(map(lambda x: int(x.split("_")[0]), source_keys))

    return trajectory_dict, sim_dict

In [None]:
# Load model
import glob
file_string = f"models/supervised/actor-critic/steps=5||actor_loss=PG||negative_selection=random/model.pth"
ac = torch.load(glob.glob(file_string)[0], map_location=device).to(device)

In [None]:
for t_i, test_smile in enumerate(test_data["smiles"]): 
    %time source_list = np.array(get_start_mols(test_smile, 10))
    target_list = np.array([test_smile for i in range(len(source_list))])

    print("$$$$$$$$$$$$$$$$")
    print(f"$ Test mol: {t_i} $")
    print("$$$$$$$$$$$$$$$$")

    traj_d, sim_d = run_predictions(source_list, target_list, steps=5)

    pickle.dump({"traj": traj_d, "sim":sim_d}, open(f"results/eval_on_coma_{PROPERTY_NAME}/{t_i}.pickle", 'wb'))