<a href="https://www.kaggle.com/code/pietrocaforio/unimodal-ct-training-kaggle?scriptVersionId=197281783" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Train 3D unimodal CT

In [1]:
!git clone https://github.com/PietroCaforio/research-biocv-proj
!cd research-biocv-proj && git switch dev

Cloning into 'research-biocv-proj'...
remote: Enumerating objects: 233, done.[K
remote: Counting objects: 100% (25/25), done.[K
remote: Compressing objects: 100% (21/21), done.[K
remote: Total 233 (delta 13), reused 9 (delta 4), pack-reused 208 (from 1)[K
Receiving objects: 100% (233/233), 3.62 MiB | 20.73 MiB/s, done.
Resolving deltas: 100% (135/135), done.
Branch 'dev' set up to track remote branch 'dev' from 'origin'.
Switched to a new branch 'dev'


In [2]:
!cd research-biocv-proj && git pull

Already up to date.


In [3]:
!pip install wandb



In [4]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("3d_wandb_key")

In [5]:
import wandb
wandb.login(key=secret_value_0)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [6]:
import sys
from pathlib import Path

# Add the 'data' directory to sys.path
sys.path.append(str(Path('research-biocv-proj').resolve()))
from data.unimodal3D import *
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader

from sklearn.utils.class_weight import compute_class_weight

### Train ResNet model

In [7]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [8]:
def train(model,config, run_name=None):
    wandb.init(
        # set the wandb project where this run will be logged
        project="3Dunimodal_ct_training",
        name = run_name,
        # track hyperparameters and run metadata
        config=config
    )
    optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
    criterion = nn.CrossEntropyLoss(weight = torch.tensor(config["class_weights"], dtype=torch.float).to(device) )
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',factor = config["reduce_lr_factor"], patience = config["patience"])
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=config["early_stop_patience"], verbose=True)
    
    # Training loop
    num_epochs = config["epochs"]
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        train_correct = 0
        total = 0
        correct_per_class = [0, 0, 0]  # For G1, G2, G3
        total_per_class = [0, 0, 0]  # For G1, G2, G3
        for batch in train_loader:
            volumes = batch['volume'].float().to(device)
            labels = batch['label'].long().to(device)
            #print(f"volumes batch dimensions: {volumes.size()}")
            optimizer.zero_grad()
            outputs = model(volumes.unsqueeze(1))
            loss = criterion(outputs, labels)
            _, predicted = torch.max(outputs, 1)
            train_correct += (predicted == labels).sum().item()
            loss.backward()
            optimizer.step()
            total += labels.size(0)
            running_loss += loss.item()
            
            # Calculate accuracy per class
            for i in range(3):  # We have 3 classes: G1 (0), G2 (1), G3 (2)
                correct_per_class[i] += ((predicted == i) & (labels == i)).sum().item()
                total_per_class[i] += (labels == i).sum().item()

        train_accuracy = 100 * train_correct / total
        class_accuracy = [(100 * correct_per_class[i] / total_per_class[i]) if total_per_class[i] > 0 else 0 for i in range(3)]
        print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")
        wandb.log({"Train Accuracy": train_accuracy, "Train loss": running_loss/len(train_loader), "G1_TrainAcc":class_accuracy[0], "G2_TrainAcc":class_accuracy[1], "G3_TrainAcc":class_accuracy[2]})

        # Validation loop
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        # Initialize counters for each class (G1, G2, G3)
        correct_per_class = [0, 0, 0]  # For G1, G2, G3
        total_per_class = [0, 0, 0]  # For G1, G2, G3

        with torch.no_grad():
            for batch in val_loader:
                volumes = batch['volume'].float().to(device)
                labels = batch['label'].long().to(device)
                
                outputs = model(volumes.unsqueeze(1))
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                # Calculate accuracy per class
                for i in range(3):  # We have 3 classes: G1 (0), G2 (1), G3 (2)
                    correct_per_class[i] += ((predicted == i) & (labels == i)).sum().item()
                    total_per_class[i] += (labels == i).sum().item()
        scheduler.step(val_loss)
        # Compute total accuracy and per-class accuracy
        total_accuracy = 100 * correct / total
        class_accuracy = [(100 * correct_per_class[i] / total_per_class[i]) if total_per_class[i] > 0 else 0 for i in range(3)]
        print(f"Validation Loss: {val_loss/len(val_loader)}, Total Accuracy: {total_accuracy:.2f}%")
        print(f"Accuracy per class - G1: {class_accuracy[0]:.2f}%, G2: {class_accuracy[1]:.2f}%, G3: {class_accuracy[2]:.2f}%")
        # log metrics to wandb
        wandb.log({"Total Accuracy": total_accuracy, "Validation Loss": val_loss/len(val_loader), "G1_Acc":class_accuracy[0], "G2_Acc":class_accuracy[1], "G3_Acc":class_accuracy[2]})
        early_stopping(val_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
    wandb.finish()  

In [9]:
import torch.nn as nn
import torch.optim as optim
from transformers import ResNetForImageClassification

In [10]:
train_dataset = UnimodalCTDataset3D(split='train',dataset_path = "/kaggle/input/preprocessed57patientscptacpda/processed/" )
val_dataset = UnimodalCTDataset3D(split='val',dataset_path = "/kaggle/input/preprocessed57patientscptacpda/processed/")

train_loader = DataLoader(train_dataset, batch_size=5, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)

In [11]:
print(f"Training set stats:{train_dataset.stats()}")
print(f"Validation set stats:{val_dataset.stats()}")

Training set stats:{'length': 135, 'class_frequency': {'G1': 11, 'G2': 91, 'G3': 33}}
Validation set stats:{'length': 19, 'class_frequency': {'G1': 5, 'G2': 8, 'G3': 6}}


In [12]:
labels = []
for sample in train_dataset:
    labels.append(sample["label"])
labels = np.array(labels)
class_weights = compute_class_weight("balanced", classes=np.unique(labels), y=labels)

In [13]:
print(class_weights)

[4.09090909 0.49450549 1.36363636]


### 3D Resnet-18

In [14]:
import torchvision.models as models

In [15]:
class ResNet3D(nn.Module):
    def __init__(self, num_classes):
        super(ResNet3D, self).__init__()
        self.model = models.video.r3d_18(pretrained=True)  # 3D ResNet18
        self.model.stem[0] = nn.Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2), padding=(3, 3, 3), bias=False)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
    
    def forward(self, x):
        return self.model(x)

In [16]:
model = ResNet3D(num_classes = 3)

Downloading: "https://download.pytorch.org/models/r3d_18-b3b3357e.pth" to /root/.cache/torch/hub/checkpoints/r3d_18-b3b3357e.pth
100%|██████████| 127M/127M [00:01<00:00, 114MB/s]


In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

ResNet3D(
  (model): VideoResNet(
    (stem): BasicStem(
      (0): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2), padding=(3, 3, 3), bias=False)
      (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Sequential(
          (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (conv2): Sequential(
          (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (1): BasicBlock(
        (conv1): Sequential(
          (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3)

In [18]:
config={
    "learning_rate": 9e-8,
    "architecture": "ResNet3D",
    "epochs": 400,
    "weight_decay": 1e-5,
    "reduce_lr_factor": 0.25,
    "patience": 20,
    "class_weights": class_weights,
    "early_stop_patience": 30 
    }
train(model, config, run_name = config["architecture"])

[34m[1mwandb[0m: Currently logged in as: [33mpietro-caforio[0m ([33mpietro-caforio-politecnico-di-milano[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.18.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.17.7
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20240923_091615-5wbhh7ug[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mResNet3D[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/pietro-caforio-politecnico-di-milano/3Dunimodal_ct_training[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/pietro-caforio-politecnico-di-milano/3Dunimodal_ct_training/runs/5wbhh7ug[0m


Epoch 1, Loss: 1.2906238834063213
Validation Loss: 1.1735534846782685, Total Accuracy: 31.58%
Accuracy per class - G1: 100.00%, G2: 0.00%, G3: 16.67%
Validation loss decreased (inf --> 11.735535).  Saving model ...
Epoch 2, Loss: 1.2899802415459245
Validation Loss: 1.2708132445812226, Total Accuracy: 26.32%
Accuracy per class - G1: 100.00%, G2: 0.00%, G3: 0.00%
EarlyStopping counter: 1 out of 30
Epoch 3, Loss: 1.2236692861274436
Validation Loss: 1.2446797370910645, Total Accuracy: 26.32%
Accuracy per class - G1: 100.00%, G2: 0.00%, G3: 0.00%
EarlyStopping counter: 2 out of 30
Epoch 4, Loss: 1.2792436877886455
Validation Loss: 1.23627108335495, Total Accuracy: 26.32%
Accuracy per class - G1: 100.00%, G2: 0.00%, G3: 0.00%
EarlyStopping counter: 3 out of 30
Epoch 5, Loss: 1.2529872059822083
Validation Loss: 1.24156054854393, Total Accuracy: 26.32%
Accuracy per class - G1: 100.00%, G2: 0.00%, G3: 0.00%
EarlyStopping counter: 4 out of 30
Epoch 6, Loss: 1.2623171673880682
Validation Loss: 1.

[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:          G1_Acc ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:     G1_TrainAcc ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:          G2_Acc ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:     G2_TrainAcc ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:          G3_Acc █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:     G3_TrainAcc ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:  Total Accuracy █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:  Train Accuracy ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:      Train loss ██▇▇█▅▇▅▄▄▄▃▆▄▃▅▃▅▄▃▃▃▅▃▅▃▂▂▃▄▃▁▂▂▂▃▂▄▄▃
[34m[1mwandb[0m: Validation Loss ▃█▆▅▅▅▄▄▅▄▄▄▃▃▄▃▃▃▂▃▃▃▂▃▁▂▂▃▂▂▂▂▂▃▂▁▂▂▂▂
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:   