In [1]:
import lmdb
# Open LMDB 
db = lmdb.open('../datasets/vibench/geom/geom_test.lmdb', subdir=False, lock=False, map_size=int(1e11))
with db.begin() as txn:
    geom_data = list(txn.cursor())

In [2]:
import pickle
import multiprocessing as mp 
from tqdm import tqdm 
import pandas as pd

test_df =  pd.DataFrame([pickle.loads(item[1]) for item in tqdm(geom_data)])

smiles = test_df['kekule_smiles'].to_list()
query_spectra = test_df['q_raman'].to_list()
key_spectra = test_df['k_raman'].to_list()

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('../models/MolTokenizer')


100%|██████████| 5659/5659 [00:00<00:00, 77666.45it/s]
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
len(query_spectra), len(key_spectra)

(5659, 5659)

In [4]:
import numpy as np
import torch
import torch.nn.functional as F

def calculate_similarity_matrix(embedding_query, embedding_key):
    embedding_query = torch.tensor(embedding_query)
    embedding_key = torch.tensor(embedding_key)
    
    embedding_query = F.normalize(embedding_query, p=2, dim=1)
    embedding_key = F.normalize(embedding_key, p=2, dim=1)

    similarity_matrix = torch.matmul(embedding_query, embedding_key.t())
    return similarity_matrix

def compute_recall(similarity_matrix, k, verbose=False):
    num_queries = similarity_matrix.size(0)
    _, topk_indices = similarity_matrix.topk(k, dim=1, largest=True, sorted=True)
    correct = 0
    for i in range(num_queries):
        if i in topk_indices[i]:
            correct += 1
    recall_at_k = correct / num_queries
    
    if verbose:
        print(f'recall@{k}:{recall_at_k:.5f}')
    else:
        return recall_at_k

similarity_matrix = calculate_similarity_matrix(np.array(query_spectra), np.array(key_spectra))
# similarity_matrix = np.corrcoef(np.array(query_spectra), np.array(key_spectra))[len(query_spectra):, :len(query_spectra)]
similarity_matrix = torch.tensor(similarity_matrix)
compute_recall(similarity_matrix, k=1, verbose=True)
compute_recall(similarity_matrix, k=3, verbose=True)
compute_recall(similarity_matrix, k=5, verbose=True)
compute_recall(similarity_matrix, k=10, verbose=True)

recall@1:0.34670
recall@3:0.43753
recall@5:0.47623
recall@10:0.53455


  similarity_matrix = torch.tensor(similarity_matrix)


In [5]:
import sys 
sys.path.append('../')

import torch

from models import build_model
from models import PretrainModel_Phase
from utils.base import seed_everything

seed_everything(624)
device = 'cpu' if not torch.cuda.is_available() else 'cuda:0'

model = build_model('vib2mol_phase').to(device)
ckpt = torch.load('../checkpoints/mols/raman-kekule_smiles/vib2mol_phase.pth', 
                  map_location=device, weights_only=True)


ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
model.load_state_dict(ckpt)

<All keys matched successfully>

In [6]:
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra):
        self.tgt_spectra = tgt_spectra

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx]

class TestCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        spectra = torch.as_tensor(np.array(batch), dtype=torch.float32).unsqueeze(1).to(device)
        return {'spectra':spectra}    


test_dataset = TestDataset(query_spectra + key_spectra)
test_collator = TestCollator()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=test_collator)
test_bar = tqdm(test_loader)

spectral_embeddings = []

model.eval()
for batch in test_bar:
    with torch.no_grad():
        spectral_embedding = model.get_spectral_embeddings(batch)['proj_output']
        spectral_embeddings.append(spectral_embedding.cpu().numpy())

spectral_embeddings = np.vstack(spectral_embeddings)
query_spectral_embeddings = spectral_embeddings[:len(query_spectra)]
key_spectral_embeddings = spectral_embeddings[len(query_spectra):]

100%|██████████| 177/177 [02:34<00:00,  1.15it/s]


In [7]:
similarity_matrix = calculate_similarity_matrix(query_spectral_embeddings, key_spectral_embeddings)

compute_recall(similarity_matrix, k=1, verbose=True)
compute_recall(similarity_matrix, k=3, verbose=True)
compute_recall(similarity_matrix, k=5, verbose=True)
compute_recall(similarity_matrix, k=10, verbose=True)

recall@1:0.76462
recall@3:0.91005
recall@5:0.94487
recall@10:0.97473
