In [1]:
import random
from cgitb import small

import tqdm
import numpy as np
import h5py
import torch
import pandas as pd
import matplotlib.pyplot as plt
from dreams.utils.data import MSData
from dreams.api import dreams_embeddings
from dreams.utils.plots import init_plotting
from dreams.utils.mols import formula_type
from dreams.definitions import DREAMS_EMBEDDING
from matchms.importing import load_from_mgf
from matchms.exporting import save_as_mgf
from mol2dreams.featurizer.featurize import MoleculeFeaturizer
from mol2dreams.utils.data import prepare_datasets
from mol2dreams.featurizer.atom_features import AtomFeaturizer
from mol2dreams.featurizer.bond_features import BondFeaturizer

  from cgitb import small
  from .autonotebook import tqdm as notebook_tqdm
Determination of memory status is not supported on this 
 platform, measuring for memoryleaks will never fail


In [2]:
hdf5_path = "../../data/data/MassSpecGym_DreaMS.hdf5"
msdata = MSData.from_hdf5(hdf5_path, prec_mz_col='precursor_mz')
embs = msdata[DREAMS_EMBEDDING]
embs.shape

(213548, 1024)

## Connect

In [3]:
bond_config = {
    'features': {
        'bond_type': True,
        'conjugated': True,
        'in_ring': True,
        'stereochemistry': False,
    }
}

atom_config = {
    'features': {
        'atom_symbol': True,
        'total_valence': True,
        'aromatic': True,
        'hybridization': True,
        'formal_charge': True,
        'default_valence': True,
        'ring_size': True,        
        'hydrogen_count': True,
    },
    'feature_attributes': {
        'atom_symbol': {
            'top_n_atoms': 42,     
            'include_other': True,    
        },

    }
}

## Computing embeddings

In [4]:
def extract_first_n_spectra(original_mgf_path, new_mgf_path, n=50):

    spectra = load_from_mgf(original_mgf_path)

    first_n_spectra = []
    for i, spectrum in enumerate(spectra):
        if i >= n:
            break
        first_n_spectra.append(spectrum)

    save_as_mgf(first_n_spectra, new_mgf_path)
    print(f"Extracted {len(first_n_spectra)} spectra and saved to {new_mgf_path}")

# # Example usage
# original_mgf = "../../data/data/MassSpecGym.mgf"
# new_mgf = "../../data/data/MassSpecGym_first50.mgf"
# extract_first_n_spectra(original_mgf, new_mgf, n=50)

In [17]:
new_mgf = "../../data/data/MassSpecGym_first50.mgf"
dreams_embs_50 = dreams_embeddings(new_mgf, prec_mz_col='PRECURSOR_MZ')

Computing DreaMS embedding: 100%|██████████| 50/50 [00:01<00:00, 43.97it/s]


## Preparing dataset

In [6]:
# def prepare_datasets(msdata, embs, splits=['train', 'val', 'test'], 
#                     smiles_col='smiles', embedding_col='DreaMS_embedding', fold_col='FOLD',
#                     extra_cols=['COLLISION_ENERGY', 'adduct', 'precursor_mz']):
#     """
#     Prepares train, validation, and test datasets from MSData and embeddings, ensuring unique IDENTIFIERs.
# 
#     Args:
#         msdata (MSData): The MSData object loaded from HDF5.
#         embs (np.ndarray or torch.Tensor): Embeddings matrix with shape [num_samples, embedding_size].
#         splits (list of str): List of fold names to extract. Default is ['train', 'valid', 'test'].
#         smiles_col (str): Column name for SMILES strings in the DataFrame. Default is 'smiles'.
#         embedding_col (str): Column name for embeddings in the DataFrame. Default is 'DreaMS_embedding'.
#         fold_col (str): Column name for fold attribute in the DataFrame. Default is 'FOLD'.
#         extra_cols (list of str): Additional columns to include in the datasets.
# 
#     Returns:
#         dict: A dictionary where keys are split names and values are lists of dictionaries 
#               with 'smiles', 'embedding', and extra attributes.
#     """
#     # Convert msdata to pandas DataFrame
#     df = msdata.to_pandas()
#     
#     # Check alignment
#     num_rows = df.shape[0]
#     if embs.shape[0] != num_rows:
#         raise ValueError(f"Number of embeddings ({embs.shape[0]}) does not match number of data samples ({num_rows}).")
#     
#     # Assign embeddings to the DataFrame
#     # Ensure embeddings are numpy arrays
#     if isinstance(embs, torch.Tensor):
#         embs = embs.numpy()
#     elif not isinstance(embs, np.ndarray):
#         embs = np.array(embs)
#     
#     df[embedding_col] = list(embs)
#     
#     datasets = {split: [] for split in splits}
#     
# 
#     for split in splits:
#         split_df = df[df[fold_col] == split].reset_index(drop=True)
#         print(f"Processing split '{split}' with {len(split_df)} samples.")
#         
#         for idx, row in tqdm.tqdm(split_df.iterrows(), total=split_df.shape[0], desc=f"Featurizing {split}"):
#             identifier = row.get('IDENTIFIER', None)
#             smiles = row.get(smiles_col, None)
#             embedding = row.get(embedding_col, None)
#             
#             extra_attrs = {col: row.get(col, None) for col in extra_cols}
#             
#             if pd.isna(smiles):
#                 print(f"Skipping row {idx} due to missing SMILES.")
#                 continue
#             if pd.isna(identifier):
#                 print(f"Skipping row {idx} due to missing IDENTIFIER.")
#                 continue
#             if embedding is None or len(embedding) != embs.shape[1]:
#                 print(f"Skipping row {idx} due to invalid embedding.")
#                 continue
#             
#             datasets[split].append({
#                 'IDENTIFIER': identifier,
#                 'smiles': smiles,
#                 'embedding': embedding,
#                 **extra_attrs
#             })
#     
#     unique_folds = df[fold_col].unique()
#     additional_folds = set(unique_folds) - set(splits)
#     for split in additional_folds:
#         split_df = df[df[fold_col] == split].reset_index(drop=True)
#         print(f"Processing additional split '{split}' with {len(split_df)} samples.")
#         datasets[split] = []
#         for idx, row in tqdm(split_df.iterrows(), total=split_df.shape[0], desc=f"Featurizing {split}"):
#             identifier = row.get('IDENTIFIER', None)
#             smiles = row.get(smiles_col, None)
#             embedding = row.get(embedding_col, None)
#             
#             extra_attrs = {col: row.get(col, None) for col in extra_cols}
# 
#             if pd.isna(smiles):
#                 print(f"Skipping row {idx} in split '{split}' due to missing SMILES.")
#                 continue
#             if pd.isna(identifier):
#                 print(f"Skipping row {idx} in split '{split}' due to missing IDENTIFIER.")
#                 continue
#             if embedding is None or len(embedding) != embs.shape[1]:
#                 print(f"Skipping row {idx} in split '{split}' due to invalid embedding.")
#                 continue
#             
#             datasets[split].append({
#                 'IDENTIFIER': identifier,
#                 'smiles': smiles,
#                 'embedding': embedding,
#                 **extra_attrs
#             })
#     
#     return datasets

In [7]:
# hdf5_path = "../../data/data/MassSpecGym_DreaMS.hdf5"
# msdata = MSData.from_hdf5(hdf5_path, prec_mz_col='precursor_mz')

In [8]:
msdata.columns()

['COLLISION_ENERGY',
 'DreaMS_embedding',
 'FOLD',
 'FORMULA',
 'IDENTIFIER',
 'INCHIKEY',
 'INSTRUMENT_TYPE',
 'PARENT_MASS',
 'PRECURSOR_FORMULA',
 'SIMULATION_CHALLENGE',
 'adduct',
 'precursor_mz',
 'smiles',
 'spectrum']

In [9]:
embs = msdata[DREAMS_EMBEDDING]  
extra_features = ['COLLISION_ENERGY', 'adduct', 'precursor_mz']
# Prepare datasets
datasets = prepare_datasets(
    msdata=msdata, 
    embs=embs, 
    splits=['train', 'val'],  # Include 'test' if present
    smiles_col='smiles', 
    embedding_col='DreaMS_embedding', 
    fold_col='FOLD'
)


Processing split 'train' with 194119 samples.


Featurizing train: 100%|██████████| 194119/194119 [00:03<00:00, 53170.64it/s]


Processing split 'val' with 19429 samples.


Featurizing val: 100%|██████████| 19429/19429 [00:00<00:00, 51265.86it/s]


In [10]:
# datasets

In [11]:
# datasets['val']

In [12]:
small_dataset = {}
small_dataset['valid'] = datasets['train'][:50]
small_dataset['valid'] = datasets['val'][:50]

In [13]:
spectrum_embedding_size = 1024 
featurizer = MoleculeFeaturizer(atom_config, bond_config, spectrum_embedding_size)

data_list_train = featurizer.featurize_dataset(
    small_dataset['valid'], include_extra_attr=True)

data_list_valid = featurizer.featurize_dataset(
    small_dataset['valid'], include_extra_attr=True)

print(f"Number of successfully featurized training molecules: {len(data_list_train)}")
print(f"Number of successfully featurized validation molecules: {len(data_list_valid)}")

Featurizing dataset: 100%|██████████| 50/50 [00:00<00:00, 145.01it/s]
Featurizing dataset: 100%|██████████| 50/50 [00:00<00:00, 146.77it/s]

Number of successfully featurized training molecules: 50
Number of successfully featurized validation molecules: 50





In [14]:
data_list_train

[Data(x=[64, 84], edge_index=[2, 128], edge_attr=[128, 7], y=[1, 1024], IDENTIFIER=[1], COLLISION_ENERGY=[1, 1], adduct=[1], precursor_mz=[1, 1]),
 Data(x=[64, 84], edge_index=[2, 128], edge_attr=[128, 7], y=[1, 1024], IDENTIFIER=[1], COLLISION_ENERGY=[1, 1], adduct=[1], precursor_mz=[1, 1]),
 Data(x=[64, 84], edge_index=[2, 128], edge_attr=[128, 7], y=[1, 1024], IDENTIFIER=[1], COLLISION_ENERGY=[1, 1], adduct=[1], precursor_mz=[1, 1]),
 Data(x=[64, 84], edge_index=[2, 128], edge_attr=[128, 7], y=[1, 1024], IDENTIFIER=[1], COLLISION_ENERGY=[1, 1], adduct=[1], precursor_mz=[1, 1]),
 Data(x=[64, 84], edge_index=[2, 128], edge_attr=[128, 7], y=[1, 1024], IDENTIFIER=[1], COLLISION_ENERGY=[1, 1], adduct=[1], precursor_mz=[1, 1]),
 Data(x=[64, 84], edge_index=[2, 128], edge_attr=[128, 7], y=[1, 1024], IDENTIFIER=[1], COLLISION_ENERGY=[1, 1], adduct=[1], precursor_mz=[1, 1]),
 Data(x=[42, 84], edge_index=[2, 96], edge_attr=[96, 7], y=[1, 1024], IDENTIFIER=[1], COLLISION_ENERGY=[1, 1], adduct=

# Try torch dataloader 

In [15]:
from torch_geometric.loader import DataLoader

batch_size_train = 32
batch_size_valid = 32

loader_train = DataLoader(
    data_list_train, 
    batch_size=batch_size_train, 
    shuffle=True, 
    num_workers=1  
)

loader_valid = DataLoader(
    data_list_valid, 
    batch_size=batch_size_valid, 
    shuffle=False, 
    num_workers=1  
)


for batch in loader_train:
    print(batch)
    print(f"Batch size: {batch.num_graphs}")  
    print(f"Node feature shape: {batch.x.shape}")          
    print(f"Edge index shape: {batch.edge_index.shape}")  
    print(f"Edge feature shape: {batch.edge_attr.shape}") 
    print(f"Spectrum embedding shape: {batch.y.shape}")  
    break  

DataBatch(x=[1555, 84], edge_index=[2, 3164], edge_attr=[3164, 7], y=[32, 1024], IDENTIFIER=[32], COLLISION_ENERGY=[32, 1], adduct=[32], precursor_mz=[32, 1], batch=[1555], ptr=[33])
Batch size: 32
Node feature shape: torch.Size([1555, 84])
Edge index shape: torch.Size([2, 3164])
Edge feature shape: torch.Size([3164, 7])
Spectrum embedding shape: torch.Size([32, 1024])
