In [2]:
import pytorch_lightning as pl
from phantoms.models.retrieval.gnn_retrieval_model import GNNRetrievalModel
from massspecgym.data import RetrievalDataset, MassSpecDataModule
from massspecgym.data.transforms import MolFingerprinter
from massspecgym.featurize import SpectrumFeaturizer
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from massspecgym.data.datasets import MSnRetrievalDataset
import yaml
import os

In [17]:
# Optional: Set random seed for reproducibility
pl.seed_everything(42)

# Load configuration
config_path = '/Users/macbook/CODE/PhantoMS/phantoms/models/retrieval/configs/config_retrieval.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

experiment_name = config.get('experiment_name', 'default_experiment')

Seed set to 42


In [4]:
# Initialize Featurizer
featurizer = SpectrumFeaturizer(config['featurizer'], mode='torch')

In [5]:
# Initialize Dataset
dataset_msn = MSnRetrievalDataset(
    pth=config['data']['file_mgf'],
    candidates_pth=config['data']['file_json'],
    featurizer=featurizer,
    mol_transform=MolFingerprinter(fp_size=config['model']['fp_size']),
    max_allowed_deviation=config['data']['max_allowed_deviation']
)

Total valid indices: 15674
Dataset length: 15674


In [6]:
# Initialize DataModule
data_module_msn = MassSpecDataModule(
    dataset=dataset_msn,
    batch_size=config['data']['batch_size'],
    split_pth=config['data']['split_file'],
    num_workers=config['data']['num_workers']
)

In [7]:
# Initialize Model
model = GNNRetrievalModel(
    hidden_channels=config['model']['hidden_channels'],
    out_channels=config['model']['fp_size'],
    node_feature_dim=config['model']['node_feature_dim'],
)

In [18]:
# Initialize Logger
logger = TensorBoardLogger(
    save_dir=config['logs']['dir'],
    name=config['logs']['name'],
    version=experiment_name  # Use experiment_name as the version
)

In [19]:
wandb_logger = WandbLogger(
    project=config['wandb']['project'],
    entity=config['wandb']['entity'],
    name=experiment_name,  # Name the run with experiment_name
    log_model="all"  # Log all models (optional)
)

In [24]:
# Define experiment-specific checkpoint directory
checkpoint_dir = os.path.join(config['trainer']['checkpoint_dir'], experiment_name)
os.makedirs(checkpoint_dir, exist_ok=True)

In [25]:
# Initialize Trainer
trainer = pl.Trainer(
    accelerator=config['trainer']['accelerator'],
    devices=config['trainer']['devices'],
    max_epochs=config['trainer']['max_epochs'],
    logger=[logger, wandb_logger],
    log_every_n_steps=config['trainer']['log_every_n_steps'],
    limit_train_batches=config['trainer']['limit_train_batches'],
    limit_val_batches=config['trainer']['limit_val_batches'],
    limit_test_batches=config['trainer']['limit_test_batches'],
    callbacks=[
        pl.callbacks.ModelCheckpoint(
            monitor=config['trainer']['checkpoint_monitor'],
            save_top_k=config['trainer']['save_top_k'],
            mode=config['trainer']['checkpoint_mode'],
            dirpath=checkpoint_dir,
            filename='gnn_retrieval-{epoch:02d}-{val_loss:.2f}'
        ),
        pl.callbacks.LearningRateMonitor(logging_interval='step')
    ]
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/macbook/UTILS/anaconda3/envs/phantoms_env/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


In [26]:
# Train the model
trainer.fit(model, datamodule=data_module_msn)

/Users/macbook/UTILS/anaconda3/envs/phantoms_env/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/macbook/CODE/PhantoMS/notebooks/checkpoints/experiment_1 exists and is not empty.

   | Name              | Type          | Params | Mode 
-------------------------------------------------------------
0  | gcn1              | GCNLayer      | 133 K  | eval 
1  | gcn2              | GCNLayer      | 16.5 K | eval 
2  | head              | RetrievalHead | 544 K  | eval 
3  | loss_fn           | MSELoss       | 0      | eval 
4  | val_hit_rate@1    | MeanMetric    | 0      | train
5  | val_hit_rate@5    | MeanMetric    | 0      | train
6  | val_hit_rate@20   | MeanMetric    | 0      | train
7  | val_mces@1        | MeanMetric    | 0      | train
8  | train_hit_rate@1  | MeanMetric    | 0      | train
9  | train_hit_rate@5  | MeanMetric    | 0      | train
10 | train_hit_rate@20 | MeanMetric    | 0      | train
11 | train_mces@1      |

Train dataset size: 11883
Val dataset size: 1899
Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/macbook/UTILS/anaconda3/envs/phantoms_env/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


in collate_fn
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  7.53it/s]in collate_fn
                                                                           

/Users/macbook/UTILS/anaconda3/envs/phantoms_env/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/macbook/UTILS/anaconda3/envs/phantoms_env/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0:   0%|          | 0/2 [00:00<?, ?it/s] in collate_fn
Epoch 0:  50%|█████     | 1/2 [00:07<00:07,  0.13it/s, v_num=fosx]in collate_fn
Epoch 0: 100%|██████████| 2/2 [00:18<00:00,  0.11it/s, v_num=fosx, train_loss=0.295]
Validation: |          | 0/? [00:00<?, ?it/s][Ain collate_fn

Validation:   0%|          | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|█████     | 1/2 [00:00<00:00, 11.42it/s][Ain collate_fn

Validation DataLoader 0: 100%|██████████| 2/2 [00:10<00:00,  0.18it/s][A
Epoch 1:   0%|          | 0/2 [00:00<?, ?it/s, v_num=fosx, train_loss=0.290, val_loss=0.289]        in collate_fn
in collate_fn                                                              ss=0.290, val_loss=0.289]
Epoch 1: 100%|██████████| 2/2 [00:24<00:00,  0.08it/s, v_num=fosx, train_loss=0.279, val_loss=0.289]
Validation: |          | 0/? [00:00<?, ?it/s][Ain collate_fn

Validation:   0%|          | 0/2 [00:00<?, ?it/s][

`Trainer.fit` stopped: `max_epochs=7` reached.


Epoch 6: 100%|██████████| 2/2 [00:43<00:00,  0.05it/s, v_num=fosx, train_loss=0.227, val_loss=0.227]


In [27]:
# Test the model
trainer.test(model, datamodule=data_module_msn)

/Users/macbook/UTILS/anaconda3/envs/phantoms_env/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Test dataset size: 1892
Testing: |          | 0/? [00:00<?, ?it/s]in collate_fn
Testing DataLoader 0:  50%|█████     | 1/2 [00:00<00:00, 11.94it/s]in collate_fn
Testing DataLoader 0: 100%|██████████| 2/2 [00:11<00:00,  0.17it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_hit_rate@1                0.0
    test_hit_rate@20       0.009999999776482582
     test_hit_rate@5                0.0
        test_loss           0.22103449702262878
       test_mces@1          25.790000915527344
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.22103449702262878,
  'test_hit_rate@1': 0.0,
  'test_hit_rate@5': 0.0,
  'test_hit_rate@20': 0.009999999776482582,
  'test_mces@1': 25.790000915527344}]

In [1]:
# Load the model from a specific checkpoint
from phantoms.models.retrieval.gnn_retrieval_model import GNNRetrievalModel

checkpoint_path = "/Users/macbook/CODE/PhantoMS/notebooks/checkpoints/gnn_retrieval-epoch=04-val_loss=0.31.ckpt"
model = GNNRetrievalModel.load_from_checkpoint(checkpoint_path)

  from .autonotebook import tqdm as notebook_tqdm
