In [3]:
import pandas as pd
from matchms.importing import load_from_mgf
from rdkit import Chem

from collections import defaultdict
import pandas as pd
from tqdm import tqdm
import logging

import torch
from torch.utils.data import Dataset, DataLoader
import random

from dreams.algorithms.murcko_hist.murcko_hist import murcko_hist

In [6]:
def load_mgf_with_folds(mgf_path):

    spectra = list(load_from_mgf(mgf_path))
    records = []
    for spec in spectra:
        record = spec.to_dict()
        records.append(record)
    df = pd.DataFrame(records)

    if 'fold' not in df.columns:
        raise ValueError("fold column is missing. Ensure the dataset has been split into train/val/test.")
    
    return df

def extract_training_set(df):

    df_train = df[df['fold'] == 'train'].reset_index(drop=True)
    print(f"Training set size: {len(df_train)} spectra")
    return df_train

def get_unique_smiles(df_train, smiles_col='smiles'):

    df_us = df_train.drop_duplicates(subset=[smiles_col]).copy()
    print(f"Number of unique SMILES in training set: {df_us[smiles_col].nunique()}")
    return df_us

def compute_murcko_histograms(df_us, smiles_col='smiles'):

    print("Computing Murcko histograms...")
    tqdm.pandas()
    df_us['murckohist'] = df_us[smiles_col].progress_apply(
        lambda x: murcko_hist(Chem.MolFromSmiles(x)) if Chem.MolFromSmiles(x) else {}
    )
    
    # Convert dictionaries to strings for easier handling
    df_us['murckohiststr'] = df_us['murckohist'].astype(str)
    
    print('Number of unique SMILES:', df_us[smiles_col].nunique(), 
          'Number of unique Murcko histograms:', df_us['murckohiststr'].nunique())
    
    print('Top 20 most common Murcko histograms:')
    print(df_us['murckohiststr'].value_counts().head(20))
    
    return df_us

def group_by_murcko_histograms(df_us, smiles_col='smiles'):

    df_gb = df_us.groupby('murckohiststr').agg(
        count=(smiles_col, 'count'),
        smiles_list=(smiles_col, list)
    ).reset_index()

    df_gb['murckohist'] = df_gb['murckohiststr'].apply(eval)

    df_gb = df_gb.sort_values('count', ascending=False).reset_index(drop=True)
    
    print(f"Grouped into {len(df_gb)} Murcko histogram groups.")
    print(df_gb.head())
    
    return df_gb

def merge_murcko_hist(df_train, df_us_train, smiles_col='smiles'):

    df_train = df_train.merge(df_us_train[['smiles', 'murckohiststr']], on='smiles', how='left')
    # Handle any missing murckohiststr by assigning an empty dictionary string
    df_train['murckohiststr'] = df_train['murckohiststr'].fillna('{}')
    print(f"After merge, 'murckohiststr' assigned to {df_train['murckohiststr'].notna().sum()} identifiers.")
    return df_train

def create_mappings(df_train, hist_col='murckohiststr', identifier_col='identifier'):

    hist_to_indices = defaultdict(list)
    for idx, hist in enumerate(df_train[hist_col]):
        hist_to_indices[hist].append(idx)
    
    index_to_hist = df_train[hist_col].to_dict()
    
    print(f"Total unique Murcko histograms: {len(hist_to_indices)}")
    return hist_to_indices, index_to_hist

In [5]:
spectra_path = "../../data/data/MassSpecGym.mgf"
df = load_mgf_with_folds(spectra_path)
print(df.head())

             identifier                                         smiles  \
0  MassSpecGymID0000001  CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
1  MassSpecGymID0000002  CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
2  MassSpecGymID0000003  CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
3  MassSpecGymID0000004  CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
4  MassSpecGymID0000005  CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   

         inchikey    formula precursor_formula parent_mass  precursor_mz  \
0  VFMQMACUYWGDOJ  C16H17NO4         C16H18NO4  287.115224      288.1225   
1  VFMQMACUYWGDOJ  C16H17NO4         C16H18NO4  287.115224      288.1225   
2  VFMQMACUYWGDOJ  C16H17NO4         C16H18NO4  287.115224      288.1225   
3  VFMQMACUYWGDOJ  C16H17NO4         C16H18NO4  287.115224      288.1225   
4  VFMQMACUYWGDOJ  C16H17NO4         C16H18NO4  287.115224      288.1225   

   adduct instrument_type collision_energy   fold simulation_challenge  \
0  [M+H]+        Orbitra

In [7]:
df_train = extract_training_set(df)

df_us_train = get_unique_smiles(df_train)

df_us_train = compute_murcko_histograms(df_us_train)

df_gb_train = group_by_murcko_histograms(df_us_train)

df_train = merge_murcko_hist(df_train, df_us_train)

Training set size: 194119 spectra
Number of unique SMILES in training set: 25046
Computing Murcko histograms...


100%|██████████| 25046/25046 [02:33<00:00, 163.46it/s] 


Number of unique SMILES: 25046 Number of unique Murcko histograms: 338
Top 20 most common Murcko histograms:
murckohiststr
{'0_1': 1, '1_0': 1, '1_1': 1}              3351
{}                                          2953
{'0_1': 2}                                  2948
{'0_0': 1}                                  2875
{'1_0': 2}                                  1626
{'0_1': 2, '0_2': 1}                        1436
{'1_0': 2, '2_0': 2}                        1134
{'1_0': 2, '2_0': 1}                        1096
{'0_1': 1, '0_2': 1, '1_0': 1, '1_1': 1}     751
{'0_1': 2, '1_1': 2}                         545
{'0_1': 1, '1_0': 1, '1_1': 1, '2_0': 1}     495
{'0_1': 2, '1_0': 1, '1_2': 1}               481
{'1_0': 2, '1_1': 2}                         456
{'0_1': 3}                                   341
{'1_0': 2, '2_0': 3}                         288
{'0_1': 2, '0_2': 2}                         281
{'0_1': 1, '1_0': 1, '1_1': 1, '2_0': 2}     246
{'0_1': 2, '0_2': 1, '1_1': 2}              

In [8]:
len(df_train['murckohiststr'].unique())

338

In [10]:
hist_to_indices, index_to_hist = create_mappings(df_train)

Total unique Murcko histograms: 338


In [9]:
df_train.columns

Index(['identifier', 'smiles', 'inchikey', 'formula', 'precursor_formula',
       'parent_mass', 'precursor_mz', 'adduct', 'instrument_type',
       'collision_energy', 'fold', 'simulation_challenge', 'peaks_json',
       'murckohiststr'],
      dtype='object')