In [1]:
import sys
import os
sys.path.append('/root/Projects/LocalRetro/scripts')
sys.path.append('/root/Projects/LocalRetro')
from Retrosynthesis import LocalRetro
import torch
import time
import pickle
from tqdm import tqdm
import pandas as pd
from LocalTemplate.template_decoder import *

In [2]:
retro190_routes = pickle.load(open('/root/Projects/LLM4Retro/dataset/retro_190/routes_possible_test_hard.pkl', 'rb'))
products, gt_reactants = [tmp[0].split(">>")[0] for tmp in retro190_routes], [tmp[0].split(">>")[1].split('.') for tmp in retro190_routes]

## Run Retrosynthesis

In [4]:
# Load the model and necessary files for prediction decoding
dataset = 'USPTO_50K' # USPTO_MIT or USPTO_50K
device = torch.device('cuda:0') # cpu or cuda:0
model_path = 'models/LocalRetro_%s.pth' % (dataset)
config_path = 'data/configs/default_config.json'
data_dir = 'data/%s' % dataset

args = {'data_dir': data_dir, 'model_path': model_path, 'config_path': config_path, 'device': device}
localretro = LocalRetro(args)

Parameters of loaded LocalRetro:
{'attention_heads': 8, 'attention_layers': 1, 'batch_size': 16, 'edge_hidden_feats': 64, 'node_out_feats': 320, 'num_step_message_passing': 6, 'activation': 'gelu', 'AtomTemplate_n': 124, 'BondTemplate_n': 548, 'in_node_feats': 80, 'in_edge_feats': 13}


In [5]:
test_file = pd.read_csv(f'{data_dir}/raw_test.csv')

In [6]:
rxn_ps = [demap(Chem.MolFromSmiles(rxn.split('>>')[1])) for rxn in test_file['reactants>reagents>production']]

In [7]:
ground_truth = [demap(Chem.MolFromSmiles(rxn.split('>>')[0])) for rxn in test_file['reactants>reagents>production']]

In [7]:
class_given = False

result_dir = 'outputs/decoded_prediction' 
if class_given:
    result_dir += '_class'

result_file = '%s/LocalRetro_%s.txt' % (result_dir, dataset)

results = {}       
results_MaxFrag = {}
with open(result_file, 'r') as f:
    for i, line in enumerate(f.readlines()):
        line = line.split('\n')[0]
        i = int(line.split('\t')[0])
        predictions = line.split('\t')[1:]
        MaxFrags = []
        results[i] = [eval(p)[0] for p in predictions]
        for p in results[i]:
            if p not in MaxFrags:
                MaxFrags.append(get_MaxFrag(p))
        results_MaxFrag[i] = MaxFrags

In [8]:
len(list(results.keys())), len(rxn_ps)

(5007, 5007)

In [8]:
pred_reactants = []
for prod in tqdm(rxn_ps):
    pred_reactants.append([tmp.split(".") for tmp in localretro.retrosnythesis(prod,top_k=5)['SMILES'].tolist()])

  0%|                                                                                                                                                                      | 0/5007 [00:00<?, ?it/s]

  2%|██▍                                                                                                                                                          | 77/5007 [00:18<12:06,  6.79it/s]

In [None]:
sorted(pred_reactants[0]), sorted(ground_truth[0].split('.')), rxn_ps[0]

([['CC(=O)Cl', 'CC(C)(C)OC(=O)n1ccc2ccccc21'],
  ['CC(=O)c1ccc2[nH]ccc2c1', 'CC(C)(C)OC(=O)OC(=O)OC(C)(C)C'],
  ['CC(=O)c1ccc2c(ccn2C(=O)OC(C)(C)C)c1'],
  ['CC(O)c1ccc2c(ccn2C(=O)OC(C)(C)C)c1'],
  ['CON(C)C(=O)c1ccc2c(ccn2C(=O)OC(C)(C)C)c1', 'C[Mg+]']],
 ['CC(=O)c1ccc2[nH]ccc2c1', 'CC(C)(C)OC(=O)OC(=O)OC(C)(C)C'],
 'CC(=O)c1ccc2c(ccn2C(=O)OC(C)(C)C)c1')

In [None]:
sum([sorted(ground_truth[i].split('.')) in sorted(pred_reactants[i]) for i in range(len(pred_reactants))]) / len(pred_reactants)

0.8434192131016577

In [None]:
import sys
sys.path.insert(0, '/root/Projects/LLM4Retro')
from chem_utils import *

In [None]:
top_n_rt_accuracy = topk_rt_accuracy_moltransformer(
    pred_reactants,
    rxn_ps,
    5
)

NameError: name 'topk_rt_accuracy_moltransformer' is not defined

In [13]:
top_n_rt_accuracy

[62.99181146395046,
 95.46634711404035,
 98.06271220291592,
 98.52206910325543,
 98.8615937687238]

### Predict on Retro190 dataset

In [None]:
pred_reactants = []
for prod in products:
    pred_reactants.append([tmp.split(".") for tmp in localretro.retrosnythesis(prod,top_k=5)['SMILES'].tolist()])

In [None]:
sum([gt_reactants[i] in pred_reactants[i] for i in range(190)]) / 190

0.3894736842105263

In [85]:
pickle.dump(pred_reactants, open('retro190_pred_reactants.pkl', 'wb+'))

## Predict Round-trip accuracy

In [23]:
def unflatten_list(flat_list, chunk_size):
    return [flat_list[i:i + chunk_size] for i in range(0, len(flat_list), chunk_size)]

In [24]:
def canocialize_smiles(smi):
    mol = Chem.MolFromSmiles(smi)
    if mol:
        return Chem.MolToSmiles(mol)
    else:
        return ""

In [25]:
def calculate_top_n_round_trip_accuracy(original_input, output, n):
    num_samples = len(original_input)
    rt_cnt = [0] * n
    for i in range(num_samples):
        fwd_idx, retro_idx = n, n
        original = canocialize_smiles(original_input[i])
        predicted = output[i][:n]
        for idx, cur_chunk in enumerate(predicted):
            cur_chunk = [canocialize_smiles(tmp.strip().replace(" ", "")) for tmp in cur_chunk]
            if original in cur_chunk:
                # print(f"Found in {idx} row at {cur_chunk.index(original)} col")
                # print(fwd_idx, retro_idx)
                fwd_idx = min(fwd_idx, cur_chunk.index(original))
                retro_idx = min(idx, retro_idx)
                # print(fwd_idx, retro_idx, "\n")
            else:
                # print(f"Not found at {idx}th row\n")
                pass
        k = max(fwd_idx, retro_idx)
        if k <= n:
            rt_cnt[k:] = [tmp+1 for tmp in rt_cnt[k:]]
    top_n_round_trip_accuracy = [(tmp / num_samples) * 100 for tmp in rt_cnt]
    return top_n_round_trip_accuracy

### On Local Retro

In [6]:
pred_reactants = pickle.load(open('/root/Projects/LocalRetro/retro190_pred_reactants.pkl', 'rb+'))

In [7]:
with open('/root/Projects/LocalRetro/MolecularTransformer/flatten_retro190_pred_reactants_translated.txt', 'r') as file:
    flatten_retro190_pred_reactants_translated = file.readlines()

In [8]:
unflatten_list(flatten_retro190_pred_reactants_translated, 5)[-190:][0]

['O = S ( = O ) ( c 1 c c ( C ( F ) ( F ) F ) c c c 1 Br ) C 1 C C C ( c 2 c c c ( Cl ) c c 2 ) C 1\n',
 'O = S ( = O ) ( c 1 c c ( C ( F ) ( F ) F ) c c c 1 Br ) C 1 C C O C ( c 2 c c c ( Cl ) c c 2 ) C 1\n',
 'O = S ( = O ) ( c 1 c c ( C ( F ) ( F ) F ) c c c 1 Br ) [C@H] 1 C C [C@@H] ( c 2 c c c ( Cl ) c c 2 ) C 1\n',
 'O = S ( = O ) ( c 1 c c ( C ( F ) ( F ) F ) c c c 1 Br ) C 1 C C C ( = C c 2 c c c ( Cl ) c c 2 ) C 1\n',
 'N S ( = O ) ( = O ) c 1 c c c c c 1\n']

In [9]:
len([item for sublist in pred_reactants for item in sublist])

1030

In [10]:
retro190_pred_prod = []
idx = 0
unflatten_retro190_pred_reactants_translated = unflatten_list(flatten_retro190_pred_reactants_translated, 5)
for i in range(len(pred_reactants)):
    cur = [tmp for tmp in unflatten_retro190_pred_reactants_translated[idx:idx+len(pred_reactants[i])]]
    idx += len(pred_reactants[i])
    retro190_pred_prod.append(cur)

In [18]:
len(retro190_pred_prod)

190

In [11]:
top_n = 5  # specify the value of n
top_n_round_trip_accuracy = calculate_top_n_round_trip_accuracy(products, retro190_pred_prod, top_n)
top_n_round_trip_accuracy

[62.63157894736842,
 91.05263157894737,
 95.78947368421052,
 96.84210526315789,
 97.36842105263158]

### On LLM output

In [12]:
agent_preds = pickle.load(open("/root/Projects/LLM4Retro/agent_output_test_512max_clean.pkl", 'rb'))

In [13]:
with open('/root/Projects/LocalRetro/MolecularTransformer/flatten_agent_pred_products.txt', 'r') as file:
    flatten_agent_pred_prod = file.readlines()

In [14]:
agent_pred_prod = []
idx = 0
unflatten_agent_pred_prod = unflatten_list(flatten_agent_pred_prod, 5)
for i in range(len(agent_preds)):
    cur = [tmp for tmp in unflatten_agent_pred_prod[idx:idx+len(agent_preds[i])]]
    idx += len(agent_preds[i])
    agent_pred_prod.append(cur)

In [15]:
calculate_top_n_round_trip_accuracy(products, agent_pred_prod, 5)

[33.68421052631579,
 51.578947368421055,
 62.63157894736842,
 68.94736842105263,
 73.15789473684211]