# 1 Retrieval (given functional groups, retrieval the situated sites)

In [1]:
import lmdb
import pickle
import pandas as pd
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('../models/MolTokenizer')

db = lmdb.open('../datasets/vibench/pahs/pahs_test.lmdb', subdir=False, lock=False, map_size=int(1e11))

# Open a transaction and perform a read operation
with db.begin() as txn:
    test_data = list(txn.cursor())

test_df = pd.DataFrame([pickle.loads(item[1]) for item in test_data])

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import sys 
sys.path.append('..')

import torch

from models import build_model
from models import PretrainModel_Phase
from utils.base import seed_everything

seed_everything(624)
device = 'cpu'

model = build_model('vib2mol_phase').to(device)
ckpt = torch.load('../checkpoints/pahs/raman-kekule_smiles/vib2mol_phase.pth', 
                  map_location=device, weights_only=True)

ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
model.load_state_dict(ckpt)

<All keys matched successfully>

In [3]:
retrieval_pattern_dict = {
1: 
    {
    '(.*?)C1=CC=CC=C1': '1', 
    '(.*?)C1=C\((.*?)\)C=CC=C1': '1,2',
    '(.*?)C1=CC=CC\((.*?)\)=C1': '1,3',
    '(.*?)C1=CC\((.*?)\)=CC=C1': '1,3',
    '(.*?)C1=CC=C\((.*?)\)C=C1': '1,4',
    },
2:
    {
    '(.*?)C1=CC2=CC=CC2=CC=CC=C1': '1',
    '(.*?)C1=C2C=CC=CC2=CC=CC=C1': '1',  
    '(.*?)C1=CC=C2C=CC=CC2=C1(.*?)\Z': '1,2',
    '(.*?)C1=C\((.*?)\)C=CC2=CC=CC=C12': '1,2',
    '(.*?)C1=CC=CC2=C\((.*?)\)C=CC=C12': '1,5',
    '(.*?)C1=CC=CC2=CC=CC\((.*?)\)=C12': '1,8',
    '(.*?)C1=CC=C2C=C\((.*?)\)C=CC2=C1': '2,6',
    '(.*?)C1=CC=C2C=CC\((.*?)\)=CC2=C1': '2,7',
    },
3:
    {'(.*?)C1=CC=C2C=C3C=CC=CC3=CC2=C1(.*?)\Z': '1,2',
    '(.*?)C1=C\((.*?)\)C=CC2=CC3=CC=CC=C3C=C12': '1,2',
    '(.*?)C1=C\((.*?)\)C=C2C=C3C=CC=CC3=CC2=C1': '2,3',
    '(.*?)C1=CC2=CC3=CC=CC=C3C=C2C=C1(.*?)\Z': '2,3',
    '(.*?)C1=CC=C2C=C3C=C\((.*?)\)C=CC3=CC2=C1': '2,6'
    }}

In [4]:
generate_pattern_dict = {
1: 
    {
    '(.*?)C1=C\((.*?)\)C=CC=C1': '1,2',
    '(.*?)C1=CC=CC\((.*?)\)=C1': '1,3',
    '(.*?)C1=CC=C\((.*?)\)C=C1': '1,4',
    },
2:
    {
    '(.*?)C1=CC=C2C=CC=CC2=C1(.*?)\Z': '1,2',
    '(.*?)C1=CC=CC2=C\((.*?)\)C=CC=C12': '1,5',
    '(.*?)C1=CC=CC2=CC=CC\((.*?)\)=C12': '1,8',
    '(.*?)C1=CC=C2C=C\((.*?)\)C=CC2=C1': '2,6',
    '(.*?)C1=CC=C2C=CC\((.*?)\)=CC2=C1': '2,7',
    },
3:
    {'(.*?)C1=CC=C2C=C3C=CC=CC3=CC2=C1(.*?)\Z': '1,2',
    '(.*?)C1=CC2=CC3=CC=CC=C3C=C2C=C1(.*?)\Z': '2,3',
    '(.*?)C1=CC=C2C=C3C=C\((.*?)\)C=CC3=CC2=C1': '2,6'
    }}

In [5]:
import re 
from tqdm import trange, tqdm
from rdkit import RDLogger
from rdkit import Chem
RDLogger.DisableLog('rdApp.*')

def generate_candidate_list(smiles):
    correct = False
    if 'C3' in smiles:
        ring_count = 3
    elif 'C2' in smiles:
        ring_count = 2
    else:
        ring_count = 1
        
    candidate_smiles_list = []
    site_list = []

    for k, v in retrieval_pattern_dict[ring_count].items():
        tmp = re.findall(k, smiles)  
        if len(tmp) != 0:  
            correct = True
            sub1 = tmp[0][0]
            sub2 = tmp[0][1]
            tgt_v = v
            break
        
    for k, v in generate_pattern_dict[ring_count].items():
        if tgt_v == '1' or v == '1' or v == tgt_v:
            continue
        backbone = re.sub(r'\(\.\*\?\)', 'PLACEHOLDER', k)
        backbone = backbone.replace(r'\(', '(').replace(r'\)', ')')
        
        candidate_smiles = backbone.replace('(PLACEHOLDER)', f"({sub2})")
        candidate_smiles = candidate_smiles.replace(r'PLACEHOLDER\Z', f"{sub2}")
        candidate_smiles = candidate_smiles.replace('PLACEHOLDER', sub1)

        candidate_smiles_list.append(candidate_smiles)

        site_list += [v]
    if not correct:  
        raise ValueError("Re pattern Error")  
    else:
        candidate_smiles_list = [smiles] + candidate_smiles_list
        site_list = [tgt_v] + site_list
        return candidate_smiles_list, site_list


batch_id = []
target_spectra_list = []
candidate_smiles_list = []
candidate_site_list = []

for i in trange(len(test_df)):

    target_smiles = test_df['kekule_smiles'].iloc[i]
    target_spectra = test_df['raman'].iloc[i]
    try:
        tmp_smiles_list, tmp_site_list = generate_candidate_list(target_smiles)
    except Exception:
        continue
    else:
        candidate_smiles_list.append(tmp_smiles_list)
        candidate_site_list.append(tmp_site_list)
        target_spectra_list.append([target_spectra] * len(tmp_smiles_list))
        batch_id.extend([i] * len(tmp_smiles_list))

100%|██████████| 860/860 [00:00<00:00, 51245.24it/s]


In [6]:
tgt_smiles_list = [item[0] for item in candidate_smiles_list]
candidate_smiles_list = [subitem for item in candidate_smiles_list for subitem in item]
target_spectra_list = [subitem for item in target_spectra_list for subitem in item]

In [7]:
# calculate similarity between predicted molecules and target spectra
import numpy as np
import torch
import torch.nn.functional as F

def calculate_similarity_matrix(embedding_query, embedding_key):
    if type(embedding_query) != torch.Tensor:
        embedding_query = torch.tensor(embedding_query)
    if type(embedding_key) != torch.Tensor:
        embedding_key = torch.tensor(embedding_key)
    
    embedding_query = F.normalize(embedding_query, p=2, dim=1)
    embedding_key = F.normalize(embedding_key, p=2, dim=1)

    similarity_matrix = torch.matmul(embedding_query, embedding_key.t())
    return similarity_matrix


class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra, pred_smiles):
        self.tgt_spectra = tgt_spectra
        self.pred_smiles = pred_smiles

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        # try:
        return self.tgt_spectra[idx], self.pred_smiles[idx]
        # except Exception:
        #     print(idx, len(self.tgt_spectra), len(self.pred_smiles))

class TestCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        # try:
        tgt_spectra, pred_smiles = zip(*batch)
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        input_ids = self.tokenizer(list(pred_smiles), return_tensors='pt', padding='max_length', max_length=256, truncation=True)
        input_ids = {'input_ids':input_ids['input_ids'].to(device), 'attention_mask':input_ids['attention_mask'].to(device)}
        return {'smiles': input_ids,  'spectra':spectra}
        # except Exception:
        #     print(batch)

    
valid_sim_list = []

test_dataset = TestDataset(target_spectra_list, candidate_smiles_list)
test_collator = TestCollator(tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        molecular_embedding = model.get_molecular_embeddings(batch, use_cls_token=True)['proj_output']
        spectral_embedding = model.get_spectral_embeddings(batch)['proj_output']
        sim = calculate_similarity_matrix(spectral_embedding, molecular_embedding)
    valid_sim_list += torch.diag(sim).tolist()

100%|██████████| 421/421 [01:31<00:00,  4.59it/s]


In [8]:
from rdkit import Chem
from tqdm import trange
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

def check_beam_mols(pred_smiles_list, tgt_smiles):
    pred_mol_list = []
    for item in pred_smiles_list:
        mol = Chem.MolFromSmiles(item)
        if mol is not None:
            try:
                inchi_key = Chem.MolToInchiKey(mol)
                pred_mol_list.append(inchi_key)
            except Exception as e:
                print(f"Error processing SMILES {item}: {e}")
                pred_mol_list.append('')
        else:
            pred_mol_list.append('')
    tgt_mol = Chem.MolToInchiKey(Chem.MolFromSmiles(tgt_smiles))
    if tgt_mol in pred_mol_list:
        return 1
    return 0

In [14]:
import pandas as pd

def get_ring_count(smiles):
    mol = Chem.MolFromSmiles(smiles)
    Chem.GetSymmSSSR(mol)  
    return mol.GetRingInfo().NumRings()
    
res = {}
valid_sim_list = np.array(valid_sim_list)
batch_id = np.array(batch_id)
for i, bid in enumerate(set(batch_id)):
    tmp_sim_list = valid_sim_list[batch_id == bid]
    tmp_smiles_list = np.array(candidate_smiles_list)[batch_id == bid]
    tmp_site_list = np.array(candidate_site_list[i])
    if len(tmp_site_list) < 2:
        continue
    res[i] = {'sim':tmp_sim_list[np.argsort(-tmp_sim_list)], 
              'smiles':tmp_smiles_list[np.argsort(-tmp_sim_list)].tolist(),
              'target_smiles':tgt_smiles_list[i],
              'sites':tmp_site_list[np.argsort(-tmp_sim_list)],
              }

df_cl = pd.DataFrame(res).T
df_cl['ring_count'] = df_cl['target_smiles'].apply(lambda x: get_ring_count(x))
df_cl['top_1'] = df_cl.apply(lambda row: check_beam_mols(row['smiles'][:1], row['target_smiles']), axis=1)
df_cl['top_3'] = df_cl.apply(lambda row: check_beam_mols(row['smiles'][:3], row['target_smiles']), axis=1)
df_cl['top_5'] = df_cl.apply(lambda row: check_beam_mols(row['smiles'][:5], row['target_smiles']), axis=1)
df_cl['top_10'] = df_cl.apply(lambda row: check_beam_mols(row['smiles'][:10], row['target_smiles']), axis=1)

print(f'top-1:\t\t{df_cl.top_1.mean():.5f}\ntop-3:\t\t{df_cl.top_3.mean():.5f}\ntop-5:\t\t{df_cl.top_5.mean():.5f}\ntop-10:\t\t{df_cl.top_10.mean():.5f}')
df_cl.groupby('ring_count').mean('correct')

top-1:		0.99535
top-3:		1.00000
top-5:		1.00000
top-10:		1.00000


Unnamed: 0_level_0,top_1,top_3,top_5,top_10
ring_count,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,0.995745,1.0,1.0,1.0
2,0.997455,1.0,1.0,1.0
3,0.991379,1.0,1.0,1.0


# 2 MLM (give situated sites, predict the functional groups)

In [10]:
pattern_dict = {
 '(.*?)C1=CC=CC=C1': '1',
 '(.*?)C1=CC2=CC=CC2=CC=CC=C1': '1',
 '(.*?)C1=C2C=CC=CC2=CC=CC=C1':'1',  
 '(.*?)C1=C\((.*?)\)C=CC=C1': '1,2',
 '(.*?)C1=CC=CC\((.*?)\)=C1': '1,3',
 '(.*?)C1=CC\((.*?)\)=CC=C1': '1,3',
 '(.*?)C1=CC=C\((.*?)\)C=C1': '1,4',
 '(.*?)C1=CC=C2C=CC=CC2=C1(.*?)\Z': '1,2',
 '(.*?)C1=C\((.*?)\)C=CC2=CC=CC=C12': '1,2',
 '(.*?)C1=CC=CC2=C\((.*?)\)C=CC=C12': '1,5',
 '(.*?)C1=CC=CC2=CC=CC\((.*?)\)=C12': '1,8',
 '(.*?)C1=CC=C2C=C\((.*?)\)C=CC2=C1': '2,6',
 '(.*?)C1=CC=C2C=CC\((.*?)\)=CC2=C1': '2,7',
 '(.*?)C1=CC=C2C=C3C=CC=CC3=CC2=C1(.*?)\Z': '1,2',
 '(.*?)C1=C\((.*?)\)C=CC2=CC3=CC=CC=C3C=C12': '1,2',
 '(.*?)C1=C\((.*?)\)C=C2C=C3C=CC=CC3=CC2=C1': '2,3',
 '(.*?)C1=CC2=CC3=CC=CC=C3C=C2C=C1(.*?)\Z': '2,3',
 '(.*?)C1=CC=C2C=C3C=C\((.*?)\)C=CC3=CC2=C1': '2,6'}

In [None]:
import re 
from tqdm import trange, tqdm
from rdkit import RDLogger
from rdkit import Chem
RDLogger.DisableLog('rdApp.*')

def generate_masked_list(smiles, mode='double'):
    correct = False
    masked_smiles_list = []

    for k, v in pattern_dict.items():
        tmp = re.findall(k, smiles)  
        if len(tmp) != 0:  
            correct = True
            sub1 = tmp[0][0]
            sub2 = tmp[0][1]           
            
            backbone = re.sub(r'\(\.\*\?\)', 'PLACEHOLDER', k)
            backbone = backbone.replace(r'\(', '(').replace(r'\)', ')')
            
            id1 = tokenizer.encode(sub1)  
            id2 = tokenizer.encode(sub2)  
            if mode == 'double':
                masked_smiles = backbone.replace('(PLACEHOLDER)', f"({'*' * (len(id2)-2)})")
                masked_smiles_cache = masked_smiles.replace(r'PLACEHOLDER\Z', f"{'*' * (len(id2)-2)}")
                masked_smiles = masked_smiles_cache.replace('PLACEHOLDER', '*' * (len(id1)-2))
                masked_smiles_list.append(masked_smiles)
                        
            elif mode == 'single':

                masked_smiles1 = backbone.replace('(PLACEHOLDER)', f"({'*' * (len(id2)-2)})")
                masked_smiles1 = masked_smiles1.replace(r'PLACEHOLDER\Z', f"{'*' * (len(id2)-2)}")                
                masked_smiles1 = masked_smiles1.replace('PLACEHOLDER', sub1)
                masked_smiles_list.append(masked_smiles1)
                
                masked_smiles2 = backbone.replace('(PLACEHOLDER)', f'({sub2})')
                masked_smiles2 = masked_smiles2.replace(r'PLACEHOLDER\Z', sub2)                
                masked_smiles2 = masked_smiles2.replace('PLACEHOLDER', '*' * (len(id1)-2))
                masked_smiles_list.append(masked_smiles2)
            break  

    if not correct:  
        raise ValueError("Re pattern Error")  
    return masked_smiles_list


target_smiles_list = []
target_spectra_list = []
candidate_smiles_list = []
batch_id = []

for i in trange(len(test_df)):

    target_smiles = test_df['kekule_smiles'].iloc[i]
    target_spectra = test_df['raman'].iloc[i]
    try:
        candidate_smiles = generate_masked_list(target_smiles, mode='single')
    except Exception:
        continue
    else:
        target_smiles_list.append(target_smiles)
        target_spectra_list.append([target_spectra] * len(candidate_smiles))
        candidate_smiles_list.append(candidate_smiles)
        batch_id.extend([i] * len(candidate_smiles))
# candidate_smiles_smiles_list = [[subitem.replace('*', '<mask>') for subitem in item] for item in candidate_smiles_smiles_list ]

(len(target_smiles_list), len(target_spectra_list), len(candidate_smiles_list))

100%|██████████| 860/860 [00:00<00:00, 6502.22it/s]


(860, 860, 860)

In [17]:
masked_smiles_list = [subitem.replace('*', '<mask>') for item in candidate_smiles_list for subitem in item]
target_spectra_list = [subitem for item in target_spectra_list for subitem in item]

In [18]:
from copy import deepcopy
import numpy as np 

all_pred_smiles = []
all_pred_scores = []
correct, total = 0, 0

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, smiles, spectra):
        self.smiles = smiles
        self.spectra = spectra

    def __len__(self):
        return len(self.smiles)

    def __getitem__(self, idx):
        return self.smiles[idx], self.spectra[idx]

class TestCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        smiles, spectra = zip(*batch)
        spectra = torch.as_tensor(np.array(spectra), dtype=torch.float32).unsqueeze(1).to(device)
        input_ids = self.tokenizer(list(smiles), return_tensors='pt', padding='max_length', max_length=256, truncation=True)
        input_ids = {'input_ids':input_ids['input_ids'].to(device), 'attention_mask':input_ids['attention_mask'].to(device)}
        return {'smiles': input_ids,  'spectra':spectra}

test_dataset = TestDataset(masked_smiles_list, target_spectra_list)
test_collator = TestCollator(tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_tokens_logits = model.infer_mlm(batch)
    pred_tokens_logits = torch.nn.functional.softmax(pred_tokens_logits, dim=-1)
    pred_score, pred_tokens = torch.max(pred_tokens_logits, dim=-1) 
    
    mask = (batch['smiles']['input_ids'] == 4).cpu()
    pred_score = [torch.mean(pred_score[i][mask[i]]).item() for i in range(len(pred_score))]
    
    output = deepcopy(batch['smiles']['input_ids'])
    output[mask] = pred_tokens[mask]
    
    preds = tokenizer.batch_decode(output, skip_special_tokens=True)
    all_pred_smiles.extend(preds)
    all_pred_scores.extend(pred_score)

100%|██████████| 27/27 [00:05<00:00,  5.04it/s]


In [19]:
# filter out invalid smiles
from tqdm import trange

valid_pred_smiles = []
valid_batch_id = []
valid_spectra_list = []
valid_scores = []

for i in trange(len(all_pred_smiles)):
    mol = Chem.MolFromSmiles(all_pred_smiles[i])
    if mol is not None:
        valid_pred_smiles.append(all_pred_smiles[i])
        valid_batch_id.append(batch_id[i])
        valid_spectra_list.append(target_spectra_list[i])
        valid_scores.append(all_pred_scores[i])
        
print(len(valid_pred_smiles), len(valid_batch_id), len(valid_spectra_list), len(valid_scores))

100%|██████████| 1720/1720 [00:00<00:00, 8271.64it/s]

1684 1684 1684 1684





## 2.1 rank by mlm score

In [None]:
import pandas as pd
from rdkit import Chem
from tqdm import trange
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

def check_beam_mols(pred_smiles_list, tgt_smiles):
    pred_mol_list = []
    for item in pred_smiles_list:
        mol = Chem.MolFromSmiles(item)
        if mol is not None:
            try:
                inchi_key = Chem.MolToInchiKey(mol)
                pred_mol_list.append(inchi_key)
            except Exception as e:
                print(f"Error processing SMILES {item}: {e}")
                pred_mol_list.append('')
        else:
            pred_mol_list.append('')
    try:
        tgt_mol = Chem.MolToInchiKey(Chem.MolFromSmiles(tgt_smiles))
    except Exception:
        print(tgt_smiles)
    if tgt_mol in pred_mol_list:
        return 1
    return 0

res = {}
valid_scores = np.array(valid_scores)
batch_id = np.array(valid_batch_id)

for i, bid in enumerate(set(batch_id)):
    tmp_pred_score = valid_scores[batch_id == bid]
    
    tmp_smiles_list = np.array(valid_pred_smiles)[batch_id == bid]
    tmp_smiles_list = tmp_smiles_list[np.argsort(-tmp_pred_score)].tolist()
    
    tmp_pred_score = tmp_pred_score[np.argsort(-tmp_pred_score)].tolist()
    tmp_smiles_list = list(dict.fromkeys(tmp_smiles_list))
    
    res[i] = {
              'mlm_score':tmp_pred_score,
              'smiles':tmp_smiles_list,
              'target_smiles':target_smiles_list[bid], 
              }

df_mlm = pd.DataFrame(res).T
df_mlm['ring_count'] = df_mlm['target_smiles'].apply(lambda x: Chem.MolFromSmiles(x).GetRingInfo().NumRings())
df_mlm['top_1'] = df_mlm.apply(lambda row: check_beam_mols(row['smiles'][:1], row['target_smiles']), axis=1)
df_mlm['top_3'] = df_mlm.apply(lambda row: check_beam_mols(row['smiles'][:3], row['target_smiles']), axis=1)
df_mlm['top_5'] = df_mlm.apply(lambda row: check_beam_mols(row['smiles'][:5], row['target_smiles']), axis=1)
df_mlm['top_10'] = df_mlm.apply(lambda row: check_beam_mols(row['smiles'][:10], row['target_smiles']), axis=1)

print(f'top-1:\t\t{df_mlm.top_1.mean():.5f}\ntop-3:\t\t{df_mlm.top_3.mean():.5f}\ntop-5:\t\t{df_mlm.top_5.mean():.5f}\ntop-10:\t\t{df_mlm.top_10.mean():.5f}')
df_mlm.groupby('ring_count').mean('top_1')

top-1:		0.98952
top-3:		0.99767
top-5:		0.99767
top-10:		0.99767


Unnamed: 0_level_0,top_1,top_3,top_5,top_10
ring_count,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,0.995745,1.0,1.0,1.0
2,0.987245,0.994898,0.994898,0.994898
3,0.987069,1.0,1.0,1.0


## 2.2 rank by contrastive(retrieval) model

In [35]:
import torch
import torch.nn.functional as F

def calculate_similarity_matrix(embedding_query, embedding_key):
    if type(embedding_query) != torch.Tensor:
        embedding_query = torch.tensor(embedding_query)
    if type(embedding_key) != torch.Tensor:
        embedding_key = torch.tensor(embedding_key)
    
    embedding_query = F.normalize(embedding_query, p=2, dim=1)
    embedding_key = F.normalize(embedding_key, p=2, dim=1)

    similarity_matrix = torch.matmul(embedding_query, embedding_key.t())
    return similarity_matrix


class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra, pred_smiles):
        self.tgt_spectra = tgt_spectra
        self.pred_smiles = pred_smiles

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx], self.pred_smiles[idx]

class TestCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        tgt_spectra, pred_smiles = zip(*batch)
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        input_ids = self.tokenizer(list(pred_smiles), return_tensors='pt', padding='max_length', max_length=256, truncation=True)
        input_ids = {'input_ids':input_ids['input_ids'].to(device), 'attention_mask':input_ids['attention_mask'].to(device)}
        return {'smiles': input_ids,  'spectra':spectra}


valid_sim_list = []

test_dataset = TestDataset(valid_spectra_list, valid_pred_smiles)
test_collator = TestCollator(tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        molecular_embedding = model.get_molecular_embeddings(batch, use_cls_token=True)['proj_output']
        spectral_embedding = model.get_spectral_embeddings(batch)['proj_output']
        sim = calculate_similarity_matrix(spectral_embedding, molecular_embedding)
    valid_sim_list += torch.diag(sim).tolist()

100%|██████████| 12/12 [00:01<00:00,  9.87it/s]


In [36]:
import pandas as pd
from rdkit import Chem
from tqdm import trange
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

def check_beam_mols(pred_smiles_list, tgt_smiles):
    pred_mol_list = []
    for item in pred_smiles_list:
        mol = Chem.MolFromSmiles(item)
        if mol is not None:
            try:
                inchi_key = Chem.MolToInchiKey(mol)
                pred_mol_list.append(inchi_key)
            except Exception as e:
                print(f"Error processing SMILES {item}: {e}")
                pred_mol_list.append('')
        else:
            pred_mol_list.append('')
    try:
        tgt_mol = Chem.MolToInchiKey(Chem.MolFromSmiles(tgt_smiles))
    except Exception:
        print(tgt_smiles)
    if tgt_mol in pred_mol_list:
        return 1
    return 0

res = {}
valid_sim_list = np.array(valid_sim_list)
valid_scores = np.array(valid_scores)
batch_id = np.array(valid_batch_id)

for i, bid in enumerate(set(batch_id)):
    tmp_pred_score = valid_scores[batch_id == bid]
    tmp_sim_list = valid_sim_list[batch_id == bid]
    
    tmp_smiles_list = np.array(valid_pred_smiles)[batch_id == bid]
    tmp_smiles_list = tmp_smiles_list[np.argsort(-tmp_sim_list)].tolist()
    
    tmp_pred_score = tmp_pred_score[np.argsort(-tmp_sim_list)].tolist()
    tmp_sim_list = tmp_sim_list[np.argsort(-tmp_sim_list)].tolist()
    

    tmp_sim_list = list(dict.fromkeys(tmp_sim_list))
    tmp_smiles_list = list(dict.fromkeys(tmp_smiles_list))
    
    res[i] = {'sim':tmp_sim_list, 
              'mlm_score':tmp_pred_score,
              'smiles':tmp_smiles_list,
              'target_smiles':target_smiles_list[bid], # target_smiles_list[i],
              }

df_mlm = pd.DataFrame(res).T
df_mlm['ring_count'] = df_mlm['target_smiles'].apply(lambda x: Chem.MolFromSmiles(x).GetRingInfo().NumRings())
df_mlm['top_1'] = df_mlm.apply(lambda row: check_beam_mols(row['smiles'][:1], row['target_smiles']), axis=1)
df_mlm['top_3'] = df_mlm.apply(lambda row: check_beam_mols(row['smiles'][:3], row['target_smiles']), axis=1)
df_mlm['top_5'] = df_mlm.apply(lambda row: check_beam_mols(row['smiles'][:5], row['target_smiles']), axis=1)
df_mlm['top_10'] = df_mlm.apply(lambda row: check_beam_mols(row['smiles'][:10], row['target_smiles']), axis=1)

print(f'top-1:\t\t{df_mlm.top_1.mean():.5f}\ntop-3:\t\t{df_mlm.top_3.mean():.5f}\ntop-5:\t\t{df_mlm.top_5.mean():.5f}\ntop-10:\t\t{df_mlm.top_10.mean():.5f}')

top-1:		0.92744
top-3:		0.92744
top-5:		0.92744
top-10:		0.92744


In [37]:
df_mlm.groupby('ring_count').mean('top_1')

Unnamed: 0_level_0,top_1,top_3,top_5,top_10
ring_count,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,0.953052,0.953052,0.953052,0.953052
2,0.923754,0.923754,0.923754,0.923754
3,0.906863,0.906863,0.906863,0.906863


# 3 de novo generation PAHs

In [None]:
import sys 
sys.path.append('..')

import torch

from models import build_model
from models import PretrainModel_PAHs, PretrainModel_Phase
from utils.base import seed_everything

seed_everything(624)
device = 'cuda:0'

model = build_model('vib2mol_phase').to(device)
ckpt = torch.load('../checkpoints/pahs/raman-kekule_smiles/vib2mol_phase.pth', 
                  map_location=device, weights_only=True)

ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
model.load_state_dict(ckpt)

<All keys matched successfully>

## 3.1 greedy decoding

In [None]:
import numpy as np
from tqdm import tqdm

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra):
        self.tgt_spectra = tgt_spectra

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx]

class TestCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        tgt_spectra = batch
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        return {'spectra':spectra}

all_pred_smiles = []
test_dataset = TestDataset(test_df.raman.tolist())
test_collator = TestCollator()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_smiles_ids = model.infer_lm(batch, max_len=64)['pred_ids']
    pred_smiles = tokenizer.batch_decode(pred_smiles_ids)
    pred_smiles = [item.split('</s>')[0].replace('<s>', '') for item in pred_smiles]
    all_pred_smiles.extend(pred_smiles)

100%|██████████| 7/7 [00:09<00:00,  1.31s/it]


In [24]:
from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

res_smiles = []
for item in all_pred_smiles:
    tmp_mol = Chem.MolFromSmiles(item)
    if tmp_mol is not None:
        tmp_smiles = Chem.MolToSmiles(tmp_mol, isomericSmiles=False, kekuleSmiles=True, canonical=True)
    else:
        tmp_smiles = '*'
    res_smiles.append(tmp_smiles)

In [25]:
from rdkit import Chem
from tqdm import trange

def check_mols(pred_smiles, tgt_smiles):
    pred_mol = Chem.MolFromSmiles(pred_smiles)
    tgt_mol = Chem.MolFromSmiles(tgt_smiles)
    if pred_mol is not None and tgt_mol is not None:
        if Chem.MolToInchiKey(pred_mol) == Chem.MolToInchiKey(tgt_mol):
            return 1
    return 0

import pandas as pd
df_lm = pd.DataFrame({'tgt_smiles':test_df.kekule_smiles.tolist(), 'pred_smiles':res_smiles, 'correct':[check_mols(res_smiles[i], test_df.kekule_smiles.tolist()[i]) for i in range(len(res_smiles))]})
df_lm['ring_count'] = df_lm['tgt_smiles'].apply(lambda x: Chem.MolFromSmiles(x).GetRingInfo().NumRings())
df_lm.groupby('ring_count').mean('correct')

Unnamed: 0_level_0,correct
ring_count,Unnamed: 1_level_1
1,0.940426
2,0.829517
3,0.840517


## 3.2 beam search

In [26]:
import numpy as np
from tqdm import tqdm 

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra):
        self.tgt_spectra = tgt_spectra

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx]

class TestCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        tgt_spectra = batch
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        return {'spectra':spectra}

beam_size = 10

all_pred_smiles = []
test_dataset = TestDataset(test_df.raman.tolist())
test_collator = TestCollator()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_smiles_ids_list = model.beam_infer_lm(batch, max_len=64, beam_size=beam_size, temperature=3.5)['pred_ids']
    for pred_smiles_ids in pred_smiles_ids_list:
        pred_smiles = tokenizer.batch_decode(pred_smiles_ids)
        pred_smiles = [item.split('</s>')[0].replace('<s>', '') for item in pred_smiles]
        all_pred_smiles.append(pred_smiles)

100%|██████████| 7/7 [00:55<00:00,  7.90s/it]


### 3.2.1 rank by beam score

In [27]:
from rdkit import Chem
from tqdm import trange
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

def check_beam_mols(pred_smiles_list, tgt_smiles):
    pred_mol_list = []
    for item in pred_smiles_list:
        mol = Chem.MolFromSmiles(item)
        if mol is not None:
            try:
                inchi_key = Chem.MolToInchiKey(mol)
                pred_mol_list.append(inchi_key)
            except Exception as e:
                print(f"Error processing SMILES {item}: {e}")
                pred_mol_list.append('')
        else:
            pred_mol_list.append('')
    tgt_mol = Chem.MolToInchiKey(Chem.MolFromSmiles(tgt_smiles))
    if tgt_mol in pred_mol_list:
        return 1
    return 0

In [28]:
import pandas as pd
import torch

df = pd.DataFrame({'tgt_smiles':test_df.kekule_smiles.tolist(), 'pred_smiles':[list(dict.fromkeys(item)) for item in all_pred_smiles]})
df['ring_count'] = df['tgt_smiles'].apply(lambda x: Chem.MolFromSmiles(x).GetRingInfo().NumRings())
df['top_1'] = df.apply(lambda row: check_beam_mols(row['pred_smiles'][:1], row['tgt_smiles']), axis=1)
df['top_3'] = df.apply(lambda row: check_beam_mols(row['pred_smiles'][:3], row['tgt_smiles']), axis=1)
df['top_5'] = df.apply(lambda row: check_beam_mols(row['pred_smiles'][:5], row['tgt_smiles']), axis=1)
df['top_10'] = df.apply(lambda row: check_beam_mols(row['pred_smiles'][:10], row['tgt_smiles']), axis=1)

print(f'top-1:\t\t{df.top_1.mean():.5f}\ntop-3:\t\t{df.top_3.mean():.5f}\ntop-5:\t\t{df.top_5.mean():.5f}\ntop-10:\t\t{df.top_10.mean():.5f}')
df.groupby('ring_count').mean('top_1')

top-1:		0.86628
top-3:		0.91512
top-5:		0.91628
top-10:		0.91628


Unnamed: 0_level_0,top_1,top_3,top_5,top_10
ring_count,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,0.953191,0.957447,0.957447,0.957447
2,0.824427,0.898219,0.900763,0.900763
3,0.849138,0.900862,0.900862,0.900862


### 3.3.2 rerank by contrastive(retrieval) model

In [29]:
candidate_smiles_list = [list(dict.fromkeys(item)) for item in all_pred_smiles]
candidate_spectra_list = [[test_df.raman.tolist()[i]] * len(item) for i, item in enumerate(candidate_smiles_list)]
tgt_smiles_list = [[test_df.kekule_smiles.tolist()[i]] * len(item) for i, item in enumerate(candidate_smiles_list)]

candidate_smiles_list = [subitem for item in candidate_smiles_list for subitem in item]
candidate_spectra_list = [subitem for item in candidate_spectra_list for subitem in item]
tgt_smiles_list = [subitem for item in tgt_smiles_list for subitem in item]

In [30]:
# calculate similarity between predicted molecules and target spectra
import torch
import torch.nn.functional as F

def calculate_similarity_matrix(embedding_query, embedding_key):
    if type(embedding_query) != torch.Tensor:
        embedding_query = torch.tensor(embedding_query)
    if type(embedding_key) != torch.Tensor:
        embedding_key = torch.tensor(embedding_key)
    
    embedding_query = F.normalize(embedding_query, p=2, dim=1)
    embedding_key = F.normalize(embedding_key, p=2, dim=1)

    similarity_matrix = torch.matmul(embedding_query, embedding_key.t())
    return similarity_matrix


class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra, pred_smiles):
        self.tgt_spectra = tgt_spectra
        self.pred_smiles = pred_smiles

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx], self.pred_smiles[idx]

class TestCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        tgt_spectra, pred_smiles = zip(*batch)
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        input_ids = self.tokenizer(list(pred_smiles), return_tensors='pt', padding='max_length', max_length=256, truncation=True)
        input_ids = {'input_ids':input_ids['input_ids'].to(device), 'attention_mask':input_ids['attention_mask'].to(device)}
        return {'smiles': input_ids,  'spectra':spectra}

    
valid_sim_list = []

test_dataset = TestDataset(candidate_spectra_list, candidate_smiles_list)
test_collator = TestCollator(tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        molecular_embedding = model.get_molecular_embeddings(batch, use_cls_token=True)['proj_output']
        spectral_embedding = model.get_spectral_embeddings(batch)['proj_output']
        sim = calculate_similarity_matrix(spectral_embedding, molecular_embedding)
    valid_sim_list += torch.diag(sim).tolist()

100%|██████████| 19/19 [00:02<00:00,  8.18it/s]


In [None]:
df = pd.DataFrame({'target_smiles':tgt_smiles_list, 'pred_smiles':candidate_smiles_list, 'similarity':valid_sim_list})
df_sorted = df.sort_values(by=['target_smiles', 'similarity'], ascending=[True, False])

grouped = df_sorted.groupby('target_smiles').agg({
    'pred_smiles': lambda x: ','.join(x),
    'similarity': lambda x: ','.join(map(str, x))
}).reset_index()

for top_k in [1, 3, 5, 10]: 
    grouped[f'top_{top_k}'] = grouped.apply(lambda row: check_beam_mols(row['pred_smiles'].split(',')[:top_k], row['target_smiles']) , axis=1)


In [32]:
print(f"""
recall@1:\t{grouped.top_1.mean():.5f} 
recall@3:\t{grouped.top_3.mean():.5f} 
recall@5:\t{grouped.top_5.mean():.5f} 
recall@10:\t{grouped.top_10.mean():.5f}
    """)

grouped['ring_count'] = grouped['target_smiles'].apply(lambda x: Chem.MolFromSmiles(x).GetRingInfo().NumRings())
grouped.groupby('ring_count').mean('top_1')


recall@1:	0.88140 
recall@3:	0.91512 
recall@5:	0.91628 
recall@10:	0.91628
    


Unnamed: 0_level_0,top_1,top_3,top_5,top_10
ring_count,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,0.953191,0.957447,0.957447,0.957447
2,0.867684,0.900763,0.900763,0.900763
3,0.831897,0.896552,0.900862,0.900862
