In [1]:
#Load the packages
#import torch
#import torch.nn as nn
import os
from lightning.pytorch import Trainer, seed_everything #https://lightning.ai/docs/pytorch/stable/common/trainer.html
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping #3
from lightning.pytorch.loggers import TensorBoardLogger #3
from maldi_zsl_edit.data import MALDITOFDataModule #1
from maldi_zsl_edit.models import ZSLClassifier #2
import random
from torch import manual_seed as torch_manual_seed
from numpy.random import seed as np_random_seed
from torch.cuda import is_available as torch_cuda_is_available
from torch.cuda import manual_seed as torch_cuda_manual_seed
from torch.cuda import manual_seed_all as torch_cuda_manual_seed_all
#import h5py
#import numpy as np

In [2]:
#for reproduce the run
def set_seed(seed):
    seed_everything(seed, workers=True)
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np_random_seed(seed)
    torch_manual_seed(seed)
    if torch_cuda_is_available():
        torch_cuda_manual_seed(seed)
        torch_cuda_manual_seed_all(seed)
set_seed(4)

Global seed set to 4


In [3]:
#Load the data set
data_path = "Data/zsl_SINAwasabi.h5t"
dm = MALDITOFDataModule( #Personalized lightning data modules
    data_path, #The old has problems on split
    zsl_mode = True, # False: multi-class CLF, True: ZSL
    split_index = 1, # 0 for not general eva 1 for general eva, 2 for general with no seen classes at validation
    batch_size = 512, # important hyperparameter
    n_workers = 2, # you can leave this always if you are not CPU limited
    in_memory = True, # you can leave this always if memory is no problems
    general = True # False: Regular ZSL (only val species), True:General ZSL (val+train species)
    )
dm.setup(None)

#Check?
#if True:
#    batch = next(iter(dm.train_dataloader()))
#    print(batch.keys())
#    print(batch['seq_ohe'].shape)


# Training

In [22]:
#Now there should be a batch instance ["seq_ohe]", replace the batch["seq"] with it in the models file
n_species = 623 #batch['strain'].shape[0] #Number the seq considered for the train, #The batch should be 623 (463 of training and 160 of val, the rest 165 are on test)
t_species = 463
model = ZSLClassifier(
    embed_dim=1024, #520
    cnn_kwargs= { #specify the parameters to buld the CNN () #[The limit is 64,128]
        'conv_sizes' : [64,128], #[32, 64, 128] Out chanels of the convolutions #On the nlp mode the first is an embeding dimension
        'hidden_sizes' : [0], #MLP: [512, 256]. If [0] then goes directly from conv to embeding layer
        'blocks_per_stage' : 2, #How many residual blocks are applied before the pooling
        'kernel_size' : 7,
        'dropout' : 0.2,
        'mode': "max", #max or mean
    },
    n_classes = n_species,
    t_classes = t_species,
    lr=1e-3, # important to tune
    weight_decay=0, # this you can keep constant
    lr_decay_factor=1.00, # this you can keep constant
    warmup_steps=250, # this you can keep constant
)

In [23]:
model

ZSLClassifier(
  (spectrum_embedder): MLPEmbedding(
    (net): Sequential(
      (0): Linear(in_features=6000, out_features=512, bias=True)
      (1): GELU(approximate='none')
      (2): Dropout(p=0.2, inplace=False)
      (3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (4): Linear(in_features=512, out_features=256, bias=True)
      (5): GELU(approximate='none')
      (6): Dropout(p=0.2, inplace=False)
      (7): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (8): Linear(in_features=256, out_features=1024, bias=True)
      (9): GELU(approximate='none')
      (10): Dropout(p=0.2, inplace=False)
      (11): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (12): Linear(in_features=1024, out_features=1024, bias=True)
      (13): GELU(approximate='none')
      (14): Dropout(p=0.2, inplace=False)
      (15): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (16): Linear(in_features=1024, out_features=1024, bias=True)
    )
  )
  (seq_embed

In [24]:
#Save and monitor training with tensor board
from datetime import datetime
timenow = datetime.now()
strtime = timenow.strftime("%Y-%m-%d_%H-%M-%S")

val_ckpt = ModelCheckpoint(monitor="val_acc", mode="max")
callbacks = [val_ckpt, EarlyStopping(monitor="val_acc", patience=25, mode="max")]
logger = TensorBoardLogger("logs", name="zsl_train", version=strtime) # Ctrl+Shift+P # Main folder where the training is saved and the name for the training

#Training specification
trainer = Trainer(
    min_epochs= 50,
    max_epochs = 125, 
    accelerator='gpu', 
    strategy='auto',
    callbacks=callbacks,
    logger=logger,
    devices=[0]) #You can define epochs and training devices (look on documentation)



GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [25]:
#Load from check point
sure = False
if sure:
    from torch import load as torch_load
    model = torch_load('Models/Model_1')
    checkpath = 'logs/2024-08-07_13-29-32/checkpoints/epoch=77-step=2340.ckpt'
    checkpoint = torch_load(checkpath)
    
    for name, param in checkpoint['state_dict'].items():
        print(f"Key: {name}, Shape: {param.shape}")
    for name, param in model.state_dict().items():
        print(f"Key: {name}, Shape: {param.shape}")
    model.state_dict().keys() == checkpoint['state_dict'].keys()

    model.load_state_dict(checkpoint['state_dict'])

    trainer = Trainer(resume_from_checkpoint=checkpath)#To resume the training


In [26]:
#Start training
trainer.fit(model, dm.train_dataloader(), dm.val_dataloader()) #Important: normally you can use only dm, but here we specify as the dim of a are different for train and val 
#Note: The model object specify what is considered an input values and what is considered an input/output value during the training on the training step method

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]

  | Name              | Type               | Params
---------------------------------------------------------
0 | spectrum_embedder | MLPEmbedding       | 5.6 M 
1 | seq_embedder      | CNNEmbedding       | 437 K 
2 | accuracy          | MulticlassAccuracy | 0     
3 | accuracy2         | MulticlassAccuracy | 0     
4 | top5_accuracy     | MulticlassAccuracy | 0     
---------------------------------------------------------
6.0 M     Trainable params
0         Non-trainable params
6.0 M     Total params
24.038    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  3.10it/s]
Epoch 0 - Train loss: NA, Train accu: 0.0
Epoch 0 - Val loss: 13.60734748840332, Val accu: 0.0

                                                                           

  rank_zero_warn(
  rank_zero_warn(


Epoch 0: 100%|██████████| 30/30 [00:15<00:00,  1.96it/s, v_num=1-06]
Epoch 0 - Train loss: 17.094280242919922, Train accu: 0.0020833334419876337
Epoch 0 - Val loss: 6.6367411613464355, Val accu: 0.0

Epoch 1: 100%|██████████| 30/30 [00:15<00:00,  1.94it/s, v_num=1-06]
Epoch 1 - Train loss: 7.635072231292725, Train accu: 0.00264359381981194
Epoch 1 - Val loss: 6.4396467208862305, Val accu: 0.004464285913854837

Epoch 2: 100%|██████████| 30/30 [00:15<00:00,  1.92it/s, v_num=1-06]
Epoch 2 - Train loss: 6.661075115203857, Train accu: 0.002057656180113554
Epoch 2 - Val loss: 6.437726974487305, Val accu: 0.0029296877328306437

Epoch 3: 100%|██████████| 30/30 [00:16<00:00,  1.87it/s, v_num=1-06]
Epoch 3 - Train loss: 6.491706371307373, Train accu: 0.0023831771686673164
Epoch 3 - Val loss: 6.445499897003174, Val accu: 0.004464285913854837

Epoch 4: 100%|██████████| 30/30 [00:16<00:00,  1.84it/s, v_num=1-06]
Epoch 4 - Train loss: 6.421058177947998, Train accu: 0.002298359526321292
Epoch 4 - Val

Trainer was signaled to stop but the required `min_epochs=50` or `min_steps=None` has not been met. Training will continue...


Epoch 34: 100%|██████████| 30/30 [00:16<00:00,  1.86it/s, v_num=1-06]
Epoch 34 - Train loss: 6.153400421142578, Train accu: 0.001992552075535059
Epoch 34 - Val loss: 6.435069561004639, Val accu: 0.004464285913854837

Epoch 35: 100%|██████████| 30/30 [00:16<00:00,  1.82it/s, v_num=1-06]
Epoch 35 - Train loss: 6.15416955947876, Train accu: 0.0022135418839752674
Epoch 35 - Val loss: 6.432463645935059, Val accu: 0.0005580357392318547

Epoch 36: 100%|██████████| 30/30 [00:16<00:00,  1.83it/s, v_num=1-06]
Epoch 36 - Train loss: 6.150712013244629, Train accu: 0.001927447970956564
Epoch 36 - Val loss: 6.431375026702881, Val accu: 0.004464285913854837

Epoch 37: 100%|██████████| 30/30 [00:16<00:00,  1.85it/s, v_num=1-06]
Epoch 37 - Train loss: 6.1511077880859375, Train accu: 0.0022135418839752674
Epoch 37 - Val loss: 6.431115627288818, Val accu: 0.00041852681897580624

Epoch 38: 100%|██████████| 30/30 [00:16<00:00,  1.84it/s, v_num=1-06]
Epoch 38 - Train loss: 6.148756504058838, Train accu: 0.0

In [27]:
#timenow = datetime.now()
#traintime = timenow.strftime("%Y-%m-%d_%H-%M-%S") - strtime
#print(f"The model lasted {traintime} to train")

#Save the model also at the end of the training
sure = True
if sure:
    from torch import save as torch_save
    torch_save(model, f'Models/ZSLmodel{strtime}.pth')
    print(f"Saved as ZSLmodel{strtime}.pth")


Saved as ZSLmodel2024-08-19_20-01-06.pth


In [28]:
end = datetime.now() - timenow
total_seconds = end.total_seconds()
hours = total_seconds // 3600
minutes = (total_seconds % 3600) // 60

print(f"{int(hours)} hours and {int(minutes)} minutes")

0 hours and 17 minutes


In [29]:
from torch import load as torch_load
#model = torch_load('../SavedModels/ZSLmodel2024-08-18_17-27-39.pth')

In [30]:
from maldi_zsl_edit.utils import ZSL_levels_metrics
data_path = "Data/zsl_SINAwasabi.h5t"
levels = ["Family", "Genus", "Species", "Strain"]
accug, f1g, gen, unacc, snacc, hmean, ev_species, labels = ZSL_levels_metrics(data_path,model,levels,"Test",split_index=2,general=False) #Consider the split and the new labels

--- Getting predictions ---
Working with test set

For Family there are 77 different labels
For Genus there are 154 different labels
For Species there are 568 different labels
For Strain there are 628 different labels

--- Multi level evaluation general ---

--- Calculating Accuracy ---
For the level Family the accu score is: 0.19635498523712158
For the level Genus the accu score is: 0.05118858814239502
For the level Species the accu score is: 0.006022186949849129
For the level Strain the accu score is: 0.005705229938030243

--- Calculating F1 scores ---
For the level Family the F1 score is: 0.026128826662898064
For the level Genus the F1 score is: 0.008865450508892536
For the level Species the F1 score is: 0.0009803102584555745
For the level Strain the F1 score is: 0.0008283915231004357

--- Multi level evaluation test_geni ---

--- Calculating Accuracy ---
For the level Family the accu score is: 0.18151001632213593
For the level Genus the accu score is: 0.051463790237903595
For the l