In [1]:
import json
from matchms.importing import load_from_mgf
from rdkit import Chem
from massspecgym.tools.analyzers import analyze_canonical_smiles
import os

In [2]:
file_mgf = "/Users/macbook/CODE/Majer:MassSpecGym/data/MSn/20240929_msn_library_pos_all_lib_MSn.mgf"
file_json = "/Users/macbook/CODE/Majer:MassSpecGym/data/Retrieval/MassSpecGym_retrieval_candidates_mass.json"
split_file = "/Users/macbook/CODE/Majer:MassSpecGym/data/MSn/20240929_split.tsv"

In [10]:
def canonicalize_smiles(smiles):
    """
    Canonicalize a SMILES string using RDKit.
    Returns the canonical SMILES or None if invalid.
    """
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        return Chem.MolToSmiles(mol, canonical=True)
    else:
        return None

In [17]:
# def analyze_canonical_smiles(data, mode='spectra'):
#     """
#     Processes SMILES strings from either filtered spectra or JSON data,
#     extracting original and canonical SMILES sets, and computing relevant statistics.
# 
#     Parameters:
#     - data (list or dict): 
#         - If mode='spectra', a list of Spectrum objects.
#         - If mode='json', a dictionary with SMILES strings as keys.
#     - mode (str): Mode of processing. Either 'spectra' or 'json'.
# 
#     Returns:
#     - original_smiles_set (set): SMILES strings as extracted from the input.
#     - canonical_smiles_set (set): Canonicalized SMILES strings.
#     - invalid_smiles (set): SMILES strings that could not be canonicalized.
#     """
#     # Initialize sets to store SMILES
#     original_smiles_set = set()
#     canonical_smiles_set = set()
#     invalid_smiles = set()
# 
#     # Validate mode
#     if mode not in ['spectra', 'json']:
#         raise ValueError("Invalid mode. Choose 'spectra' or 'json'.")
# 
#     if mode == 'spectra':
#         for spectrum in data:
#             # Extract SMILES from spectrum metadata
#             smiles = spectrum.metadata.get("SMILES") or spectrum.metadata.get("smiles")
#             if smiles:
#                 original_smiles_set.add(smiles)
#                 # Canonicalize SMILES using RDKit
#                 mol = Chem.MolFromSmiles(smiles)
#                 if mol:
#                     canonical_smiles = Chem.MolToSmiles(mol, canonical=True)
#                     canonical_smiles_set.add(canonical_smiles)
#                 else:
#                     invalid_smiles.add(smiles)
#     elif mode == 'json':
#         if not isinstance(data, dict):
#             raise TypeError("For mode 'json', data must be a dictionary with SMILES as keys.")
#         for smiles in data.keys():
#             original_smiles_set.add(smiles)
#             # Canonicalize SMILES using RDKit
#             mol = Chem.MolFromSmiles(smiles)
#             if mol:
#                 canonical_smiles = Chem.MolToSmiles(mol, canonical=True)
#                 canonical_smiles_set.add(canonical_smiles)
#             else:
#                 invalid_smiles.add(smiles)
# 
#     # Compute statistics
#     total_smiles = len(original_smiles_set)
#     unique_original = len(original_smiles_set)
#     unique_canonical = len(canonical_smiles_set)
#     num_invalid = len(invalid_smiles)
#     # Intersection counts how many original SMILES are already canonical
#     intersection_count = len(original_smiles_set.intersection(canonical_smiles_set))
# 
#     print("=== SMILES Processing Statistics ===")
#     print(f"Mode: {mode.upper()}")
#     print(f"Total SMILES extracted: {total_smiles}")
#     print(f"Unique original SMILES: {unique_original}")
#     print(f"Unique canonical SMILES: {unique_canonical}")
#     print(f"Number of invalid SMILES: {num_invalid}")
#     print(f"Number of SMILES unchanged after canonicalization: {intersection_count}")
#     print("====================================\n")
# 
#     return original_smiles_set, canonical_smiles_set, invalid_smiles

In [11]:
print("Loading spectra from MGF file...")
spectra = list(load_from_mgf(file_mgf))
print(f"Total number of spectra loaded: {len(spectra)}")

Loading spectra from MGF file...
Total number of spectra loaded: 803405


In [12]:
print("Filtering spectra with SPECTYPE=ALL_ENERGIES and MS_LEVEL=2...")
filtered_spectra = [
    s for s in spectra
    if s.metadata.get("spectype") == "ALL_ENERGIES" and int(s.metadata["ms_level"]) == 2
]
print(f"Number of spectra after filtering: {len(filtered_spectra)}")

Filtering spectra with SPECTYPE=ALL_ENERGIES and MS_LEVEL=2...
Number of spectra after filtering: 16476


In [18]:
analyze_canonical_smiles(filtered_spectra)

=== SMILES Processing Statistics ===
Mode: SPECTRA
Total SMILES extracted: 13984
Unique original SMILES: 13984
Unique canonical SMILES: 13984
Number of invalid SMILES: 0
Number of SMILES unchanged after canonicalization: 6427



({'CC(C)c1nn(C(=O)NC(C)(C)C)c(=O)n1N',
  'CC(=O)N[C@@H](CCC(N)=O)C(=O)O',
  'CNc1nc2nc(C)c(-c3cc(NC(=O)NCCC(C)(C)C)c(F)cc3C)cc2cn1',
  'CNc1nc(-c2ncc(Cl)c(C(=O)N3CCN(Cc4ccc(Cl)cc4)CC3)c2)cs1',
  'O=C(Nc1c(Cl)ccnc1)N1CCN(Cc2cc3c(cc2)OC(F)(F)O3)CC1',
  'CC(=O)c1nnc(-c2cccs2)n1Nc1ccc(I)cc1',
  'C[C@H]1[C@H](C)CC[C@]2(C(=O)OC3OC(CO)C(O)C(O)C3O)CC[C@]3(C(=O)O)C(=CCC4[C@@]5(C)CC[C@H](O[C@@H]6O[C@@H](C)[C@H](OC7OC(CO)C(O)C(O)C7O)[C@@H](O)[C@H]6O)C(C)(C)C5CC[C@]43C)[C@H]12',
  'COc1cccc2[n+](C)c3ccccc3nc12',
  'CN1CCN(c2cc3c(cc2)nc(-c2cc4c(cc2)nc(-c2cc(O)ccc2)[nH]4)[nH]3)CC1',
  'CC[C@@H](C)Nc1nc(C)nc2c(-c3c(C)nc(OC)cc3)c(C)nn12',
  'Nc1c2c(-c3ccc(Oc4ccccc4)cc3)nn(C3CCCC3)c2ncn1',
  'O=C1C=CC(O)(CC(=O)OC[C@H]2O[C@@H](OC(=O)CC3(O)C=CC(=O)C=C3)[C@H](OC(=O)Cc3ccccc3)[C@@H](O)[C@@H]2O)C=C1',
  'NC(=O)c1cccc(-c2c(O)ccc(OC(=O)NC3CCCCC3)c2)c1',
  'COc1c2cc3c(c1)CCN(C)[C@H]3Cc1ccc(cc1)Oc1c(O)ccc(c1)C[C@@H]1c3c(O2)c(OC)c(OC)cc3CCN1C',
  'COc1c(OC)cc(C(=O)OC2C3C=COC(OC4OC(CO)C(O)C(O)C4O)C3C3(CO)OC23)cc1

In [13]:
print("Extracting and canonicalizing SMILES from filtered spectra...")
smiles_set = set()
invalid_smiles_mgf = set()

for spectrum in filtered_spectra:
    smiles = spectrum.get("smiles")
    if smiles:
        canonical_smiles = canonicalize_smiles(smiles)
        if canonical_smiles:
            smiles_set.add(canonical_smiles)
        else:
            invalid_smiles_mgf.add(smiles)

Extracting and canonicalizing SMILES from filtered spectra...


In [14]:
if len(invalid_smiles_mgf) > 0:
    print(f"Number of invalid SMILES skipped from MGF: {len(invalid_smiles_mgf)}")
else:
    print("No valid SMILES skipped from MGF")

No valid SMILES skipped from MGF


In [19]:
with open(file_json, 'r') as f:
    smiles_dict = json.load(f)

In [20]:
analyze_canonical_smiles(smiles_dict, mode='json')

=== SMILES Processing Statistics ===
Mode: JSON
Total SMILES extracted: 32010
Unique original SMILES: 32010
Unique canonical SMILES: 32010
Number of invalid SMILES: 0
Number of SMILES unchanged after canonicalization: 1447



({'CC1=CC(=O)C2=C(O1)C=C3C(=C2O)C(=O)C(=C(C3=O)OC)C4=C(C(=O)C5=CC6=C(C(=O)C=C(O6)C)C(=C5C4=O)O)OC',
  'C1CC(=O)OC1CC2=CC(=C(C=C2)O)O',
  'C1=CC=C(C=C1)COC2=C(C=C(C=C2)CCN)OCC3=CC=CC=C3',
  'COC1=CC(=CC(=C1OC)OC)C2=NN=C(O2)SCC3=CC=C(C=C3)C(=O)OC',
  'CC(=CCC1=C(C=CC(=C1)C2CC(=O)C3=C(O2)C(=C(C=C3)O)CC=C(C)C)O)C',
  'COC1=CC(=CC(=C1C(=O)/C=C/C2=CC=CC=C2)O)O',
  'CCC1(C2=C(COC1=O)C(=O)N3CC4=CC5=C(C=CC=C5OC)N=C4C3=C2)O',
  'CC1=CC(=O)OC2=C1C=CC3=C2C(=CO3)NC4=CC(=CC=C4)[N+](=O)[O-]',
  'CCCCCCCCCCCCCCCCCC(=O)N[C@@H](CO)[C@@H]([C@@H](CCCCCCCCCCCCCC)O)O',
  'C[C@@]12CCC3[C@@](C1CC=C4[C@]2(CC[C@@]5(C4CC(CC5)(C)CO)C(=O)O[C@H]6[C@@H]([C@H]([C@@H]([C@H](O6)CO)O)O)O)C)(C[C@@H]([C@@H]([C@@]3(C)CO)O[C@H]7[C@@H]([C@H]([C@@H](CO7)O)O)O)O)C',
  'CC1=CC=CC=C1CN2C=CC=C(C2=O)C(=O)NC3=CC(=NC4=CC=CC=C43)C',
  'C1=C(C=C(C(=C1Cl)N2C(=C(C(=N2)C#N)SC(F)(F)F)N)Cl)C(F)(F)F',
  'CC(C)(C)OC(=O)C1=CC2=CC3=C4C(=C2OC1=O)CCCN4CCC3',
  'C[C@H](CCCC(C)C(=O)O)[C@H]1CC[C@@H]2[C@@]1(CC[C@H]3C2=CC[C@@H]4[C@@]3(CCC(=O)C4)C)C',

In [15]:

json_keys_set = set()
invalid_smiles_json = set()

for key in smiles_dict.keys():
    canonical_key = canonicalize_smiles(key)
    if canonical_key:
        json_keys_set.add(canonical_key)
    else:
        invalid_smiles_json.add(key)
if len(invalid_smiles_json) > 0:
    print(f"Number of invalid SMILES skipped from JSON: {len(invalid_smiles_json)}")
else:
    print("No valid SMILES skipped from JSON")


No valid SMILES skipped from JSON


In [16]:
print("Comparing SMILES from MGF with JSON keys...")
smiles_in_json = smiles_set.intersection(json_keys_set)
smiles_not_in_json = smiles_set.difference(json_keys_set)

all_present = len(smiles_not_in_json) == 0

print("\n--- Comparison Results ---")
if all_present:
    print("All SMILES from the filtered MGF file are present in the JSON file.")
else:
    print(f"Not all SMILES from the filtered MGF file are present in the JSON file.")
    print(f"Number of SMILES present in JSON: {len(smiles_in_json)}")
    print(f"Number of SMILES NOT present in JSON: {len(smiles_not_in_json)}")

print("\n--- Detailed Summary ---")
print(f"Total SMILES extracted from MGF: {len(smiles_set)}")
print(f"Total SMILES in JSON: {len(json_keys_set)}")
print(f"SMILES present in JSON: {len(smiles_in_json)}")
print(f"SMILES not present in JSON: {len(smiles_not_in_json)}")

Comparing SMILES from MGF with JSON keys...

--- Comparison Results ---
Not all SMILES from the filtered MGF file are present in the JSON file.
Number of SMILES present in JSON: 12786
Number of SMILES NOT present in JSON: 1198

--- Detailed Summary ---
Total SMILES extracted from MGF: 13984
Total SMILES in JSON: 32010
SMILES present in JSON: 12786
SMILES not present in JSON: 1198


# MSnRetrieval

In [3]:
from massspecgym.data.transforms import MolFingerprinter, MolToInChIKey, MolToFormulaVector
from massspecgym.data.datasets import MSnDataset, MSnRetrievalDataset
from massspecgym.featurize import SpectrumFeaturizer
from massspecgym.data.data_module import MassSpecDataModule

In [4]:
config = {
    'features': ['collision_energy', 'ionmode', 'adduct', 'spectrum_stats', 'atom_counts', 'value', "retention_time", 'ion_source', 'binned_peaks'],
    'feature_attributes': {
        'atom_counts': {
            'top_n_atoms': 12,
            'include_other': True,
        },
    },
}

In [5]:
featurizer = SpectrumFeaturizer(config, mode='torch')

In [None]:
# Instantiate the dataset
mol_transform = MolFingerprinter(fp_size=2048)
msn_retrieval_dataset = MSnRetrievalDataset(
    pth=file_mgf,
    mol_transform=mol_transform,
    featurizer=featurizer,
    candidates_pth=file_json,
    max_allowed_deviation=0.005
)


In [7]:
# Initialize the data module
data_module = MassSpecDataModule(
    dataset=msn_retrieval_dataset,
    batch_size=12,
    num_workers=0,
    split_pth=split_file
)

In [8]:
data_module.prepare_data()
data_module.setup()

train_loader = data_module.train_dataloader()

In [9]:
# Test the data loader
for batch in train_loader:
    print(batch['spec'])  # PyG Batch object
    print(f"batch['mol'] shape: {batch['mol'].shape}")  # Should be [batch_size, fp_size]
    print(f"batch['candidates'] shape: {batch['candidates'].shape}")  # [total_candidates, fp_size]
    print(f"batch['labels'] shape: {batch['labels'].shape}")  # [total_candidates]
    print(f"batch['batch_ptr']: {batch['batch_ptr']}")  # [batch_size]
    break

IndexError: list index out of range