In [64]:
import pandas as pd
from matchms.importing import load_from_mgf
from rdkit import Chem

from collections import defaultdict
import pandas as pd
from tqdm import tqdm
import logging

import torch
from torch.utils.data import Dataset, DataLoader
import random

from dreams.algorithms.murcko_hist.murcko_hist import murcko_hist

# This is creation on dataset for on positive and multiple negative example based on Murcko Histogram, refined for InfoNCE loss

In [75]:
def load_mgf_with_folds(mgf_path):

    spectra = list(load_from_mgf(mgf_path))
    records = []
    for spec in spectra:
        record = spec.to_dict()
        records.append(record)
    df = pd.DataFrame(records)

    if 'fold' not in df.columns:
        raise ValueError("fold column is missing. Ensure the dataset has been split into train/val/test.")
    
    return df


In [76]:
spectra_path = "../../data/data/MassSpecGym.mgf"
df = load_mgf_with_folds(spectra_path)
print(df.head())

             identifier                                         smiles  \
0  MassSpecGymID0000001  CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
1  MassSpecGymID0000002  CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
2  MassSpecGymID0000003  CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
3  MassSpecGymID0000004  CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
4  MassSpecGymID0000005  CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   

         inchikey    formula precursor_formula parent_mass  precursor_mz  \
0  VFMQMACUYWGDOJ  C16H17NO4         C16H18NO4  287.115224      288.1225   
1  VFMQMACUYWGDOJ  C16H17NO4         C16H18NO4  287.115224      288.1225   
2  VFMQMACUYWGDOJ  C16H17NO4         C16H18NO4  287.115224      288.1225   
3  VFMQMACUYWGDOJ  C16H17NO4         C16H18NO4  287.115224      288.1225   
4  VFMQMACUYWGDOJ  C16H17NO4         C16H18NO4  287.115224      288.1225   

   adduct instrument_type collision_energy   fold simulation_challenge  \
0  [M+H]+        Orbitra

In [77]:
def extract_training_set(df):

    df_train = df[df['fold'] == 'train'].reset_index(drop=True)
    print(f"Training set size: {len(df_train)} spectra")
    return df_train

df_train = extract_training_set(df)

Training set size: 194119 spectra


In [78]:
def get_unique_smiles(df_train, smiles_col='smiles'):

    df_us = df_train.drop_duplicates(subset=[smiles_col]).copy()
    print(f"Number of unique SMILES in training set: {df_us[smiles_col].nunique()}")
    return df_us

df_us_train = get_unique_smiles(df_train)

Number of unique SMILES in training set: 25046


In [80]:
def compute_murcko_histograms(df_us, smiles_col='smiles'):

    print("Computing Murcko histograms...")
    tqdm.pandas()
    df_us['murckohist'] = df_us[smiles_col].progress_apply(
        lambda x: murcko_hist(Chem.MolFromSmiles(x)) if Chem.MolFromSmiles(x) else {}
    )
    
    # Convert dictionaries to strings for easier handling
    df_us['murckohiststr'] = df_us['murckohist'].astype(str)
    
    print('Number of unique SMILES:', df_us[smiles_col].nunique(), 
          'Number of unique Murcko histograms:', df_us['murckohiststr'].nunique())
    
    print('Top 20 most common Murcko histograms:')
    print(df_us['murckohiststr'].value_counts().head(20))
    
    return df_us

df_us_train = compute_murcko_histograms(df_us_train)

Computing Murcko histograms...


100%|██████████| 25046/25046 [02:37<00:00, 158.58it/s] 

Number of unique SMILES: 25046 Number of unique Murcko histograms: 338
Top 20 most common Murcko histograms:
murckohiststr
{'0_1': 1, '1_0': 1, '1_1': 1}              3351
{}                                          2953
{'0_1': 2}                                  2948
{'0_0': 1}                                  2875
{'1_0': 2}                                  1626
{'0_1': 2, '0_2': 1}                        1436
{'1_0': 2, '2_0': 2}                        1134
{'1_0': 2, '2_0': 1}                        1096
{'0_1': 1, '0_2': 1, '1_0': 1, '1_1': 1}     751
{'0_1': 2, '1_1': 2}                         545
{'0_1': 1, '1_0': 1, '1_1': 1, '2_0': 1}     495
{'0_1': 2, '1_0': 1, '1_2': 1}               481
{'1_0': 2, '1_1': 2}                         456
{'0_1': 3}                                   341
{'1_0': 2, '2_0': 3}                         288
{'0_1': 2, '0_2': 2}                         281
{'0_1': 1, '1_0': 1, '1_1': 1, '2_0': 2}     246
{'0_1': 2, '0_2': 1, '1_1': 2}              




In [81]:
def group_by_murcko_histograms(df_us, smiles_col='smiles'):

    df_gb = df_us.groupby('murckohiststr').agg(
        count=(smiles_col, 'count'),
        smiles_list=(smiles_col, list)
    ).reset_index()

    df_gb['murckohist'] = df_gb['murckohiststr'].apply(eval)

    df_gb = df_gb.sort_values('count', ascending=False).reset_index(drop=True)
    
    print(f"Grouped into {len(df_gb)} Murcko histogram groups.")
    print(df_gb.head())
    
    return df_gb

df_gb_train = group_by_murcko_histograms(df_us_train)

Grouped into 338 Murcko histogram groups.
                    murckohiststr  count  \
0  {'0_1': 1, '1_0': 1, '1_1': 1}   3351   
1                              {}   2953   
2                      {'0_1': 2}   2948   
3                      {'0_0': 1}   2875   
4                      {'1_0': 2}   1626   

                                         smiles_list  \
0  [C1=CC=C(C=C1)C2=C(C(=O)NC3=CC=CC=C32)O, CN1C(...   
1  [CC(=C)C(=O)/C(=C/C(=O)O)/OC, CC[C@@H](C)[C@H]...   
2  [CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC...   
3  [C[C@@H]1C[C@H]2[C@H](O2)/C=C\C(=O)CC(=O)O1, C...   
4  [C[C@H]1CCCC(=O)CCC/C=C/C2=C(C(=CC(=C2)O)O)C(=...   

                       murckohist  
0  {'0_1': 1, '1_0': 1, '1_1': 1}  
1                              {}  
2                      {'0_1': 2}  
3                      {'0_0': 1}  
4                      {'1_0': 2}  


In [82]:
def merge_murcko_hist(df_train, df_us_train, smiles_col='smiles'):

    df_train = df_train.merge(df_us_train[['smiles', 'murckohiststr']], on='smiles', how='left')
    # Handle any missing murckohiststr by assigning an empty dictionary string
    df_train['murckohiststr'] = df_train['murckohiststr'].fillna('{}')
    print(f"After merge, 'murckohiststr' assigned to {df_train['murckohiststr'].notna().sum()} identifiers.")
    return df_train

df_train = merge_murcko_hist(df_train, df_us_train)

After merge, 'murckohiststr' assigned to 194119 identifiers.


In [88]:
len(df_train['murckohiststr'].unique())

338

In [89]:
def create_mappings(df_train, hist_col='murckohiststr', identifier_col='identifier'):

    hist_to_indices = defaultdict(list)
    for idx, hist in enumerate(df_train[hist_col]):
        hist_to_indices[hist].append(idx)
    
    index_to_hist = df_train[hist_col].to_dict()
    
    print(f"Total unique Murcko histograms: {len(hist_to_indices)}")
    return hist_to_indices, index_to_hist


hist_to_indices, index_to_hist = create_mappings(df_train)

Total unique Murcko histograms: 338


In [91]:
hist_to_indices

defaultdict(list,
            {"{'0_1': 2}": [0,
              1,
              2,
              3,
              4,
              5,
              6,
              7,
              8,
              9,
              10,
              463,
              464,
              465,
              466,
              467,
              468,
              654,
              655,
              656,
              657,
              658,
              659,
              660,
              661,
              662,
              663,
              752,
              753,
              754,
              755,
              756,
              757,
              758,
              759,
              760,
              761,
              1485,
              1486,
              1487,
              1488,
              1489,
              1490,
              1491,
              1492,
              1493,
              1666,
              1667,
              1668,
              1669,
              1670,
      

In [92]:
index_to_hist

{0: "{'0_1': 2}",
 1: "{'0_1': 2}",
 2: "{'0_1': 2}",
 3: "{'0_1': 2}",
 4: "{'0_1': 2}",
 5: "{'0_1': 2}",
 6: "{'0_1': 2}",
 7: "{'0_1': 2}",
 8: "{'0_1': 2}",
 9: "{'0_1': 2}",
 10: "{'0_1': 2}",
 11: "{'1_0': 2}",
 12: "{'1_0': 2}",
 13: "{'1_0': 2}",
 14: "{'1_0': 2}",
 15: "{'1_0': 2}",
 16: "{'1_0': 2}",
 17: "{'1_0': 2}",
 18: "{'1_0': 2}",
 19: "{'1_0': 2}",
 20: "{'1_0': 2}",
 21: "{'1_0': 2}",
 22: "{'1_0': 2}",
 23: "{'1_0': 2}",
 24: "{'1_0': 2}",
 25: "{'1_0': 2}",
 26: "{'1_0': 2}",
 27: "{'1_0': 2}",
 28: "{'1_0': 2}",
 29: "{'1_0': 2}",
 30: "{'1_0': 2}",
 31: "{'1_0': 2}",
 32: "{'1_0': 2}",
 33: "{'1_0': 2}",
 34: "{'1_0': 2}",
 35: "{'1_0': 2}",
 36: "{'1_0': 2}",
 37: "{'1_0': 2}",
 38: "{'1_0': 2}",
 39: "{'1_0': 2}",
 40: "{'1_0': 2}",
 41: "{'1_0': 2}",
 42: "{'0_1': 1, '0_2': 1, '1_0': 1, '1_1': 1}",
 43: "{'0_1': 1, '0_2': 1, '1_0': 1, '1_1': 1}",
 44: "{'0_1': 1, '0_2': 1, '1_0': 1, '1_1': 1}",
 45: "{'0_1': 1, '0_2': 1, '1_0': 1, '1_1': 1}",
 46: "{'0_1': 1,

In [95]:
df_train.columns()

TypeError: 'Index' object is not callable

In [96]:
if isinstance(df_train, pd.Series):
    df_train = df_train.to_frame().T  # Transpose if necessary

In [97]:
df_train.columns

Index(['identifier', 'smiles', 'inchikey', 'formula', 'precursor_formula',
       'parent_mass', 'precursor_mz', 'adduct', 'instrument_type',
       'collision_energy', 'fold', 'simulation_challenge', 'peaks_json',
       'murckohiststr'],
      dtype='object')

## Dataset

In [117]:


def generate_triplets_precomputed(
    df, 
    index_to_hist, 
    hist_to_indices, 
    num_positives=5, 
    num_negatives=5, 
    num_negative_lists=5
):
    """
    Precompute positive and negative indices for each row in the DataFrame.

    Args:
        df (pd.DataFrame): The training DataFrame.
        index_to_hist (dict): Mapping from row index to murckohiststr.
        hist_to_indices (dict): Mapping from murckohiststr to list of row indices.
        num_positives (int): Number of positive samples per row.
        num_negatives (int): Number of negative samples per list.
        num_negative_lists (int): Number of negative lists per row.

    Returns:
        pd.DataFrame: The original DataFrame with two new columns:
                      - 'positive_indices': List of positive indices.
                      - 'negative_indices': List of lists containing negative indices.
    """
    # Configure logging
    logging.basicConfig(level=logging.WARNING, format='%(levelname)s: %(message)s')

    # Create a copy of the DataFrame to avoid SettingWithCopyWarning
    df = df.copy()

    # Initialize new columns with empty lists
    df['positive_indices'] = [[] for _ in range(len(df))]
    df['negative_indices'] = [[] for _ in range(len(df))]

    # Precompute the set of all indices for negative sampling
    all_indices_set = set(df.index.tolist())

    # Convert hist_to_indices values to sets for efficient operations
    hist_to_indices_set = {hist: set(indices) for hist, indices in hist_to_indices.items()}

    # Iterate over each row index with progress bar
    for idx in tqdm(df.index, desc="Generating triplets"):
        # Get the histogram for the current index
        anchor_hist = index_to_hist.get(idx)

        if anchor_hist is None:
            logging.warning(f"No histogram found for index {idx}. Skipping this index.")
            continue  # Skip to the next index

        # --- Positive Sampling ---
        positive_candidates = hist_to_indices_set.get(anchor_hist, set()).copy()
        positive_candidates.discard(idx)  # Remove the anchor index if present

        if len(positive_candidates) >= num_positives:
            # Sample without replacement
            positives = random.sample(list(positive_candidates), num_positives)
        elif len(positive_candidates) > 0:
            # Sample with replacement if not enough positives
            positives = list(positive_candidates)
            additional_needed = num_positives - len(positives)
            positives += random.choices(list(positive_candidates), k=additional_needed)
        else:
            # No positive candidates available
            # Strategy: Assign an empty list and log a warning
            positives = []
            logging.warning(f"No positive samples available for index {idx}.")

        df.at[idx, 'positive_indices'] = positives

        # --- Negative Sampling ---
        # Negative pool: all indices not sharing the same histogram
        negative_pool = all_indices_set - hist_to_indices_set.get(anchor_hist, set())

        # Ensure there are enough negatives
        required_negatives = num_negatives * num_negative_lists
        if len(negative_pool) < required_negatives:
            logging.warning(
                f"Not enough negative samples for index {idx}. "
                f"Required: {required_negatives}, Available: {len(negative_pool)}. "
                f"Skipping negative sampling for this index."
            )
            continue  # Skip negative sampling but retain positives

        # Convert negative_pool to a list for sampling
        negative_pool = list(negative_pool)

        # Generate negative lists
        negative_lists = []
        for _ in range(num_negative_lists):
            negs = random.sample(negative_pool, num_negatives)
            negative_lists.append(negs)

        df.at[idx, 'negative_indices'] = negative_lists

    return df

In [118]:
# Generate the triplets
df_train = generate_triplets_precomputed(
    df=df_train,
    index_to_hist=index_to_hist,
    hist_to_indices=hist_to_indices,
    num_positives=5,
    num_negatives=5,
    num_negative_lists=5
)

# Inspect the DataFrame
print(df_train[['identifier', 'smiles', 'positive_indices', 'negative_indices']])

Generating triplets: 100%|██████████| 194119/194119 [15:59<00:00, 202.34it/s]


                  identifier  \
0       MassSpecGymID0000001   
1       MassSpecGymID0000002   
2       MassSpecGymID0000003   
3       MassSpecGymID0000004   
4       MassSpecGymID0000005   
...                      ...   
194114  MassSpecGymID0414159   
194115  MassSpecGymID0414160   
194116  MassSpecGymID0414161   
194117  MassSpecGymID0414162   
194118  MassSpecGymID0414163   

                                                   smiles  \
0           CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
1           CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
2           CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
3           CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
4           CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
...                                                   ...   
194114  CC(=O)OC[C@@H]1[C@H](C(C([C@@H](O1)OC2=C(OC3=C...   
194115  C[C@@H]1[C@@H]([C@@H]([C@H]([C@@H](O1)OC2=CC(=...   
194116  CC1[C@@H](C([C@@H]([C@@H](O1)OC2[C@@H](C(O[C@H...   
19411

In [119]:

print(df_train[['identifier', 'smiles', 'positive_indices', 'negative_indices']])

                  identifier  \
0       MassSpecGymID0000001   
1       MassSpecGymID0000002   
2       MassSpecGymID0000003   
3       MassSpecGymID0000004   
4       MassSpecGymID0000005   
...                      ...   
194114  MassSpecGymID0414159   
194115  MassSpecGymID0414160   
194116  MassSpecGymID0414161   
194117  MassSpecGymID0414162   
194118  MassSpecGymID0414163   

                                                   smiles  \
0           CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
1           CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
2           CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
3           CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
4           CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
...                                                   ...   
194114  CC(=O)OC[C@@H]1[C@H](C(C([C@@H](O1)OC2=C(OC3=C...   
194115  C[C@@H]1[C@@H]([C@@H]([C@H]([C@@H](O1)OC2=CC(=...   
194116  CC1[C@@H](C([C@@H]([C@@H](O1)OC2[C@@H](C(O[C@H...   
19411