In [25]:
working_dir = "../."
dataset_h5_path = "/Users/andry/Documents/GitHub/lus-dl-framework/data/iclus/dataset.h5"
hospitaldict_path = "/Users/andry/Documents/GitHub/lus-dl-framework/data/iclus/hospitals-patients-dict.pkl"
libraries_dir = working_dir + "/libraries"


import warnings
import sys
import os
import glob
import pickle
import lightning as pl
from tabulate import tabulate
from torch.utils.data import Subset

from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, DeviceStatsMonitor, ModelCheckpoint
from lightning.pytorch.tuner import Tuner


sys.path.append(working_dir)
from data_setup import HDF5Dataset, FrameTargetDataset
from lightning_modules.ViTLightningModule import ViTLightningModule
from lightning_modules.ResNet18LightningModule import ResNet18LightningModule
from lightning_modules.BEiTLightningModule import BEiTLightningModule

os.chdir(working_dir)
os.getcwd()


'/Users/andry/Documents'

# Dataset

In [13]:
import pickle
from torch.utils.data import Subset

train_ratio = 0.2
rseed = 21

In [14]:
dataset = HDF5Dataset(dataset_h5_path)

train_indices_path = os.path.dirname(dataset_h5_path) + f"/train_indices_{train_ratio}.pkl"
test_indices_path = os.path.dirname(dataset_h5_path) + f"/test_indices_{train_ratio}.pkl"

Serialized frame index map FOUND.

Loaded serialized data.


277 videos (58924 frames) loaded.


In [15]:
if os.path.exists(train_indices_path) and os.path.exists(test_indices_path):
    print("Loading pickled indices")
    with open(train_indices_path, 'rb') as train_pickle_file:
        train_indices = pickle.load(train_pickle_file)
    with open(test_indices_path, 'rb') as test_pickle_file:
        test_indices = pickle.load(test_pickle_file)
    # Create training and test subsets
    train_subset = Subset(dataset, train_indices)
    test_subset = Subset(dataset, test_indices)  
else:
    train_subset, test_subset, split_info, train_indices, test_indices = dataset.split_dataset(hospitaldict_path, 
                                                              rseed, 
                                                              train_ratio)
    print("Pickling sets...")
    
    # Pickle the indices
    with open(train_indices_path, 'wb') as train_pickle_file:
        pickle.dump(train_indices, train_pickle_file)
    with open(test_indices_path, 'wb') as test_pickle_file:
        pickle.dump(test_indices, test_pickle_file)

Loading pickled indices


In [16]:
test_subset_size = train_ratio/2
test_subset = Subset(test_subset, range(int(test_subset_size * len(test_indices))))
test_subset

<torch.utils.data.dataset.Subset at 0x17f313610>

In [17]:
train_dataset = FrameTargetDataset(train_subset)
test_dataset = FrameTargetDataset(test_subset)

print(f"Train size: {len(train_dataset)}")
print(f"Test size: {len(test_dataset)}")

Train size: 11053
Test size: 4787


# Models

## ViT

In [33]:
# Model class ------------------------------------------------------------
from transformers import ViTForImageClassification
import torch.nn as nn
from torch.utils.data import DataLoader
import lightning.pytorch as pl
import torch
import torchvision
from transformers import ViTImageProcessor
from torchmetrics.classification import MulticlassF1Score
from kornia import tensor_to_image
import matplotlib.pyplot as plt
from data_setup import DataAugmentation
from transformers import ViTConfig


def collate_fn(examples):
    frames = torch.stack([example[0] for example in examples])  # Extract the preprocessed frames
    scores = torch.tensor([example[1] for example in examples])  # Extract the scores
    return (frames, scores)
  
id2label = {0: 'no', 1: 'yellow', 2: 'orange', 3: 'red'}
label2id = {"no": 0, "yellow": 1, "orange": 2, "red": 3}

class ViTLightningModule(pl.LightningModule):
    def __init__(self, 
                 train_dataset,
                 test_dataset,
                 batch_size,
                 num_workers,
                 optimizer,
                 num_classes=4,
                 lr=1e-3,
                 pretrained=True,
                 configuration=None):
        
        super(ViTLightningModule, self).__init__()
        
        if pretrained == True:
            self.vit = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                                  num_labels=4,
                                                                  id2label=id2label,
                                                                  label2id=label2id)
        else:
            self.config = ViTConfig(**configuration)
            self.vit = ViTForImageClassification(config=self.config)
        
        self.preprocess = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k', do_rescale=False)
        self.transform = DataAugmentation()
        
        self.train_dataset = train_dataset
        self.train_dataset.set_transform(self.preprocess)
        self.test_dataset = test_dataset
        self.test_dataset.set_transform(self.preprocess)
        
        
        self.num_classes = num_classes
        self.lr = lr
        self.batch_size = batch_size
        self.num_workers = num_workers\
        self.optimizer_name = str(optimizer).lower()
        self.optimizer = None
        self.f1_score_metric = MulticlassF1Score(num_classes=num_classes)
        
        
    # def on_after_batch_transfer(self, batch, dataloader_idx):
    #     pixel_values, labels = batch
    #     if self.trainer.training:
    #         x = self.transform(pixel_values)  # => we perform GPU/Batched data augmentation
    #     return x, labels
      
      
    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        return outputs.logits
      
    
    def common_step(self, batch, batch_idx):
      
        pixel_values, labels = batch
        
        logits = self(pixel_values)

        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits, labels)
        predictions = logits.argmax(-1)
        correct = (predictions == labels).sum().item()
        accuracy = correct/pixel_values.shape[0]
        #accuracy = torchmetrics.functional.accuracy(predictions, labels, task="multiclass", num_classes=4)
        f1 = self.f1_score_metric(logits, labels)

        return loss, accuracy, f1
      
    def training_step(self, batch, batch_idx):
        loss, accuracy, f1 = self.common_step(batch, batch_idx)
        self.log("training_loss", loss, on_epoch=True, prog_bar=True)
        self.log("training_accuracy", accuracy, on_epoch=True, prog_bar=True)
        self.log("training_f1", f1, on_epoch=True, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        loss, accuracy, f1 = self.common_step(batch, batch_idx)
        self.log("test_loss", loss, on_epoch=True, prog_bar=True)
        self.log("test_accuracy", accuracy, on_epoch=True, prog_bar=True)
        self.log("test_f1", f1, on_epoch=True, prog_bar=True)

        return loss

    def configure_optimizers(self):
        if self.optimizer_name == "adam":
            self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=0.05)
        elif self.optimizer_name == "sgd":
            self.optimizer = torch.optim.SGD(self.parameters(), lr=self.lr)
        else:
            raise ValueError("Invalid optimizer name. Please choose either 'adam' or 'sgd'.")

        return self.optimizer

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          pin_memory=True,
                          collate_fn=collate_fn, shuffle=False)

    def val_dataloader(self):
        return DataLoader(self.test_dataset,
                          batch_size=self.batch_size,
                          pin_memory=True,
                          collate_fn=collate_fn)


    def show_batch(self, win_size=(10, 10)):
        def _to_vis(data):
            # Ensure that pixel values are in the valid range [0, 1]
            data = torch.clamp(data, 0, 1)
            return tensor_to_image(torchvision.utils.make_grid(data, nrow=8))

        # Get a batch from the training set
        imgs, labels = next(iter(self.train_dataloader()))

        # Apply data augmentation to the batch
        imgs_aug = self.transform(imgs)

        # Use matplotlib to visualize the original and augmented images
        plt.figure(figsize=win_size)
        plt.imshow(_to_vis(imgs))
        plt.title("Original Images")

        plt.figure(figsize=win_size)
        plt.imshow(_to_vis(imgs_aug))
        plt.title("Augmented Images")

# Experiment

## Model configuration

In [36]:
selected_model="google_vit"


configuration = {
    "num_labels": 4,
    "num_attention_heads": 4,
    "num_hidden_layers":4
}

hyperparameters = {
  "train_dataset": train_dataset,
  "test_dataset": test_dataset,
  "batch_size": 4,
  "lr": 0.0001,
  "optimizer": "sgd",
  "num_workers": 0,
  "pretrained": True,
  "configuration": configuration
}
# Instantiate lightning model
if selected_model == "google_vit":
  model = ViTLightningModule(**hyperparameters)
elif selected_model == "resnet18":
  model =  ResNet18LightningModule(**hyperparameters)
elif selected_model == "beit": 
  model =  BEiTLightningModule(**hyperparameters)
else:
  raise ValueError("Invalid model name. Please choose either 'google_vit' or 'resnet18'.")

table_data = []
table_data.append(["MODEL HYPERPARAMETERS"])
table_data.append(["model", selected_model])
for key, value in hyperparameters.items():
    if key not in ["train_dataset", "test_dataset"]:
      table_data.append([key, value])

table = tabulate(table_data, headers="firstrow", tablefmt="fancy_grid")
print(table)

╒═══════════════╤═════════════════════════════════════════════════════════════════════╕
│               │ MODEL HYPERPARAMETERS                                               │
╞═══════════════╪═════════════════════════════════════════════════════════════════════╡
│ model         │ google_vit                                                          │
├───────────────┼─────────────────────────────────────────────────────────────────────┤
│ batch_size    │ 4                                                                   │
├───────────────┼─────────────────────────────────────────────────────────────────────┤
│ lr            │ 0.0001                                                              │
├───────────────┼─────────────────────────────────────────────────────────────────────┤
│ optimizer     │ sgd                                                                 │
├───────────────┼─────────────────────────────────────────────────────────────────────┤
│ num_workers   │ 0             

In [37]:
model.config

ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2",
    "3": "LABEL_3"
  },
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2,
    "LABEL_3": 3
  },
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 4,
  "num_channels": 3,
  "num_hidden_layers": 4,
  "patch_size": 16,
  "qkv_bias": true,
  "transformers_version": "4.33.3"
}

## Trainer configuration

In [38]:
# Logger configuration
name_trained = "pretrained_" if hyperparameters["pretrained"]==True else ""
model_name = f"{name_trained}{selected_model}/{hyperparameters['optimizer']}/{hyperparameters['lr']}_{hyperparameters['batch_size']}"
logger = TensorBoardLogger("tb_logs", name=model_name)

callbacks = []

checkpoint_dir = f"{working_dir}/checkpoints/{model_name}"
checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_dir, 
                                      save_top_k=3,
                                      mode="min",
                                      monitor="training_loss",
                                      save_last=True,
                                      verbose=True)
callbacks.append(checkpoint_callback)


In [39]:

trainer_args = {
    "accelerator": "mps",
    "max_epochs": 5,
    "callbacks": callbacks,
    "precision": 16,
    "accumulate_grad_batches": 16,
    "logger": logger
}
table_data = []
table_data.append(["TRAINER ARGUMENTS"])
for key, value in trainer_args.items():
    if key not in ["callbacks", "logger"]:
        table_data.append([key, value])

table = tabulate(table_data, headers="firstrow", tablefmt="fancy_grid")
print("\n\n" + table)
print(f"Model checkpoints directory is {checkpoint_dir}")
print("\n\n")
trainer = Trainer(**trainer_args,
                  default_root_dir = checkpoint_dir)

print("\n\n" + "-" * 20)
print("Trainer Callbacks:")
print("-" * 20 + "\n\n")
for callback in trainer.callbacks:
    print(f"- {type(callback).__name__}")

  rank_zero_warn(
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs




╒═════════════════════════╤═════════════════════╕
│                         │ TRAINER ARGUMENTS   │
╞═════════════════════════╪═════════════════════╡
│ accelerator             │ mps                 │
├─────────────────────────┼─────────────────────┤
│ max_epochs              │ 5                   │
├─────────────────────────┼─────────────────────┤
│ precision               │ 16                  │
├─────────────────────────┼─────────────────────┤
│ accumulate_grad_batches │ 16                  │
╘═════════════════════════╧═════════════════════╛
Model checkpoints directory is .././checkpoints/pretrained_google_vit/sgd/0.0001_4





--------------------
Trainer Callbacks:
--------------------


- TQDMProgressBar
- ModelSummary
- ModelCheckpoint


In [40]:
# Checkpointing
# Checkpoints directory

checkpoint_path = ''

# Check if checkpoint path is provided
if checkpoint_path:
    print(f"Loading checkpoint from PATH: '{checkpoint_path}'...\n")
    trainer.fit(model, ckpt_path=checkpoint_path)
else:
    # Instantiate trainer without checkpoint
    print("Instantiating trainer without checkpoint...")
    trainer.fit(model)

Instantiating trainer without checkpoint...



  | Name            | Type                      | Params
--------------------------------------------------------------
0 | vit             | ViTForImageClassification | 29.1 M
1 | transform       | DataAugmentation          | 0     
2 | f1_score_metric | MulticlassF1Score         | 0     
--------------------------------------------------------------
29.1 M    Trainable params
0         Non-trainable params
29.1 M    Total params
116.395   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]



                                                                           

  rank_zero_warn(


                                    [12:14<00:00,  3.76it/s, v_num=1, training_loss_step=0.0308, training_accuracy_step=1.000, training_f1_step=1.000]

RuntimeError: MPS backend out of memory (MPS allocated: 801.91 MB, other allocations: 8.28 GB, max allowed: 9.07 GB). Tried to allocate 9.23 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [41]:
%load_ext tensorboard
%tensorboard --logdir tb_logs/

ERROR: Failed to launch TensorBoard (exited with 1).
Contents of stderr:
TensorFlow installation not found - running with reduced feature set.
Address already in use
Port 6006 is in use by another program. Either identify and stop that program, or start the server with a different port.

: 