# Progetto per il corso di Cognitive Computing Systems
## Visual Recongition per la rilevazione di tumori cerebrali
**Anno accademico 2022/2023**  
Autori: Ferinando Simone D'Agostino, Simone D'Orta.  
Docente: prof. Paolo Maresca.  
[![Github repo](https://img.shields.io/static/v1?label=GitHub&message=FuseMedML&color=brightgreen)](https://github.com/BiomedSciAI/fuse-med-ml)

[![PyPI version](https://badge.fury.io/py/fuse-med-ml.svg)](https://badge.fury.io/py/fuse-med-ml)

[![Slack channel](https://img.shields.io/badge/support-slack-slack.svg?logo=slack)](https://join.slack.com/t/fusemedml/shared_invite/zt-xr1jaj29-h7IMsSc0Lq4qpVNxW97Phw)

[![Open Source](https://badges.frapsoft.com/os/v1/open-source.svg)](https://github.com/BiomedSciAI/fuse-med-ml)


------------
## **Installation Details - Google Colab**
Tramite i flag install_fuse e use_gpu è possibile installare FuseMedML nell'ambiente Google Colab e configurare il runtime con una GPU.

In [None]:
# @title 1. Install FuseMedML

# @markdown Please choose whether or not to install FuseMedML and execute this cell by pressing the *Play* button on the left.


install_fuse = True  # @param {type:"boolean"}
use_gpu = True  # @param {type:"boolean"}

# @markdown ### **Warning!**
# @markdown If you wish to install FuseMedML -- as a workaround for
# @markdown [this](https://stackoverflow.com/questions/57831187/need-to-restart-runtime-before-import-an-installed-package-in-colab)
# @markdown issue please follow those steps:   <br>
# @markdown 1. Execute this cell by pressing the ▶️ button on the left.
# @markdown 2. Restart runtime
# @markdown 3. Execute it once again
# @markdown 4. Enjoy
if install_fuse:
    !git clone https://github.com/BiomedSciAI/fuse-med-ml.git
    %cd fuse-med-ml
    %pip install -e .[all,examples]


## **Setup environment**
Nel seguito si configura l'ambiente del progetto. Innazitutto, si importano le libreria necessatie (FuseMed e Pytorch).

In [None]:
# @title 1. Imports

# @markdown Please execute this cell by pressing the *Play* button on the left.

import os
import copy
from typing import OrderedDict

import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
from torch.utils.data.dataloader import DataLoader

from fuse.eval.evaluator import EvaluatorDefault
from fuse.dl.losses.loss_default import LossDefault
from fuse.eval.metrics.classification.metrics_classification_common import MetricAccuracy, MetricAUCROC, MetricROCCurve, MetricConfusion
from fuse.eval.metrics.classification.metrics_thresholding_common import MetricApplyThresholds
from fuse.dl.models.model_wrapper import ModelWrapSeqToDict
from fuse.data.utils.samplers import BatchSamplerDefault
from fuse.data.utils.collates import CollateDefault
from fuse.dl.lightning.pl_module import LightningModuleDefault
from fuse.dl.lightning.pl_funcs import convert_predictions_to_dataframe
from fuse.utils.file_io.file_io import create_dir, save_dataframe
from fuseimg.datasets.mnist import MNIST

from fuse_examples.imaging.hello_world.hello_world_utils import LeNet, perform_softmax

from fuse.data.datasets.dataset_wrap_seq_to_dict import  DatasetWrapSeqToDict
from torchvision import transforms, datasets

##### **Output paths**
Si configurano i percorsi necessari al corretto funzionamento dello script.
- VOL:
	- cache_dir: non rilevante;
	- model_dir: contiene il modello e il file riassuntivo delle epoche;
	- infer_dir: contiene il file infer_file.gz (che contiene le predizioni) e l'immagine della curva ROC;
	- eval_dir: contiene il file results.txt, in cui sono salvati i risultati della fase di valutazione del modello (in particolare, le 4 metriche accuracy, recall, precision, f1-score).

In [None]:
ROOT = '/content/drive/MyDrive/PROGETTO_CCS'
model_dir = os.path.join(ROOT, "VOL/model_dir")
PATHS = {
    "model_dir": model_dir,
    "cache_dir": os.path.join(ROOT, "VOL/cache_dir"),
    "inference_dir": os.path.join(ROOT, "VOL/infer_dir"),
    "eval_dir": os.path.join(ROOT, "VOL/eval_dir"),
}

paths = PATHS

##### **Training Parameters**
Si configurano gli iperparametri e i parametri per il training dei modelli.

In [None]:
TRAIN_COMMON_PARAMS = {}

### Data ###
TRAIN_COMMON_PARAMS["data.batch_size"] = 64
TRAIN_COMMON_PARAMS["data.train_num_workers"] = 2
TRAIN_COMMON_PARAMS["data.validation_num_workers"] = 2

### PL Trainer ###
TRAIN_COMMON_PARAMS["trainer.num_epochs"] = 50
TRAIN_COMMON_PARAMS["trainer.num_devices"] = 1
TRAIN_COMMON_PARAMS["trainer.accelerator"] = "gpu" if use_gpu else "cpu"
TRAIN_COMMON_PARAMS["trainer.ckpt_path"] = None  #  path to the checkpoint you wish continue the training from

### Optimizer ###
TRAIN_COMMON_PARAMS["opt.lr"] = 1e-4
TRAIN_COMMON_PARAMS["opt.weight_decay"] = 0.001

train_params = TRAIN_COMMON_PARAMS

## **Training the model**
Si effettua il preprocessing e l'addestramento del modello.

##### **Data**
Per prima cosa si caricano i dati da Google Drive.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Si caricano i dati e si effettua il preprocessing applicando le trasformazioni ai dati.

In [None]:
## Training Data
# Create dataset
data_dir = '/content/drive/MyDrive/PROGETTO_CCS/DATASET'
transform = transforms.Compose([
    transforms.Resize((224,224)),         # si ridefinisce la dimensione dell'immagine IN BASE ALLA RETE (PER RESNET è 224 X 224)
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))])
    #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
torch_train_dataset = {x: datasets.ImageFolder(os.path.join(data_dir, x),transform) for x in ['training_VOL', 'testing_VOL', 'validation_VOL']}
train_dataset = DatasetWrapSeqToDict(name='training_VOL', dataset=torch_train_dataset['training_VOL'], sample_keys=('data.image', 'data.label'))
train_dataset.create()

In [None]:
# Create Fuse's custom sampler
sampler = BatchSamplerDefault(
    dataset=train_dataset,
    balanced_class_name="data.label",
    num_balanced_classes=2,
    batch_size=train_params["data.batch_size"],
    balanced_class_weights=None,
)

# Create dataloader
train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_sampler=sampler,
    collate_fn=CollateDefault(),
    num_workers=train_params["data.train_num_workers"],
)

## Validation data
# Create dataset
validation_dataset = DatasetWrapSeqToDict(name='validation_VOL', dataset=torch_train_dataset['validation_VOL'], sample_keys=('data.image', 'data.label'))
validation_dataset.create()

# dataloader
validation_dataloader = DataLoader(
    dataset=validation_dataset,
    batch_size=train_params["data.batch_size"],
    collate_fn=CollateDefault(),
    num_workers=train_params["data.validation_num_workers"],
)

multiprocess pool created with 10 workers.


batch_sampler: 100%|██████████| 1111/1111 [00:23<00:00, 46.85it/s]


##### **Model**

Si definisce il modello pretrainato. Occorre modificare qui il nome del modello per testarne altri.

In [None]:
import torchvision.models as models
def create_model():
    torch_model = models.resnet152(pretrained=True) # Definisci modello
    model = ModelWrapSeqToDict(
        model=torch_model,
        model_inputs=["data.image"],
        post_forward_processing_function=perform_softmax,
        model_outputs=["model.logits.classification", "model.output.classification"],
    )
    return model

model = create_model()



##### **Loss function**
Si definisce la funzione di loss, la quale compara i valori di target e le predizioni misurando quindi le prestazioni della rete.

In [None]:
losses = {
    "cls_loss": LossDefault(
        pred="model.logits.classification", target="data.label", callable=F.cross_entropy, weight=1.0
    ),
}

##### **Metrics**
Si definiscono le metriche con le quali valutare il training e la validation.

In [None]:
train_metrics = OrderedDict(
    [
        ("operation_point", MetricApplyThresholds(pred="model.output.classification")),
        ("accuracy", MetricAccuracy(pred="results:metrics.operation_point.cls_pred", target="data.label")),
        ("confusion", MetricConfusion(pred="model.output.classification", target="data.label"))
    ]
)
validation_metrics = copy.deepcopy(train_metrics)

##### **Best Epoch Source**
Si definisce il criterio con il quale bisogna salvare l'epoca migliore, ossia la metrica di accuracy nella fase di validazione.

In [None]:
best_epoch_source = dict(monitor="validation.metrics.accuracy", mode="max")

##### **Training**
Si effettua il training.

In [None]:
# create optimizer
optimizer = optim.Adam(model.parameters(), lr=train_params["opt.lr"], weight_decay=train_params["opt.weight_decay"])

# create scheduler
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
lr_sch_config = dict(scheduler=lr_scheduler, monitor="validation.losses.total_loss")

# optimizer and lr sch - see pl.LightningModule.configure_optimizers return value for all options
optimizers_and_lr_schs = dict(optimizer=optimizer, lr_scheduler=lr_sch_config)

# create instance of PL module - FuseMedML generic version
pl_module = LightningModuleDefault(
    model_dir=paths["model_dir"],
    model=model,
    losses=losses,
    train_metrics=train_metrics,
    validation_metrics=validation_metrics,
    best_epoch_source=best_epoch_source,
    optimizers_and_lr_schs=optimizers_and_lr_schs,
)

# create lightning trainer
pl_trainer = pl.Trainer(
    default_root_dir=paths["model_dir"],
    max_epochs=train_params["trainer.num_epochs"],
    accelerator=train_params["trainer.accelerator"],
    devices=train_params["trainer.num_devices"],
)

# train
pl_trainer.fit(pl_module, train_dataloader, validation_dataloader, ckpt_path=train_params["trainer.ckpt_path"])

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name   | Type               | Params
----------------------------------------------
0 | _model | ModelWrapSeqToDict | 11.7 M
----------------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
46.758    Total estimated 

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Stats for epoch: 0 (Currently the best epoch for source validation.metrics.accuracy!)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (0)                          | Current Epoch (0)                       |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 4.7487                                  | 4.7487                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 4.7487                                  | 4.7487                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accur



Validation: 0it [00:00, ?it/s]

Stats for epoch: 1 (Currently the best epoch for source validation.metrics.accuracy!)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (1)                          | Current Epoch (1)                       |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 1.1178                                  | 1.1178                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 1.1178                                  | 1.1178                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accur



Validation: 0it [00:00, ?it/s]

Stats for epoch: 2 (Currently the best epoch for source validation.metrics.accuracy!)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (2)                          | Current Epoch (2)                       |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.1833                                  | 0.1833                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.1833                                  | 0.1833                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accur



Validation: 0it [00:00, ?it/s]

Stats for epoch: 3 (Currently the best epoch for source validation.metrics.accuracy!)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (3)                          | Current Epoch (3)                       |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0813                                  | 0.0813                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0813                                  | 0.0813                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accur



Validation: 0it [00:00, ?it/s]

Stats for epoch: 4 (Currently the best epoch for source validation.metrics.accuracy!)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (4)                          | Current Epoch (4)                       |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0361                                  | 0.0361                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0361                                  | 0.0361                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accur



Validation: 0it [00:00, ?it/s]

Stats for epoch: 5 (Best epoch is 4 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (4)                          | Current Epoch (5)                       |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0361                                  | 0.0224                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0361                                  | 0.0224                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy       



Validation: 0it [00:00, ?it/s]

Stats for epoch: 6 (Currently the best epoch for source validation.metrics.accuracy!)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (6)                          | Current Epoch (6)                       |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0202                                  | 0.0202                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0202                                  | 0.0202                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accur



Validation: 0it [00:00, ?it/s]

Stats for epoch: 7 (Best epoch is 6 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (6)                          | Current Epoch (7)                       |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0202                                  | 0.0185                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0202                                  | 0.0185                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy       



Validation: 0it [00:00, ?it/s]

Stats for epoch: 8 (Currently the best epoch for source validation.metrics.accuracy!)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (8)                       |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0137                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0137                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accur



Validation: 0it [00:00, ?it/s]

Stats for epoch: 9 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (9)                       |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0117                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0117                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy       



Validation: 0it [00:00, ?it/s]

Stats for epoch: 10 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (10)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0129                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0129                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 11 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (11)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0123                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0123                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 12 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (12)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0046                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0046                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 13 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (13)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0107                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0107                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 14 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (14)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0156                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0156                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 15 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (15)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0107                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0107                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 16 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (16)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0069                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0069                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 17 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (17)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0082                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0082                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 18 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (18)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0072                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0072                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 19 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (19)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0050                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0050                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 20 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (20)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0081                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0081                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 21 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (21)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0048                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0048                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 22 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (22)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0066                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0066                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 23 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (23)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0056                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0056                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 24 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (24)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0068                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0068                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 25 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (25)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0065                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0065                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 26 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (26)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0074                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0074                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 27 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (27)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0054                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0054                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 28 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (28)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0059                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0059                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 29 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (29)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0052                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0052                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 30 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (30)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0065                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0065                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 31 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (31)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0061                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0061                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 32 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (32)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0058                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0058                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 33 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (33)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0061                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0061                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 34 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (34)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0063                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0063                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 35 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (35)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0052                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0052                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 36 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (36)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0058                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0058                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 37 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (37)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0065                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0065                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 38 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (38)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0056                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0056                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 39 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (39)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0067                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0067                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 40 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (40)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0047                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0047                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 41 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (41)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0059                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0059                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 42 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (42)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0063                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0063                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 43 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (43)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0058                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0058                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 44 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (44)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0067                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0067                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 45 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (45)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0055                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0055                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 46 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (46)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0060                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0060                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 47 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (47)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0065                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0065                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 48 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (48)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0052                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0052                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      



Validation: 0it [00:00, ?it/s]

Stats for epoch: 49 (Best epoch is 8 for source validation.metrics.accuracy)

------------------------------------------------------------------------------------------------------------------------------
|                                         | Best Epoch (8)                          | Current Epoch (49)                      |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.cls_loss                   | 0.0137                                  | 0.0067                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.losses.total_loss                 | 0.0137                                  | 0.0067                                  |
------------------------------------------------------------------------------------------------------------------------------
| train.metrics.accuracy      

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.


## **Infer**
Si effettua la fase di predizione e si salvano i risultati.

##### **Define Infer Common Params**
Si definiscono i parametri per effettuare le predizioni. Per esempio, si sceglie dove salvare il file con le predizioni e il percorso del file in cui è presente il modello da caricare.

In [None]:
INFER_COMMON_PARAMS = {}
INFER_COMMON_PARAMS["infer_filename"] = "infer_file.gz"
INFER_COMMON_PARAMS["checkpoint"] = "best_epoch.ckpt"
INFER_COMMON_PARAMS["trainer.num_devices"] = TRAIN_COMMON_PARAMS["trainer.num_devices"]
INFER_COMMON_PARAMS["trainer.accelerator"] = TRAIN_COMMON_PARAMS["trainer.accelerator"]

infer_common_params = INFER_COMMON_PARAMS

##### **Infer**
Genera le predizioni.

In [None]:
# setting dir and paths
create_dir(paths["inference_dir"])
infer_file = os.path.join(paths["inference_dir"], infer_common_params["infer_filename"])
checkpoint_file = os.path.join(paths["model_dir"], infer_common_params["checkpoint"])

# creating a dataloader
#validation_dataloader = DataLoader(dataset=validation_dataset, collate_fn=CollateDefault(), batch_size=2, num_workers=2)
testing_dataset = DatasetWrapSeqToDict(name='testing_VOL', dataset=torch_train_dataset['testing_VOL'], sample_keys=('data.image', 'data.label'))
testing_dataset.create()
testing_dataloader = DataLoader(dataset=testing_dataset, collate_fn=CollateDefault(), batch_size=2, num_workers=2) # testing_dataloader

# load pytorch lightning module
model = create_model()
pl_module = LightningModuleDefault.load_from_checkpoint(
    checkpoint_file, model_dir=paths["model_dir"], model=model, map_location="cpu", strict=True
)

# set the prediction keys to extract (the ones used be the evaluation function).
pl_module.set_predictions_keys(
    ["model.output.classification", "data.label"]
)  # which keys to extract and dump into file

# create a trainer instance
pl_trainer = pl.Trainer(
    default_root_dir=paths["model_dir"],
    accelerator=infer_common_params["trainer.accelerator"],
    devices=infer_common_params["trainer.num_devices"],
)

# predict
predictions = pl_trainer.predict(pl_module, testing_dataloader, return_predictions=True)

# convert list of batch outputs into a dataframe
infer_df = convert_predictions_to_dataframe(predictions)
save_dataframe(infer_df, infer_file)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

## **Evaluation**
Si effettua la valutazione del modello, caricando le predizioni precedentemente salvate e confrontandole con le label.


##### **Define EVAL Common Params**
Si definscono i parametri per la fase di valutazione, ossia il percorso del file in cui sono presenti le predizioni.


In [None]:
EVAL_COMMON_PARAMS = {}
EVAL_COMMON_PARAMS["infer_filename"] = INFER_COMMON_PARAMS["infer_filename"]

eval_common_params = EVAL_COMMON_PARAMS

##### **Define metrics**
Si definiscono le metriche per la valutazione del modello, ossia accuracy, precision, recall, f1-score. Inoltre sono presenti anche le curve ROC e la metrica AUCROC.

In [None]:
class_names = ['GBM', 'MET']

# metrics
metrics = OrderedDict(
    [
        ("operation_point", MetricApplyThresholds(pred="model.output.classification")),  # will apply argmax
        ("accuracy", MetricAccuracy(pred="results:metrics.operation_point.cls_pred", target="data.label")),
        (
            "roc",
            MetricROCCurve(
                pred="model.output.classification",
                target="data.label",
                class_names=class_names,
                output_filename=os.path.join(paths["inference_dir"], "roc_curve.png"),
            ),
        ),
        ("auc", MetricAUCROC(pred="model.output.classification", target="data.label", class_names=class_names)),
        ("confusion", MetricConfusion(pred="model.output.classification", target="data.label",metrics=("precision", "sensitivity", "f1")))
    ]
)

##### **Evaluate**
Si effettua la valutazione.

In [None]:
# create evaluator
evaluator = EvaluatorDefault()

# run eval
results = evaluator.eval(
    ids=None,
    data=os.path.join(paths["inference_dir"], eval_common_params["infer_filename"]),
    metrics=metrics,
    output_dir=paths["eval_dir"],
    silent=False,
)

print("Done!")

Results:

Metric operation_point.cls_pred:
------------------------------------------------
<fuse.eval.metrics.utils.PerSampleData object at 0x7f2f68621840>

Metric accuracy:
------------------------------------------------
0.96

Metric roc.GBM.fpr:
------------------------------------------------
[0.         0.         0.         0.01923077 0.01923077 0.03846154
 0.03846154 0.05769231 0.05769231 0.07692308 0.07692308 0.13461538
 0.13461538 1.        ]

Metric roc.GBM.tpr:
------------------------------------------------
[0.         0.01020408 0.78571429 0.78571429 0.91836735 0.91836735
 0.92857143 0.92857143 0.93877551 0.93877551 0.98979592 0.98979592
 1.         1.        ]

Metric roc.GBM.auc:
------------------------------------------------
0.9911695447409734

Metric roc.MET.fpr:
------------------------------------------------
[0.         0.         0.         0.02040816 0.02040816 0.05102041
 0.05102041 0.06122449 0.06122449 0.07142857 0.07142857 0.08163265
 0.08163265 0.15306122