<a href="https://colab.research.google.com/github/PietroCaforio/research-biocv-proj/blob/dev/unimodal_ct_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Train unimodal CT

In [1]:
!git clone https://github.com/PietroCaforio/research-biocv-proj
!cd research-biocv-proj && git switch dev

Cloning into 'research-biocv-proj'...
remote: Enumerating objects: 135, done.[K
remote: Counting objects: 100% (135/135), done.[K
remote: Compressing objects: 100% (100/100), done.[K
remote: Total 135 (delta 64), reused 94 (delta 27), pack-reused 0 (from 0)[K
Receiving objects: 100% (135/135), 3.34 MiB | 4.06 MiB/s, done.
Resolving deltas: 100% (64/64), done.
Branch 'dev' set up to track remote branch 'dev' from 'origin'.
Switched to a new branch 'dev'


In [2]:
!cd research-biocv-proj && git pull

Already up to date.


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
!unzip -o /content/drive/MyDrive/processed57PatientsCPTAC-PDA.zip -d research-biocv-proj/data/

Archive:  /content/drive/MyDrive/processed57PatientsCPTAC-PDA.zip
   creating: research-biocv-proj/data/processed/CT/
   creating: research-biocv-proj/data/processed/CT/C3L-00189/
  inflating: research-biocv-proj/data/processed/CT/C3L-00189/0.npy  
  inflating: research-biocv-proj/data/processed/CT/C3L-00189/1.npy  
   creating: research-biocv-proj/data/processed/CT/C3L-00599/
  inflating: research-biocv-proj/data/processed/CT/C3L-00599/0.npy  
  inflating: research-biocv-proj/data/processed/CT/C3L-00599/1.npy  
  inflating: research-biocv-proj/data/processed/CT/C3L-00599/2.npy  
   creating: research-biocv-proj/data/processed/CT/C3L-00622/
  inflating: research-biocv-proj/data/processed/CT/C3L-00622/0.npy  
  inflating: research-biocv-proj/data/processed/CT/C3L-00622/1.npy  
  inflating: research-biocv-proj/data/processed/CT/C3L-00622/2.npy  
   creating: research-biocv-proj/data/processed/CT/C3L-00625/
  inflating: research-biocv-proj/data/processed/CT/C3L-00625/0.npy  
  inflating: 

In [5]:
import sys
from pathlib import Path

# Add the 'data' directory to sys.path
sys.path.append(str(Path('research-biocv-proj').resolve()))
from data.unimodal import *
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader

In [6]:
# Instantiate the dataset
dataset = UnimodalCTDataset(split='train', dataset_path = "research-biocv-proj/data/processed/")

# Check the length of the dataset
print(f"Dataset length: {len(dataset)}")

# Check the first few items in the dataset
for i in range(3):
    item = dataset[i]
    print(f"Item {i}:")
    print(f"  Patient ID: {item['patient_id']}")
    print(f"  Frame shape: {item['frame'].shape}")
    print(f"  Label: {item['label']}")

# Check if DataLoader works with the dataset
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Get a batch of data
batch = next(iter(dataloader))

# Check the batch
print(f"Batch patient IDs: {batch['patient_id']}")
print(f"Batch frame shape: {batch['frame'].shape}")
print(f"Batch labels: {batch['label']}")

# Move batch to device (e.g., GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset.move_batch_to_device(batch, device)
print(f"Batch moved to device: {device}")

Dataset length: 2329
Item 0:
  Patient ID: C3L-03632
  Frame shape: (3, 224, 224)
  Label: 0
Item 1:
  Patient ID: C3L-03632
  Frame shape: (3, 224, 224)
  Label: 0
Item 2:
  Patient ID: C3L-03632
  Frame shape: (3, 224, 224)
  Label: 0
Batch patient IDs: ['C3L-02613', 'C3N-00511', 'C3N-03430', 'C3N-03430']
Batch frame shape: torch.Size([4, 3, 224, 224])
Batch labels: tensor([2, 1, 1, 1])
Batch moved to device: cuda


### Train ResNet model

In [7]:
import torch.nn as nn
import torch.optim as optim
from transformers import ResNetForImageClassification

In [9]:
train_dataset = UnimodalCTDataset(split='train',dataset_path = "research-biocv-proj/data/processed/" )
val_dataset = UnimodalCTDataset(split='val',dataset_path = "research-biocv-proj/data/processed/")

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [10]:
model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50')
model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, UnimodalCTDataset.num_classes) #Adjusting the final layer to the unimodal number of classes

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/69.6k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/102M [00:00<?, ?B/s]

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

ResNetForImageClassification(
  (resnet): ResNetModel(
    (embedder): ResNetEmbeddings(
      (embedder): ResNetConvLayer(
        (convolution): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (pooler): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (encoder): ResNetEncoder(
      (stages): ModuleList(
        (0): ResNetStage(
          (layers): Sequential(
            (0): ResNetBottleNeckLayer(
              (shortcut): ResNetShortCut(
                (convolution): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (normalization): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              )
              (layer): Sequential(
                (0): ResNetConvLayer(
                  (convolution): Conv2d(64

In [12]:
optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for batch in train_loader:
        frames = batch['frame'].float().to(device)
        labels = batch['label'].long().to(device)

        optimizer.zero_grad()
        outputs = model(frames)
        loss = criterion(outputs.logits, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

    # Validation loop
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in val_loader:
            frames = batch['frame'].float().to(device)
            labels = batch['label'].long().to(device)

            outputs = model(frames)
            loss = criterion(outputs.logits, labels)

            val_loss += loss.item()
            _, predicted = torch.max(outputs.logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f"Validation Loss: {val_loss/len(val_loader)}, Accuracy: {100 * correct / total}")

Epoch 1, Loss: 0.7835296763949198
Validation Loss: 28.55037311911583, Accuracy: 55.70469798657718
Epoch 2, Loss: 0.35914525471321523
Validation Loss: 24.09751783311367, Accuracy: 45.63758389261745
Epoch 3, Loss: 0.1415451267249372
Validation Loss: 17.654741038754583, Accuracy: 53.691275167785236
Epoch 4, Loss: 0.07011035559316205
Validation Loss: 13.263962956517934, Accuracy: 52.348993288590606
Epoch 5, Loss: 0.035886044295070925
Validation Loss: 9.628280287235976, Accuracy: 53.02013422818792
Epoch 6, Loss: 0.01423217496541265
Validation Loss: 6.547758491057903, Accuracy: 55.033557046979865
Epoch 7, Loss: 0.006666411789236207
Validation Loss: 7.722487875446677, Accuracy: 53.02013422818792
Epoch 8, Loss: 0.0041951010855272005
Validation Loss: 5.034614668088034, Accuracy: 55.369127516778526
Epoch 9, Loss: 0.002656789166431823
Validation Loss: 4.78961096489802, Accuracy: 54.69798657718121
Epoch 10, Loss: 0.0021610848728427024
Validation Loss: 6.605505459709093, Accuracy: 52.34899328859060

KeyboardInterrupt: 