In [1]:
import sys 
sys.path.append('../')

import torch

from models import build_model
from models import BertModel, PretrainModel_CL, PretrainModel_MLM, PretrainModel_LM, PretrainModel_CL_MLM, PretrainModel_CL_LM, PretrainModel_ALL, PretrainModel_Phase 
from utils.base import seed_everything

seed_everything(624)
device = 'cpu'

# model = build_model('bert_mlm').to(device)
# ckpt = torch.load('../checkpoints/mols/raman-kekule_smiles/bert_mlm.pth', 
#                   map_location=device, weights_only=True)

# model = build_model('vib2mol_cl').to(device)
# ckpt = torch.load('../checkpoints/mols/raman-kekule_smiles/vib2mol_cl.pth', 
#                   map_location=device, weights_only=True)

# model = build_model('vib2mol_mlm').to(device)
# ckpt = torch.load('../checkpoints/mols/raman-kekule_smiles/vib2mol_mlm.pth', 
#                   map_location=device, weights_only=True)

# model = build_model('vib2mol_lm').to(device)
# ckpt = torch.load('../checkpoints/mols/raman-kekule_smiles/vib2mol_lm.pth', 
#                   map_location=device, weights_only=True)

# model = build_model('vib2mol_cl_mlm').to(device)
# ckpt = torch.load('../checkpoints/mols/raman-kekule_smiles/vib2mol_cl_mlm.pth', 
#                   map_location=device, weights_only=True)

# model = build_model('vib2mol_cl_lm').to(device)
# ckpt = torch.load('../checkpoints/mols/raman-kekule_smiles/vib2mol_cl_lm.pth', 
#                   map_location=device, weights_only=True)

# model = build_model('vib2mol_all').to(device)
# ckpt = torch.load('../checkpoints/mols/raman-kekule_smiles/vib2mol_cl_mlm_lm.pth', 
#                   map_location=device, weights_only=True)

model = build_model('vib2mol_phase').to(device)
ckpt = torch.load('../checkpoints/mols/raman-kekule_smiles/vib2mol_phase.pth', 
                  map_location=device, weights_only=True)


ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
model.load_state_dict(ckpt, strict=False)

<All keys matched successfully>

# 1 Evaluate contrastive retrieval

In [2]:
from utils.dataloader import Dataloader
from utils.collators import BaseCollator
from utils.base import BaseEngine

dataloader = Dataloader(lmdb_path='qm9', 
                            data_dir=f'../datasets/vibench', 
                            target_keys=['raman', 'kekule_smiles'], 
                            collate_fn=BaseCollator(spectral_types=['raman'],tokenizer_path=f'../models/MolTokenizer'), 
                            device=device)

test_loader = dataloader.generate_dataloader(mode='test', batch_size=64)

engine = BaseEngine(test_loader=test_loader, model=model, device=device, device_rank=0)
out = engine.infer()

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 417/417 [04:47<00:00,  1.45it/s]


In [3]:
import torch
import torch.nn.functional as F

def calculate_similarity_matrix(embedding_query, embedding_key):
    embedding_query = F.normalize(embedding_query, p=2, dim=1)
    embedding_key = F.normalize(embedding_key, p=2, dim=1)

    similarity_matrix = torch.matmul(embedding_query, embedding_key.t())
    return similarity_matrix

def compute_recall(similarity_matrix, k, verbose=False):
    num_queries = similarity_matrix.size(0)
    _, topk_indices = similarity_matrix.topk(k, dim=1, largest=True, sorted=True)
    correct = 0
    for i in range(num_queries):
        if i in topk_indices[i]:
            correct += 1
    recall_at_k = correct / num_queries
    
    if verbose:
        print(f'recall@{k}:{recall_at_k:.5f}')
    else:
        return recall_at_k

similarity_matrix = calculate_similarity_matrix(out['spectral_proj_output'], out['molecular_proj_output'])
compute_recall(similarity_matrix, k=1, verbose=True)
compute_recall(similarity_matrix, k=3, verbose=True)
compute_recall(similarity_matrix, k=5, verbose=True)
compute_recall(similarity_matrix, k=10, verbose=True)
compute_recall(similarity_matrix, k=100, verbose=True)

recall@1:0.81508
recall@3:0.96373
recall@5:0.98168
recall@10:0.99314
recall@100:0.99951


In [11]:
similarity_matrix.shape

torch.Size([26687, 26687])

# 2 MLM

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import lmdb
import pickle
import pandas as pd 
from tqdm import tqdm 
from transformers import AutoTokenizer

db = lmdb.open('../datasets/vibench/qm9/qm9_test.lmdb', subdir=False, lock=False, map_size=int(1e11))

# Open a transaction and perform a read operation
with db.begin() as txn:
    test_data = list(txn.cursor())

test_df = pd.DataFrame([pickle.loads(item[1]) for item in test_data])
tokenizer = AutoTokenizer.from_pretrained('../models/MolTokenizer')


  from .autonotebook import tqdm as notebook_tqdm
  test_df = pd.DataFrame([pickle.loads(item[1]) for item in test_data])


In [5]:
def extract_branches(s):
    result = []
    stack = []
    start = -1
    
    for i, char in enumerate(s):
        if char == '(':
            if not stack:
                start = i
            stack.append(i)
        elif char == ')':
            if stack:
                stack.pop()
                if not stack:
                    result.append(s[start+1:i])
                    extract_nested = extract_branches(s[start+1:i])
                    result.extend(extract_nested)
    
    return result

# test
test_string = "NC1(C(=O)O)CC1C1=CC=CC=C1"
branches = extract_branches(test_string)
for branch in branches:
    masked_input = test_string.replace(branch, '<mask>'*len(branch))
    print(branch, masked_input)

C(=O)O NC1(<mask><mask><mask><mask><mask><mask>)CC1C1=CC=CC=C1
=O NC1(C(<mask><mask>)O)CC1C1=CC=CC=C1


## evaluate molecular accuracy

In [6]:
raw_smiles = []
masked_smiles = []
masked_spectra = []
branches_list = []
masked_smi_dict = {}

mask_prob = None
for smi, spec in tqdm(zip(test_df['kekule_smiles'].to_list(), test_df['raman'].to_list()), total=len(test_df)):
    branches = extract_branches(smi)
    for branch in branches:
        if mask_prob is None or (len(smi) * mask_prob < len(branch) and len(branch) <= len(smi) * (mask_prob + 0.15)):
            len_token = len(tokenizer(branch)['input_ids'])-2
            masked_smi = smi.replace(f"({branch})", f"({'<mask>'*len_token})")
            if masked_smi not in masked_smi_dict:
                branches_list.append(branch)
                masked_smi_dict[masked_smi] = 1
                masked_smiles.append(masked_smi)
                masked_spectra.append(spec)
                raw_smiles.append(smi)

print(len(raw_smiles), len(branches_list))

100%|██████████| 26687/26687 [00:00<00:00, 73311.93it/s]

20966 20966





In [None]:
import numpy as np 

all_pred_smiles = []
correct, total = 0, 0

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, smiles, spectra):
        self.smiles = smiles
        self.spectra = spectra

    def __len__(self):
        return len(self.smiles)

    def __getitem__(self, idx):
        return self.smiles[idx], self.spectra[idx]

class TestCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        smiles, spectra = zip(*batch)
        spectra = torch.as_tensor(np.array(spectra), dtype=torch.float32).unsqueeze(1).to(device)
        input_ids = self.tokenizer(list(smiles), return_tensors='pt', padding='max_length', max_length=256, truncation=True)
        input_ids = {'input_ids':input_ids['input_ids'].to(device), 'attention_mask':input_ids['attention_mask'].to(device)}
        return {'smiles': input_ids,  'spectra':spectra}

test_dataset = TestDataset(masked_smiles, masked_spectra)
test_collator = TestCollator(tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_tokens_logits = model.infer_mlm(batch)
    pred_tokens = torch.argmax(pred_tokens_logits, dim=-1)
    
    output = batch['smiles']['input_ids'][:]
    mask = (batch['smiles']['input_ids'] == 4).cpu()
    output[mask] = pred_tokens[mask]
    
    preds = tokenizer.batch_decode(output, skip_special_tokens=True)
    all_pred_smiles.extend(preds)

In [None]:
import numpy as np
from tqdm import trange

tgt_fgs = []
pred_fgs = []
tgt_total_counter = {}
tgt_correct_counter = {}

for i in trange(len(raw_smiles)):

    masked_fg = np.array(tokenizer(masked_smiles[i])['input_ids'])
    tgt_fg = np.array(tokenizer(raw_smiles[i])['input_ids'])
    pred_fg = np.array(tokenizer(all_pred_smiles[i])['input_ids'])[:len(masked_fg)]

    # 找出所有mask的索引
    indices_of_mask = np.where(masked_fg == 4)[0]

    # 找出所有连续mask的索引
    consecutive_indices = []
    current_list = []

    for i in range(len(indices_of_mask)):
        if i == 0 or indices_of_mask[i] == indices_of_mask[i - 1] + 1:
            current_list.append(indices_of_mask[i])
        else:
            if current_list:
                consecutive_indices.append(current_list)
            current_list = [indices_of_mask[i]]
    if current_list:
        consecutive_indices.append(current_list)

    for mask in consecutive_indices:
        if len(mask) == 1: mask = mask[0]
        tgt_fg_str = tokenizer.decode(tgt_fg[mask])
        pred_fg_str = tokenizer.decode(pred_fg[mask])
        tgt_fgs.append(tgt_fg_str)
        pred_fgs.append(pred_fg_str)

        if tgt_fg_str in tgt_total_counter:
            tgt_total_counter[tgt_fg_str] += 1
        else:
            tgt_total_counter[tgt_fg_str] = 1
        if pred_fg_str == tgt_fg_str and pred_fg_str not in tgt_correct_counter:
            tgt_correct_counter[tgt_fg_str] = 1
        elif pred_fg_str == tgt_fg_str and pred_fg_str  in tgt_correct_counter:
            tgt_correct_counter[tgt_fg_str] += 1

100%|██████████| 37100/37100 [00:07<00:00, 5078.88it/s]


In [None]:
import pandas as pd

df1 = pd.DataFrame({'fg':tgt_total_counter.keys(), 'count':tgt_total_counter.values()})
df2 = pd.DataFrame({'fg':tgt_correct_counter.keys(), 'correct':tgt_correct_counter.values()})
df = pd.merge(df1, df2, on='fg', how='left')
df[df.isna()] = 0
df['accuracy'] = df['correct'] / df['count']
print(f"unique: {len(df)} | total: {df['count'].sum()} | correct: {df['correct'].sum()} | accuracy: {df['correct'].sum() / df['count'].sum()}")

df.sort_values('count',ascending=False).head(20)

unique: 1744 | total: 41334 | correct: 38390.0 | accuracy: 0.9287753423331881


# 3 spectrum-guided casual decoding

In [None]:
import lmdb
# Open LMDB 
db = lmdb.open('../datasets/vibench/zinc15/zinc15_test.lmdb', subdir=False, lock=False, map_size=int(1e11))

# Open a transaction and perform a read operation
with db.begin() as txn:
    test_data = list(txn.cursor())

In [None]:
import pickle
import multiprocessing as mp 
from tqdm.auto import tqdm 

smiles = [pickle.loads(item[1])['kekule_smiles'] for item in tqdm(test_data)]
spectra = [pickle.loads(item[1])['raman'] for item in tqdm(test_data)]
# filenames = [pickle.loads(item[1])['filename'] for item in tqdm(test_data)]

length = [len(item) for item in smiles]
max_len = max(length)+2
print(f'max_len:{max_len}')
    
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('../models/MolTokenizer')


  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 10883/10883 [00:00<00:00, 74187.15it/s]
100%|██████████| 10883/10883 [00:00<00:00, 33787.08it/s]


max_len:102


## 3.1 greedy generation

In [58]:
import numpy as np

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra):
        self.tgt_spectra = tgt_spectra

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx]

class TestCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        tgt_spectra = batch
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        return {'spectra':spectra}

all_pred_smiles = []
test_dataset = TestDataset(spectra)
test_collator = TestCollator()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_smiles_ids = model.infer_lm(batch, max_len=max_len)['pred_ids']
    pred_smiles = tokenizer.batch_decode(pred_smiles_ids)
    pred_smiles = [item.split('</s>')[0].replace('<s>', '') for item in pred_smiles]
    all_pred_smiles.extend(pred_smiles)

100%|██████████| 209/209 [01:38<00:00,  2.11it/s]


In [59]:
from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

res_smiles = []
for item in all_pred_smiles:
    tmp_mol = Chem.MolFromSmiles(item)
    if tmp_mol is not None:
        tmp_smiles = Chem.MolToSmiles(tmp_mol, isomericSmiles=False, kekuleSmiles=True, canonical=True)
    else:
        tmp_smiles = '*'
    res_smiles.append(tmp_smiles)

In [60]:
from rdkit import Chem
from tqdm import trange

def check_mols(pred_smiles, tgt_smiles):
    pred_mol = Chem.MolFromSmiles(pred_smiles)
    tgt_mol = Chem.MolFromSmiles(tgt_smiles)
    if pred_mol is not None and tgt_mol is not None:
        if Chem.MolToInchiKey(pred_mol) == Chem.MolToInchiKey(tgt_mol):
            return 1
    return 0


In [None]:
import pandas as pd
filenames = [pickle.loads(item[1])['filename'] for item in tqdm(test_data)]
df = pd.DataFrame({'pred':res_smiles, 'tgt':smiles, 'filename':filenames, 'correct':[check_mols(res_smiles[i], smiles[i]) for i in trange(len(res_smiles))]})
print(df.correct.mean())
df.head()

100%|██████████| 26687/26687 [00:00<00:00, 73987.25it/s]
100%|██████████| 26687/26687 [00:12<00:00, 2107.01it/s]


0.6289579195863154


Unnamed: 0,pred,tgt,filename,correct
0,COCC12CC1OC2=N,COCC12CC1OC2=N,dsgdb9nsd_119549,1
1,O=C1C2COC1(CO)C2,O=C1C2COC1(CO)C2,dsgdb9nsd_106611,1
2,OCC1C(O)C1O,OCC1C(O)C1O,dsgdb9nsd_003107,1
3,CC1C(=N)OC2CC21O,CC1C(=N)OC2CC21O,dsgdb9nsd_075828,1
4,CC1OC(C)(C)C1C=O,CC1OC(C)C1(C)C=O,dsgdb9nsd_086297,0


In [62]:
print(df[df.filename.str.startswith('d')].correct.mean(),
df[df.filename.str.startswith('Z')].correct.mean())

0.6289579195863154 nan


## 3.2 beam search

In [4]:
import numpy as np

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra):
        self.tgt_spectra = tgt_spectra

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx]

class TestCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        tgt_spectra = batch
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        return {'spectra':spectra}

beam_size = 10

all_pred_smiles = []
test_dataset = TestDataset(spectra[:])
test_collator = TestCollator()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_smiles_ids_list = model.beam_infer_lm(batch, max_len=max_len, beam_size=beam_size, temperature=3.5)['pred_ids']
    for pred_smiles_ids in pred_smiles_ids_list:
        pred_smiles = tokenizer.batch_decode(pred_smiles_ids)
        pred_smiles = [item.split('</s>')[0].replace('<s>', '') for item in pred_smiles]
        all_pred_smiles.append(pred_smiles)

100%|██████████| 86/86 [19:41<00:00, 13.74s/it]


### rank by beam score

In [5]:
from rdkit import Chem
from tqdm import trange
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

def check_beam_mols(pred_smiles_list, tgt_smiles):
    pred_mol_list = []
    for item in pred_smiles_list:
        mol = Chem.MolFromSmiles(item)
        if mol is not None:
            try:
                inchi_key = Chem.MolToInchiKey(mol)
                pred_mol_list.append(inchi_key)
            except Exception as e:
                print(f"Error processing SMILES {item}: {e}")
                pred_mol_list.append('')
        else:
            pred_mol_list.append('')
    tgt_mol = Chem.MolToInchiKey(Chem.MolFromSmiles(tgt_smiles))
    if tgt_mol in pred_mol_list:
        return 1
    return 0

In [6]:
import pandas as pd
import torch

df = pd.DataFrame({'tgt_smiles':smiles, 'pred_smiles':[list(dict.fromkeys(item)) for item in all_pred_smiles]})
df['top_1'] = df.apply(lambda row: check_beam_mols(row['pred_smiles'][:1], row['tgt_smiles']), axis=1)
df['top_3'] = df.apply(lambda row: check_beam_mols(row['pred_smiles'][:3], row['tgt_smiles']), axis=1)
df['top_5'] = df.apply(lambda row: check_beam_mols(row['pred_smiles'][:5], row['tgt_smiles']), axis=1)
df['top_10'] = df.apply(lambda row: check_beam_mols(row['pred_smiles'][:10], row['tgt_smiles']), axis=1)

print(f'top-1:\t\t{df.top_1.mean():.5f}\ntop-3:\t\t{df.top_3.mean():.5f}\ntop-5:\t\t{df.top_5.mean():.5f}\ntop-10:\t\t{df.top_10.mean():.5f}')

top-1:		0.49858
top-3:		0.58835
top-5:		0.59809
top-10:		0.59901
