In [1]:
from massspecgym.data.transforms import MolFingerprinter, MolToInChIKey, MolToFormulaVector
from massspecgym.data.datasets import MSnDataset
from massspecgym.featurize import SpectrumFeaturizer
from massspecgym.data.data_module import MassSpecDataModule

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = {
    'features': ['collision_energy', 'ionmode', 'adduct', 'spectrum_stats', 'atom_counts', 'value', "retention_time", 'ion_source', 'binned_peaks'],
    'feature_attributes': {
        'atom_counts': {
            'top_n_atoms': 12,
            'include_other': True,
        },
    },
}

In [3]:
featurizer = SpectrumFeaturizer(config, mode='torch')

In [4]:
mol_transform = MolFingerprinter()
msn_dataset = MSnDataset(pth="/Users/macbook/CODE/Majer:MassSpecGym/data/MSn/20240929_msn_library_pos_all_lib_MSn.mgf",
                         mol_transform=mol_transform,
                         featurizer=featurizer,
                         max_allowed_deviation=0.005)
print(len(msn_dataset))

16476


In [5]:
msn_dataset[0]

{'spec': Data(x=[14, 1039], edge_index=[2, 13]),
 'mol': tensor([0., 0., 0.,  ..., 0., 0., 0.])}

In [6]:
BATCH_SIZE = 12
data_module = MassSpecDataModule(
    dataset=msn_dataset,
    batch_size=BATCH_SIZE,
    num_workers=0,
    split_pth="/Users/macbook/CODE/Majer:MassSpecGym/data/MSn/20240929_split.tsv"
)

In [8]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer

from massspecgym.models.base import Stage
from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel
from torch_geometric.nn import GCNConv, global_mean_pool

# Define your custom model
class MyGNNRetrievalModel(RetrievalMassSpecGymModel):
    def __init__(
        self,
        hidden_channels: int = 128,
        out_channels: int = 2048,  # fingerprint size
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        
        # Assuming node features have the dimension specified in your data (e.g., 1039)
        node_feature_dim = 1039  # Adjust based on your actual node feature size

        # GNN layers
        self.conv1 = GCNConv(in_channels=node_feature_dim, out_channels=hidden_channels)
        self.conv2 = GCNConv(in_channels=hidden_channels, out_channels=hidden_channels)

        # Readout and prediction layers
        self.fc = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, out_channels),
            nn.Sigmoid()
        )

    def forward(self, data):
        # data is a batch from PyG DataLoader, containing 'x', 'edge_index', 'batch', etc.
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        x = torch.relu(x)
        # Global pooling to get graph-level representation
        x = global_mean_pool(x, batch)
        x = self.fc(x)
        return x

    def step(self, batch: dict, stage: Stage) -> dict:
        # Unpack inputs
        data = batch['spec']  # PyG DataBatch
        fp_true = batch['mol']     # True fingerprints

        # Predict fingerprint
        fp_pred = self.forward(data)

        # Calculate loss
        loss = nn.functional.mse_loss(fp_pred, fp_true)

        # Log loss
        self.log(f"{stage.to_pref()}loss", loss, on_step=True, on_epoch=True, prog_bar=True)

        return {'loss': loss}

# Instantiate the data module (already done in your code)
# Assuming 'data_module' is your MassSpecDataModule instance

# Instantiate the model
model = MyGNNRetrievalModel(out_channels=2048)

# Initialize the Trainer
trainer = pl.Trainer(accelerator="cpu", devices=1, max_epochs=1, log_every_n_steps=1)

# Train and validate the model
trainer.fit(model, datamodule=data_module)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/macbook/UTILS/anaconda3/envs/phantoms_env/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.

  | Name  | Type       | Params | Mode 
---------------------------------------------
0 | conv1 | GCNConv    | 133 K  | train
1 | conv2 | GCNConv    | 16.5 K | train
2 | fc    | Sequential | 280 K  | train
---------------------------------------------
430 K     Trainable params
0         Non-trainable params
430 K     Total params
1.721     Total estimated model params size (MB)
11        Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00, 93.09it/s]

/Users/macbook/UTILS/anaconda3/envs/phantoms_env/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/macbook/UTILS/anaconda3/envs/phantoms_env/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 114. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


KeyError: 'scores'

In [17]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer

from massspecgym.models.base import MassSpecGymModel, Stage
from torch_geometric.nn import GCNConv, global_mean_pool

class MyGNNModel(MassSpecGymModel):  # Inherit from MassSpecGymModel
    def __init__(
        self,
        hidden_channels: int = 128,
        out_channels: int = 2048,  # fingerprint size
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        
        # Adjust node_feature_dim based on your actual node feature size
        node_feature_dim = 1039  # Replace with your node feature size

        # GNN layers
        self.conv1 = GCNConv(in_channels=node_feature_dim, out_channels=hidden_channels)
        self.conv2 = GCNConv(in_channels=hidden_channels, out_channels=hidden_channels)

        # Readout and prediction layers
        self.fc = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, out_channels),
            nn.Sigmoid()
        )

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        print(f"x shape: {x.shape}")             # Node features
        print(f"edge_index shape: {edge_index.shape}")  # Edge indices
        print(f"batch shape: {batch.shape}")     # Batch indices
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        x = torch.relu(x)
        x = global_mean_pool(x, batch)
        print(f"Pooled x shape: {x.shape}")      # Graph-level representation
        x = self.fc(x)
        print(f"Output x shape: {x.shape}")      # Predicted fingerprint
        return x

    def step(self, batch: dict, stage: Stage) -> dict:
        # Unpack inputs
        data = batch['spec']  # PyG DataBatch
        fp_true = batch['mol']     # True fingerprints

        # Predict fingerprint
        fp_pred = self.forward(data)

        # Calculate loss
        loss = nn.functional.mse_loss(fp_pred, fp_true)

        return {'loss': loss}

    def on_batch_end(self, outputs: dict, batch: dict, batch_idx: int, stage: Stage) -> None:
        # Log the loss
        self.log(
            f"{stage.to_pref()}loss",
            outputs['loss'],
            batch_size=batch['mol'].size(0),
            sync_dist=True,
            prog_bar=True,
        )
            # Calculate and log cosine similarity between predicted and true fingerprints
        fp_pred = self.forward(batch['spec'])
        fp_true = batch['mol']
        cos_sim = nn.functional.cosine_similarity(fp_pred, fp_true).mean()
        self.log(
            f"{stage.to_pref()}cos_sim",
            cos_sim,
            batch_size=batch['mol'].size(0),
            sync_dist=True,
            prog_bar=True,
        )
        # You can add custom evaluation metrics here if needed

In [16]:
# Instantiate the model
model = MyGNNModel(out_channels=2048)

# Initialize the Trainer
trainer = pl.Trainer(accelerator="cpu", devices=1, max_epochs=2, log_every_n_steps=1)

# Train and validate the model
trainer.fit(model, datamodule=data_module)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/macbook/UTILS/anaconda3/envs/phantoms_env/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.

  | Name  | Type       | Params | Mode 
---------------------------------------------
0 | conv1 | GCNConv    | 133 K  | train
1 | conv2 | GCNConv    | 16.5 K | train
2 | fc    | Sequential | 280 K  | train
---------------------------------------------
430 K     Trainable params
0         Non-trainable params
430 K     Total params
1.721     Total estimated model params size (MB)
11        Modules in train mode
0         Modules in eval mode


                                                                            

/Users/macbook/UTILS/anaconda3/envs/phantoms_env/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/macbook/UTILS/anaconda3/envs/phantoms_env/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 1046/1046 [00:09<00:00, 115.79it/s, v_num=5, train_loss=0.0182, train_cos_sim=0.393]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/162 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/162 [00:00<?, ?it/s][A
Validation DataLoader 0:   1%|          | 1/162 [00:00<00:00, 233.97it/s][A
Validation DataLoader 0:   1%|          | 2/162 [00:00<00:00, 219.93it/s][A
Validation DataLoader 0:   2%|▏         | 3/162 [00:00<00:00, 199.19it/s][A
Validation DataLoader 0:   2%|▏         | 4/162 [00:00<00:00, 196.32it/s][A
Validation DataLoader 0:   3%|▎         | 5/162 [00:00<00:00, 194.43it/s][A
Validation DataLoader 0:   4%|▎         | 6/162 [00:00<00:00, 195.82it/s][A
Validation DataLoader 0:   4%|▍         | 7/162 [00:00<00:00, 198.33it/s][A
Validation DataLoader 0:   5%|▍         | 8/162 [00:00<00:00, 193.92it/s][A
Validation DataLoader 0:   6%|▌         | 9/162 [00:00<00:00, 192.16it/s][A
Validation DataLoader 0:

`Trainer.fit` stopped: `max_epochs=4` reached.


Epoch 3: 100%|██████████| 1046/1046 [00:09<00:00, 107.33it/s, v_num=5, train_loss=0.0183, train_cos_sim=0.461, val_loss=0.0183, val_cos_sim=0.481]
