### This file is meant to preprocess the mgf data for contrastive fine-tuning

In [2]:
from dreams.utils.data import MSData, evaluate_split
import numpy as np
import pandas as pd
from dreams.api import dreams_embeddings
from sklearn.metrics.pairwise import cosine_similarity
import umap
from dreams.utils.mols import formula_type
import seaborn as sns
import matplotlib.pyplot as plt
from dreams.utils.plots import init_plotting
from matplotlib.colors import LinearSegmentedColormap
from tqdm import tqdm
from rdkit import Chem

In [None]:
data_full = MSData.from_mgf('data/mgf_MoNA_experimental.mgf', in_mem=False)
print(data_full)
print(data_full.columns())

### Get meaningful subset of dataset

In [None]:
total_samples = 10000
SAMPLE_DATASET = f"data/MoNA_experimental_{total_samples}.hdf5"

spectrum_types = data_full['SPECTRUM_TYPE']
spectrum_types = np.array(spectrum_types)
unique_types, type_counts = np.unique(spectrum_types, return_counts=True)

print("Spectrum Types and their counts:")
for type, count in zip(unique_types, type_counts):
    print(f"{type}: {count}")


type_proportions = type_counts / len(spectrum_types)

samples_per_type = np.round(type_proportions * total_samples).astype(int)

sampled_indices = []

for spectrum_type, n_samples in zip(unique_types, samples_per_type):
    type_indices = np.where(spectrum_types == spectrum_type)[0]
    
    # If we have fewer spectra of this type than we want to sample, take all of them
    if len(type_indices) <= n_samples:
        sampled_indices.extend(type_indices)
    else:
        sampled_indices.extend(np.random.choice(type_indices, size=n_samples, replace=False))


# Ensure we have exactly total_samples samples
sampled_indices = sampled_indices[:total_samples]
sampled_spectra = data_full.get_spectra()[sampled_indices]
print(f"Total sampled indices: {len(sampled_indices)}")
print("Print 3 random samples (sanity check):")
print_indices = np.random.choice(sampled_indices, size=3, replace=False)
for i in print_indices:
    print(data_full.at(i))

data_short = data_full.form_subset(sampled_indices, SAMPLE_DATASET)
del data_short
data_short = MSData(SAMPLE_DATASET, mode='a')
data_short.add_column("ORIGINAL_INDEX", sampled_indices)
print("\nSampled dataset saved!\n")

### Create constrastive dataset according to paper and ContrastiveSpectraDataset class in data.py line 999

In [None]:
data_short_pd = data_short.to_pandas()
data_short_pd.columns

In [25]:
# reduced parameters from paper due to computational constraints
n_spectra=8000
n_unique_inchi=2000 
mass_diff=0.05

# Apply filters: [M+H]+ adducts and 60 eV collision energy
# how to filter by collision energy?
contrastive_df = data_short_pd[data_short_pd['charge'] == '[M+H]+']

# Group by first 14 characters of InChIKey
contrastive_df['INCHI_PREFIX'] = contrastive_df['INCHIKEY'].str[:14]
grouped = contrastive_df.groupby('INCHI_PREFIX')


# Sample InChI groups
sampled_groups = grouped.size().nlargest(n_unique_inchi).index

# Create dataset
dataset = []
for inchi in tqdm(sampled_groups):
    group = grouped.get_group(inchi)
    
    # Sample spectra from this group
    n_samples = min(len(group), max(1, int(n_spectra / len(sampled_groups))))
    sampled_spectra = group.sample(n=n_samples)
    
    for idx, ref in sampled_spectra.iterrows():
        # Positive examples: same InChI, different spectrum
        pos_candidates = group[group.index != idx]
        pos_idx = pos_candidates.index.tolist()
        
        # Negative examples: different InChI, similar mass
        neg_candidates = contrastive_df[
            (contrastive_df['INCHI_PREFIX'] != ref['INCHI_PREFIX']) & 
            (abs(contrastive_df['EXACTMASS'] - ref['EXACTMASS']) <= mass_diff)
        ]
        neg_idx = neg_candidates.index.tolist()
        
        if len(pos_idx) > 0 and len(neg_idx) > 0:
                entry = {
                    'index': idx,
                    'pos_idx': pos_idx,
                    'neg_idx': neg_idx,
                    'INCHI_PREFIX': inchi,
                    'PRECURSOR M/Z':ref['precursor_mz'],
                    'PARSED PEAKS': ref["spectrum"],
                    'CHARGE': ref['charge'],
                    'ROMol': Chem.MolFromSmiles(ref['smiles'])
                }
                # Add all original columns
                for col in contrastive_df.columns:
                    entry[col] = ref[col]
                dataset.append(entry)

# Create a new DataFrame from the dataset
contrastive_df = pd.DataFrame(dataset)

import dreams.utils.spectra as su 
contrastive_df = su.df_to_MSnSpectra(contrastive_df, as_new_column=True, assert_is_valid=False)

# Ensure we have the desired number of spectra
if len(contrastive_df) > n_spectra:
    contrastive_df = contrastive_df.sample(n=n_spectra)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  contrastive_df['INCHI_PREFIX'] = contrastive_df['INCHIKEY'].str[:14]
100%|██████████| 2000/2000 [00:03<00:00, 550.42it/s] 


In [28]:
import dreams.utils.spectra as su 
# add MSnSpectra as new column
contrastive_df['PRECURSOR M/Z'] = contrastive_df['precursor_mz']
contrastive_df['PARSED PEAKS'] = contrastive_df["spectrum"]
contrastive_df['CHARGE'] = contrastive_df['charge']
contrastive_df = su.df_to_MSnSpectra(contrastive_df, as_new_column=True, assert_is_valid=False)

### Create Murcko histogram split according to https://github.com/pluskal-lab/DreaMS/blob/main/tutorials/murcko_hist_split.ipynb

In [29]:
from dreams.algorithms.murcko_hist import murcko_hist
# some smiles are invalid?
def is_valid_smiles(smiles):
    if pd.isna(smiles) or smiles == 'n/a':
        return False
    mol = Chem.MolFromSmiles(smiles)
    return mol is not None


# Compute Murcko histograms

contrastive_df = contrastive_df[contrastive_df['smiles'].apply(is_valid_smiles)]
df_us = contrastive_df.drop_duplicates(subset=['smiles']).copy()

df_us['MurckoHist'] = df_us['smiles'].progress_apply(
    lambda x: murcko_hist.murcko_hist(Chem.MolFromSmiles(x))
)

# Convert dictionaries to strings for easier handling
df_us['MurckoHistStr'] = df_us['MurckoHist'].astype(str)
print('Num. unique smiles:', df_us['smiles'].nunique(), 'Num. unique Murcko histograms:', df_us['MurckoHistStr'].nunique())
print('Top 20 most common Murcko histograms:')
df_us['MurckoHistStr'].value_counts()[:20]

# Group by MurckoHistStr and aggregate
df_gb = df_us.groupby('MurckoHistStr').agg(
    count=('smiles', 'count'),
    smiles_list=('smiles', list)
).reset_index()

# Convert MurckoHistStr to MurckoHist
df_gb['MurckoHist'] = df_gb['MurckoHistStr'].apply(eval)

# Sort by 'n' in descending order and reset index
df_gb = df_gb.sort_values('count', ascending=False).reset_index(drop=True)

100%|██████████| 1088/1088 [00:00<00:00, 1757.66it/s]

Num. unique smiles: 1088 Num. unique Murcko histograms: 52
Top 20 most common Murcko histograms:





In [30]:
#Split the dataset into training and validation sets based on Murcko histograms

median_i = len(df_gb) // 2
cum_val_mols = 0
val_mols_frac = 0.15  # Approximately 15% of the molecules go to validation set
val_idx, train_idx = [], []

# Iterate from median to start, assigning molecules to train or val sets
for i in range(median_i, -1, -1):
    current_hist = df_gb.iloc[i]['MurckoHist']
    is_val_subhist = any(
        murcko_hist.are_sub_hists(current_hist, df_gb.iloc[j]['MurckoHist'], k=3, d=4)
        for j in val_idx
    )

    if is_val_subhist:
        train_idx.append(i)
    else:
        if cum_val_mols / len(df_us) <= val_mols_frac:
            cum_val_mols += df_gb.iloc[i]['count']
            val_idx.append(i)
        else:
            train_idx.append(i)

# Add remaining indices to train set
train_idx.extend(range(median_i + 1, len(df_gb)))
assert(len(train_idx) + len(val_idx) == len(df_gb))

# Map SMILES to their assigned fold
smiles_to_fold = {}
for i, row in df_gb.iterrows():
    fold = 'val' if i in val_idx else 'train'
    for smiles in row['smiles_list']:
        smiles_to_fold[smiles] = fold
contrastive_df['fold'] = contrastive_df['smiles'].map(smiles_to_fold)

# Display fold distributions
print('Distribution of spectra:')
display(contrastive_df['fold'].value_counts(normalize=True))
print('Distribution of smiles:')
display(contrastive_df.drop_duplicates(subset=['smiles'])['fold'].value_counts(normalize=True))

Distribution of spectra:


fold
train    0.75724
val      0.24276
Name: proportion, dtype: float64

Distribution of smiles:


fold
train    0.756434
val      0.243566
Name: proportion, dtype: float64

In [None]:
# Evaluate data leakage
eval_res = evaluate_split(contrastive_df, n_workers=4)
init_plotting(figsize=(3, 3))
sns.histplot(eval_res['val'], bins=100)
plt.xlabel('Max Tanimoto similarity to training set')
plt.ylabel('Num. validation set molecules')
plt.show()

In [31]:
# Store the dataset to pickle (required by train.py)
contrastive_df.to_pickle(f'data/MoNA_experimental_split_{n_spectra}.pkl')

Contrastive fine-tuning