# ViBench

In [1]:
import sys 
sys.path.append('../')

import torch

from models import build_model
from models import PretrainModel_Phase
from utils.base import seed_everything

seed_everything(624)
device = 'cpu'

model = build_model('vib2mol_phase', spectral_channel=2).to(device)
ckpt = torch.load('../checkpoints/mols/raman-ir-kekule_smiles/vib2mol_phase.pth', 
                  map_location=device, weights_only=True)

ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
model.load_state_dict(ckpt, strict=False)

<All keys matched successfully>

## Evaluate contrastive retrieval

In [2]:
from utils.dataloader import Dataloader
from utils.collators import BaseCollator
from utils.base import BaseEngine


dataloader = Dataloader(lmdb_path='mols', 
                            data_dir='../datasets/vibench', 
                            target_keys=['raman', 'ir', 'kekule_smiles'], 
                            collate_fn=BaseCollator(spectral_types=['ir', 'raman'], tokenizer_path='../models/MolTokenizer'), 
                            device=device)

test_loader = dataloader.generate_dataloader(mode='test', batch_size=64)

engine = BaseEngine(test_loader=test_loader, model=model, device=device, device_rank=0)
out = engine.infer()

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 588/588 [08:35<00:00,  1.14it/s]


In [3]:
import torch
import torch.nn.functional as F

def calculate_similarity_matrix(embedding_query, embedding_key):
    embedding_query = F.normalize(embedding_query, p=2, dim=1)
    embedding_key = F.normalize(embedding_key, p=2, dim=1)

    similarity_matrix = torch.matmul(embedding_query, embedding_key.t())
    return similarity_matrix

def compute_recall(similarity_matrix, k, verbose=False):
    num_queries = similarity_matrix.size(0)
    _, topk_indices = similarity_matrix.topk(k, dim=1, largest=True, sorted=True)
    correct = 0
    for i in range(num_queries):
        if i in topk_indices[i]:
            correct += 1
    recall_at_k = correct / num_queries
    
    if verbose:
        print(f'recall@{k}:{recall_at_k:.5f}')
    else:
        return recall_at_k

similarity_matrix = calculate_similarity_matrix(out['spectral_proj_output'], out['molecular_proj_output'])
compute_recall(similarity_matrix, k=1, verbose=True)
compute_recall(similarity_matrix, k=3, verbose=True)
compute_recall(similarity_matrix, k=5, verbose=True)
compute_recall(similarity_matrix, k=10, verbose=True)
compute_recall(similarity_matrix, k=100, verbose=True)

recall@1:0.78829
recall@3:0.95917
recall@5:0.98126
recall@10:0.99377
recall@100:0.99952


# 3 spectrum-guided casual decoding

In [None]:
import lmdb
# Open LMDB 
db = lmdb.open('../datasets/vibench/qm9/qm9_test.lmdb', subdir=False, lock=False, map_size=int(1e11))

# Open a transaction and perform a read operation
with db.begin() as txn:
    test_data = list(txn.cursor())

In [None]:
import pickle
import pandas as pd
import multiprocessing as mp 
from tqdm import tqdm 
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('../models/MolTokenizer')
test_df =  pd.DataFrame([pickle.loads(item[1]) for item in tqdm(test_data)])
length = [len(item) for item in test_df['kekule_smiles'].to_list()]
max_len = max(length)+2
print(f'max_len:{max_len}')


100%|██████████| 26687/26687 [00:00<00:00, 42281.18it/s]


max_len:34


## 3.1 greedy generation

In [17]:
import numpy as np

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, raman, ir):
        self.raman = raman
        self.ir = ir

    def __len__(self):
        return len(self.raman)

    def __getitem__(self, idx):
        return self.raman[idx], self.ir[idx]

class TestCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        tmp_raman, tmp_ir = zip(*batch)
        tmp_raman = torch.as_tensor(np.array(tmp_raman), dtype=torch.float32).unsqueeze(1).to(device)
        tmp_ir = torch.as_tensor(np.array(tmp_ir), dtype=torch.float32).unsqueeze(1).to(device)
        return {'raman':tmp_raman, 'ir':tmp_ir}
        # return {'ir':tmp_ir}

all_pred_smiles = []
# test_dataset = TestDataset(test_df['ir'].to_list(), test_df['ir'].to_list())
test_dataset = TestDataset(test_df['raman'].to_list(), test_df['ir'].to_list())

test_collator = TestCollator()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_smiles_ids = model.infer_lm(batch, max_len=max_len)['pred_ids']
    pred_smiles = tokenizer.batch_decode(pred_smiles_ids)
    pred_smiles = [item.split('</s>')[0].replace('<s>', '') for item in pred_smiles]
    all_pred_smiles.extend(pred_smiles)

100%|██████████| 209/209 [01:39<00:00,  2.09it/s]


In [18]:
from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

res_smiles = []
for item in all_pred_smiles:
    tmp_mol = Chem.MolFromSmiles(item)
    if tmp_mol is not None:
        tmp_smiles = Chem.MolToSmiles(tmp_mol, isomericSmiles=False, kekuleSmiles=True, canonical=True)
    else:
        tmp_smiles = '*'
    res_smiles.append(tmp_smiles)

In [19]:
from rdkit import Chem
from tqdm import trange

def check_mols(pred_smiles, tgt_smiles):
    pred_mol = Chem.MolFromSmiles(pred_smiles)
    tgt_mol = Chem.MolFromSmiles(tgt_smiles)
    if pred_mol is not None and tgt_mol is not None:
        if Chem.MolToInchiKey(pred_mol) == Chem.MolToInchiKey(tgt_mol):
            return 1
    return 0


In [20]:
import pandas as pd

df = pd.DataFrame({'pred':res_smiles, 
                   'tgt':test_df['kekule_smiles'].to_list(), 
                   'filename':test_df['filename'].to_list(), 
                   'correct':[check_mols(res_smiles[i], test_df['kekule_smiles'].to_list()[i]) for i in trange(len(res_smiles))]})
print(df.correct.mean())
df.head()

100%|██████████| 26687/26687 [00:25<00:00, 1049.92it/s]

0.6651927904972459





Unnamed: 0,pred,tgt,filename,correct
0,COC1C2CC1C(=N)O2,COCC12CC1OC2=N,dsgdb9nsd_119549,0
1,O=C1C2CC1(CO)CO2,O=C1C2COC1(CO)C2,dsgdb9nsd_106611,0
2,OCC1C(O)C1O,OCC1C(O)C1O,dsgdb9nsd_003107,1
3,CC1(O)CC2OC(=N)C21,CC1C(=N)OC2CC21O,dsgdb9nsd_075828,0
4,CC1OC(C=O)C1(C)C,CC1OC(C)C1(C)C=O,dsgdb9nsd_086297,0


## 3.2 beam search

In [22]:
import numpy as np

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, raman, ir):
        self.raman = raman
        self.ir = ir

    def __len__(self):
        return len(self.raman)

    def __getitem__(self, idx):
        return self.raman[idx], self.ir[idx]

class TestCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        tmp_raman, tmp_ir = zip(*batch)
        tmp_raman = torch.as_tensor(np.array(tmp_raman), dtype=torch.float32).unsqueeze(1).to(device)
        tmp_ir = torch.as_tensor(np.array(tmp_ir), dtype=torch.float32).unsqueeze(1).to(device)
        return {'raman':tmp_raman, 'ir':tmp_ir}
        # return {'ir':tmp_ir}

beam_size = 10

all_pred_smiles = []
test_dataset = TestDataset(test_df['raman'].to_list(), test_df['ir'].to_list())
test_collator = TestCollator()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_smiles_ids_list = model.beam_infer_lm(batch, max_len=max_len, beam_size=beam_size, temperature=3.5)['pred_ids']
    for pred_smiles_ids in pred_smiles_ids_list:
        pred_smiles = tokenizer.batch_decode(pred_smiles_ids)
        pred_smiles = [item.split('</s>')[0].replace('<s>', '') for item in pred_smiles]
        all_pred_smiles.append(pred_smiles)

100%|██████████| 209/209 [10:43<00:00,  3.08s/it]


### rank by beam score

In [23]:
from rdkit import Chem
from tqdm import trange
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

def check_beam_mols(pred_smiles_list, tgt_smiles):
    pred_mol_list = []
    for item in pred_smiles_list:
        mol = Chem.MolFromSmiles(item)
        if mol is not None:
            try:
                inchi_key = Chem.MolToInchiKey(mol)
                pred_mol_list.append(inchi_key)
            except Exception as e:
                print(f"Error processing SMILES {item}: {e}")
                pred_mol_list.append('')
        else:
            pred_mol_list.append('')
    tgt_mol = Chem.MolToInchiKey(Chem.MolFromSmiles(tgt_smiles))
    if tgt_mol in pred_mol_list:
        return 1
    return 0

In [24]:
import pandas as pd
import torch

df = pd.DataFrame({'tgt_smiles':test_df['kekule_smiles'].to_list(), 'pred_smiles':[list(dict.fromkeys(item)) for item in all_pred_smiles]})
df['top_1'] = df.apply(lambda row: check_beam_mols(row['pred_smiles'][:1], row['tgt_smiles']), axis=1)
df['top_3'] = df.apply(lambda row: check_beam_mols(row['pred_smiles'][:3], row['tgt_smiles']), axis=1)
df['top_5'] = df.apply(lambda row: check_beam_mols(row['pred_smiles'][:5], row['tgt_smiles']), axis=1)
df['top_10'] = df.apply(lambda row: check_beam_mols(row['pred_smiles'][:10], row['tgt_smiles']), axis=1)

print(f'top-1:\t\t{df.top_1.mean():.5f}\ntop-3:\t\t{df.top_3.mean():.5f}\ntop-5:\t\t{df.top_5.mean():.5f}\ntop-10:\t\t{df.top_10.mean():.5f}')

top-1:		0.69378
top-3:		0.79271
top-5:		0.79616
top-10:		0.79634


# MLM

In [None]:
import lmdb
# Open LMDB 
db = lmdb.open('../datasets/vibench/mols/mols_test.lmdb', subdir=False, lock=False, map_size=int(1e11))

# Open a transaction and perform a read operation
with db.begin() as txn:
    test_data = list(txn.cursor())

In [None]:
import pickle
import pandas as pd
import multiprocessing as mp 
from tqdm import tqdm 
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('../models/MolTokenizer')
test_df =  pd.DataFrame([pickle.loads(item[1]) for item in tqdm(test_data)])

100%|██████████| 37570/37570 [00:00<00:00, 49084.88it/s]


In [55]:
def mask_manually(input_ids, masked_indices):
    mask = torch.zeros_like(input_ids).bool()
    mask[:, torch.tensor(masked_indices)+1] = True
    masked_ids = input_ids.clone()           
    mask[masked_ids == 1] = False # tokenizer.pad_token_id = 1
    masked_ids[mask] = 4 # tokenizer.mask_token_id = 4      
    return masked_ids, mask

def generate_mlmmask(input_ids, mask_prob=0.15):
    masked_ids = input_ids.clone()
    probability_matrix = torch.full(masked_ids.shape, mask_prob)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    masked_indices[masked_ids == 1] = False # tokenizer.pad_token_id = 1
    masked_ids[masked_indices] = 4 # tokenizer.mask_token_id = 4
    
    return masked_ids, masked_indices

In [None]:
def extract_branches(s):
    result = []
    stack = []
    start = -1
    
    for i, char in enumerate(s):
        if char == '(':
            if not stack:
                start = i
            stack.append(i)
        elif char == ')':
            if stack:
                stack.pop()
                if not stack:
                    result.append(s[start+1:i])
                    extract_nested = extract_branches(s[start+1:i])
                    result.extend(extract_nested)
    
    return result

test_string = "NC1(C(=O)O)CC1C1=CC=CC=C1"
branches = extract_branches(test_string)
for branch in branches:
    masked_input = test_string.replace(branch, '<mask>'*len(branch))
    print(branch, masked_input)

C(=O)O NC1(<mask><mask><mask><mask><mask><mask>)CC1C1=CC=CC=C1
=O NC1(C(<mask><mask>)O)CC1C1=CC=CC=C1


## 2.2 evaluate molecular accuracy

In [57]:
raw_smiles = []
masked_smiles = []
masked_spectra = []
branches_list = []
masked_smi_dict = {}

mask_prob = None
for i in trange(len(test_df)):
    smi = test_df.iloc[i]['kekule_smiles']
    raman = test_df.iloc[i]['raman']
    ir = test_df.iloc[i]['ir']
    branches = extract_branches(smi)
    spec = np.vstack([raman, ir])
    for branch in branches:
        if mask_prob is None or (len(smi) * mask_prob < len(branch) and len(branch) <= len(smi) * (mask_prob + 0.15)):
            len_token = len(tokenizer(branch)['input_ids'])-2
            masked_smi = smi.replace(f"({branch})", f"({'<mask>'*len_token})")
            if masked_smi not in masked_smi_dict:
                branches_list.append(branch)
                masked_smi_dict[masked_smi] = 1
                masked_smiles.append(masked_smi)
                masked_spectra.append(spec)
                raw_smiles.append(smi)

100%|██████████| 37570/37570 [00:04<00:00, 8544.30it/s]


In [58]:
from copy import deepcopy
import numpy as np 

all_pred_smiles = []
correct, total = 0, 0

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, smiles, spectra):
        self.smiles = smiles
        self.spectra = spectra

    def __len__(self):
        return len(self.smiles)

    def __getitem__(self, idx):
        return self.smiles[idx], self.spectra[idx]

class TestCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        smiles, spectra = zip(*batch)
        spectra = torch.as_tensor(np.array(spectra), dtype=torch.float32).to(device)
        # spectra = torch.as_tensor(np.array(spectra), dtype=torch.float32).unsqueeze(1).to(device)
        input_ids = self.tokenizer(list(smiles), return_tensors='pt', padding='max_length', max_length=256, truncation=True)
        input_ids = {'input_ids':input_ids['input_ids'].to(device), 'attention_mask':input_ids['attention_mask'].to(device)}
        return {'smiles': input_ids,  'spectra':spectra}

test_dataset = TestDataset(masked_smiles, masked_spectra)
test_collator = TestCollator(tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_tokens_logits = model.infer_mlm(batch)
    pred_tokens = torch.argmax(pred_tokens_logits, dim=-1)
    
    output = deepcopy(batch['smiles']['input_ids'])
    mask = (batch['smiles']['input_ids'] == 4).cpu()
    output[mask] = pred_tokens[mask]
    
    preds = tokenizer.batch_decode(output, skip_special_tokens=True)
    all_pred_smiles.extend(preds)

100%|██████████| 580/580 [01:02<00:00,  9.25it/s]


In [59]:
import numpy as np
from tqdm import trange

tgt_fgs = []
pred_fgs = []
tgt_total = {}
tgt_correct = {}

for i in trange(len(raw_smiles)):

    masked_fg = np.array(tokenizer(masked_smiles[i])['input_ids'])
    tgt_fg = np.array(tokenizer(raw_smiles[i])['input_ids'])
    pred_fg = np.array(tokenizer(all_pred_smiles[i])['input_ids'])[:len(masked_fg)]

    # 找出所有mask的索引
    indices_of_mask = np.where(masked_fg == 4)[0]

    # 找出所有连续mask的索引
    consecutive_indices = []
    current_list = []

    for i in range(len(indices_of_mask)):
        if i == 0 or indices_of_mask[i] == indices_of_mask[i - 1] + 1:
            current_list.append(indices_of_mask[i])
        else:
            if current_list:
                consecutive_indices.append(current_list)
            current_list = [indices_of_mask[i]]
    if current_list:
        consecutive_indices.append(current_list)


    for mask in consecutive_indices:
        if len(mask) == 1: mask = mask[0]
        tgt_fg_str = tokenizer.decode(tgt_fg[mask])
        pred_fg_str = tokenizer.decode(pred_fg[mask])
        tgt_fgs.append(tgt_fg_str)
        pred_fgs.append(pred_fg_str)

        if tgt_fg_str in tgt_total:
            tgt_total[tgt_fg_str] += 1
        else:
            tgt_total[tgt_fg_str] = 1
        if pred_fg_str == tgt_fg_str and pred_fg_str not in tgt_correct:
            tgt_correct[tgt_fg_str] = 1
        elif pred_fg_str == tgt_fg_str and pred_fg_str  in tgt_correct:
            tgt_correct[tgt_fg_str] += 1

100%|██████████| 37100/37100 [00:07<00:00, 5038.57it/s]


In [60]:
import pandas as pd

df1 = pd.DataFrame({'fg':tgt_total.keys(), 'count':tgt_total.values()})
df2 = pd.DataFrame({'fg':tgt_correct.keys(), 'correct':tgt_correct.values()})
df = pd.merge(df1, df2, on='fg', how='left')
df[df.isna()] = 0
df['accuracy'] = df['correct'] / df['count']
print(f"unique: {len(df)} | total: {df['count'].sum()} | correct: {df['correct'].sum()} | accuracy: {df['correct'].sum() / df['count'].sum()}")

unique: 1744 | total: 41334 | correct: 38793.0 | accuracy: 0.93852518507766


# NIST

# IR-only

In [None]:
import sys 
sys.path.append('../')

import torch

from models import build_model
from models import PretrainModel_Phase
from utils.base import seed_everything

seed_everything(624)
device = 'cuda:0'

model = build_model('vib2mol_cl', spectral_channel=1).to(device)
ckpt = torch.load('/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/luxinyu-240207020178/vib2mol/checkpoints/nist_ir/exp_ir-kekule_smiles/vib2mol_cl/2025-01-03_07:39/epoch835_recall37.pth', 
                  map_location=device, weights_only=True)

ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
model.load_state_dict(ckpt)

  from .autonotebook import tqdm as notebook_tqdm


<All keys matched successfully>

## Evaluate contrastive retrieval

In [2]:
from utils.base import Dataloader
from utils.pretrain import PretrainCollator, Engine

dataloader = Dataloader(lmdb_path='nist_ir', 
                            root_dir='/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/luxinyu-240207020178/vib2mol/datasets/historical', 
                            target_keys=['exp_ir', 'kekule_smiles'], 
                            collate_fn=PretrainCollator(spectral_types=['exp_ir'], tokenizer_path='/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/luxinyu-240207020178/vib2mol/models/MolTokenizer'), 
                            device=device)

test_loader = dataloader.generate_dataloader(mode='test', batch_size=64)

In [3]:
engine = Engine(test_loader=test_loader, model=model, device=device, device_rank=0)
out = engine.infer()

100%|██████████| 41/41 [00:01<00:00, 20.64it/s]


In [4]:
import torch
import torch.nn.functional as F

def calculate_similarity_matrix(embedding_query, embedding_key):
    embedding_query = F.normalize(embedding_query, p=2, dim=1)
    embedding_key = F.normalize(embedding_key, p=2, dim=1)

    similarity_matrix = torch.matmul(embedding_query, embedding_key.t())
    return similarity_matrix

def compute_recall(similarity_matrix, k, verbose=False):
    num_queries = similarity_matrix.size(0)
    _, topk_indices = similarity_matrix.topk(k, dim=1, largest=True, sorted=True)
    correct = 0
    for i in range(num_queries):
        if i in topk_indices[i]:
            correct += 1
    recall_at_k = correct / num_queries
    
    if verbose:
        print(f'recall@{k}:{recall_at_k:.5f}')
    else:
        return recall_at_k

similarity_matrix = calculate_similarity_matrix(out['spectral_proj_output'], out['molecular_proj_output'])
compute_recall(similarity_matrix, k=1, verbose=True)
compute_recall(similarity_matrix, k=3, verbose=True)
compute_recall(similarity_matrix, k=5, verbose=True)
compute_recall(similarity_matrix, k=10, verbose=True)
compute_recall(similarity_matrix, k=100, verbose=True)

recall@1:0.31453
recall@3:0.46561
recall@5:0.52898
recall@10:0.60935
recall@100:0.85858


## MLM

In [None]:
import lmdb
# Open LMDB 
db = lmdb.open('../datasets/nist_ir/nist_ir_test.lmdb', subdir=False, lock=False, map_size=int(1e11))

# Open a transaction and perform a read operation
with db.begin() as txn:
    test_data = list(txn.cursor())

In [None]:
import pickle
import multiprocessing as mp 
from tqdm import tqdm 

def get_smiles(idx):
    try:
        smiles = pickle.loads(test_data[idx][1])['kekule_smiles']
        return smiles
    except:
        return None

# def get_raman(idx):
#     raman = pickle.loads(test_data[idx][1])['raman']
#     return raman

def get_ir(idx):
    ir = pickle.loads(test_data[idx][1])['exp_ir']
    return ir

with mp.Pool(4) as pool:
    smiles = list(tqdm(pool.imap(get_smiles, range(len(test_data))), total=len(test_data)))

with mp.Pool(4) as pool:
    ir = list(tqdm(pool.imap(get_ir, range(len(test_data))), total=len(test_data)))

from transformers import RobertaTokenizer
tokenizer = RobertaTokenizer.from_pretrained('../models/RoBERTa')


100%|██████████| 2588/2588 [00:00<00:00, 13742.21it/s]
100%|██████████| 2588/2588 [00:00<00:00, 19077.08it/s]


In [37]:

demo_idx = 1
demo_smiles = smiles[demo_idx]
demo_spectra = ir[demo_idx]
demo_smiles_input = tokenizer(demo_smiles, return_tensors='pt', padding='max_length', max_length=256, truncation=True)

print(demo_smiles, len(demo_smiles))


CC(C)(C)O 9


In [38]:
def mask_manually(input_ids, masked_indices):
    mask = torch.zeros_like(input_ids).bool()
    mask[:, torch.tensor(masked_indices)+1] = True
    masked_ids = input_ids.clone()           
    mask[masked_ids == 1] = False # tokenizer.pad_token_id = 1
    masked_ids[mask] = 4 # tokenizer.mask_token_id = 4      
    return masked_ids, mask

def generate_mlmmask(input_ids, mask_prob=0.15):
    masked_ids = input_ids.clone()
    probability_matrix = torch.full(masked_ids.shape, mask_prob)         
    masked_indices = torch.bernoulli(probability_matrix).bool()                                  
    masked_indices[masked_ids == 1] = False # tokenizer.pad_token_id = 1
    masked_ids[masked_indices] = 4 # tokenizer.mask_token_id = 4
    
    return masked_ids, masked_indices

In [None]:
def extract_branches(s):
    result = []
    stack = []
    start = -1
    
    for i, char in enumerate(s):
        if char == '(':
            if not stack:
                start = i
            stack.append(i)
        elif char == ')':
            if stack:
                stack.pop()
                if not stack:
                    result.append(s[start+1:i])
                    extract_nested = extract_branches(s[start+1:i])
                    result.extend(extract_nested)
    
    return result

test_string = "NC1(C(=O)O)CC1C1=CC=CC=C1"
branches = extract_branches(test_string)
for branch in branches:
    masked_input = test_string.replace(branch, '<mask>'*len(branch))
    print(branch, masked_input)

C(=O)O NC1(<mask><mask><mask><mask><mask><mask>)CC1C1=CC=CC=C1
=O NC1(C(<mask><mask>)O)CC1C1=CC=CC=C1


### demon the reconstruction of masked tokens

In [40]:
demo_idx = 53
demo_smiles = smiles[demo_idx]
demo_spectra = ir[demo_idx]
demo_smiles_input = tokenizer(demo_smiles, return_tensors='pt', padding='max_length', max_length=256, truncation=True)

# mask smiles manually
masked_idx = [2]
# masked_idx = [i for i in range(13, 24)]
masked_smiles_tuple = mask_manually(demo_smiles_input['input_ids'], masked_idx)

# or mask smiles randomly
# masked_smiles_tuple = generate_mlmmask(demo_smiles_input['input_ids'], mask_prob=0.75)

masked_smiles = tokenizer.decode(masked_smiles_tuple[0].flatten()).replace('<pad>', '').replace('<s>', '').replace('</s>', '')

model.eval()
masked_smiles_input = {'input_ids':masked_smiles_tuple[0].to(device), 'attention_mask':demo_smiles_input['attention_mask'].to(device)}
pred_tokens_logits = model.infer({'smiles': masked_smiles_input, 'spectra': torch.tensor(demo_spectra, dtype=torch.float32).to(device).unsqueeze(0).unsqueeze(0)})

pred_tokens = torch.argmax(pred_tokens_logits, dim=-1)
pred_smiles = tokenizer.decode(pred_tokens.flatten())[3:len(demo_smiles)+3] # remove <s>

print('target: ', tokenizer.decode(demo_smiles_input['input_ids'].flatten()).replace('<pad>', '').replace('<s>', '').replace('</s>', ''))
print('masked: ', masked_smiles)
print('pred:   ', pred_smiles) # remove <s>

target:  N#CC#N
masked:  N#<mask>C#N
pred:    N#CC#N


### evaluate molecular accuracy

In [41]:
raw_smiles = []
masked_smiles = []
masked_spectra = []
branches_list = []
masked_smi_dict = {}

mask_prob = None
for smi, spec in tqdm(zip(smiles, ir), total=len(smiles)):
    branches = extract_branches(smi)
    for branch in branches:
        if mask_prob is None or (len(smi) * mask_prob < len(branch) and len(branch) <= len(smi) * (mask_prob + 0.15)):
            len_token = len(tokenizer(branch)['input_ids'])-2
            masked_smi = smi.replace(f"({branch})", f"({'<mask>'*len_token})")
            if masked_smi not in masked_smi_dict:
                branches_list.append(branch)
                masked_smi_dict[masked_smi] = 1
                masked_smiles.append(masked_smi)
                masked_spectra.append(spec)
                raw_smiles.append(smi)

100%|██████████| 2588/2588 [00:00<00:00, 7760.87it/s]


In [42]:
(len(raw_smiles), len(branches_list))

(3819, 3819)

In [43]:
import numpy as np
all_pred_smiles = []
correct, total = 0, 0

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, smiles, spectra):
        self.smiles = smiles
        self.spectra = spectra

    def __len__(self):
        return len(self.smiles)

    def __getitem__(self, idx):
        return self.smiles[idx], self.spectra[idx]

class TestCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        smiles, spectra = zip(*batch)
        spectra = torch.as_tensor(np.array(spectra), dtype=torch.float32).unsqueeze(1).to(device)
        input_ids = self.tokenizer(list(smiles), return_tensors='pt', padding='max_length', max_length=256, truncation=True)
        input_ids = {'input_ids':input_ids['input_ids'].to(device), 'attention_mask':input_ids['attention_mask'].to(device)}
        return {'smiles': input_ids,  'spectra':spectra}

test_dataset = TestDataset(masked_smiles, masked_spectra)
test_collator = TestCollator(tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_tokens_logits = model.infer(batch)
    pred_tokens = torch.argmax(pred_tokens_logits, dim=-1)

    preds = tokenizer.batch_decode(pred_tokens, skip_special_tokens=True)
    all_pred_smiles.extend(preds)


100%|██████████| 60/60 [00:16<00:00,  3.57it/s]


In [44]:
import numpy as np
from tqdm import trange

tgt_fgs = []
pred_fgs = []
tgt_total = {}
tgt_correct = {}

for i in trange(len(raw_smiles)):
    
    masked_fg = np.array(tokenizer(masked_smiles[i])['input_ids'])
    tgt_fg = np.array(tokenizer(raw_smiles[i])['input_ids'])
    pred_fg = np.array(tokenizer(all_pred_smiles[i])['input_ids'])[:len(masked_fg)]

    # 找出所有4的索引
    indices_of_mask = np.where(masked_fg == 4)[0]

    # 找出所有连续的4的索引
    consecutive_indices = []
    current_list = []

    for i in range(len(indices_of_mask)):
        if i == 0 or indices_of_mask[i] == indices_of_mask[i - 1] + 1:
            current_list.append(indices_of_mask[i])
        else:
            if current_list:
                consecutive_indices.append(current_list)
            current_list = [indices_of_mask[i]]
    if current_list:
        consecutive_indices.append(current_list)


    for mask in consecutive_indices:
        if len(mask) == 1: mask = mask[0]
        tgt_fg_str = tokenizer.decode(tgt_fg[mask])
        pred_fg_str = tokenizer.decode(pred_fg[mask])
        tgt_fgs.append(tgt_fg_str)
        pred_fgs.append(pred_fg_str)

        if tgt_fg_str in tgt_total:
            tgt_total[tgt_fg_str] += 1
        else:
            tgt_total[tgt_fg_str] = 1
        if pred_fg_str == tgt_fg_str and pred_fg_str not in tgt_correct:
            tgt_correct[tgt_fg_str] = 1
        elif pred_fg_str == tgt_fg_str and pred_fg_str  in tgt_correct:
            tgt_correct[tgt_fg_str] += 1

100%|██████████| 3819/3819 [00:01<00:00, 2124.35it/s]


In [45]:
import pandas as pd

df1 = pd.DataFrame({'fg':tgt_total.keys(), 'count':tgt_total.values()})
df2 = pd.DataFrame({'fg':tgt_correct.keys(), 'correct':tgt_correct.values()})
df = pd.merge(df1, df2, on='fg', how='left')
df[df.isna()] = 0
df['accuracy'] = df['correct'] / df['count']
print(f"unique: {len(df)} | total: {df['count'].sum()} | correct: {df['correct'].sum()} | accuracy: {df['correct'].sum() / df['count'].sum()}")

unique: 506 | total: 5035 | correct: 3502.0 | accuracy: 0.6955312810327706


# IR-Raman

In [None]:
import sys 
sys.path.append('../')

import torch

from models import build_model
from models import PretrainModel_Phase
from utils.base import seed_everything

seed_everything(624)
device = 'cuda:0'

model = build_model('vib2mol_phase', spectral_channel=2).to(device)
ckpt = torch.load('/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/luxinyu-240207020178/vib2mol/checkpoints/mols_custom/raman-ir-kekule_smiles/vib2mol_phase/2025-01-01_04:18/epoch990_acc76.pth', 
                  map_location=device, weights_only=True)

# model = build_model('vib2mol_cl_lm').to(device)
# ckpt = torch.load('/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/luxinyu-240207020178/vib2mol/checkpoints/mols_custom/raman-kekule_smiles/vib2mol_cl_lm/2024-12-03_13:38/epoch952_recall74.pth', 
#                   map_location=device, weights_only=True)

ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
model.load_state_dict(ckpt)

<All keys matched successfully>

# 1 Evaluate contrastive retrieval

In [13]:
from utils.base import Dataloader
from utils.pretrain import PretrainCollator, Engine

dataloader = Dataloader(lmdb_path='mols', 
                            root_dir='/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/luxinyu-240207020178/vib2mol/datasets/vibbench', 
                            # target_keys=['exp_ir', 'kekule_smiles'], 
                            # collate_fn=PretrainCollator(spectral_types=['exp_ir'], tokenizer_path='/data/xinyulu/vib2mol/models/RoBERTa'), 
                            target_keys=['raman', 'ir', 'kekule_smiles'], 
                            collate_fn=PretrainCollator(spectral_types=['raman','ir'], tokenizer_path='/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/luxinyu-240207020178/vib2mol/models/MolTokenizer'), 
                            device=device)

test_loader = dataloader.generate_dataloader(mode='test', batch_size=64)

In [14]:
engine = Engine(test_loader=test_loader, model=model, device=device, device_rank=0)
out = engine.infer()

100%|██████████| 294/294 [00:19<00:00, 15.02it/s]


In [15]:
import torch
import torch.nn.functional as F

def calculate_similarity_matrix(embedding_query, embedding_key):
    embedding_query = F.normalize(embedding_query, p=2, dim=1)
    embedding_key = F.normalize(embedding_key, p=2, dim=1)

    similarity_matrix = torch.matmul(embedding_query, embedding_key.t())
    return similarity_matrix

def compute_recall(similarity_matrix, k, verbose=False):
    num_queries = similarity_matrix.size(0)
    _, topk_indices = similarity_matrix.topk(k, dim=1, largest=True, sorted=True)
    correct = 0
    for i in range(num_queries):
        if i in topk_indices[i]:
            correct += 1
    recall_at_k = correct / num_queries
    
    if verbose:
        print(f'recall@{k}:{recall_at_k:.5f}')
    else:
        return recall_at_k

similarity_matrix = calculate_similarity_matrix(out['spectral_proj_output'], out['molecular_proj_output'])
compute_recall(similarity_matrix, k=1, verbose=True)
compute_recall(similarity_matrix, k=3, verbose=True)
compute_recall(similarity_matrix, k=5, verbose=True)
compute_recall(similarity_matrix, k=10, verbose=True)
compute_recall(similarity_matrix, k=100, verbose=True)

recall@1:0.00005
recall@3:0.00027
recall@5:0.00032
recall@10:0.00059
recall@100:0.00511


## MLM

In [None]:
import lmdb
# Open LMDB 
db = lmdb.open('../datasets/nist_ir/nist_ir_test.lmdb', subdir=False, lock=False, map_size=int(1e11))

# Open a transaction and perform a read operation
with db.begin() as txn:
    test_data = list(txn.cursor())

In [None]:
import pickle
import multiprocessing as mp 
from tqdm import tqdm 

def get_smiles(idx):
    try:
        smiles = pickle.loads(test_data[idx][1])['kekule_smiles']
        return smiles
    except:
        return None

def get_raman(idx):
    raman = pickle.loads(test_data[idx][1])['raman']
    return raman

def get_ir(idx):
    ir = pickle.loads(test_data[idx][1])['exp_ir']
    return ir

with mp.Pool(4) as pool:
    smiles = list(tqdm(pool.imap(get_smiles, range(len(test_data))), total=len(test_data)))

with mp.Pool(4) as pool:
    ir = list(tqdm(pool.imap(get_ir, range(len(test_data))), total=len(test_data)))

with mp.Pool(4) as pool:
    raman = list(tqdm(pool.imap(get_raman, range(len(test_data))), total=len(test_data)))
    
from transformers import RobertaTokenizer
tokenizer = RobertaTokenizer.from_pretrained('../models/RoBERTa')


100%|██████████| 2588/2588 [00:00<00:00, 22656.11it/s]
100%|██████████| 2588/2588 [00:00<00:00, 12259.45it/s]
100%|██████████| 2588/2588 [00:00<00:00, 16320.57it/s]


In [52]:
def mask_manually(input_ids, masked_indices):
    mask = torch.zeros_like(input_ids).bool()
    mask[:, torch.tensor(masked_indices)+1] = True
    masked_ids = input_ids.clone()           
    mask[masked_ids == 1] = False # tokenizer.pad_token_id = 1
    masked_ids[mask] = 4 # tokenizer.mask_token_id = 4      
    return masked_ids, mask

def generate_mlmmask(input_ids, mask_prob=0.15):
    masked_ids = input_ids.clone()
    probability_matrix = torch.full(masked_ids.shape, mask_prob)         
    masked_indices = torch.bernoulli(probability_matrix).bool()                                  
    masked_indices[masked_ids == 1] = False # tokenizer.pad_token_id = 1
    masked_ids[masked_indices] = 4 # tokenizer.mask_token_id = 4
    
    return masked_ids, masked_indices

## 2.4 evaluate molecular accuracy

In [55]:
spectra = np.stack([raman, ir], axis=1)

raw_smiles = []
masked_smiles = []
masked_spectra = []
branches_list = []
masked_smi_dict = {}

mask_prob = None
for smi, spec in tqdm(zip(smiles, spectra), total=len(smiles)):
    branches = extract_branches(smi)
    for branch in branches:
        if mask_prob is None or (len(smi) * mask_prob < len(branch) and len(branch) <= len(smi) * (mask_prob + 0.15)):
            len_token = len(tokenizer(branch)['input_ids'])-2
            masked_smi = smi.replace(f"({branch})", f"({'<mask>'*len_token})")
            if masked_smi not in masked_smi_dict:
                branches_list.append(branch)
                masked_smi_dict[masked_smi] = 1
                masked_smiles.append(masked_smi)
                masked_spectra.append(spec)
                raw_smiles.append(smi)

  0%|          | 0/2588 [00:00<?, ?it/s]

100%|██████████| 2588/2588 [00:00<00:00, 7975.33it/s]


In [56]:
(len(raw_smiles), len(branches_list))

(3819, 3819)

In [57]:
import numpy as np
all_pred_smiles = []
correct, total = 0, 0

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, smiles, spectra):
        self.smiles = smiles
        self.spectra = spectra

    def __len__(self):
        return len(self.smiles)

    def __getitem__(self, idx):
        return self.smiles[idx], self.spectra[idx]

class TestCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        smiles, spectra = zip(*batch)
        spectra = torch.as_tensor(np.array(spectra), dtype=torch.float32).to(device)
        input_ids = self.tokenizer(list(smiles), return_tensors='pt', padding='max_length', max_length=256, truncation=True)
        input_ids = {'input_ids':input_ids['input_ids'].to(device), 'attention_mask':input_ids['attention_mask'].to(device)}
        return {'smiles': input_ids,  'spectra':spectra}

test_dataset = TestDataset(masked_smiles, masked_spectra)
test_collator = TestCollator(tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_tokens_logits = model.infer(batch)
    pred_tokens = torch.argmax(pred_tokens_logits, dim=-1)

    preds = tokenizer.batch_decode(pred_tokens, skip_special_tokens=True)
    all_pred_smiles.extend(preds)


100%|██████████| 60/60 [00:16<00:00,  3.55it/s]


In [None]:
import numpy as np
from tqdm import trange

tgt_fgs = []
pred_fgs = []
tgt_total = {}
tgt_correct = {}

for i in trange(len(raw_smiles)):
    
    masked_fg = np.array(tokenizer(masked_smiles[i])['input_ids'])
    tgt_fg = np.array(tokenizer(raw_smiles[i])['input_ids'])
    pred_fg = np.array(tokenizer(all_pred_smiles[i])['input_ids'])[:len(masked_fg)]

    indices_of_mask = np.where(masked_fg == 4)[0]

    consecutive_indices = []
    current_list = []

    for i in range(len(indices_of_mask)):
        if i == 0 or indices_of_mask[i] == indices_of_mask[i - 1] + 1:
            current_list.append(indices_of_mask[i])
        else:
            if current_list:
                consecutive_indices.append(current_list)
            current_list = [indices_of_mask[i]]
    if current_list:
        consecutive_indices.append(current_list)


    for mask in consecutive_indices:
        if len(mask) == 1: mask = mask[0]
        tgt_fg_str = tokenizer.decode(tgt_fg[mask])
        pred_fg_str = tokenizer.decode(pred_fg[mask])
        tgt_fgs.append(tgt_fg_str)
        pred_fgs.append(pred_fg_str)

        if tgt_fg_str in tgt_total:
            tgt_total[tgt_fg_str] += 1
        else:
            tgt_total[tgt_fg_str] = 1
        if pred_fg_str == tgt_fg_str and pred_fg_str not in tgt_correct:
            tgt_correct[tgt_fg_str] = 1
        elif pred_fg_str == tgt_fg_str and pred_fg_str  in tgt_correct:
            tgt_correct[tgt_fg_str] += 1

100%|██████████| 3819/3819 [00:01<00:00, 2121.59it/s]


In [59]:
import pandas as pd

df1 = pd.DataFrame({'fg':tgt_total.keys(), 'count':tgt_total.values()})
df2 = pd.DataFrame({'fg':tgt_correct.keys(), 'correct':tgt_correct.values()})
df = pd.merge(df1, df2, on='fg', how='left')
df[df.isna()] = 0
df['accuracy'] = df['correct'] / df['count']
print(f"unique: {len(df)} | total: {df['count'].sum()} | correct: {df['correct'].sum()} | accuracy: {df['correct'].sum() / df['count'].sum()}")

unique: 506 | total: 5035 | correct: 3750.0 | accuracy: 0.7447864945382324
