# 1 retrieval

In [None]:
import torch
import lmdb
import pickle
from tqdm.auto import tqdm 
import numpy as np
import pandas as pd

db = lmdb.open('../datasets/vibench/peptide/peptide_test.lmdb', subdir=False, lock=False, map_size=int(1e11))
with db.begin() as txn:
    test_data = list(txn.cursor())

test_df = pd.DataFrame([pickle.loads(item[1]) for item in tqdm(test_data)])
device = 'cpu'

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra, pred_smiles):
        self.tgt_spectra = tgt_spectra
        self.pred_smiles = pred_smiles

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx], self.pred_smiles[idx]

class TestCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        tgt_spectra, pred_smiles = zip(*batch)
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        input_ids = self.tokenizer(list(pred_smiles), return_tensors='pt', padding='max_length', max_length=256, truncation=True)
        input_ids = {'input_ids':input_ids['input_ids'].to(device), 'attention_mask':input_ids['attention_mask'].to(device)}
        return {'sequence': input_ids,  'spectra':spectra}

  test_df = pd.DataFrame([pickle.loads(item[1]) for item in tqdm(test_data)])
100%|██████████| 5191/5191 [00:00<00:00, 99729.44it/s]


## smiles

In [None]:
import sys 
sys.path.append('../')
import torch

from models import build_model
from models import PretrainModel_Phase

torch.manual_seed(624)
device = 'cpu'

model = build_model('vib2mol_phase').to(device)
ckpt_path = '../checkpoints/peptide/raman-kekule_smiles/vib2mol_phase.pth'
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)

ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
model.load_state_dict(ckpt)

from transformers import AutoTokenizer
tokenizer_path = f'../models/MolTokenizer'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

test_dataset = TestDataset(test_df.raman.to_list(), test_df.kekule_smiles.to_list())


## sequence

In [9]:
import sys 
sys.path.append('../')

import torch

from models import build_model
from models import PretrainModel_Phase

torch.manual_seed(624)
device = 'cpu'

model = build_model('vib2mol_phase').to(device)
ckpt_path = '../checkpoints/peptide/raman-sequence/vib2mol_phase.pth'
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)

ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
model.load_state_dict(ckpt)

from transformers import AutoTokenizer
tokenizer_path = '../models/PepTokenizer'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

test_dataset = TestDataset(test_df.raman.to_list(), test_df.sequence.to_list())

## calculate similarity between molecules and spectra

In [10]:
test_collator = TestCollator(tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
all_molecular_embeddings = []
all_spectral_embeddings = []

for batch in test_bar:
    with torch.no_grad():
        molecular_embedding = model.get_molecular_embeddings(batch, use_cls_token=True)['proj_output']
        spectral_embedding = model.get_spectral_embeddings(batch)['proj_output']
        
        all_molecular_embeddings.append(molecular_embedding)
        all_spectral_embeddings.append(spectral_embedding)

100%|██████████| 82/82 [02:23<00:00,  1.76s/it]


In [11]:
import torch
import torch.nn.functional as F

def calculate_similarity_matrix(embedding_query, embedding_key):
    embedding_query = F.normalize(embedding_query, p=2, dim=1)
    embedding_key = F.normalize(embedding_key, p=2, dim=1)

    similarity_matrix = torch.matmul(embedding_query, embedding_key.t())
    return similarity_matrix

def compute_recall(similarity_matrix, k, return_result=False):
    num_queries = similarity_matrix.size(0)
    _, topk_indices = similarity_matrix.topk(k, dim=1, largest=True, sorted=True)
    correct_list = []
    for i in range(num_queries):
        if i in topk_indices[i]:
            correct_list.append(1)
        else:
            correct_list.append(0)
    recall_at_k = sum(correct_list) / num_queries
    
    print(f'recall@{k}:{recall_at_k:.5f}')
    if return_result:
        return correct_list

all_molecular_embeddings = torch.cat(all_molecular_embeddings)
all_spectral_embeddings = torch.cat(all_spectral_embeddings)

similarity_matrix = calculate_similarity_matrix(all_spectral_embeddings, all_molecular_embeddings)
top1 = compute_recall(similarity_matrix, k=1, return_result=True)
top3 = compute_recall(similarity_matrix, k=3, return_result=True)
compute_recall(similarity_matrix, k=5, return_result=False)
compute_recall(similarity_matrix, k=10, return_result=False)

recall@1:0.67579
recall@3:0.89886
recall@5:0.94722
recall@10:0.97765


In [12]:
df = test_df[:]
df['top_1'] = top1
df['top_3'] = top3
df['length'] = df.apply(lambda row: len(row['sequence'].split('-')), axis=1)
df.groupby('length').mean('top_1')

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['top_1'] = top1
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['top_3'] = top3
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['length'] = df.apply(lambda row: len(row['sequence'].split('-')), axis=1)


Unnamed: 0_level_0,top_1,top_3
length,Unnamed: 1_level_1,Unnamed: 2_level_1
2,0.947368,1.0
3,0.775586,0.944513
4,0.654767,0.889452


# 2 de novo generation

In [None]:
import lmdb
import pickle
from tqdm.auto import tqdm 
import pandas as pd

db = lmdb.open('../datasets/vibench/peptide/peptide_test.lmdb', subdir=False, lock=False, map_size=int(1e11))

# Open a transaction and perform a read operation
with db.begin() as txn:
    test_data = list(txn.cursor())

test_df = pd.DataFrame([pickle.loads(item[1]) for item in tqdm(test_data)])

  test_df = pd.DataFrame([pickle.loads(item[1]) for item in tqdm(test_data)])
100%|██████████| 5191/5191 [00:00<00:00, 135771.77it/s]


In [14]:
test_df.head()

Unnamed: 0,smiles,norm_smiles,kekule_smiles,raman,ir,filename,sequence
0,CC[C@H](C)[C@H](NC(=O)[C@H](Cc1ccc(O)cc1)NC(=O...,CCC(C)C(NC(=O)C(Cc1ccc(O)cc1)NC(=O)C(Cc1ccccc1...,CCC(C)C(NC(=O)C(CC1=CC=C(O)C=C1)NC(=O)C(CC1=CC...,"[0.031817872173130125, 0.04253405263639632, 0....","[0.008984905034481022, 0.012997145252672785, 0...",G-F-Y-I,G-F-Y-I
1,C[C@H](NC(=O)[C@@H](N)Cc1ccccc1)C(=O)N[C@@H](C...,CC(NC(=O)C(N)Cc1ccccc1)C(=O)NC(Cc1ccccc1)C(=O)...,CC(NC(=O)C(N)CC1=CC=CC=C1)C(=O)NC(CC1=CC=CC=C1...,"[0.03639097157542858, 0.03839822624049318, 0.0...","[0.019057326865201937, 0.02214613556867351, 0....",F-A-F-H,F-A-F-H
2,CSCC[C@H](NC(=O)[C@H](C)NC(=O)[C@@H](N)Cc1c[nH...,CSCCC(NC(=O)C(C)NC(=O)C(N)Cc1c[nH]cn1)C(=O)NC(...,CSCCC(NC(=O)C(C)NC(=O)C(N)CC1=CNC=N1)C(=O)NC(C...,"[0.010788252706046516, 0.012695604522622993, 0...","[0.01820227941387322, 0.019182739075950875, 0....",H-A-M-V,H-A-M-V
3,CC(C)C[C@H](NC(=O)[C@H](CO)NC(=O)[C@@H](NC(=O)...,CC(C)CC(NC(=O)C(CO)NC(=O)C(NC(=O)C(N)Cc1ccccc1...,CC(C)CC(NC(=O)C(CO)NC(=O)C(NC(=O)C(N)CC1=CC=CC...,"[0.010883635762013337, 0.011228595019176968, 0...","[0.010824394450446622, 0.012379132733776744, 0...",F-V-S-L,F-V-S-L
4,N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CS)C(=O)N[C@@H]...,NC(Cc1ccccc1)C(=O)NC(CS)C(=O)NC(CC(=O)O)C(=O)O,NC(CC1=CC=CC=C1)C(=O)NC(CS)C(=O)NC(CC(=O)O)C(=O)O,"[0.031162503659202024, 0.046240888283535654, 0...","[0.025249009656411414, 0.037494951225140724, 0...",F-C-D,F-C-D


## 2.1 smiles

In [None]:
import sys 
sys.path.append('../')

import torch

from models import build_model
from models import PretrainModel_Phase

torch.manual_seed(624)
device = 'cuda:0'

model = build_model('vib2mol_phase').to(device)
ckpt_path = '../checkpoints/peptide/raman-kekule_smiles/vib2mol_phase.pth'
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)

ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
model.load_state_dict(ckpt)

from transformers import AutoTokenizer
tokenizer_path = '../models/MolTokenizer'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

### 2.1.1 greedy decoding

In [4]:
import numpy as np

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra):
        self.tgt_spectra = tgt_spectra

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx]

class TestCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        tgt_spectra = batch
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        return {'spectra':spectra}

length = [len(test_df.iloc[i]['kekule_smiles']) for i in range(len(test_df))]
max_len = max(length)+2
print(f'max_len:{max_len}')

all_pred_smiles = []
test_dataset = TestDataset(test_df['raman'].to_list())
test_collator = TestCollator()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_smiles_ids = model.infer_lm(batch, max_len=max_len)['pred_ids']
    pred_smiles = tokenizer.batch_decode(pred_smiles_ids)
    pred_smiles = [item.split('</s>')[0].replace('<s>', '') for item in pred_smiles]
    all_pred_smiles.extend(pred_smiles)


max_len:99


  0%|          | 0/41 [00:00<?, ?it/s]

100%|██████████| 41/41 [01:42<00:00,  2.50s/it]


In [5]:
import pandas as pd
from rdkit import RDLogger
from rdkit import Chem
RDLogger.DisableLog('rdApp.*')

def check_mols(pred_smiles, tgt_smiles):
    pred_mol = Chem.MolFromSmiles(pred_smiles)
    tgt_mol = Chem.MolFromSmiles(tgt_smiles)
    if pred_mol is not None and tgt_mol is not None:
        if Chem.MolToInchiKey(pred_mol) == Chem.MolToInchiKey(tgt_mol):
            return 1
    return 0
    
df = pd.DataFrame({'tgt_seq':test_df['kekule_smiles'].to_list(), 'pred_seq':all_pred_smiles})
df['top_1'] = df.apply(lambda row: check_mols(row['tgt_seq'], row['pred_seq']), axis=1)
print(df.top_1.mean())

0.25794644577152764


### 2.1.2 beam searching

In [6]:
import numpy as np

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra):
        self.tgt_spectra = tgt_spectra

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx]

class TestCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        tgt_spectra = batch
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        return {'spectra':spectra}

all_pred_smiles = []
test_dataset = TestDataset(test_df['raman'].to_list())
test_collator = TestCollator()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_smiles_ids_list = model.beam_infer_lm(batch, max_len=max_len, beam_size=10, temperature=3.5)['pred_ids']
    for pred_smiles_ids in pred_smiles_ids_list:
        pred_smiles = tokenizer.batch_decode(pred_smiles_ids)
        pred_smiles = [item.split('</s>')[0].replace('<s>', '') for item in pred_smiles]
        all_pred_smiles.append(pred_smiles)

100%|██████████| 325/325 [23:12<00:00,  4.28s/it]


In [7]:
import pandas as pd
import torch

def check_beam_mols(tgt_smiles, pred_smiles_list):
    pred_mol_list = []
    for item in pred_smiles_list:
        mol = Chem.MolFromSmiles(item)
        if mol is not None:
            try:
                inchi_key = Chem.MolToInchiKey(mol)
                pred_mol_list.append(inchi_key)
            except Exception as e:
                print(f"Error processing SMILES {item}: {e}")
                pred_mol_list.append('')
        else:
            pred_mol_list.append('')
    tgt_mol = Chem.MolToInchiKey(Chem.MolFromSmiles(tgt_smiles))
    if tgt_mol in pred_mol_list:
        return 1
    return 0

df = pd.DataFrame({'tgt_smiles':test_df['kekule_smiles'].to_list(), 'pred_smiles':[list(dict.fromkeys(item)) for item in all_pred_smiles]})
df['top_1'] = df.apply(lambda row: check_beam_mols(row['tgt_smiles'], row['pred_smiles'][:1]), axis=1)
df['top_3'] = df.apply(lambda row: check_beam_mols(row['tgt_smiles'], row['pred_smiles'][:3]), axis=1)
df['top_5'] = df.apply(lambda row: check_beam_mols(row['tgt_smiles'], row['pred_smiles'][:5]), axis=1)
df['top_10'] = df.apply(lambda row: check_beam_mols(row['tgt_smiles'], row['pred_smiles'][:10]), axis=1)

# df['token_acc'] = df.apply(lambda row: token_accuracy(row['tgt_smiles'], row['pred_smiles'][0]), axis=1)
print(f'top-1:\t\t{df.top_1.mean():.5f}\ntop-3:\t\t{df.top_3.mean():.5f}\ntop-5:\t\t{df.top_5.mean():.5f}\ntop-10:\t\t{df.top_10.mean():.5f}')

top-1:		0.26296
top-3:		0.41090
top-5:		0.42728
top-10:		0.42901


### 2.1.3 rerank by retrieval module

In [8]:
candidate_smiles_list = [list(set(item)) for item in all_pred_smiles]
candidate_spectra_list = [[test_df.iloc[i]['raman']] * len(item) for i, item in enumerate(candidate_smiles_list)]
tgt_smiles_list = [[test_df.iloc[i]['kekule_smiles']] * len(item) for i, item in enumerate(candidate_smiles_list)]

candidate_smiles_list = [subitem for item in candidate_smiles_list for subitem in item]
candidate_spectra_list = [subitem for item in candidate_spectra_list for subitem in item]
tgt_smiles_list = [subitem for item in tgt_smiles_list for subitem in item]

In [9]:
# calculate similarity between predicted molecules and target spectra
import torch
import torch.nn.functional as F

def calculate_similarity_matrix(embedding_query, embedding_key):
    if type(embedding_query) != torch.Tensor:
        embedding_query = torch.tensor(embedding_query)
    if type(embedding_key) != torch.Tensor:
        embedding_key = torch.tensor(embedding_key)
    
    embedding_query = F.normalize(embedding_query, p=2, dim=1)
    embedding_key = F.normalize(embedding_key, p=2, dim=1)

    similarity_matrix = torch.matmul(embedding_query, embedding_key.t())
    return similarity_matrix


class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra, pred_smiles):
        self.tgt_spectra = tgt_spectra
        self.pred_smiles = pred_smiles

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx], self.pred_smiles[idx]

class TestCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        tgt_spectra, pred_smiles = zip(*batch)
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        input_ids = self.tokenizer(list(pred_smiles), return_tensors='pt', padding='max_length', max_length=256, truncation=True)
        input_ids = {'input_ids':input_ids['input_ids'].to(device), 'attention_mask':input_ids['attention_mask'].to(device)}
        return {'sequence': input_ids,  'spectra':spectra}

    
valid_sim_list = []

test_dataset = TestDataset(candidate_spectra_list, candidate_smiles_list)
test_collator = TestCollator(tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        molecular_embedding = model.get_molecular_embeddings(batch, use_cls_token=True)['proj_output']
        spectral_embedding = model.get_spectral_embeddings(batch)['proj_output']
        sim = calculate_similarity_matrix(spectral_embedding, molecular_embedding)
    valid_sim_list += torch.diag(sim).tolist()

100%|██████████| 233/233 [00:34<00:00,  6.72it/s]


In [10]:
df = pd.DataFrame({'target_smiles':tgt_smiles_list, 'pred_smiles':candidate_smiles_list, 'similarity':valid_sim_list})
# 首先按照 'target_smiles' 和 'similarity' 降序排序
df_sorted = df.sort_values(by=['target_smiles', 'similarity'], ascending=[True, False])

# 然后按照 'target_smiles' 分组，并对 'pred_smiles' 和 'similarity' 进行聚合
grouped = df_sorted.groupby('target_smiles').agg({
    'pred_smiles': lambda x: ','.join(x),
    'similarity': lambda x: ','.join(map(str, x))
}).reset_index()

# 应用函数计算每个 target_smiles 的 TOP-K recall
for top_k in [1, 3, 5, 10, 100]: 
    grouped[f'top_{top_k}_recall'] = grouped.apply(lambda row: row['target_smiles'].replace('-', '') in row['pred_smiles'].split(',')[:top_k], axis=1)

grouped['rank'] = grouped.apply(lambda row: (row['pred_smiles'].split(',').index(row['target_smiles']))+1 if row['target_smiles'] in row['pred_smiles'].split(',') else 0, axis=1)
grouped['length'] = grouped.apply(lambda row: len(row['target_smiles']), axis=1)

print(f"""
recall@1:\t{grouped.top_1_recall.mean():.5f} 
recall@3:\t{grouped.top_3_recall.mean():.5f} 
recall@5:\t{grouped.top_5_recall.mean():.5f} 
recall@10:\t{grouped.top_10_recall.mean():.5f}
recall@100:\t{grouped.top_100_recall.mean():.5f}
    """)


recall@1:	0.27555 
recall@3:	0.40879 
recall@5:	0.41689 
recall@10:	0.41728
recall@100:	0.41728
    


## 2.2 sequence

In [16]:
import sys 
sys.path.append('../')

import torch

from models import build_model
from models import PretrainModel_Phase

torch.manual_seed(624)
device = 'cpu'

model = build_model('vib2mol_phase').to(device)
ckpt_path = '../checkpoints/peptide/raman-sequence/vib2mol_phase.pth'
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)

ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
model.load_state_dict(ckpt)

from transformers import AutoTokenizer
tokenizer_path = '../models/PepTokenizer'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

### 2.2.1 greddy decoding

In [17]:
import numpy as np

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra):
        self.tgt_spectra = tgt_spectra

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx]

class TestCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        tgt_spectra = batch
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        return {'spectra':spectra}

all_pred_sequences = []
test_dataset = TestDataset(test_df['raman'].to_list())
test_collator = TestCollator()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_smiles_ids = model.infer_lm(batch, max_len=6)['pred_ids']
    pred_smiles = tokenizer.batch_decode(pred_smiles_ids)
    pred_smiles = [item.split('</s>')[0].replace('<s>', '') for item in pred_smiles]
    all_pred_sequences.extend(pred_smiles)


100%|██████████| 41/41 [01:31<00:00,  2.24s/it]


In [19]:
import pandas as pd
from rdkit import RDLogger
from rdkit import Chem
RDLogger.DisableLog('rdApp.*')

def check_seq(tgt_seq, pred_seq):
    tgt_seq = tgt_seq.replace('-', '')
    if tgt_seq == pred_seq:
        return 1
    return 0

df = pd.DataFrame({'tgt_seq':test_df['sequence'].to_list(), 'pred_seq':all_pred_sequences})
df['top_1'] = df.apply(lambda row: check_seq(row['tgt_seq'], row['pred_seq']), axis=1)
print(df.top_1.mean())

0.3515700250433442


### 2.2.2 beam searching

In [None]:
import numpy as np

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra):
        self.tgt_spectra = tgt_spectra

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx]

class TestCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        tgt_spectra = batch
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        return {'spectra':spectra}

all_pred_sequence = []
test_dataset = TestDataset(test_df['raman'].to_list())
test_collator = TestCollator()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_sequence_ids_list = model.beam_infer_lm(batch, max_len=6, beam_size=10, temperature=17.5)['pred_ids']
    for pred_sequence_ids in pred_sequence_ids_list:
        pred_sequence = tokenizer.batch_decode(pred_sequence_ids)
        pred_sequence = [item.split('</s>')[0].replace('<s>', '') for item in pred_sequence]
        all_pred_sequence.append(pred_sequence)

In [15]:
import pandas as pd
import torch

def check_beam_seq(tgt_seq, pred_seq_list):
    tgt_seq = tgt_seq.replace('-', '')
    pred_mol_list = []
    if tgt_seq in pred_seq_list:
        return 1
    return 0

df = pd.DataFrame({'tgt_seq':test_df['sequence'].to_list(), 'pred_seq':[list(dict.fromkeys(item)) for item in all_pred_sequence]})
df['top_1'] = df.apply(lambda row: check_beam_seq(row['tgt_seq'], row['pred_seq'][:1]), axis=1)
df['top_3'] = df.apply(lambda row: check_beam_seq(row['tgt_seq'], row['pred_seq'][:3]), axis=1)
df['top_5'] = df.apply(lambda row: check_beam_seq(row['tgt_seq'], row['pred_seq'][:5]), axis=1)
df['top_10'] = df.apply(lambda row: check_beam_seq(row['tgt_seq'], row['pred_seq'][:10]), axis=1)

print(f'top-1:\t\t{df.top_1.mean():.5f}\ntop-3:\t\t{df.top_3.mean():.5f}\ntop-5:\t\t{df.top_5.mean():.5f}\ntop-10:\t\t{df.top_10.mean():.5f}')

top-1:		0.27259
top-3:		0.49586
top-5:		0.54749
top-10:		0.56694


In [16]:
df.head()

Unnamed: 0,tgt_seq,pred_seq,top_1,top_3,top_5,top_10
0,G-F-Y-I,"[GVYI, GYYI, GYFI, GIYI, GYIF, GLYI]",0,0,0,0
1,F-A-F-H,"[FAFH, FAPH, FAMH]",1,1,1,1
2,H-A-M-V,"[HMA, GHAV, HGAV, HMAV, HCAV, HFAV, GHAM, HMDV]",0,0,0,0
3,F-V-S-L,"[FTL, FTVL, FALL, FTCF, FTAL, FTFL, FTCL, FTTL...",0,0,0,0
4,F-C-D,[FDC],0,0,0,0


### 2.2.3 rerank by retrieval module

In [17]:
candidate_sequence_list = [list(set(item)) for item in all_pred_sequence]
candidate_spectra_list = [[test_df.iloc[i]['raman']] * len(item) for i, item in enumerate(candidate_sequence_list)]
tgt_sequence_list = [[test_df.iloc[i]['sequence']] * len(item) for i, item in enumerate(candidate_sequence_list)]

candidate_sequence_list = [subitem for item in candidate_sequence_list for subitem in item]
candidate_spectra_list = [subitem for item in candidate_spectra_list for subitem in item]
tgt_sequence_list = [subitem for item in tgt_sequence_list for subitem in item]

In [18]:
# calculate similarity between predicted molecules and target spectra
import torch
import torch.nn.functional as F

def calculate_similarity_matrix(embedding_query, embedding_key):
    if type(embedding_query) != torch.Tensor:
        embedding_query = torch.tensor(embedding_query)
    if type(embedding_key) != torch.Tensor:
        embedding_key = torch.tensor(embedding_key)
    
    embedding_query = F.normalize(embedding_query, p=2, dim=1)
    embedding_key = F.normalize(embedding_key, p=2, dim=1)

    similarity_matrix = torch.matmul(embedding_query, embedding_key.t())
    return similarity_matrix


class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra, pred_sequence):
        self.tgt_spectra = tgt_spectra
        self.pred_sequence = pred_sequence

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx], self.pred_sequence[idx]

class TestCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        tgt_spectra, pred_sequence = zip(*batch)
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        input_ids = self.tokenizer(list(pred_sequence), return_tensors='pt', padding='max_length', max_length=256, truncation=True)
        input_ids = {'input_ids':input_ids['input_ids'].to(device), 'attention_mask':input_ids['attention_mask'].to(device)}
        return {'sequence': input_ids,  'spectra':spectra}

    
valid_sim_list = []

test_dataset = TestDataset(candidate_spectra_list, candidate_sequence_list)
test_collator = TestCollator(tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        molecular_embedding = model.get_molecular_embeddings(batch, use_cls_token=True)['proj_output']
        spectral_embedding = model.get_spectral_embeddings(batch)['proj_output']
        sim = calculate_similarity_matrix(spectral_embedding, molecular_embedding)
    valid_sim_list += torch.diag(sim).tolist()

100%|██████████| 386/386 [00:50<00:00,  7.69it/s]


In [None]:
tgt_sequence_list = [item.replace('-', '') for item in tgt_sequence_list]
df = pd.DataFrame({'target_sequence':tgt_sequence_list, 'pred_sequence':candidate_sequence_list, 'similarity':valid_sim_list})

df_sorted = df.sort_values(by=['target_sequence', 'similarity'], ascending=[True, False])

df = df_sorted.groupby('target_sequence').agg({
    'pred_sequence': lambda x: ','.join(x),
    'similarity': lambda x: ','.join(map(str, x))
}).reset_index()

for top_k in [1, 3, 5, 10, 100]: 
    df[f'top_{top_k}_recall'] = df.apply(lambda row: row['target_sequence'].replace('-', '') in row['pred_sequence'].split(',')[:top_k], axis=1)

df['rank'] = df.apply(lambda row: (row['pred_sequence'].split(',').index(row['target_sequence']))+1 if row['target_sequence'] in row['pred_sequence'].split(',') else 0, axis=1)
df['length'] = df.apply(lambda row: len(row['target_sequence']), axis=1)

print(f"""
recall@1:\t{df.top_1_recall.mean():.5f} 
recall@3:\t{df.top_3_recall.mean():.5f} 
recall@5:\t{df.top_5_recall.mean():.5f} 
recall@10:\t{df.top_10_recall.mean():.5f}
recall@100:\t{df.top_100_recall.mean():.5f}
    """)


recall@1:	0.39915 
recall@3:	0.55288 
recall@5:	0.56502 
recall@10:	0.56694
recall@100:	0.56694
    


In [25]:
df.groupby('length').mean('top_1_recall')

Unnamed: 0_level_0,top_1_recall,top_3_recall,top_5_recall,top_10_recall,top_100_recall,rank
length,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2,0.894737,0.947368,0.947368,0.947368,0.947368,1.0
3,0.552404,0.683107,0.692972,0.692972,0.692972,0.885327
4,0.366191,0.525104,0.537771,0.540074,0.540074,0.777982


# 3 modified site retrieval

In [None]:
import lmdb
import pickle
from tqdm.auto import tqdm 
import pandas as pd

db = lmdb.open('../datasets/vibench/peptide_mod/peptide_mod_test.lmdb', subdir=False, lock=False, map_size=int(1e11))

# Open a transaction and perform a read operation
with db.begin() as txn:
    test_data = list(txn.cursor())

test_df = pd.DataFrame([pickle.loads(item[1]) for item in tqdm(test_data)])

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 2511/2511 [00:00<00:00, 63137.09it/s]


In [None]:
import sys 
sys.path.append('../')

import torch

from models import build_model
from models import PretrainModel_Phase

torch.manual_seed(624)
device = 'cuda:0'

model = build_model('vib2mol_phase').to(device)
ckpt_path = '../checkpoints/peptide_mod/raman-sequence/vib2mol_phase.pth'
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)

ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
model.load_state_dict(ckpt)

from transformers import AutoTokenizer
tokenizer_path = '../models/PepTokenizer'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

In [3]:
test_df.head()

Unnamed: 0,smiles,norm_smiles,kekule_smiles,raman,ir,filename,sequence
0,C[C@@H](OS(=O)(=O)O)[C@@H](NC(=O)[C@H](CCC(=O)...,CC(OS(=O)(=O)O)C(NC(=O)C(CCC(=O)NC(=O)CCC(N)C(...,CC(OS(=O)(=O)O)C(NC(=O)C(CCC(=O)NC(=O)CCC(N)C(...,"[0.0610237247023169, 0.07534843062887989, 0.08...","[0.011994992794439888, 0.01829167933971773, 0....",C-E-sT-Q,C-E-sT-Q
1,CC(C)[C@H](NC(=O)CNC(=O)[C@@H](N)COS(=O)(=O)O)...,CC(C)C(NC(=O)CNC(=O)C(N)COS(=O)(=O)O)C(=O)O,CC(C)C(NC(=O)CNC(=O)C(N)COS(=O)(=O)O)C(=O)O,"[0.057882732891363045, 0.04388163818948183, 0....","[0.019754579708913297, 0.015484649837671371, 0...",sS-G-V,sS-G-V
2,N[C@@H](Cc1ccc(O)cc1)C(=O)NCC(=O)N[C@@H](Cc1c[...,NC(Cc1ccc(O)cc1)C(=O)NCC(=O)NC(Cc1c[nH]c2ccccc...,NC(CC1=CC=C(O)C=C1)C(=O)NCC(=O)NC(CC1=CNC2=CC=...,"[0.025843460103289473, 0.022190718157910284, 0...","[0.06034964879213473, 0.04087290119317211, 0.0...",Y-G-W,Y-G-W
3,CC[C@H](C)[C@H](N)C(=O)N[C@@H](C(=O)N[C@@H](C)...,CCC(C)C(N)C(=O)NC(C(=O)NC(C)C(=O)O)C(C)OP(=O)(O)O,CCC(C)C(N)C(=O)NC(C(=O)NC(C)C(=O)O)C(C)OP(=O)(O)O,"[0.0727303876705453, 0.05813604322334051, 0.04...","[0.02712906135457303, 0.01848644633855094, 0.0...",I-pT-A,I-pT-A
4,NCC(=O)N[C@@H](COS(=O)(=O)O)C(=O)N[C@@H](CS)C(...,NCC(=O)NC(COS(=O)(=O)O)C(=O)NC(CS)C(=O)O,NCC(=O)NC(COS(=O)(=O)O)C(=O)NC(CS)C(=O)O,"[0.06237074766339497, 0.0889687593086402, 0.14...","[0.010306055339872014, 0.016394710928848952, 0...",G-sS-C,G-sS-C


## 3.1 site classification

In [None]:
import random

def masking_seq(seq):
    mask_id = -1
    seq_list = seq.split('-')
    for i, item in enumerate(seq_list):
        if 's' in item or 'p' in item:
            mask_id = i
            break
    if mask_id != -1:
        masked_seq = seq_list[:]
        masked_seq[mask_id] = '<mask>'
        masked_seq = '-'.join(masked_seq)
    else:
        indices = []
        target_chars = ['H', 'S', 'T', 'Y']
        for char in target_chars:
            for i, item in enumerate(seq_list):
                if char in item and i not in indices:
                    indices.append(i)
                    break

        mask_id = random.choice(indices)
        masked_seq = seq_list.copy()
        masked_seq[mask_id] = '<mask>'
        return '-'.join(masked_seq)
    
    return masked_seq

# 测试
masking_seq('F-H-M-Y-H')


'F-<mask>-M-Y-H'

In [None]:
masked_spectra = test_df.raman.to_list()
masked_sequence = [masking_seq(item) for item in test_df.sequence.to_list()]

In [None]:
from copy import deepcopy
import numpy as np 

all_pred_sequence = []
correct, total = 0, 0

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, sequence, spectra):
        self.sequence = sequence
        self.spectra = spectra

    def __len__(self):
        return len(self.sequence)

    def __getitem__(self, idx):
        return self.sequence[idx], self.spectra[idx]

class TestCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        sequence, spectra = zip(*batch)
        spectra = torch.as_tensor(np.array(spectra), dtype=torch.float32).unsqueeze(1).to(device)
        input_ids = self.tokenizer(list(sequence), return_tensors='pt', padding='max_length', max_length=256, truncation=True)
        input_ids = {'input_ids':input_ids['input_ids'].to(device), 'attention_mask':input_ids['attention_mask'].to(device)}
        return {'smiles': input_ids,  'spectra':spectra}

test_dataset = TestDataset(masked_sequence, masked_spectra)
test_collator = TestCollator(tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_tokens_logits = model.infer_mlm(batch)
    pred_tokens = torch.argmax(pred_tokens_logits, dim=-1)
    
    output = deepcopy(batch['smiles']['input_ids'])
    mask = (batch['smiles']['input_ids'] == 4).cpu()
    output[mask] = pred_tokens[mask]
    
    preds = tokenizer.batch_decode(output, skip_special_tokens=True)
    all_pred_sequence.extend(preds)

100%|██████████| 40/40 [00:07<00:00,  5.05it/s]


In [None]:
import pandas as pd
from rdkit import RDLogger
from rdkit import Chem
RDLogger.DisableLog('rdApp.*')

def check_seq(tgt_seq, pred_seq):
    tgt_seq = tgt_seq.replace('-', '')
    if tgt_seq == pred_seq:
        return 1
    return 0

df = pd.DataFrame({'tgt_seq':test_df['sequence'].to_list(), 'pred_seq':all_pred_sequence})
df['top_1'] = df.apply(lambda row: check_seq(row['tgt_seq'], row['pred_seq']), axis=1)
df['length'] = df.apply(lambda row: len(row['tgt_seq'].split('-')), axis=1)
df['mod'] = df.apply(lambda row: 'phosphorylated' if 'p' in row['tgt_seq'] 
                             else 'sulfated' if 's' in row['tgt_seq'] 
                             else 'unmodified', axis=1)
print(df.top_1.mean())

0.7502986857825568


In [None]:
df.groupby('mod').mean('top_1')

Unnamed: 0_level_0,top_1,length
mod,Unnamed: 1_level_1,Unnamed: 2_level_1
negative,0.791612,3.858453
phosphorylation,0.749664,3.495289
sulfation,0.719403,3.638806


## 3.2 Retrieval

In [4]:
import torch
import lmdb
import pickle
from tqdm.auto import tqdm 
import numpy as np
import pandas as pd


class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra, pred_smiles):
        self.tgt_spectra = tgt_spectra
        self.pred_smiles = pred_smiles

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx], self.pred_smiles[idx]

class TestCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        tgt_spectra, pred_smiles = zip(*batch)
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        input_ids = self.tokenizer(list(pred_smiles), return_tensors='pt', padding='max_length', max_length=256, truncation=True)
        input_ids = {'input_ids':input_ids['input_ids'].to(device), 'attention_mask':input_ids['attention_mask'].to(device)}
        return {'sequence': input_ids,  'spectra':spectra}

In [5]:
test_dataset = TestDataset(test_df.raman.to_list(), test_df.sequence.to_list())
test_collator = TestCollator(tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
all_molecular_embeddings = []
all_spectral_embeddings = []

for batch in test_bar:
    with torch.no_grad():
        molecular_embedding = model.get_molecular_embeddings(batch, use_cls_token=True)['proj_output']
        spectral_embedding = model.get_spectral_embeddings(batch)['proj_output']
        
        all_molecular_embeddings.append(molecular_embedding)
        all_spectral_embeddings.append(spectral_embedding)

100%|██████████| 40/40 [00:04<00:00,  8.15it/s]


In [6]:
import torch
import torch.nn.functional as F

def calculate_similarity_matrix(embedding_query, embedding_key):
    embedding_query = F.normalize(embedding_query, p=2, dim=1)
    embedding_key = F.normalize(embedding_key, p=2, dim=1)

    similarity_matrix = torch.matmul(embedding_query, embedding_key.t())
    return similarity_matrix

def compute_recall(similarity_matrix, k, return_result=False):
    num_queries = similarity_matrix.size(0)
    _, topk_indices = similarity_matrix.topk(k, dim=1, largest=True, sorted=True)
    correct_list = []
    for i in range(num_queries):
        if i in topk_indices[i]:
            correct_list.append(1)
        else:
            correct_list.append(0)
    recall_at_k = sum(correct_list) / num_queries
    
    print(f'recall@{k}:{recall_at_k:.5f}')
    if return_result:
        return correct_list

all_molecular_embeddings = torch.cat(all_molecular_embeddings)
all_spectral_embeddings = torch.cat(all_spectral_embeddings)

similarity_matrix = calculate_similarity_matrix(all_spectral_embeddings, all_molecular_embeddings)
top1 = compute_recall(similarity_matrix, k=1, return_result=True)
top3 = compute_recall(similarity_matrix, k=3, return_result=True)
compute_recall(similarity_matrix, k=5, return_result=False)
compute_recall(similarity_matrix, k=10, return_result=False)

recall@1:0.65034
recall@3:0.87774
recall@5:0.92991
recall@10:0.96655


In [None]:
df = test_df[:]
df['top_1'] = top1
df['top_3'] = top3
df['mod'] = df.apply(lambda row: 'phosphorylated' if 'p' in row['sequence'] 
                             else 'sulfated' if 's' in row['sequence'] 
                             else 'unmodified', axis=1)
df.groupby('mod').mean('top_1')

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['top_1'] = top1
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['top_3'] = top3
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['mod'] = df.apply(lambda row: 'phosphorylation' if 'p' in row['sequence']


Unnamed: 0_level_0,top_1,top_3
mod,Unnamed: 1_level_1,Unnamed: 2_level_1
negative,0.760157,0.930537
phosphorylation,0.623149,0.86541
sulfation,0.587065,0.846766


## 3.3 de novo generation

### 3.3.1 greddy decoding

In [8]:
import numpy as np

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra):
        self.tgt_spectra = tgt_spectra

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx]

class TestCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        tgt_spectra = batch
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        return {'spectra':spectra}

all_pred_sequences = []
test_dataset = TestDataset(test_df['raman'].to_list())
test_collator = TestCollator()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_smiles_ids = model.infer_lm(batch, max_len=6)['pred_ids']
    pred_smiles = tokenizer.batch_decode(pred_smiles_ids)
    pred_smiles = [item.split('</s>')[0].replace('<s>', '') for item in pred_smiles]
    all_pred_sequences.extend(pred_smiles)


100%|██████████| 20/20 [00:05<00:00,  3.69it/s]


In [9]:
import pandas as pd
from rdkit import RDLogger
from rdkit import Chem
RDLogger.DisableLog('rdApp.*')

def check_seq(tgt_seq, pred_seq):
    tgt_seq = tgt_seq.replace('-', '')
    if tgt_seq == pred_seq:
        return 1
    return 0

df = pd.DataFrame({'tgt_seq':test_df['sequence'].to_list(), 'pred_seq':all_pred_sequences})
df['top_1'] = df.apply(lambda row: check_seq(row['tgt_seq'], row['pred_seq']), axis=1)
print(df.top_1.mean())

0.2524890481879729


### 3.3.2 beam searching

In [10]:
import numpy as np

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra):
        self.tgt_spectra = tgt_spectra

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx]

class TestCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        tgt_spectra = batch
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        return {'spectra':spectra}

all_pred_sequence = []
test_dataset = TestDataset(test_df['raman'].to_list())
test_collator = TestCollator()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        pred_sequence_ids_list = model.beam_infer_lm(batch, max_len=6, beam_size=10, temperature=17.5)['pred_ids']
    for pred_sequence_ids in pred_sequence_ids_list:
        pred_sequence = tokenizer.batch_decode(pred_sequence_ids)
        pred_sequence = [item.split('</s>')[0].replace('<s>', '') for item in pred_sequence]
        all_pred_sequence.append(pred_sequence)

100%|██████████| 157/157 [00:14<00:00, 10.85it/s]


In [11]:
import pandas as pd
import torch

def check_beam_seq(tgt_seq, pred_seq_list):
    tgt_seq = tgt_seq.replace('-', '')
    pred_mol_list = []
    if tgt_seq in pred_seq_list:
        return 1
    return 0

df = pd.DataFrame({'tgt_seq':test_df['sequence'].to_list(), 'pred_seq':[list(dict.fromkeys(item)) for item in all_pred_sequence]})
df['top_1'] = df.apply(lambda row: check_beam_seq(row['tgt_seq'], row['pred_seq'][:1]), axis=1)
df['top_3'] = df.apply(lambda row: check_beam_seq(row['tgt_seq'], row['pred_seq'][:3]), axis=1)
df['top_5'] = df.apply(lambda row: check_beam_seq(row['tgt_seq'], row['pred_seq'][:5]), axis=1)
df['top_10'] = df.apply(lambda row: check_beam_seq(row['tgt_seq'], row['pred_seq'][:10]), axis=1)

print(f'top-1:\t\t{df.top_1.mean():.5f}\ntop-3:\t\t{df.top_3.mean():.5f}\ntop-5:\t\t{df.top_5.mean():.5f}\ntop-10:\t\t{df.top_10.mean():.5f}')

top-1:		0.17483
top-3:		0.34090
top-5:		0.40064
top-10:		0.43011


In [12]:
df.head()

Unnamed: 0,tgt_seq,pred_seq,top_1,top_3,top_5,top_10
0,C-E-sT-Q,"[CPsTD, CPID, CDEsT, CsTsTD, CsTED, CPsTsT, CP...",0,0,0,0
1,sS-G-V,"[sSG, sSGV, sSVG, sSVM, sSGpH1, sSVP]",0,1,1,1
2,Y-G-W,"[WYW, WYG, WYY, WGY, YGW]",0,0,1,1
3,I-pT-A,"[VpTA, VpTI, TIA, VpTAF]",0,0,0,0
4,G-sS-C,"[, sSGC, GsSC, WGC, sSsSC, WpSC]",0,1,1,1


### 3.3.3 rerank by retrieval module

In [13]:
candidate_sequence_list = [list(set(item)) for item in all_pred_sequence]
candidate_spectra_list = [[test_df.iloc[i]['raman']] * len(item) for i, item in enumerate(candidate_sequence_list)]
tgt_sequence_list = [[test_df.iloc[i]['sequence']] * len(item) for i, item in enumerate(candidate_sequence_list)]

candidate_sequence_list = [subitem for item in candidate_sequence_list for subitem in item]
candidate_spectra_list = [subitem for item in candidate_spectra_list for subitem in item]
tgt_sequence_list = [subitem for item in tgt_sequence_list for subitem in item]

In [14]:
# calculate similarity between predicted molecules and target spectra
import torch
import torch.nn.functional as F

def calculate_similarity_matrix(embedding_query, embedding_key):
    if type(embedding_query) != torch.Tensor:
        embedding_query = torch.tensor(embedding_query)
    if type(embedding_key) != torch.Tensor:
        embedding_key = torch.tensor(embedding_key)
    
    embedding_query = F.normalize(embedding_query, p=2, dim=1)
    embedding_key = F.normalize(embedding_key, p=2, dim=1)

    similarity_matrix = torch.matmul(embedding_query, embedding_key.t())
    return similarity_matrix


class TestDataset(torch.utils.data.Dataset):
    def __init__(self, tgt_spectra, pred_sequence):
        self.tgt_spectra = tgt_spectra
        self.pred_sequence = pred_sequence

    def __len__(self):
        return len(self.tgt_spectra)

    def __getitem__(self, idx):
        return self.tgt_spectra[idx], self.pred_sequence[idx]

class TestCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        tgt_spectra, pred_sequence = zip(*batch)
        spectra = torch.as_tensor(np.array(tgt_spectra), dtype=torch.float32).unsqueeze(1).to(device)
        input_ids = self.tokenizer(list(pred_sequence), return_tensors='pt', padding='max_length', max_length=256, truncation=True)
        input_ids = {'input_ids':input_ids['input_ids'].to(device), 'attention_mask':input_ids['attention_mask'].to(device)}
        return {'sequence': input_ids,  'spectra':spectra}

    
valid_sim_list = []

test_dataset = TestDataset(candidate_spectra_list, candidate_sequence_list)
test_collator = TestCollator(tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=test_collator)
test_bar = tqdm(test_loader)

model.eval()
for batch in test_bar:
    with torch.no_grad():
        molecular_embedding = model.get_molecular_embeddings(batch, use_cls_token=True)['proj_output']
        spectral_embedding = model.get_spectral_embeddings(batch)['proj_output']
        sim = calculate_similarity_matrix(spectral_embedding, molecular_embedding)
    valid_sim_list += torch.diag(sim).tolist()

100%|██████████| 232/232 [00:31<00:00,  7.44it/s]


In [None]:
tgt_sequence_list = [item.replace('-', '') for item in tgt_sequence_list]
df = pd.DataFrame({'target_sequence':tgt_sequence_list, 'pred_sequence':candidate_sequence_list, 'similarity':valid_sim_list})

df_sorted = df.sort_values(by=['target_sequence', 'similarity'], ascending=[True, False])

df = df_sorted.groupby('target_sequence').agg({
    'pred_sequence': lambda x: ','.join(x),
    'similarity': lambda x: ','.join(map(str, x))
}).reset_index()

for top_k in [1, 3, 5, 10, 100]: 
    df[f'top_{top_k}_recall'] = df.apply(lambda row: row['target_sequence'].replace('-', '') in row['pred_sequence'].split(',')[:top_k], axis=1)

df['rank'] = df.apply(lambda row: (row['pred_sequence'].split(',').index(row['target_sequence']))+1 if row['target_sequence'] in row['pred_sequence'].split(',') else 0, axis=1)
df['mod'] = df.apply(lambda row: 'phosphorylated' if 'p' in row['target_sequence'] else ('sulfated' if 's' in row['target_sequence'] else 'unmodified'), axis=1)

print(f"""
recall@1:\t{df.top_1_recall.mean():.5f} 
recall@3:\t{df.top_3_recall.mean():.5f} 
recall@5:\t{df.top_5_recall.mean():.5f} 
recall@10:\t{df.top_10_recall.mean():.5f}
recall@100:\t{df.top_100_recall.mean():.5f}
    """)


recall@1:	0.27997 
recall@3:	0.41179 
recall@5:	0.42732 
recall@10:	0.43011
recall@100:	0.43011
    


In [16]:
df.groupby('mod').mean('top_1_recall')

Unnamed: 0_level_0,top_1_recall,top_3_recall,top_5_recall,top_10_recall,top_100_recall,rank
mod,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
normal,0.272608,0.418087,0.435125,0.439056,0.439056,0.70249
phosphorylation,0.289367,0.41319,0.423957,0.425303,0.425303,0.621803
sulfation,0.278607,0.40597,0.423881,0.426866,0.426866,0.654726
