# CSU-MS2 cross-modal retireval 
### importing required libraries

In [1]:
# allow to import modules from the project root directory
import sys
import os

# Add the project root directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [39]:
import bisect
from infer import ModelInference
import numpy as np
import torch
from tqdm import tqdm
from matchms.importing import load_from_mgf
from matchms.Fragments import Fragments
import pandas as pd
import IsoSpecPy
import json
import os
import requests
import pubchempy as pc
from bs4 import BeautifulSoup
import matchms.filtering as msfilters

### define spectrum processing function and reference dataset searching function

In [34]:
def spectrum_processing(s):
    """This is how one would typically design a desired pre- and post-
    processing pipeline."""
    s = msfilters.normalize_intensities(s)
    s = msfilters.select_by_mz(s, mz_from=0, mz_to=1500)
    return s

def search_formula(formulaDB,mass, ppm): 
    mmin = mass - mass*ppm/10**6 
    mmax = mass + mass*ppm/10**6 
    lf = bisect.bisect_left(formulaDB['Exact mass'], mmin) 
    rg = bisect.bisect_right(formulaDB['Exact mass'], mmax) 
    formulas = list(formulaDB['Formula'][lf:rg]) 
    return formulas 

def search_structure(structureDB,formulas): 
    structures=pd.DataFrame()
    for formula in formulas:
        structure = structureDB[structureDB['MolecularFormula']==formula] 
        structures=pd.concat([structures,structure])
    return structures 

def search_structure_from_mass(structureDB,mass, ppm): 
    structures=pd.DataFrame()
    # mmin = mass - mass*ppm/10**6 
    # mmax = mass + mass*ppm/10**6 
    # print(mmin, mmax)
    # structures = structureDB[(structureDB['MonoisotopicMass'] >= mmin) & (structureDB['MonoisotopicMass'] <= mmax)]
    return structureDB 

### define structural feature vector calculation function

In [63]:
# Original code
# def get_feature(lst,save_name,model_inference,
#                 n=256,flag_get_value=False):

#     print("Size of the library: ", len(lst))
#     fn = model_inference.smiles_encode
#     contexts = []


#     print("start load batch")
#     for i in range(0, len(lst), n):
#        contexts.append(lst[i:i + n])

#     print("start encode batch")
#     result = [fn(i).cpu() for i in tqdm(contexts)]
#     result = torch.cat(result, 0)
#     if flag_get_value is True:
#         if save_name is not None:
#             torch.save((result, lst), save_name)
#         return result, lst

def get_feature(lst, save_name, model_inference, n=256, flag_get_value=False):
    print("Size of the library: ", len(lst))
    fn = model_inference.smiles_encode
    contexts = []

    print("start load batch")
    for i in range(0, len(lst), n):
        batch = lst[i:i + n]
        contexts.append(batch)
        # print(f"Created batch {len(contexts)}: items {i} to {i + len(batch) - 1} (size: {len(batch)})")


    print("start encode batch")
    result = []
    total_batches = len(contexts)
    print("Total_batches: ", total_batches)
    for i, context in enumerate(contexts):
        batch_result = fn(context).cpu()
        result.append(batch_result)
        
        # Print progress every 10% or every 10 batches
        if (i + 1) % max(1, total_batches // 10) == 0 or i + 1 == total_batches:
            progress = ((i + 1) / total_batches) * 100
            print(f"Progress: {i + 1}/{total_batches} batches ({progress:.1f}%)")
    
    result = torch.cat(result, 0)
    if flag_get_value is True:
        if save_name is not None:
            torch.save((result, lst), save_name)
        return result, lst


def get_topK_result(library,ms_feature, smiles_feature, topK):
    indices = []
    scores = []
    candidates = []
    if topK >= len(library):
        topK = len(library)
    with torch.no_grad():
        ms_smiles_distances_tmp = (
            ms_feature @ smiles_feature.t()).cpu()
        scores_, indices_ = ms_smiles_distances_tmp.topk(topK,
                                                      dim=1,
                                                      largest=True,
                                                      sorted=True)
        candidates_=[library[i] for i in indices_.tolist()[0]]
        indices.append(indices_.tolist()[0])
        scores.append(scores_.tolist()[0])
        candidates.append(candidates_)
    return indices, scores, candidates

### loading model

In [9]:
config_path_low = "model/hcd_model/low_energy/checkpoints/config.yaml"
pretrain_model_path_low = "model/hcd_model/low_energy/checkpoints/model.pth"
model_inference_low = ModelInference(config_path=config_path_low,
                            pretrain_model_path=pretrain_model_path_low,
                            device="cpu")
config_path_median = "model/hcd_model/median_energy/checkpoints/config.yaml"
pretrain_model_path_median = "model/hcd_model/median_energy/checkpoints/model.pth"
model_inference_median = ModelInference(config_path=config_path_median,
                            pretrain_model_path=pretrain_model_path_median,
                            device="cpu")
config_path_high = "model/hcd_model/high_energy/checkpoints/config.yaml"
pretrain_model_path_high = "model/hcd_model/high_energy/checkpoints/model.pth"
model_inference_high = ModelInference(config_path=config_path_high,
                            pretrain_model_path=pretrain_model_path_high,
                            device="cpu") 

### create result save file

In [10]:
output_file='results/'
os.mkdir(output_file)

### upload spectra file, reference dataset file and collision energy file

In [80]:
ms_list=list(load_from_mgf("data/test_spectrum.mgf"))
reference_library = pd.read_csv('../../notebooks/output.csv')
# collision_energy_file=pd.read_csv('collision_energy.csv')
# collision_energy_lst = list(collision_energy_file['CollisionEnergy'])

KeyboardInterrupt: 

 ### perform cross-modal retrieval for a list of spectra

In [62]:
for i in tqdm(range(len(ms_list))):
            result=pd.DataFrame()
            spectrum = ms_list[i]
            spectrum = spectrum_processing(spectrum)
            ms_feature_low = model_inference_low.ms2_encode([spectrum])
            ms_feature_median = model_inference_median.ms2_encode([spectrum])
            ms_feature_high = model_inference_high.ms2_encode([spectrum])
            
            # The precursor mass and collison energy can also be set manually 
            query_ms = float(spectrum.metadata['precursor_mz'])-1.008
            collision_energy = int(spectrum.metadata.get('collision_energy', 20)) 
            # collision_energy = collision_energy_lst[i]
            search_res=search_structure_from_mass(reference_library, query_ms, 30)
            smiles_lst = list(search_res['SMILES'])

            print("smiles_feature_low prediction:")
            smiles_feature_low, smiles_lst1 = get_feature(smiles_lst,save_name=None,
                model_inference=model_inference_low,n=256,flag_get_value=True)
            
            print("smiles_feature_median prediction:")
            smiles_feature_median, smiles_lst2 = get_feature(smiles_lst,save_name=None,
                model_inference=model_inference_median,n=256,flag_get_value=True)
            
            print("smiles_feature_high prediction:")
            smiles_feature_high, smiles_lst3 = get_feature(smiles_lst,save_name=None,
                model_inference=model_inference_high,n=256,flag_get_value=True)
            
            low_similarity = ms_feature_low @ smiles_feature_low.t()
            median_similarity = ms_feature_median @ smiles_feature_median.t()
            high_similarity = ms_feature_high @ smiles_feature_high.t()
            low_similarity = low_similarity.numpy()
            median_similarity = median_similarity.numpy()
            high_similarity = high_similarity.numpy()
            
            weight1 = (1/abs(collision_energy-10+1e-10))/((1/abs(collision_energy-10+1e-10))+(1/abs(collision_energy-20+1e-10))+(1/abs(collision_energy-40+1e-10)))
            weight2 = (1/abs(collision_energy-20+1e-10))/((1/abs(collision_energy-10+1e-10))+(1/abs(collision_energy-20+1e-10))+(1/abs(collision_energy-40+1e-10)))
            weight3 = (1/abs(collision_energy-40+1e-10))/((1/abs(collision_energy-10+1e-10))+(1/abs(collision_energy-20+1e-10))+(1/abs(collision_energy-40+1e-10)))
   
            weighted_similarity = weight1 * low_similarity + weight2 * median_similarity + weight3 * high_similarity
            weighted_similarity = np.squeeze(weighted_similarity, axis=0)
            weighted_similarity_scores=[(smiles_lst1[i],weighted_similarity[i]) for i in range(len(smiles_lst1))]
            weighted_similarity_scores.sort(key=lambda x: x[1], reverse=True)
            results = pd.DataFrame({'SMILES':[x[0] for x in weighted_similarity_scores],'Score':[x[1] for x in weighted_similarity_scores],'Rank':list(range(1,len(smiles_lst1)+1))})
            results.to_csv(output_file+'spectrum'+str(i)+'.csv')


[A

smiles_feature_low prediction:
Size of the library:  1000
start encode batch







Encoding SMILES:   0%|          | 0/1000 [00:06<?, ?it/s]
  0%|          | 0/1 [00:08<?, ?it/s]


KeyboardInterrupt: 

In [55]:
print(results.head())

                                              SMILES     Score  Rank
0  C=C[C@@]1(C)CC(=O)[C@]2(O)[C@@]3(C)[C@@H](O)CC...  0.667050     1
1  CO/C1=C\C(C)=C\[C@@H](C)[C@@H](O)[C@@H](C)C/C(...  0.665769     2
2  C=C1[C@@H](O)[C@@H]2O[C@]3(CC[C@H](/C=C/[C@@H]...  0.661226     3
3  CC[C@@]1([C@@H]2O[C@@H]([C@H]3O[C@@](O)(CO)[C@...  0.649401     4
4  CO[C@@H]1C[C@@H](C[C@H]2CC[C@H](C)[C@H]([C@@H]...  0.642349     5


In [77]:
from rdkit import Chem
from rdkit.Chem import DataStructs
from rdkit.Chem import rdMolDescriptors
import pandas as pd

def compute_tanimoto_similarity(target_smiles, results_df):
    """
    Compute Tanimoto similarity between target SMILES and each SMILES in results DataFrame
    """
    # Convert target SMILES to molecule and fingerprint
    target_mol = Chem.MolFromSmiles(target_smiles)
    if target_mol is None:
        print(f"Warning: Could not parse target SMILES: {target_smiles}")
        return results_df
    
    target_fp = rdMolDescriptors.GetMorganFingerprintAsBitVect(target_mol, 2, nBits=2048)
    
    tanimoto_scores = []
    
    for smiles in results_df['SMILES']:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            tanimoto_scores.append(0.0)  # Invalid SMILES gets 0 similarity
            continue
        
        # Generate Morgan fingerprint
        fp = rdMolDescriptors.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
        
        # Calculate Tanimoto similarity
        tanimoto = DataStructs.TanimotoSimilarity(target_fp, fp)
        tanimoto_scores.append(tanimoto)
    
    # Add Tanimoto similarity to results DataFrame
    results_df = results_df.copy()
    results_df['Tanimoto_Similarity'] = tanimoto_scores
    
    return results_df

In [78]:
# Usage example:
spectrum = ms_list[0]
target_smiles = spectrum.metadata['smiles']
results_with_tanimoto = compute_tanimoto_similarity(target_smiles, results)

# You can also sort by Tanimoto similarity
results_sorted_by_tanimoto = results_with_tanimoto.sort_values('Tanimoto_Similarity', ascending=False)
print(results_sorted_by_tanimoto.head())

                                                SMILES     Score  Rank  \
791              NC(=O)c1ccccc1Nc1ccnc(Nc2cccc(O)c2)n1 -0.089166   792   
679    C[C@@H](Oc1ccc(Oc2ncc(C(F)(F)F)cc2Cl)cc1)C(=O)O -0.036087   680   
786                         O=C(O)C(=O)Nc1ccccc1C(=O)O -0.085290   787   
568  CC(C)[C@H](CO)Nc1nc(Nc2ccc(C(=O)O)c(Cl)c2)c2nc...  0.018972   569   
374  O=C(NC(=O)c1c(F)cccc1F)Nc1ccc(Cl)c(Oc2ncc(C(F)...  0.138712   375   

     Tanimoto_Similarity  
791             0.301587  
679             0.265625  
786             0.265306  
568             0.263158  
374             0.260274  


In [79]:
results_sorted_by_tanimoto

Unnamed: 0,SMILES,Score,Rank,Tanimoto_Similarity
791,NC(=O)c1ccccc1Nc1ccnc(Nc2cccc(O)c2)n1,-0.089166,792,0.301587
679,C[C@@H](Oc1ccc(Oc2ncc(C(F)(F)F)cc2Cl)cc1)C(=O)O,-0.036087,680,0.265625
786,O=C(O)C(=O)Nc1ccccc1C(=O)O,-0.085290,787,0.265306
568,CC(C)[C@H](CO)Nc1nc(Nc2ccc(C(=O)O)c(Cl)c2)c2nc...,0.018972,569,0.263158
374,O=C(NC(=O)c1c(F)cccc1F)Nc1ccc(Cl)c(Oc2ncc(C(F)...,0.138712,375,0.260274
...,...,...,...,...
599,COCCOCCOCCOCCOCCOC,0.005190,600,0.000000
967,COCCOCCOCCOC,-0.247590,968,0.000000
968,Cl[Cu]Cl,-0.249652,969,0.000000
592,CCOCCOCCOCCOCCOCCOCC,0.007715,593,0.000000
