In [1]:
import sys
sys.path.append('..')

import os
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from src.dataset import BrainMRIDataset

# Seting data path
DATA_PATH = '../data/mri-segmentation/kaggle_3m/'

In [2]:
# Get all patient IDs
dataset = BrainMRIDataset(data_path=DATA_PATH)
patient_ids = list(set([os.path.basename(os.path.dirname(path)) for path in dataset.image_paths]))

# Split patients
train_patients, test_patients = train_test_split(patient_ids, test_size=0.15, random_state=42)
train_patients, val_patients = train_test_split(train_patients, test_size=0.176, random_state=42)

# Create filtered datasets
train_dataset = BrainMRIDataset(DATA_PATH, patient_list=train_patients)
val_dataset = BrainMRIDataset(DATA_PATH, patient_list=val_patients)
test_dataset = BrainMRIDataset(DATA_PATH, patient_list=test_patients)

print(f"Train dataset: {len(train_dataset)} slices")
print(f"Val dataset: {len(val_dataset)} slices")
print(f"Test dataset: {len(test_dataset)} slices")

Found 3929 image-mask pairs
Found 2720 image-mask pairs
Found 475 image-mask pairs
Found 734 image-mask pairs
Train dataset: 2720 slices
Val dataset: 475 slices
Test dataset: 734 slices


In [3]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Training augmentations
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.Affine(scale=(0.9, 1.1), translate_percent=(-0.1, 0.1), rotate=(-15, 15), p=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

# Validation/Test augmentations
val_transform = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

print("Transforms defined!")

Transforms defined!


In [4]:
# Recreate datasets with transforms
train_dataset = BrainMRIDataset(DATA_PATH, transform=train_transform, patient_list=train_patients)
val_dataset = BrainMRIDataset(DATA_PATH, transform=val_transform, patient_list=val_patients)
test_dataset = BrainMRIDataset(DATA_PATH, transform=val_transform, patient_list=test_patients)

print("Datasets created with transforms!")
print(f"Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")

Found 2720 image-mask pairs
Found 475 image-mask pairs
Found 734 image-mask pairs
Datasets created with transforms!
Train: 2720 | Val: 475 | Test: 734


In [5]:
import torch

# Get one sample
image, mask = train_dataset[0]

print(f"Image type: {type(image)}")
print(f"Image shape: {image.shape}")  # Should be (3, 256, 256) - PyTorch format
print(f"Mask shape: {mask.shape}")
print(f"Image dtype: {image.dtype}")
print(f"Image range: [{image.min():.2f}, {image.max():.2f}]")  # Should be normalized

Image type: <class 'torch.Tensor'>
Image shape: torch.Size([3, 256, 256])
Mask shape: torch.Size([256, 256])
Image dtype: torch.float32
Image range: [-2.12, 1.53]


In [6]:
from torch.utils.data import DataLoader

BATCH_SIZE = 16

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

Train batches: 170
Val batches: 30
Test batches: 46


In [7]:
images, masks = next(iter(train_loader))

print(f"Batch images shape: {images.shape}")  # Should be (16, 3, 256, 256)
print(f"Batch masks shape: {masks.shape}")    # Should be (16, 256, 256)
print(f"Images dtype: {images.dtype}")
print(f"Masks dtype: {masks.dtype}")

Batch images shape: torch.Size([16, 3, 256, 256])
Batch masks shape: torch.Size([16, 256, 256])
Images dtype: torch.float32
Masks dtype: torch.float32


In [8]:
import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet50",        # pretrained ResNet50 as encoder
    encoder_weights="imagenet",     # using ImageNet pretrained weights
    in_channels=3,                  # RGB images
    classes=1,                      # binary segmentation (tumor vs background)
)

In [9]:
from torchinfo import summary
summary(model, input_size=(1, 3, 256, 256))  # batch_size=1, 3 channels, 256x256

Layer (type:depth-idx)                        Output Shape              Param #
Unet                                          [1, 1, 256, 256]          --
├─ResNetEncoder: 1-1                          [1, 3, 256, 256]          --
│    └─Conv2d: 2-1                            [1, 64, 128, 128]         9,408
│    └─BatchNorm2d: 2-2                       [1, 64, 128, 128]         128
│    └─ReLU: 2-3                              [1, 64, 128, 128]         --
│    └─MaxPool2d: 2-4                         [1, 64, 64, 64]           --
│    └─Sequential: 2-5                        [1, 256, 64, 64]          --
│    │    └─Bottleneck: 3-1                   [1, 256, 64, 64]          75,008
│    │    └─Bottleneck: 3-2                   [1, 256, 64, 64]          70,400
│    │    └─Bottleneck: 3-3                   [1, 256, 64, 64]          70,400
│    └─Sequential: 2-6                        [1, 512, 32, 32]          --
│    │    └─Bottleneck: 3-4                   [1, 512, 32, 32]          379,392

In [10]:
import torch.optim as optim

# Loss function - Binary Cross Entropy with Logits + Dice Loss
criterion = smp.losses.DiceLoss(mode='binary')

# Optimizer - Adam
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Device
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
model = model.to(device)

print(f"Using device: {device}")
print(f"Loss: Dice Loss")
print(f"Optimizer: Adam (lr=1e-4)")

Using device: mps
Loss: Dice Loss
Optimizer: Adam (lr=1e-4)


In [11]:
def calculate_dice_score(outputs, masks, threshold=0.5):
    """Calculate Dice score for batch"""
    # Apply sigmoid to get probabilities
    preds = torch.sigmoid(outputs) > threshold
    preds = preds.float()
    
    # Flatten
    preds = preds.view(-1)
    masks = masks.view(-1)
    
    # Dice = 2 * |X ∩ Y| / (|X| + |Y|)
    intersection = (preds * masks).sum()
    dice = (2. * intersection) / (preds.sum() + masks.sum() + 1e-8)
    
    return dice.item()

In [12]:
from tqdm import tqdm

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    running_dice = 0.0
    
    pbar = tqdm(dataloader, desc='Training', leave=False)
    
    for images, masks in pbar:
        images = images.to(device)
        masks = masks.to(device).unsqueeze(1)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        # Calculate dice score
        dice = calculate_dice_score(outputs, masks)
        
        running_loss += loss.item()
        running_dice += dice
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'dice': f'{dice:.4f}'})
    
    epoch_loss = running_loss / len(dataloader)
    epoch_dice = running_dice / len(dataloader)
    
    return epoch_loss, epoch_dice


In [13]:
def evaluate(model, dataloader, criterion, device):
    model.eval()  
    running_loss = 0.0
    running_dice = 0.0

    pbar = tqdm(dataloader, desc='Evaluating', leave=False)
    
    with torch.no_grad():
        for images, masks in pbar:
            images = images.to(device)
            masks = masks.to(device).unsqueeze(1)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # Calculate dice score
            dice = calculate_dice_score(outputs, masks)

            running_loss += loss.item()
            running_dice += dice
            
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'dice': f'{dice:.4f}'})
    
    eval_loss = running_loss / len(dataloader)
    eval_dice = running_dice / len(dataloader)
    
    return eval_loss, eval_dice

In [14]:
# Training parameters
NUM_EPOCHS = 20

# Lists to store metrics
train_losses, val_losses = [], []
train_dices, val_dices = [], []

print("Starting training...")

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    
    # Train
    train_loss, train_dice = train_one_epoch(model, train_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    train_dices.append(train_dice)
    
    # Validate
    val_loss, val_dice = evaluate(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    val_dices.append(val_dice)
    
    print(f"Train Loss: {train_loss:.4f} | Train Dice: {train_dice:.4f}")
    print(f"Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f}")

print("\nTraining complete!")

Starting training...

Epoch 1/20


                                                                                      

KeyboardInterrupt: 