In [2]:
%load_ext autoreload
%autoreload 2

# The Metastatic Tissue Classification problem -- supervised learning

---

## 1. Introduction

This notebook trains and evaluate supervised models to solve the "Metastatic Tissue Classification" task.


The following code block contains the main parameters for this notebook.

In [None]:
# Data directory
data_dir = "./data"

# Configurable validation dataset size (None = all validation samples)
DL_VAL_SAMPLES_PER_CLASS=None
DL_NUM_WORKERS=4

# Model parameters
BATCH_SIZE = 2 ** 9
SAMPLES_PER_CLASS_LIST = [
    12_800,
    25_600,
    51_200,
    128_000
]

n_versions = 3
max_steps = 100 * 20

add_from_scratch_models = True
add_pretrained_models = True

---
## 2. Setting up the dataset

We will use the PCam data module to automatically download and handle the dataset.

In [None]:
from dataset_pcam import PCamDataModule

datamodule = PCamDataModule(data_dir=data_dir, batch_size=BATCH_SIZE, num_workers=DL_NUM_WORKERS, val_samples_per_class=DL_VAL_SAMPLES_PER_CLASS)

class_names = datamodule.full_dataset.classes

print(datamodule)

  from .autonotebook import tqdm as notebook_tqdm


{Train dataloader: size=262144}
{Validation dataloader: size=32768}
{Test dataloader: size=32768}
{Predict dataloader: None}


---
## 3. Setting up the models



Let's start by writing the code to support the creation of the backbone, the prediction head, and the supervised model itself.

In [None]:
import torch
from minerva.models.nets.base import SimpleSupervisedModel
from torchmetrics import Accuracy

def generate_pred_head(backbone_out_dim=1920, hidden_dim=512):
    return torch.nn.Sequential(
        torch.nn.Linear(backbone_out_dim, hidden_dim),
        torch.nn.ReLU(),
        torch.nn.Linear(hidden_dim, len(class_names))
    )

def build_scheduler(optimizer):
  return {
    "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.1),
    "interval": "epoch",
    "frequency": 1,
    "monitor": "val_loss",
    "strict": True
  }

def build_SimpleSupervisedModel(backbone):
  return SimpleSupervisedModel(
    backbone=backbone,
    fc=generate_pred_head(hidden_dim=1024),
    loss_fn=torch.nn.CrossEntropyLoss(),
    learning_rate=1e-3,
    lr_scheduler=build_scheduler,
    train_metrics={"accuracy": Accuracy("multiclass", num_classes=len(class_names))},
    val_metrics  ={"accuracy": Accuracy("multiclass", num_classes=len(class_names))},
    test_metrics ={"accuracy": Accuracy("multiclass", num_classes=len(class_names))},
  )

Let's also create a transform pipeline to generate modified training samples.

In [7]:
import torch
from torchvision.transforms.v2 import Compose, ToImage, ToDtype, Normalize, RandomHorizontalFlip, RandomVerticalFlip, ColorJitter, CenterCrop, RandomGrayscale

precomputed_dataset_stats = {'mean': torch.tensor([0.6982, 0.5344, 0.6907]), 'std': torch.tensor([0.2343, 0.2761, 0.2113])}

# Set the training set image transformation pipeline
train_transform_pipeline = Compose([ToImage(),
                                    ToDtype(torch.float32, scale=True),
                                    CenterCrop(42),
                                    RandomHorizontalFlip(),
                                    RandomVerticalFlip(),
                                    RandomGrayscale(),
                                    ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
                                    Normalize(precomputed_dataset_stats["mean"],
                                              precomputed_dataset_stats["std"])])

Now, let's create models with different configurations (e.g., initial parameters), and models to be trained with different number of samples per class.

In [None]:
models = {}

from torchvision.models import DenseNet201_Weights
import lightning
from pcam.backbone import generate_backbone

# Let's set the seeds for reproducibility
lightning.seed_everything(1969)

for version in range(n_versions):
    # , ("notr", None)
    for train_transform_id, train_transform_pipeline in [ ("aug", train_transform_pipeline) ]:
        for samples_per_class in SAMPLES_PER_CLASS_LIST:
            # -- Add the from scratch model --
            if add_from_scratch_models:
                backbone = generate_backbone()
                models[f"From_Scratch-{train_transform_id}/{samples_per_class}_spc/{max_steps}_steps/v_{version}"] = {
                    "backbone": backbone,
                    "model": build_SimpleSupervisedModel(backbone),
                    "max_steps": max_steps,
                    "samples per class": samples_per_class,
                    "train_transform": train_transform_pipeline,
                    "version": version
                }

            # -- Add the pretrained model: ImageNet weights --
            if add_pretrained_models:
                backbone = generate_backbone(weights=DenseNet201_Weights.IMAGENET1K_V1)
                models[f"Pretrained_ImageNet-{train_transform_id}/{samples_per_class}_spc/{max_steps}_steps/v_{version}"] = {
                    "backbone": backbone,
                    "model": build_SimpleSupervisedModel(backbone),
                    "max_steps": max_steps,
                    "samples per class": samples_per_class,
                    "train_transform": train_transform_pipeline,
                    "version": version
                }

print("== The following models were included ==")
for i, k in enumerate(models.keys()):
    print(f"{i:3d} {k}")

Seed set to 1969


== The following models were included ==
  0 From_Scratch-aug/12800_spc/15000_steps/v_0
  1 Pretrained_ImageNet-aug/12800_spc/15000_steps/v_0
  2 From_Scratch-aug/25600_spc/15000_steps/v_0
  3 Pretrained_ImageNet-aug/25600_spc/15000_steps/v_0
  4 From_Scratch-aug/51200_spc/15000_steps/v_0
  5 Pretrained_ImageNet-aug/51200_spc/15000_steps/v_0
  6 From_Scratch-aug/128000_spc/15000_steps/v_0
  7 Pretrained_ImageNet-aug/128000_spc/15000_steps/v_0


--- 

## 4. Training the models

In [None]:
from lightning import Trainer
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint

# Register stats
from timeit import default_timer as timer
n_configs = len(models)
n_configs_trained = 0
start_time = timer()

for model_name, model_info in models.items():
    torch.cuda.empty_cache()
    print("***********************************")
    print(f" Training model {model_name}")
    print("***********************************")
    loggers = [TensorBoardLogger(save_dir=f"logs/PCam/Downstream/", name=model_name),
               CSVLogger(save_dir=f"logs/PCam/Downstream/", name=model_name)]
    checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min")
    trainer = Trainer(max_steps=model_info["max_steps"], benchmark=True, 
                      log_every_n_steps=8, logger=loggers,
                      callbacks=[checkpoint_callback])
    
    # Train the model
    trainer.fit(model_info["model"], 
                train_dataloaders=datamodule.train_dataloader(samples_per_class=model_info["samples per class"], 
                                                              transform=model_info["train_transform"]),
                val_dataloaders=datamodule.val_dataloader())

    # Load parameters from best epoch
    print(f"Loading weights from {checkpoint_callback.best_model_path}")
    best_model = SimpleSupervisedModel.load_from_checkpoint(checkpoint_callback.best_model_path,
                                                            backbone=model_info["model"].backbone,
                                                            fc=model_info["model"].fc,
                                                            loss_fn=torch.nn.CrossEntropyLoss(),
                                                            train_metrics={"accuracy": Accuracy("multiclass", num_classes=len(class_names))},
                                                            val_metrics  ={"accuracy": Accuracy("multiclass", num_classes=len(class_names))},
                                                            test_metrics ={"accuracy": Accuracy("multiclass", num_classes=len(class_names))})

    # Test the model
    trainer.test(best_model, dataloaders=datamodule.test_dataloader())

    # Compute and display training statistics
    elapsed = timer() - start_time
    n_configs_trained += 1
    avg = elapsed / n_configs_trained  
    print("-----------------------------------")
    print(f"Training stats")
    print(f"  - Avg time to train models: {avg:.2f} seconds ")
    est_total = avg * n_configs
    est_remaining = est_total - elapsed
    print(f"  - Total # models  : {n_configs} model(s)")
    print(f"  - Models trained  : {n_configs_trained} model(s) in {elapsed:.2f} seconds")
    print(f"  - Remaining models: {n_configs-n_configs_trained} model(s). {est_remaining} s remaining (Estimative)")
    print(f"  - Total time      : {est_total} seconds (estimate: avg * # models)")

***********************************
 Training model From_Scratch-aug/128000_spc/15000_steps/v_0
***********************************


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type             | Params | Mode 
------------------------------------------------------
0 | backbone | DenseNet         | 18.1 M | train
1 | fc       | Sequential       | 2.0 M  | train
2 | loss_fn  | CrossEntropyLoss | 0      | train
------------------------------------------------------
20.1 M    Trainable params
0         Non-trainable params
20.1 M    Total params
80.248    Total estimated model params size (MB)
718       Modules in train mode
0         Modules in eval mode


Epoch 29: 100%|██████████| 500/500 [05:46<00:00,  1.44it/s, v_num=0_1, val_loss=0.421, val_accuracy=0.820, train_loss=0.249, train_accuracy=0.897]

`Trainer.fit` stopped: `max_steps=15000` reached.


Epoch 29: 100%|██████████| 500/500 [05:46<00:00,  1.44it/s, v_num=0_1, val_loss=0.421, val_accuracy=0.820, train_loss=0.249, train_accuracy=0.897]
Loading weights from logs/PCam/Downstream/From_Scratch-aug/128000_spc/15000_steps/v_0/version_0/checkpoints/epoch=6-step=3500.ckpt


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 64/64 [00:23<00:00,  2.78it/s]


-----------------------------------
Training stats
  - Avg time to train models: 10508.49 seconds 
  - Total # models  : 2 model(s)
  - Models trained  : 1 model(s) in 10508.49 seconds
  - Remaining models: 1 model(s). 10508.486473959056 s remaining (Estimative)
  - Total time      : 21016.97294791811 seconds (estimate: avg * # models)
***********************************
 Training model Pretrained_ImageNet-aug/128000_spc/15000_steps/v_0
***********************************


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type             | Params | Mode 
------------------------------------------------------
0 | backbone | DenseNet         | 18.1 M | train
1 | fc       | Sequential       | 2.0 M  | train
2 | loss_fn  | CrossEntropyLoss | 0      | train
------------------------------------------------------
20.1 M    Trainable params
0         Non-trainable params
20.1 M    Total params
80.248    Total estimated model params size (MB)
718       Modules in train mode
0         Modules in eval mode


Epoch 29: 100%|██████████| 500/500 [05:46<00:00,  1.44it/s, v_num=0_1, val_loss=0.438, val_accuracy=0.821, train_loss=0.191, train_accuracy=0.925]

`Trainer.fit` stopped: `max_steps=15000` reached.


Epoch 29: 100%|██████████| 500/500 [05:46<00:00,  1.44it/s, v_num=0_1, val_loss=0.438, val_accuracy=0.821, train_loss=0.191, train_accuracy=0.925]
Loading weights from logs/PCam/Downstream/Pretrained_ImageNet-aug/128000_spc/15000_steps/v_0/version_0/checkpoints/epoch=2-step=1500.ckpt


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 64/64 [00:23<00:00,  2.78it/s]


-----------------------------------
Training stats
  - Avg time to train models: 10473.67 seconds 
  - Total # models  : 2 model(s)
  - Models trained  : 2 model(s) in 20947.34 seconds
  - Remaining models: 0 model(s). 0.0 s remaining (Estimative)
  - Total time      : 20947.343069495982 seconds (estimate: avg * # models)


: 