In [1]:
from massspecgym.data.transforms import MolFingerprinter, MolToInChIKey, MolToFormulaVector
from massspecgym.data.datasets import MSnDataset
from massspecgym.featurize import SpectrumFeaturizer
from massspecgym.data.data_module import MassSpecDataModule

In [2]:
config = {
    'features': ['collision_energy', 'ionmode', 'adduct', 'spectrum_stats', 'atom_counts', 'value', "retention_time", 'ion_source', 'binned_peaks'],
    'feature_attributes': {
        'atom_counts': {
            'top_n_atoms': 12,
            'include_other': True,
        },
    },
}

In [3]:
featurizer = SpectrumFeaturizer(config, mode='torch')

In [4]:
# fingerprinter = MolFingerprinter()
# mol_transform = MolToFormulaVector()
# mol_transform = MolToInChIKey()
mol_transform = MolFingerprinter()
msn_dataset = MSnDataset(pth="/Users/macbook/CODE/Majer:MassSpecGym/data/MSn/20240929_msn_library_pos_all_lib_MSn.mgf",
                         mol_transform=mol_transform,
                         featurizer=featurizer,
                         max_allowed_deviation=0.005)
print(len(msn_dataset))

16476


In [7]:
msn_dataset[0] 

{'spec_tree': Data(x=[14, 1039], edge_index=[2, 13]),
 'mol': tensor([0., 0., 0.,  ..., 0., 0., 0.])}

## Test MassSpecDataModule

In [5]:
BATCH_SIZE = 12

In [6]:
data_module = MassSpecDataModule(
    dataset=msn_dataset,
    batch_size=BATCH_SIZE,
    num_workers=0,
    split_pth="/Users/macbook/CODE/Majer:MassSpecGym/data/MSn/20240929_split.tsv"
)

In [7]:
data_module.prepare_data()
data_module.setup()

In [8]:
data_module.setup("test")

In [9]:
train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()
test_loader = data_module.test_dataloader()

print(f"Number of training samples: {len(data_module.train_dataset)}")
print(f"Number of validation samples: {len(data_module.val_dataset)}")
print(f"Number of test samples: {len(data_module.test_dataset)}")
print()
print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")
print(f"Number of test batches: {len(test_loader)}")

Number of training samples: 12550
Number of validation samples: 1938
Number of test samples: 1988

Number of training batches: 1046
Number of validation batches: 162
Number of test batches: 166


In [10]:
for batch in test_loader:
    print(batch['spec_tree'])
    print(batch['mol'].shape)
    break 

DataBatch(x=[104, 1039], edge_index=[2, 92], batch=[104], ptr=[13])
torch.Size([12, 2048])


In [11]:
import torch

In [12]:
for batch in test_loader:
    print(batch['spec_tree'])
    print(f"Type of batch['mol']: {type(batch['mol'])}")
    if isinstance(batch['mol'], torch.Tensor):
        print(f"Shape of batch['mol']: {batch['mol'].shape}")
    elif isinstance(batch['mol'], list):
        print(f"batch['mol'] is a list with length {len(batch['mol'])}")
        print(f"First element type: {type(batch['mol'][0])}")
    else:
        print("batch['mol'] is of an unexpected type.")
    break

DataBatch(x=[104, 1039], edge_index=[2, 92], batch=[104], ptr=[13])
Type of batch['mol']: <class 'torch.Tensor'>
Shape of batch['mol']: torch.Size([12, 2048])


In [13]:
batch["mol"]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.]])