In [1]:
!pip install -q segmentation_models_pytorch

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/154.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m153.6/154.8 kB[0m [31m9.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import os
import glob
import numpy as np
import nibabel as nib
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import segmentation_models_pytorch as smp

In [3]:
os.environ['KAGGLE_USERNAME'] = "olenamikhailova"

os.environ['KAGGLE_KEY'] = "KGAT_e8b5a4ccb83e496885c5e98432eb4846"

print("Attempting to download BraTS 2020...")
!kaggle datasets download -d awsaf49/brats20-dataset-training-validation --force

print("Unzipping...")
!unzip -q brats20-dataset-training-validation.zip -d ./brats2020
print("Dataset is ready.")

Attempting to download BraTS 2020...
Dataset URL: https://www.kaggle.com/datasets/awsaf49/brats20-dataset-training-validation
License(s): CC0-1.0
Downloading brats20-dataset-training-validation.zip to /content
100% 4.16G/4.16G [01:12<00:00, 49.5MB/s]
100% 4.16G/4.16G [01:12<00:00, 62.0MB/s]
Unzipping...
Dataset is ready.


In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
import glob
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np

root_search = glob.glob("./brats2020/**/MICCAI_BraTS2020_TrainingData", recursive=True)

if len(root_search) > 0:
    data_root = root_search[0]
    print(f"Dataset root found at: {data_root}")
else:
    data_root = "./brats2020"
    print(f"Assuming root at: {data_root}")

patient_dirs = sorted([d for d in glob.glob(os.path.join(data_root, "*")) if os.path.isdir(d)])
print(f"Total Patient Folders Found: {len(patient_dirs)}")

integrity_report = []

for p_dir in patient_dirs:
    pid = os.path.basename(p_dir)
    flair = glob.glob(os.path.join(p_dir, "*flair.nii"))
    seg = glob.glob(os.path.join(p_dir, "*seg.nii"))

    if not flair or not seg:
        integrity_report.append(pid)

if len(integrity_report) == 0:
    print("INTEGRITY CHECK PASSED: All patients have FLAIR and SEG files.")
else:
    print(f"WARNING: {len(integrity_report)} patients are missing files!")
    print(integrity_report[:5])

csv_files = glob.glob("./brats2020/**/*.csv", recursive=True)

if csv_files:
    print("\n METADATA FOUND ")
    meta_df = pd.read_csv(csv_files[0])
    print(f"Loaded metadata from: {os.path.basename(csv_files[0])}")
    print(meta_df.head())

    if 'Age' in meta_df.columns:
        plt.figure(figsize=(10, 4))
        plt.hist(meta_df['Age'], bins=20, color='teal', edgecolor='black')
        plt.title("Age Distribution in BraTS 2020")
        plt.xlabel("Age")
        plt.ylabel("Count")
        plt.show()
else:
    print("\nNo metadata CSV found (Demographics might be missing in this download).")

Dataset root found at: ./brats2020/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData
Total Patient Folders Found: 369
['BraTS20_Training_355']

 METADATA FOUND 
Loaded metadata from: name_mapping.csv
  Grade BraTS_2017_subject_ID BraTS_2018_subject_ID TCGA_TCIA_subject_ID  \
0   HGG   Brats17_CBICA_AAB_1   Brats18_CBICA_AAB_1                  NaN   
1   HGG   Brats17_CBICA_AAG_1   Brats18_CBICA_AAG_1                  NaN   
2   HGG   Brats17_CBICA_AAL_1   Brats18_CBICA_AAL_1                  NaN   
3   HGG   Brats17_CBICA_AAP_1   Brats18_CBICA_AAP_1                  NaN   
4   HGG   Brats17_CBICA_ABB_1   Brats18_CBICA_ABB_1                  NaN   

  BraTS_2019_subject_ID BraTS_2020_subject_ID  
0   BraTS19_CBICA_AAB_1  BraTS20_Training_001  
1   BraTS19_CBICA_AAG_1  BraTS20_Training_002  
2   BraTS19_CBICA_AAL_1  BraTS20_Training_003  
3   BraTS19_CBICA_AAP_1  BraTS20_Training_004  
4   BraTS19_CBICA_ABB_1  BraTS20_Training_005  


In [7]:
BATCH_SIZE = 16
EPOCHS = 20
PATIENCE = 3             # Stop if no improvement after 3 epochs
LEARNING_RATE = 1e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
class EarlyStopping:
    def __init__(self, patience=3, delta=0, path='utility_model_unet.pth'):
        self.patience = patience
        self.delta = delta
        self.path = path
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        torch.save(model.state_dict(), self.path)
        print(f'Validation loss decreased ({val_loss:.6f}). Saving model...')

In [9]:
class BraTSDataset(Dataset):
    def __init__(self, root_dir, image_size=256):
        self.root_dir = root_dir
        self.image_size = image_size
        self.patient_dirs = sorted(glob.glob(os.path.join(root_dir, "**", "BraTS20_Training_*"), recursive=True))
        self.patient_dirs = [d for d in self.patient_dirs if os.path.isdir(d)]

    def __len__(self):
        return len(self.patient_dirs)

    def __getitem__(self, idx):
        p_dir = self.patient_dirs[idx]
        p_id = os.path.basename(p_dir)

        flair_path = os.path.join(p_dir, f"{p_id}_flair.nii")
        seg_path = os.path.join(p_dir, f"{p_id}_seg.nii")

        try:
            flair_vol = nib.load(flair_path).get_fdata()
            seg_vol = nib.load(seg_path).get_fdata()

            # Smart Slicing: Pick slice with most tumor
            tumor_counts = np.sum(seg_vol > 0, axis=(0, 1))
            if np.max(tumor_counts) > 0:
                slice_idx = np.argmax(tumor_counts)
            else:
                slice_idx = flair_vol.shape[2] // 2

            img = flair_vol[:, :, slice_idx]
            mask = seg_vol[:, :, slice_idx]

            # Normalization (0-1)
            img = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-6)

            # Resizing
            img = cv2.resize(img, (self.image_size, self.image_size))
            mask = np.where(mask > 0, 1.0, 0.0) # Binarize mask
            mask = cv2.resize(mask, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST)

            img_tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).repeat(3, 1, 1)
            mask_tensor = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)

            return img_tensor, mask_tensor

        except Exception as e:
            return torch.zeros(3, 256, 256), torch.zeros(1, 256, 256)

In [10]:
print("Initializing Data and Model...")
dataset = BraTSDataset(root_dir="./brats2020")

train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1
).to(DEVICE)

criterion = smp.losses.DiceLoss(mode='binary', from_logits=True)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

early_stopper = EarlyStopping(patience=PATIENCE, path='utility_model_unet.pth')

print(f"\nSTARTING TRAINING on {DEVICE}")
print(f"Max Epochs: {EPOCHS} | Patience: {PATIENCE}")

train_losses = []
val_losses = []

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for imgs, masks in loop:
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)

        optimizer.zero_grad()
        preds = model(imgs)
        loss = criterion(preds, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    model.eval()
    val_running_loss = 0
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            preds = model(imgs)
            val_loss = criterion(preds, masks)
            val_running_loss += val_loss.item()

    avg_train = running_loss / len(train_loader)
    avg_val = val_running_loss / len(val_loader)

    train_losses.append(avg_train)
    val_losses.append(avg_val)

    print(f"\tTrain Loss: {avg_train:.4f} | Val Loss: {avg_val:.4f}")

    early_stopper(avg_val, model)

    if early_stopper.early_stop:
        print("\nEarly stopping triggered! Model has stopped improving.")
        break

print("\nTraining Pipeline Complete.")

Initializing Data and Model...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/87.3M [00:00<?, ?B/s]


STARTING TRAINING on cuda
Max Epochs: 20 | Patience: 3


Epoch 1/20: 100%|██████████| 21/21 [00:46<00:00,  2.23s/it, loss=0.857]


	Train Loss: 0.8450 | Val Loss: 0.8633
Validation loss decreased (0.863278). Saving model...


Epoch 2/20: 100%|██████████| 21/21 [00:45<00:00,  2.16s/it, loss=0.834]


	Train Loss: 0.7978 | Val Loss: 0.8112
Validation loss decreased (0.811180). Saving model...


Epoch 3/20: 100%|██████████| 21/21 [00:41<00:00,  2.00s/it, loss=0.738]


	Train Loss: 0.7623 | Val Loss: 0.7718
Validation loss decreased (0.771837). Saving model...


Epoch 4/20: 100%|██████████| 21/21 [00:39<00:00,  1.88s/it, loss=0.742]


	Train Loss: 0.7329 | Val Loss: 0.7415
Validation loss decreased (0.741473). Saving model...


Epoch 5/20: 100%|██████████| 21/21 [00:37<00:00,  1.80s/it, loss=0.75]


	Train Loss: 0.7091 | Val Loss: 0.7194
Validation loss decreased (0.719433). Saving model...


Epoch 6/20: 100%|██████████| 21/21 [00:38<00:00,  1.85s/it, loss=0.674]


	Train Loss: 0.6828 | Val Loss: 0.6939
Validation loss decreased (0.693899). Saving model...


Epoch 7/20: 100%|██████████| 21/21 [00:39<00:00,  1.90s/it, loss=0.644]


	Train Loss: 0.6526 | Val Loss: 0.6728
Validation loss decreased (0.672796). Saving model...


Epoch 8/20: 100%|██████████| 21/21 [00:41<00:00,  2.00s/it, loss=0.627]


	Train Loss: 0.6199 | Val Loss: 0.6412
Validation loss decreased (0.641202). Saving model...


Epoch 9/20: 100%|██████████| 21/21 [00:42<00:00,  2.01s/it, loss=0.629]


	Train Loss: 0.5824 | Val Loss: 0.6075
Validation loss decreased (0.607450). Saving model...


Epoch 10/20: 100%|██████████| 21/21 [00:44<00:00,  2.10s/it, loss=0.506]


	Train Loss: 0.5406 | Val Loss: 0.5574
Validation loss decreased (0.557423). Saving model...


Epoch 11/20: 100%|██████████| 21/21 [00:44<00:00,  2.13s/it, loss=0.504]


	Train Loss: 0.4989 | Val Loss: 0.5171
Validation loss decreased (0.517141). Saving model...


Epoch 12/20: 100%|██████████| 21/21 [00:44<00:00,  2.14s/it, loss=0.439]


	Train Loss: 0.4594 | Val Loss: 0.4872
Validation loss decreased (0.487170). Saving model...


Epoch 13/20: 100%|██████████| 21/21 [00:46<00:00,  2.21s/it, loss=0.385]


	Train Loss: 0.4228 | Val Loss: 0.4503
Validation loss decreased (0.450308). Saving model...


Epoch 14/20: 100%|██████████| 21/21 [00:47<00:00,  2.25s/it, loss=0.379]


	Train Loss: 0.3890 | Val Loss: 0.4152
Validation loss decreased (0.415201). Saving model...


Epoch 15/20: 100%|██████████| 21/21 [00:47<00:00,  2.26s/it, loss=0.324]


	Train Loss: 0.3579 | Val Loss: 0.3896
Validation loss decreased (0.389597). Saving model...


Epoch 16/20: 100%|██████████| 21/21 [00:46<00:00,  2.21s/it, loss=0.277]


	Train Loss: 0.3279 | Val Loss: 0.3794
Validation loss decreased (0.379385). Saving model...


Epoch 17/20: 100%|██████████| 21/21 [00:47<00:00,  2.26s/it, loss=0.293]


	Train Loss: 0.3005 | Val Loss: 0.3382
Validation loss decreased (0.338231). Saving model...


Epoch 18/20: 100%|██████████| 21/21 [00:48<00:00,  2.29s/it, loss=0.312]


	Train Loss: 0.2786 | Val Loss: 0.3343
Validation loss decreased (0.334299). Saving model...


Epoch 19/20: 100%|██████████| 21/21 [00:47<00:00,  2.25s/it, loss=0.256]


	Train Loss: 0.2564 | Val Loss: 0.2984
Validation loss decreased (0.298429). Saving model...


Epoch 20/20: 100%|██████████| 21/21 [00:48<00:00,  2.32s/it, loss=0.236]


	Train Loss: 0.2343 | Val Loss: 0.2750
Validation loss decreased (0.274972). Saving model...

Training Pipeline Complete.
