<a href="https://www.kaggle.com/code/pietrocaforio/unimodal-ct-training-kaggle?scriptVersionId=199214851" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Train unimodal CT

In [None]:
!git clone https://github.com/PietroCaforio/research-biocv-proj
!cd research-biocv-proj && git switch dev

In [None]:
!cd research-biocv-proj && git pull

In [None]:
!pip install wandb

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_api_key")

In [None]:
import wandb
wandb.login(key=secret_value_0)

In [None]:
import sys
from pathlib import Path

# Add the 'data' directory to sys.path
sys.path.append(str(Path('research-biocv-proj').resolve()))
from data.unimodal import *
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader

from sklearn.utils.class_weight import compute_class_weight

### Train ResNet model

In [None]:
#https://github.com/mathiaszinnen/focal_loss_torch/tree/main
!pip install focal_loss_torch

In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [None]:
from focal_loss.focal_loss import FocalLoss

def train(model,config, run_name=None):
    wandb.init(
        # set the wandb project where this run will be logged
        project="unimodal_ct_training",
        name = run_name,
        # track hyperparameters and run metadata
        config=config
    )
    if config["class_weights"] is not None: 
        config["class_weights"] = torch.tensor(config["class_weights"], dtype=torch.float).to(device) 
    optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
    if config["focal_loss"] is not None:
        
        criterion = FocalLoss(gamma = config["focal_loss"])
    else:
        criterion = nn.CrossEntropyLoss(weight = config["class_weights"])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',factor = config["reduce_lr_factor"], patience = config["patience"])
    
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=config["early_stop_patience"], verbose=True)
    
    # Training loop
    num_epochs = config["epochs"]
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        train_correct = 0
        total = 0
        correct_per_class = [0, 0, 0]  # For G1, G2, G3
        total_per_class = [0, 0, 0]  # For G1, G2, G3
        for batch in train_loader:
            frames = batch['frame'].float().to(device)
            labels = batch['label'].long().to(device)
            
            optimizer.zero_grad()
            outputs = model(frames)
            if config["focal_loss"]:
                softmax = torch.nn.Softmax(dim=-1)
                loss = criterion(softmax(outputs.logits), labels)
            else:
                loss = criterion(outputs.logits, labels)

            
            _, predicted = torch.max(outputs.logits, 1)
            train_correct += (predicted == labels).sum().item()
            loss.backward()
            optimizer.step()
            total += labels.size(0)
            running_loss += loss.item()
            
            # Calculate accuracy per class
            for i in range(3):  # We have 3 classes: G1 (0), G2 (1), G3 (2)
                correct_per_class[i] += ((predicted == i) & (labels == i)).sum().item()
                total_per_class[i] += (labels == i).sum().item()

        train_accuracy = 100 * train_correct / total
        class_accuracy = [(100 * correct_per_class[i] / total_per_class[i]) if total_per_class[i] > 0 else 0 for i in range(3)]
        print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")
        wandb.log({"Train Accuracy": train_accuracy, "Train loss": running_loss/len(train_loader), "G1_TrainAcc":class_accuracy[0], "G2_TrainAcc":class_accuracy[1], "G3_TrainAcc":class_accuracy[2]})

        # Validation loop
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        # Initialize counters for each class (G1, G2, G3)
        correct_per_class = [0, 0, 0]  # For G1, G2, G3
        total_per_class = [0, 0, 0]  # For G1, G2, G3

        with torch.no_grad():
            for batch in val_loader:
                frames = batch['frame'].float().to(device)
                labels = batch['label'].long().to(device)

                outputs = model(frames)
                
                if config["focal_loss"]:
                    softmax = torch.nn.Softmax(dim=-1)
                    loss = criterion(softmax(outputs.logits), labels)
                else:
                    loss = criterion(outputs.logits, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.logits, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                # Calculate accuracy per class
                for i in range(3):  # We have 3 classes: G1 (0), G2 (1), G3 (2)
                    correct_per_class[i] += ((predicted == i) & (labels == i)).sum().item()
                    total_per_class[i] += (labels == i).sum().item()
        scheduler.step(val_loss)
        # Compute total accuracy and per-class accuracy
        total_accuracy = 100 * correct / total
        class_accuracy = [(100 * correct_per_class[i] / total_per_class[i]) if total_per_class[i] > 0 else 0 for i in range(3)]
        print(f"Validation Loss: {val_loss/len(val_loader)}, Total Accuracy: {total_accuracy:.2f}%")
        print(f"Accuracy per class - G1: {class_accuracy[0]:.2f}%, G2: {class_accuracy[1]:.2f}%, G3: {class_accuracy[2]:.2f}%")
        # log metrics to wandb
        wandb.log({"Total Accuracy": total_accuracy, "Validation Loss": val_loss/len(val_loader), "G1_Acc":class_accuracy[0], "G2_Acc":class_accuracy[1], "G3_Acc":class_accuracy[2]})
        early_stopping(val_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
    wandb.finish()  

In [None]:
import torch.nn as nn
import torch.optim as optim
from transformers import ResNetForImageClassification
from sklearn.model_selection import StratifiedGroupKFold #For crossvalidation

In [None]:
train_dataset = UnimodalCTDataset(split='all',dataset_path = "/kaggle/input/preprocessed57patientscptacpda/processed/" )

In [None]:
#print(f"Training set stats:{train_dataset.stats()}")
#print(f"Validation set stats:{val_dataset.stats()}")

In [None]:
"""
labels = []
for sample in train_dataset:
    labels.append(sample["label"])
labels = np.array(labels)
class_weights = compute_class_weight("balanced", classes=np.unique(labels), y=labels)
"""

In [None]:
#print(class_weights)

In [None]:
#train_dataset = UnimodalCTDataset(split='train',dataset_path = "/kaggle/input/oversampling57patientscptacpda/processed_oversampling/" )
#val_dataset = UnimodalCTDataset(split='val',dataset_path = "/kaggle/input/oversampling57patientscptacpda/processed_oversampling/")

#train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
#val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [None]:
#print(f"Training set stats:{train_dataset.stats()}")
#print(f"Validation set stats:{val_dataset.stats()}")

In [None]:
#total = train_dataset.stats()["length"]
#most_frequent = max(train_dataset.stats()["class_frequency"].values())
#freq_dict = train_dataset.stats()["class_frequency"]
#target_volume_depth= {}
#for index in freq_dict.keys():
#    target_volume_depth[index] = int((total/3 ) * most_frequent / freq_dict[index])
#print(target_volume_depth)

### Resnet-50

In [None]:
#model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50')
#model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, UnimodalCTDataset.num_classes) #Adjusting the final layer to the unimodal number of classes

In [None]:
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model.to(device)

In [None]:
"""
config={
    "learning_rate": 1e-4,
    "architecture": "microsoft/resnet-50 new1",
    "epochs": 100,
    "weight_decay": 1e-4,
    "reduce_lr_factor": 0.2,
    "patience": 10,
    "class_weights": class_weights
    }
train(model, config, run_name = config["architecture"])
"""

### Resnet-18

In [None]:
"""
config={
    "learning_rate": 1e-7,
    "architecture": "microsoft/resnet-18",
    "run_name": "microsoft/resnet-18 NOOVERSAMPLING FOCALLOSS",
    "epochs": 800,
    "weight_decay": 1e-6,
    "reduce_lr_factor": 0.25,
    "patience": 20,
    "early_stop_patience": 40,
    "class_weights": None,
    "focal_loss": 2
    }

"""




### Resnet-34

In [None]:
#model = ResNetForImageClassification.from_pretrained('microsoft/resnet-34')
#model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, UnimodalCTDataset.num_classes) #Adjusting the final layer to the unimodal number of classes

In [None]:
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model.to(device)

In [None]:

config={
    "learning_rate": 5e-8,
    "architecture": "microsoft/resnet-34",
    "run_name": "microsoft/resnet-34 NOOVERSAMPLING FOCALLOSS",
    "epochs": 800,
    "weight_decay": 1e-5,
    "reduce_lr_factor": 0.25,
    "patience": 20,
    "early_stop_patience": 40,
    "class_weights": None,
    "focal_loss": 5
    }




In [None]:
from torch.utils.data import DataLoader, Subset
from collections import Counter 

k_folds = 3
batch_size = 32
# Initialize the k-fold cross validation
#group fold in order to group indices by patient_id so that we don't introduct frames of the same patient in the train set and test set
gkf = StratifiedGroupKFold(n_splits=k_folds) 
patient_ids = [info.split("/")[0].split("_")[0] for info in train_dataset.items]
labels = [train_dataset.labels[patient_id] for patient_id in patient_ids]
indices = list(range(len(patient_ids)))
# Loop through each fold
for fold, (train_idx, test_idx) in enumerate(gkf.split(indices, labels, groups=patient_ids)):
    print(f"Fold {fold + 1}")
    print("-------")
    # Define the data loaders for the current fold
    train_subset = Subset(train_dataset, train_idx)
    val_subset = Subset(train_dataset, test_idx)
    train_loader = DataLoader(
        dataset=train_subset,
        batch_size=batch_size,
        shuffle = True
    )
    val_loader = DataLoader(
        dataset=val_subset,
        batch_size=batch_size,
        shuffle = False,
    )
    train_ids = []
    train_labels = []
    for frame in train_loader:
        train_ids.extend(frame["patient_id"])
        train_labels.extend(frame["label"].tolist())
    print(f"TRAIN LABELS IN FOLD {fold + 1 }: {Counter(train_labels)}")
    val_ids = []
    val_labels = []
    for frame in val_loader:
        val_ids.extend(frame["patient_id"])
        val_labels.extend(frame["label"].tolist())
    print(f"VAL LABELS IN FOLD {fold + 1}: {Counter(val_labels)}")
    print(len(set(train_idx) - set(val_ids)))
    print(len(set(train_ids)))
    assert len(set(train_ids) - set(val_ids)) == len(set(train_ids))
    run_name = list(config["run_name"])
    run_name[-1] = str(fold + 1)
    config["run_name"] = ''.join(run_name)
    
    #Prepare model
    model = ResNetForImageClassification.from_pretrained(config["architecture"])
    model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, UnimodalCTDataset.num_classes) #Adjusting the final layer to the unimodal number of classes
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    #Train model
    train(model, config, run_name = config["run_name"])
