# load model and dataset

In [1]:
import sys
sys.path.append('../')

import torch

from models import build_model
from models import RXNModel
from utils.base import seed_everything

seed_everything(624)
device = 'cpu'

model = build_model('vib2mol_rxn').to(device)

# yields
# ckpt = torch.load('../checkpoints/rxn/raman-kekule_smiles/vib2mol_phase/yield_fixed.pth', 
# 0.1
# ckpt = torch.load('../checkpoints/rxn/raman-kekule_smiles/vib2mol_phase/yield_10_100.pth',                 
# 1.0
ckpt = torch.load('../checkpoints/rxn/raman-kekule_smiles/vib2mol_phase/unmixed.pth',                
                  map_location=device, weights_only=True)

ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
model.load_state_dict(ckpt)

<All keys matched successfully>

In [8]:
import lmdb
import pandas as pd
import pickle
from transformers import AutoTokenizer

db = lmdb.open('../datasets/vibbench/rxn/rxn_test.lmdb', subdir=False, lock=False, map_size=int(1e11))

# Open a transaction and perform a read operation
with db.begin() as txn:
    test_data = list(txn.cursor())

test_df = pd.DataFrame([pickle.loads(item[1]) for item in test_data])

tokenizer = AutoTokenizer.from_pretrained('../models/MolTokenizer')

In [9]:
import numpy as np
def mix_spectrum(row):
    spec_r1 = np.array(row['reactant1_raman'])
    spec_r2 = np.array(row['reactant2_raman'])
    spec_p = np.array(row['product_raman'])
    weight = row['Yield']
    return (0.5-0.5*weight) * (spec_r1 + spec_r2) + weight * spec_p

test_df['mix_raman'] = test_df.apply(lambda row: mix_spectrum(row), axis=1)
test_df = test_df[test_df.Yield <= 1.]

# 1 evaluating retrieval performance

In [14]:
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, spectra, smiles):
        self.spectra = spectra
        self.smiles = smiles
        
    def __len__(self):
        return len(self.spectra)

    def __getitem__(self, idx):
        return self.spectra[idx], self.smiles[idx]
    
class TestCollator:
    def __init__(self, tokenizer, spectral_types=None, smiles_types=None, mix_ratio=None):
        self.tokenizer = tokenizer
        self.spectral_types = spectral_types
        self.smiles_types = smiles_types
        self.mix_ratio = mix_ratio
        
    def __call__(self, batch):
        spectra, smiles = zip(*batch)
        spectra = torch.as_tensor(np.array(spectra), dtype=torch.float32).unsqueeze(1).to(device)
        
        input_ids = self.tokenizer(list(smiles), return_tensors='pt', padding='max_length', max_length=256, truncation=True)
        input_ids = {'input_ids':input_ids['input_ids'].to(device), 'attention_mask':input_ids['attention_mask'].to(device)}
        
        batch_data = {'raman':spectra, 'smiles':input_ids}

        return {'batch_size':len(spectra), 'target':None, 'data': batch_data}


In [15]:
from tqdm import tqdm

test_dataset = TestDataset(test_df['product_raman'].to_list(), test_df['product_kekule_smiles'].to_list())
# test_dataset = TestDataset(test_df['mix_raman'].to_list(), test_df['product_kekule_smiles'].to_list())

test_collator = TestCollator(tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
all_smiles_embeddings = []
all_spectra_embeddings = []
with torch.no_grad():
    for batch in test_bar:
        data = batch['data']
        batch_size = batch['batch_size']
        output = model(data, return_proj_output=True, return_loss=False)
                    
        all_smiles_embeddings.append(output['molecular_proj_output'].detach().cpu())
        all_spectra_embeddings.append(output['spectral_proj_output'].detach().cpu())

    all_smiles_embeddings = torch.cat(all_smiles_embeddings, dim=0)
    all_spectra_embeddings = torch.cat(all_spectra_embeddings, dim=0)

100%|██████████| 49/49 [01:25<00:00,  1.74s/it]


In [16]:
import torch
import torch.nn.functional as F

def calculate_similarity_matrix(embedding_query, embedding_key):
    embedding_query = F.normalize(embedding_query, p=2, dim=1)
    embedding_key = F.normalize(embedding_key, p=2, dim=1)

    similarity_matrix = torch.matmul(embedding_query, embedding_key.t())
    return similarity_matrix

def compute_recall(similarity_matrix, k, verbose=False):
    num_queries = similarity_matrix.size(0)
    _, topk_indices = similarity_matrix.topk(k, dim=1, largest=True, sorted=True)
    
    correct_list = []
    for i in range(num_queries):
        if i in topk_indices[i]:
            correct_list.append(1)
        else:
            correct_list.append(0)
    recall_at_k = sum(correct_list) / num_queries
    
    if verbose:
        print(f'recall@{k}:{recall_at_k:.5f}')
    else:
        return recall_at_k, correct_list

similarity_matrix = calculate_similarity_matrix(all_spectra_embeddings, all_smiles_embeddings)
compute_recall(similarity_matrix, k=1, verbose=True)
compute_recall(similarity_matrix, k=3, verbose=True)
compute_recall(similarity_matrix, k=5, verbose=True)
compute_recall(similarity_matrix, k=10, verbose=True)
compute_recall(similarity_matrix, k=100, verbose=True)

recall@1:0.84874
recall@3:0.97378
recall@5:0.98849
recall@10:0.99712
recall@100:1.00000


# 2 de novo generation

In [6]:
length = [len(item) for item in test_df['product_kekule_smiles']]
max_len = max(length)+2
print(f'max_len:{max_len}')

max_len:56


In [57]:
import numpy as np

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra):
        self.tgt_spectra = tgt_spectra

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx]

class TestCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        tgt_spectra = batch
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        return {'spectra':spectra}

all_pred_smiles = []
test_dataset = TestDataset(test_df['product_raman'].to_list())
test_collator = TestCollator()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_smiles_ids = model.infer_lm(batch, max_len=max_len)['pred_ids']
    pred_smiles = tokenizer.batch_decode(pred_smiles_ids)
    pred_smiles = [item.split('</s>')[0].replace('<s>', '') for item in pred_smiles]
    all_pred_smiles.extend(pred_smiles)

100%|██████████| 25/25 [00:23<00:00,  1.05it/s]


In [58]:
from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

from rdkit import Chem
from tqdm import trange

def check_mols(pred_smiles, tgt_smiles):
    pred_mol = Chem.MolFromSmiles(pred_smiles)
    tgt_mol = Chem.MolFromSmiles(tgt_smiles)
    if pred_mol is not None and tgt_mol is not None:
        if Chem.MolToInchiKey(pred_mol) == Chem.MolToInchiKey(tgt_mol):
            return 1
    return 0


res_smiles = []
for item in all_pred_smiles:
    tmp_mol = Chem.MolFromSmiles(item)
    if tmp_mol is not None:
        tmp_smiles = Chem.MolToSmiles(tmp_mol, isomericSmiles=False, kekuleSmiles=True, canonical=True)
    else:
        tmp_smiles = '*'
    res_smiles.append(tmp_smiles)


In [None]:
import pandas as pd
df = pd.DataFrame({'pred':res_smiles, 'tgt':test_df['product_kekule_smiles'].to_list(), 'correct':[check_mols(res_smiles[i], test_df['product_kekule_smiles'].to_list()[i]) for i in trange(len(test_df))]})
print(f'{df.correct.mean() * 100} %')
print(df.correct.mean())

100%|██████████| 3127/3127 [00:03<00:00, 957.42it/s]

16.277582347297727 %
0.16277582347297728





In [60]:
# 100% train and 100% test -> 0.23529411764705882
# 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100 %
greedy = [0.16277582347297728,
          0.17742966751918157,
          0.17647058823529413,
          0.17998721227621484, 
          0.19629156010230178, 
          0.24168797953964194, 
          0.24008951406649617, 
          0.23497442455242967, 
          0.22730179028132994, 
          0.1969309462915601, 
          0.07768542199488492]

# 3 beam search

In [19]:
import numpy as np
from tqdm import tqdm

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra):
        self.tgt_spectra = tgt_spectra

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx]

class TestCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        tgt_spectra = batch
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        return {'spectra':spectra}

beam_size = 10

all_pred_smiles = []
# test_dataset = TestDataset(test_df['mix_raman'].to_list())
test_dataset = TestDataset(test_df['product_raman'].to_list())
test_collator = TestCollator()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_smiles_ids_list = model.beam_infer_lm(batch, max_len=64, beam_size=beam_size, temperature=3.5)['pred_ids']
    for pred_smiles_ids in pred_smiles_ids_list:
        pred_smiles = tokenizer.batch_decode(pred_smiles_ids)
        pred_smiles = [item.split('</s>')[0].replace('<s>', '') for item in pred_smiles]
        all_pred_smiles.append(pred_smiles)

100%|██████████| 25/25 [09:11<00:00, 22.08s/it]


## 3.1 rank by beam score

In [None]:
from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

def check_beam_mols(pred_smiles_list, tgt_smiles):
    pred_mol_list = []
    for item in pred_smiles_list:
        mol = Chem.MolFromSmiles(item)
        if mol is not None:
            try:
                inchi_key = Chem.MolToInchiKey(mol)
                pred_mol_list.append(inchi_key)
            except Exception as e:
                print(f"Error processing SMILES {item}: {e}")
                pred_mol_list.append('')
        else:
            pred_mol_list.append('')
    tgt_mol = Chem.MolToInchiKey(Chem.MolFromSmiles(tgt_smiles))
    if tgt_mol in pred_mol_list:
        return 1
    return 0

In [None]:
import pandas as pd

df = pd.DataFrame({'tgt_smiles':test_df['product_kekule_smiles'].to_list(), 
                   'pred_smiles':[list(dict.fromkeys(item)) for item in all_pred_smiles],
                   'rxntype':test_df['rxntype'].to_list(),
                   'yield':test_df['Yield'].to_list()})

df['top_1'] = df.apply(lambda row: check_beam_mols(row['pred_smiles'][:1], row['tgt_smiles']), axis=1)
df['top_3'] = df.apply(lambda row: check_beam_mols(row['pred_smiles'][:3], row['tgt_smiles']), axis=1)
df['top_5'] = df.apply(lambda row: check_beam_mols(row['pred_smiles'][:5], row['tgt_smiles']), axis=1)
df['top_10'] = df.apply(lambda row: check_beam_mols(row['pred_smiles'][:10], row['tgt_smiles']), axis=1)

print(f'top-1:\t\t{df.top_1.mean():.5f}\ntop-3:\t\t{df.top_3.mean():.5f}\ntop-5:\t\t{df.top_5.mean():.5f}\ntop-10:\t\t{df.top_10.mean():.5f}')

top-1:		0.25584
top-3:		0.34954
top-5:		0.36137
top-10:		0.36329


## 3.2 rerank by retrieval module

In [39]:
candidate_smiles_list = [list(set(item)) for item in all_pred_smiles]
candidate_spectra_list = [[test_df['mix_raman'].to_list()[i]] * len(item) for i, item in enumerate(candidate_smiles_list)]
tgt_smiles_list = [[test_df['product_kekule_smiles'].to_list()[i]] * len(item) for i, item in enumerate(candidate_smiles_list)]

candidate_smiles_list = [subitem for item in candidate_smiles_list for subitem in item]
candidate_spectra_list = [subitem for item in candidate_spectra_list for subitem in item]
tgt_smiles_list = [subitem for item in tgt_smiles_list for subitem in item]

In [40]:
# calculate similarity between predicted molecules and target spectra
import torch
import torch.nn.functional as F

def calculate_similarity_matrix(embedding_query, embedding_key):
    if type(embedding_query) != torch.Tensor:
        embedding_query = torch.tensor(embedding_query)
    if type(embedding_key) != torch.Tensor:
        embedding_key = torch.tensor(embedding_key)
    
    embedding_query = F.normalize(embedding_query, p=2, dim=1)
    embedding_key = F.normalize(embedding_key, p=2, dim=1)

    similarity_matrix = torch.matmul(embedding_query, embedding_key.t())
    return similarity_matrix


class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra, pred_smiles):
        self.tgt_spectra = tgt_spectra
        self.pred_smiles = pred_smiles

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx], self.pred_smiles[idx]

class TestCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        tgt_spectra, pred_smiles = zip(*batch)
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        input_ids = self.tokenizer(list(pred_smiles), return_tensors='pt', padding='max_length', max_length=256, truncation=True)
        input_ids = {'input_ids':input_ids['input_ids'].to(device), 'attention_mask':input_ids['attention_mask'].to(device)}
        return {'smiles': input_ids,  'spectra':spectra}

    
valid_sim_list = []

test_dataset = TestDataset(candidate_spectra_list, candidate_smiles_list)
test_collator = TestCollator(tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        molecular_embedding = model.get_molecular_embeddings(batch, use_cls_token=True)['proj_output']
        spectral_embedding = model.get_spectral_embeddings(batch)['proj_output']
        sim = calculate_similarity_matrix(spectral_embedding, molecular_embedding)
    valid_sim_list += torch.diag(sim).tolist()

100%|██████████| 193/193 [00:11<00:00, 16.93it/s]


In [None]:
df = pd.DataFrame({'target_smiles':tgt_smiles_list, 'pred_smiles':candidate_smiles_list, 'similarity':valid_sim_list})
# sort by 'target_smiles' and 'similarity' in descending order
df_sorted = df.sort_values(by=['target_smiles', 'similarity'], ascending=[True, False])

# group by 'target_smiles' and aggregate 'pred_smiles' and 'similarity'
grouped = df_sorted.groupby('target_smiles').agg({
    'pred_smiles': lambda x: ','.join(x),
    'similarity': lambda x: ','.join(map(str, x))
}).reset_index()

# calculate recall@k
for top_k in [1, 3, 5, 10]: 
    grouped[f'top_{top_k}_recall'] = grouped.apply(lambda row: row['target_smiles'] in row['pred_smiles'].split(',')[:top_k], axis=1)

grouped['rank'] = grouped.apply(lambda row: (row['pred_smiles'].split(',').index(row['target_smiles']))+1 if row['target_smiles'] in row['pred_smiles'].split(',') else 0, axis=1)

In [42]:
print(f"""
recall@1:\t{grouped.top_1_recall.mean():.5f} 
recall@3:\t{grouped.top_3_recall.mean():.5f} 
recall@5:\t{grouped.top_5_recall.mean():.5f} 
recall@10:\t{grouped.top_10_recall.mean():.5f}
    """)


recall@1:	0.23148 
recall@3:	0.31420 
recall@5:	0.32126 
recall@10:	0.32190
    
