In [18]:
from massspecgym.data.datasets import MSnDataset, MassSpecDataset
from massspecgym.data.transforms import MolFingerprinter, SpecTokenizer
from massspecgym.data import MassSpecDataModule
from massspecgym.featurize import SpectrumFeaturizer
import torch

In [3]:
file_mgf = "/Users/macbook/CODE/Majer:MassSpecGym/data/MSn/20241211_msn_library_pos_all_lib_MSn.mgf"
file_json = "/Users/macbook/CODE/Majer:MassSpecGym/data/Retrieval/MassSpecGym_retrieval_candidates_mass.json"
split_file = "/Users/macbook/CODE/Majer:MassSpecGym/data/MSn/20241211_split.tsv"
pth_massspecgym_original = "/Users/macbook/CODE/Majer:MassSpecGym/data/MassSpecGym/MassSpecGym.tsv"

In [6]:
# Init hyperparameters
n_peaks = 60
fp_size = 4096
batch_size = 12

# Load dataset
dataset_original = MassSpecDataset(
    pth=pth_massspecgym_original,
    spec_transform=SpecTokenizer(n_peaks=n_peaks),
    mol_transform=MolFingerprinter(fp_size=fp_size),
)

In [8]:
# Init data module
data_module_original = MassSpecDataModule(
    dataset=dataset_original,
    batch_size=batch_size,
    num_workers=0,
)

In [9]:
data_module_original.prepare_data()
data_module_original.setup()

train_loader_original = data_module_original.train_dataloader()

Train dataset size: 194119
Val dataset size: 19429


In [10]:
tmp_original = []
for batch in train_loader_original:
    print(batch)
    tmp_original = batch
    break

{'spec': tensor([[[2.4421e+02, 1.1000e+00],
         [3.9022e+01, 1.8018e-02],
         [4.1038e+01, 1.8018e-02],
         ...,
         [0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00]],

        [[2.5510e+02, 1.1000e+00],
         [6.5039e+01, 5.0536e-04],
         [6.7055e+01, 3.7217e-04],
         ...,
         [0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00]],

        [[2.8306e+02, 1.1000e+00],
         [1.8902e+02, 1.8000e-02],
         [2.1105e+02, 6.0000e-03],
         ...,
         [0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00]],

        ...,

        [[3.0811e+02, 1.1000e+00],
         [6.5039e+01, 5.7057e-02],
         [9.1054e+01, 1.7918e-01],
         ...,
         [0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00]],

        [[2.8202e+02, 1.1000e+00],
         [7.2080e+01, 3.0030e-03],
   

In [11]:
# Check the keys in the batch
print(tmp_original.keys())

# Specifically check if 'batch_ptr' is present
if 'batch_ptr' in tmp_original:
    print("batch_ptr is present:", tmp_original['batch_ptr'])
else:
    print("batch_ptr is missing")

dict_keys(['spec', 'mol', 'precursor_mz', 'adduct', 'mol_freq', 'identifier'])
batch_ptr is missing


In [13]:
tmp_original['spec'].shape, tmp_original['mol'].shape, 

(torch.Size([12, 61, 2]), torch.Size([12, 4096]))

In [14]:
tmp_original['precursor_mz']

tensor([244.2060, 255.1020, 283.0600, 347.1700, 297.1849, 271.1300, 205.1180,
        217.0500, 279.1591, 308.1104, 282.0214, 325.0700])

In [15]:
tmp_original['adduct']

['[M+H]+',
 '[M+H]+',
 '[M+Na]+',
 '[M+Na]+',
 '[M+H]+',
 '[M+Na]+',
 '[M+H]+',
 '[M+H]+',
 '[M+H]+',
 '[M+H]+',
 '[M+H]+',
 '[M+Na]+']

In [16]:
tmp_original['mol_freq']

tensor([ 14.,   8.,  74., 405.,  54.,   8.,   3., 180.,  46.,  45.,  26., 507.])

In [17]:
tmp_original['identifier']

['MassSpecGymID0057420',
 'MassSpecGymID0145355',
 'MassSpecGymID0151226',
 'MassSpecGymID0003159',
 'MassSpecGymID0025429',
 'MassSpecGymID0166389',
 'MassSpecGymID0401869',
 'MassSpecGymID0091234',
 'MassSpecGymID0067582',
 'MassSpecGymID0053953',
 'MassSpecGymID0042022',
 'MassSpecGymID0124214']

# MSn

In [19]:
config = {
    'features': ['collision_energy', 'ionmode', 'adduct', 'spectrum_stats', 'atom_counts', 'value', "retention_time", 'ion_source', 'binned_peaks'],
    'feature_attributes': {
        'atom_counts': {
            'top_n_atoms': 12,
            'include_other': True,
        },
    },
}

In [20]:
featurizer = SpectrumFeaturizer(config, mode='torch')

In [21]:
# Init hyperparameters
n_peaks = 60
fp_size = 4096
batch_size = 12

msn_dataset = MSnDataset(
    pth=file_mgf,
    featurizer=featurizer,
    mol_transform=MolFingerprinter(fp_size=fp_size),
    max_allowed_deviation=0.005
)

In [22]:
data_module_msn = MassSpecDataModule(
    dataset=msn_dataset,
    batch_size=batch_size,
    split_pth=split_file,
    num_workers=0,
)

In [23]:
data_module_msn.prepare_data()
data_module_msn.setup()

train_loader_msn = data_module_msn.train_dataloader()

Train dataset size: 12536
Val dataset size: 1952


In [24]:
tmp_msn = []
for batch in train_loader_msn:
    print(batch)
    tmp_msn = batch
    break

{'spec': DataBatch(x=[134, 1039], edge_index=[2, 122], batch=[134], ptr=[13]), 'mol': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.]]), 'precursor_mz': tensor([ 435.1485, 1065.3082,  466.1205,  620.1131,  303.0499,  865.4427,
         319.1666,  405.1113,  291.1856,  449.1078,  488.2310,  418.2350]), 'adduct': ['[M+H]+', '[M+H]+', '[M+H]+', '[M+H]+', '[M+H]+', '[M+H]+', '[M+H]+', '[M+H]+', '[M+H]+', '[M+H]+', '[M+H]+', '[M+H]+'], 'identifier': ['0024427_0000000', '0003980_0000000', '0084088_0000000', '0096920_0000000', '0004159_0000000', '0001584_0000000', '0039720_0000000', '0044544_0000000', '0080851_0000000', '0059397_0000000', '0062342_0000000', '0091534_0000000'], 'mol_freq': tensor([1., 1., 1., 1., 1., 2., 1., 1., 1., 1., 1., 1.])}


In [25]:
# Check the keys in the batch
print(tmp_msn.keys())

# Specifically check if 'batch_ptr' is present
if 'batch_ptr' in tmp_msn:
    print("batch_ptr is present:", tmp_msn['batch_ptr'])
else:
    print("batch_ptr is missing")

dict_keys(['spec', 'mol', 'precursor_mz', 'adduct', 'identifier', 'mol_freq'])
batch_ptr is missing


In [27]:
tmp_msn['spec']

DataBatch(x=[134, 1039], edge_index=[2, 122], batch=[134], ptr=[13])

In [28]:
tmp_msn['precursor_mz']

tensor([ 435.1485, 1065.3082,  466.1205,  620.1131,  303.0499,  865.4427,
         319.1666,  405.1113,  291.1856,  449.1078,  488.2310,  418.2350])

In [29]:
tmp_msn['adduct']

['[M+H]+',
 '[M+H]+',
 '[M+H]+',
 '[M+H]+',
 '[M+H]+',
 '[M+H]+',
 '[M+H]+',
 '[M+H]+',
 '[M+H]+',
 '[M+H]+',
 '[M+H]+',
 '[M+H]+']

In [30]:
tmp_msn['mol_freq']

tensor([1., 1., 1., 1., 1., 2., 1., 1., 1., 1., 1., 1.])

In [31]:
tmp_msn['identifier']

['0024427_0000000',
 '0003980_0000000',
 '0084088_0000000',
 '0096920_0000000',
 '0004159_0000000',
 '0001584_0000000',
 '0039720_0000000',
 '0044544_0000000',
 '0080851_0000000',
 '0059397_0000000',
 '0062342_0000000',
 '0091534_0000000']