In [6]:
import os
import torch
import typing as T
import numpy as np
import pandas as pd
from pandarallel import pandarallel
import seaborn as sns
from tqdm import tqdm
tqdm.pandas()
import matplotlib.pyplot as plt
from rdkit import Chem
from torch.utils.data.dataloader import default_collate
from rdkit.Chem import DataStructs
from torch_geometric.data import Data
from massspecgym.data.datasets import MassSpecDataset, MSnDataset
from massspecgym.data.data_module import MassSpecDataModule
from massspecgym.data.transforms import MolFingerprinter
from massspecgym.utils import (
    morgan_fp, init_plotting, parse_paths_from_df, find_duplicate_smiles, add_identifiers, 
    visualize_MSn_tree, smiles_to_scaffold, train_val_test_split, create_split_file
    )
%reload_ext autoreload
%autoreload 2

In [116]:
import torch
from torch.utils.data.dataloader import default_collate
import typing as T
import numpy as np
from torch_geometric.data import Data
import torch.nn.functional as F


class MSnDataset(MassSpecDataset):
    def __init__(self, pth=None, dtype=torch.float32, mol_transform=None):
        # load dataset using the parent class
        super().__init__(pth=pth)
        self.metadata = self.metadata[self.metadata["spectype"] == "ALL_ENERGIES"]

        # TODO: add identifiers (and split?) to the mgf file
        # add identifiers to the metadata
        self.metadata = add_identifiers(self.metadata)
        self.identifiers = np.unique(self.metadata["identifier"])

        self.mol_transform = mol_transform

        # get paths from the metadata
        dataset_all_tree_paths = parse_paths_from_df(self.metadata)

        # generate trees from paths and convert their paths to tensors
        self.trees = []
        self.tree_tensors = []
        self.smiles = []

        for smi, root, paths in dataset_all_tree_paths:
            tree = Tree(root)
            for path in paths:
                tree.add_path(path)

            self.trees.append(tree)
            self.tree_tensors.append(torch.tensor(tree.get_edges()))
            self.smiles.append(smi)
        
        # Find the maximum length of the trees
        max_length = max(tensor.size(0) for tensor in self.tree_tensors)

        # Pad all tree tensors to the maximum length
        self.padded_tree_tensors = [pad_tensor(tensor, max_length) for tensor in self.tree_tensors]

    def __len__(self):
        return len(self.padded_tree_tensors)

    def __getitem__(self, idx, transform_mol=True):
        spec_tree_tensor = self.padded_tree_tensors[idx]
        smi = self.smiles[idx]

        mol = self.mol_transform(smi) if transform_mol and self.mol_transform else smi
        if isinstance(mol, np.ndarray):
            mol = torch.as_tensor(mol, dtype=self.dtype)

        item = {"spec_tree": spec_tree_tensor, "mol": mol}
        return item

    @staticmethod
    def pad_tensor(tensor: torch.Tensor, length: int, pad_value: float = 0) -> torch.Tensor:
        """
        Pad a tensor to a specified length with a given pad value using torch.nn.functional.pad.
        """
        # Calculate the amount of padding needed
        pad_amount = length - tensor.size(0)

        # Create the padding configuration
        pad = (0, 0) * (tensor.dim() - 1) + (0, pad_amount)

        # Apply padding
        padded_tensor = F.pad(tensor, pad, value=pad_value)
        return padded_tensor

    @staticmethod
    def collate_fn(batch):
        """
        Custom collate function to handle the outputs of __getitem__.
        """
        spec_tree_tensors = [item['spec_tree'] for item in batch]
        mols = [item['mol'] for item in batch]

        # Stack the padded spec_trees and mols into tensors
        batch_spec_trees = torch.stack(spec_tree_tensors)
        batch_mols = default_collate(mols)

        return {'spec_tree': batch_spec_trees, 'mol': batch_mols}

In [128]:
fingerprinter = MolFingerprinter()
msn_dataset = MSnDataset(pth="debug.mgf", mol_transform=fingerprinter.from_smiles)
scaffolds = [smiles_to_scaffold(smi) for smi in msn_dataset.smiles]
train, validation, test = train_val_test_split(msn_dataset.smiles, scaffolds)
split_tsv_path = "debug_splits.tsv"
create_split_file(msn_dataset, train, validation, test, split_tsv_path)
len(msn_dataset)

split tsv file already exists at debug_splits.tsv


285

In [132]:
mspec_data_module = MassSpecDataModule(batch_size=1, dataset=msn_dataset, split_pth="debug_splits.tsv")
mspec_data_module.prepare_data()
mspec_data_module.setup()
mspec_data_module.train_dataloader()
mspec_data_module.val_dataloader()

foo = mspec_data_module.val_dataloader()
for i in foo:
    print(i["spec_tree"].shape)
    print(i["mol"].shape)

torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1, 22, 2])
torch.Size([1, 2048])
torch.Size([1

IndexError: list index out of range

In [30]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer

from massspecgym.data import RetrievalDataset, MassSpecDataModule
from massspecgym.data.transforms import SpecTokenizer, MolFingerprinter
from massspecgym.models.base import Stage
from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel

In [118]:
class MyDeepSetsRetrievalModel(RetrievalMassSpecGymModel):
    def __init__(
        self,
        hidden_channels: int = 512,
        out_channels: int = 4096,  # fingerprint size
        *args,
        **kwargs
    ):
        """Implement your architecture."""
        super().__init__(*args, **kwargs)

        self.phi = nn.Sequential(
            nn.Linear(2, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
        )
        self.rho = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, out_channels),
            nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Implement your prediction logic."""
        x = self.phi(x)
        x = x.sum(dim=-2)  # sum over peaks
        x = self.rho(x)
        return x

    def step(
        self, batch: dict, stage: Stage
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Implement your custom logic of using predictions for training and inference."""
        # Unpack inputs
        x = batch["spec"]  # input spectra
        fp_true = batch["mol"]  # true fingerprints
        # cands = batch["candidates"]  # candidate fingerprints concatenated for a batch
        # batch_ptr = batch["batch_ptr"]  # number of candidates per sample in a batch

        # Predict fingerprint
        fp_pred = self.forward(x)

        # Calculate loss
        loss = nn.functional.mse_loss(fp_true, fp_pred)

        # Calculate final similarity scores between predicted fingerprints and retrieval candidates
        # fp_pred_repeated = fp_pred.repeat_interleave(batch_ptr, dim=0)
        # scores = nn.functional.cosine_similarity(fp_pred_repeated, cands)

        # return dict(loss=loss, scores=scores)
        return dict(loss=loss, scores=None)

In [141]:
# Init hyperparameters
fp_size = 2048

# Init data module
data_module = mspec_data_module

# Init model
model = MyDeepSetsRetrievalModel(out_channels=fp_size)

# Init trainer
trainer = Trainer(accelerator="cpu", devices=1)

# Train
trainer.fit(model, datamodule=data_module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
c:\Users\Lukas\anaconda3\envs\massspecgym\Lib\site-packages\pytorch_lightning\loops\utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.

  | Name | Type       | Params | Mode 
--------------------------------------------
0 | phi  | Sequential | 264 K  | train
1 | rho  | Sequential | 1.3 M  | train
--------------------------------------------
1.6 M     Trainable params
0         Non-trainable params
1.6 M     Total params
6.310     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\Lukas\anaconda3\envs\massspecgym\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


TypeError: RetrievalMassSpecGymModel.evaluate_retrieval_step() missing 2 required positional arguments: 'labels' and 'batch_ptr'