In [1]:
from matchms.importing import load_from_mgf
from mol2dreams.featurizer.featurize import MoleculeFeaturizer
from mol2dreams.utils.data import construct_triplets, pre_featurize_molecules
from mol2dreams.datasets.TripletDataset import TripletDataset
from mol2dreams.utils.parser import build_trainer_from_config

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

from torch_geometric.data import Batch


In [2]:
def load_mgf_with_folds(mgf_path):

    spectra = list(load_from_mgf(mgf_path))
    records = []
    for spec in spectra:
        record = spec.to_dict()
        records.append(record)
    df = pd.DataFrame(records)

    if 'fold' not in df.columns:
        raise ValueError("fold column is missing. Ensure the dataset has been split into train/val/test.")
    
    df.collision_energy = df.collision_energy.astype(float)
    df.parent_mass = df.parent_mass.astype(float)
    df.precursor_mz = df.precursor_mz.astype(float)
    
    return df


In [3]:
spectra_path = "../../data/data/MassSpecGym.mgf"
df = load_mgf_with_folds(spectra_path)
print(df.head())

             identifier                                         smiles  \
0  MassSpecGymID0000001  CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
1  MassSpecGymID0000002  CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
2  MassSpecGymID0000003  CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
3  MassSpecGymID0000004  CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   
4  MassSpecGymID0000005  CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC   

         inchikey    formula precursor_formula  parent_mass  precursor_mz  \
0  VFMQMACUYWGDOJ  C16H17NO4         C16H18NO4   287.115224      288.1225   
1  VFMQMACUYWGDOJ  C16H17NO4         C16H18NO4   287.115224      288.1225   
2  VFMQMACUYWGDOJ  C16H17NO4         C16H18NO4   287.115224      288.1225   
3  VFMQMACUYWGDOJ  C16H17NO4         C16H18NO4   287.115224      288.1225   
4  VFMQMACUYWGDOJ  C16H17NO4         C16H18NO4   287.115224      288.1225   

   adduct instrument_type  collision_energy   fold simulation_challenge  \
0  [M+H]+        

In [16]:
# Here is condition for
triplets_df = construct_triplets(df)
print(triplets_df.head())

Filtered dataset size: 11186 spectra
Constructed triplets for 11186 anchors.
                                       anchor_smiles             anchor_id  \
0  CC(C)[C@H]1C(=O)O[C@@H](C(=O)N([C@H](C(=O)O[C@...  MassSpecGymID0000261   
1  C[C@@H]1CC2=C(C=C(C(=C2C(=O)O1)O)C(=O)N[C@@H](...  MassSpecGymID0000740   
2  CC[C@H](C)C(=O)O[C@H]1CCC=C2[C@H]1[C@H]([C@H](...  MassSpecGymID0000882   
3       CC1=CC2=C(C(=C1)O)C(=O)C3=C(C2=O)C=C(C=C3O)O  MassSpecGymID0001133   
4  COC1=C2C3=C(C(=O)CC3)C(=O)OC2=C4[C@@H]5C=CO[C@...  MassSpecGymID0001358   

  positive_ids                                       negative_ids  
0           []  [MassSpecGymID0221255, MassSpecGymID0195900, M...  
1           []  [MassSpecGymID0218556, MassSpecGymID0218874, M...  
2           []  [MassSpecGymID0204692, MassSpecGymID0223906, M...  
3           []  [MassSpecGymID0226835, MassSpecGymID0092045, M...  
4           []  [MassSpecGymID0232803, MassSpecGymID0159715, M...  


In [5]:
triplets_df['num_positive'] = triplets_df['positive_ids'].apply(len)
triplets_df['num_negative'] = triplets_df['negative_ids'].apply(len)

# Compute value counts for positive_ids
positive_counts = triplets_df['num_positive'].value_counts().sort_index()
print("Positive IDs counts:")
print(positive_counts)

# Compute value counts for negative_ids
negative_counts = triplets_df['num_negative'].value_counts().sort_index()
print("\nNegative IDs counts:")
print(negative_counts)

# Compute unique compounds 
unique_compounds = triplets_df['anchor_smiles'].unique()
print("\nUnique compounds:")
print(len(unique_compounds))

Positive IDs counts:
num_positive
0    7467
1    1756
2    1137
3     440
4     190
5     196
Name: count, dtype: int64

Negative IDs counts:
num_negative
0        2
1       13
2       11
3        8
4       12
5    11140
Name: count, dtype: int64

Unique compounds:
9154


## Join with embedding

In [6]:
from mol2dreams.utils.data import prepare_datasets
from dreams.utils.data import MSData
from dreams.definitions import DREAMS_EMBEDDING

  from .autonotebook import tqdm as notebook_tqdm
Determination of memory status is not supported on this 
 platform, measuring for memoryleaks will never fail


In [7]:
hdf5_path = "../../data/data/MassSpecGym_DreaMS.hdf5"
msdata = MSData.from_hdf5(hdf5_path, prec_mz_col='precursor_mz')
embs = msdata[DREAMS_EMBEDDING]
embs.shape

(213548, 1024)

In [8]:
extra_features = ['COLLISION_ENERGY', 'adduct', 'precursor_mz', 'INSTRUMENT_TYPE']
# Prepare datasets
datasets = prepare_datasets(
    msdata=msdata, 
    embs=embs, 
    splits=['train', 'val'],  # Include 'test' if present
    smiles_col='smiles', 
    embedding_col='DreaMS_embedding', 
    fold_col='FOLD'
)

Processing split 'train' with 194119 samples.


Featurizing train: 100%|██████████| 194119/194119 [00:10<00:00, 18025.02it/s]


Processing split 'val' with 19429 samples.


Featurizing val: 100%|██████████| 19429/19429 [00:02<00:00, 7871.15it/s] 


In [9]:
train_data = datasets['train']

In [10]:
identifier_to_data = {entry['IDENTIFIER']: entry for entry in train_data}

In [12]:
# Anchors
anchor_ids = set(triplets_df['anchor_id'])

# Positives
positive_ids_set = set([pid for sublist in triplets_df['positive_ids'] for pid in sublist])

# Negatives
negative_ids_set = set([nid for sublist in triplets_df['negative_ids'] for nid in sublist])

# Union of all identifiers
all_triplet_ids = anchor_ids.union(positive_ids_set).union(negative_ids_set)

In [13]:
dataset_ids = set(identifier_to_data.keys())
identifiers_not_in_datasets = all_triplet_ids - dataset_ids

if identifiers_not_in_datasets:
    print(f"Identifiers in triplets not in datasets: {len(identifiers_not_in_datasets)}")
    print(identifiers_not_in_datasets)
else:
    print("All identifiers in triplets are present in datasets.")

All identifiers in triplets are present in datasets.


In [14]:
identifier_to_embedding = {entry['IDENTIFIER']: entry['embedding'] for entry in train_data}
identifier_to_smiles ={entry['IDENTIFIER']: entry['smiles'] for entry in train_data}

In [15]:
len(identifier_to_embedding), len(identifier_to_smiles)

(194119, 194119)

# Make TripletMargin Dataset

In [16]:
bond_config = {
    'features': {
        'bond_type': True,
        'conjugated': True,
        'in_ring': True,
        'stereochemistry': False,
    }
}

atom_config = {
    'features': {
        'atom_symbol': True,
        'total_valence': True,
        'aromatic': True,
        'hybridization': True,
        'formal_charge': True,
        'default_valence': True,
        'ring_size': True,        
        'hydrogen_count': True,
    },
    'feature_attributes': {
        'atom_symbol': {
            'top_n_atoms': 42,     
            'include_other': True,    
        },

    }
}
global_config = {
        'features': {
            'collision_energy': True,
            'adduct': True,
            'instrument_type': True,
            'precursor_mz': True
        }
}

spectrum_embedding_size = 1024 
featurizer = MoleculeFeaturizer(atom_config, bond_config, global_config=global_config, spectrum_embedding_size=spectrum_embedding_size)

In [18]:
featurized_molecules, failed_identifiers = pre_featurize_molecules(triplets_df, identifier_to_smiles, identifier_to_embedding, featurizer)

Total unique identifiers to featurize: 11186
Featurizing molecules...


Featurizing dataset: 100%|██████████| 11186/11186 [00:49<00:00, 224.17it/s]

Featurized 11186 molecules out of 11186





In [21]:
triplet_dataset = TripletDataset(triplets_df, featurized_molecules)

# Create the DataLoader
batch_size = 32
triplet_loader = DataLoader(
    triplet_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=triplet_dataset.collate_fn,
    num_workers=0
)

Total valid triplets: 3719


In [22]:
# save_path = '../../data/data/triplet_dataset.pt'
# torch.save(triplet_dataset, save_path)
# print(f"TripletDataset saved to {save_path}")

In [23]:
load_path = '../../data/data/triplet_dataset.pt'
triplet_dataset_loaded = torch.load(load_path)
print(f"TripletDataset loaded from {load_path}")

TripletDataset loaded from ../../data/data/triplet_dataset.pt


In [24]:
batch_size = 32
triplet_loader_loaded = DataLoader(
    triplet_dataset_loaded,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=triplet_dataset.collate_fn,
    num_workers=0
)

In [25]:
for first in triplet_loader_loaded:
    print(first)
    break

(DataBatch(x=[527, 84], edge_index=[2, 1092], edge_attr=[1092, 7], y=[32, 1024], identifier=[32], batch=[527], ptr=[33]), DataBatch(x=[527, 84], edge_index=[2, 1092], edge_attr=[1092, 7], y=[32, 1024], identifier=[32], batch=[527], ptr=[33]), DataBatch(x=[542, 84], edge_index=[2, 1130], edge_attr=[1130, 7], y=[32, 1024], identifier=[32], batch=[542], ptr=[33]))


# Train from config

In [26]:
import yaml

with open("/Users/macbook/CODE/mol2DreaMS/mol2dreams/configs/local_config_triplet_margin.yaml") as stream:
    config = yaml.safe_load(stream)

In [27]:
trainer = build_trainer_from_config(config)



In [28]:
trainer.train()

Epoch [1/10], Loss: 0.1789, Pos Dist: 7.8130, Neg Dist: 15.6266
Epoch [2/10], Loss: 0.0966, Pos Dist: 7.9670, Neg Dist: 20.3984
Validation Loss: 0.0017, Pos Dist: 0.0003, Neg Dist: 21.9667
Best model saved at epoch 2 with validation loss 0.0017
Model checkpoint saved at ../../data/logs/mol2dreams/20241009_164437_mol2dreams/model_epoch_2.pt
Epoch [3/10], Loss: 0.0750, Pos Dist: 8.0283, Neg Dist: 22.3910
Epoch [4/10], Loss: 0.0716, Pos Dist: 7.7800, Neg Dist: 24.4136
Validation Loss: 0.0013, Pos Dist: 0.0002, Neg Dist: 25.5010
Best model saved at epoch 4 with validation loss 0.0013
Model checkpoint saved at ../../data/logs/mol2dreams/20241009_164437_mol2dreams/model_epoch_4.pt
Epoch [5/10], Loss: 0.0812, Pos Dist: 7.2210, Neg Dist: 22.8026
Epoch [6/10], Loss: 0.0497, Pos Dist: 7.0394, Neg Dist: 24.5281
Validation Loss: 0.0017, Pos Dist: 0.0016, Neg Dist: 24.5976
Model checkpoint saved at ../../data/logs/mol2dreams/20241009_164437_mol2dreams/model_epoch_6.pt
Epoch [7/10], Loss: 0.0491, Po