In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import logging
import wandb
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import os
import numpy as np
import nibabel as nib
import pandas as pd
from torch import Tensor
from pathlib import Path
import logging
import wandb


In [2]:
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    # Average of Dice coefficient for all batches, or for a single mask
    assert input.size() == target.size()
    assert input.dim() == 3 or not reduce_batch_first

    sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)

    inter = 2 * (input * target).sum(dim=sum_dim)
    sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
    sets_sum = torch.where(sets_sum == 0, inter, sets_sum)

    dice = (inter + epsilon) / (sets_sum + epsilon)
    return dice.mean()


def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    # Average of Dice coefficient for all classes
    return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon)


def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
    # Dice loss (objective to minimize) between 0 and 1
    fn = multiclass_dice_coeff if multiclass else dice_coeff
    return 1 - fn(input, target, reduce_batch_first=True)


In [3]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)

In [4]:
import os
import numpy as np
import nibabel as nib
import torch
from torch.utils.data import Dataset

class BrainSegmentationDataset(Dataset):
    def __init__(self, csv_path, transform=None):
        """
        Args:
            csv_path (str): Path to the CSV file with data details.
            transform (callable, optional): Optional transforms to be applied on a sample.
        """
        self.data_summary = pd.read_csv(csv_path)
        self.subjects = self.data_summary['Subject ID'].unique()
        self.transform = transform
        self.slice_info = self._create_slice_index()

    def _create_slice_index(self):
        """Create a list of (subject_id, slice_idx) pairs."""
        slice_info = []
        for subject_id in self.subjects:
            subject_data = self.data_summary[self.data_summary['Subject ID'] == subject_id]
            flair_path = subject_data[subject_data['Scan Type'] == 'flair']['File Path'].values[0]
            nii = nib.load(flair_path)
            depth = nii.shape[2]  # Assume all modalities have the same depth
            slice_info.extend([(subject_id, z) for z in range(depth)])
        return slice_info

    def __len__(self):
        return len(self.slice_info)

    def __getitem__(self, idx):
        subject_id, slice_idx = self.slice_info[idx]
        subject_data = self.data_summary[self.data_summary['Subject ID'] == subject_id]

        # Load input modalities
        modalities = ['flair', 't1', 't1ce', 't2']
        slices = []
        for modality in modalities:
            file_path = subject_data[subject_data['Scan Type'] == modality]['File Path'].values[0]
            nii = nib.load(file_path)
            image = nii.get_fdata().astype(np.float32)
            image = (image - np.mean(image)) / np.std(image)  # Normalize
            slices.append(image[:, :, slice_idx])  # Extract 2D slice

        # Stack modalities into a single tensor (C x H x W)
        images = np.stack(slices, axis=0)
        # Check for missing segmentation mask
        seg_data = subject_data[subject_data['Scan Type'] == 'seg']
        if seg_data.empty:
            raise ValueError(f"Missing segmentation mask for subject {subject_id}")
        seg_path = seg_data['File Path'].values[0]
        seg_nii = nib.load(seg_path)
        seg_mask = seg_nii.get_fdata().astype(np.uint8)
        
        # Remap labels: 4 -> 3
        seg_mask[seg_mask == 4] = 3
        seg_slice = seg_mask[:, :, slice_idx]  # Extract 2D slice

        # Apply transforms if specified
        if self.transform:
            images, seg_slice = self.transform(images, seg_slice)

        return torch.tensor(images, dtype=torch.float32), torch.tensor(seg_slice, dtype=torch.long)


In [5]:
import os
import numpy as np
import nibabel as nib
import torch
from torch.utils.data import Dataset, DataLoader

'''
class BrainSegmentationDataset(Dataset):
    def __init__(self, csv_path, transform=None):
        """
        Args:
            csv_path (str): Path to the CSV file with data details.
            transform (callable, optional): Optional transforms to be applied on a sample.
        """
        self.data_summary = pd.read_csv(csv_path)
        self.subjects = self.data_summary['Subject ID'].unique()
        self.transform = transform

    def __len__(self):
        return len(self.subjects)

    def __getitem__(self, idx):
        subject_id = self.subjects[idx]
        subject_data = self.data_summary[self.data_summary['Subject ID'] == subject_id]

        # Load input modalities
        modalities = ['flair', 't1', 't1ce', 't2']
        images = []
        for modality in modalities:
            file_path = subject_data[subject_data['Scan Type'] == modality]['File Path'].values[0]
            nii = nib.load(file_path)
            image = nii.get_fdata().astype(np.float32)
            image = (image - np.mean(image)) / np.std(image)  # Normalize
            images.append(image)

        # Stack modalities into a single tensor (C x H x W)
        images = np.stack(images, axis=0)

        # Load segmentation mask
        seg_path = subject_data[subject_data['Scan Type'] == 'seg']['File Path'].values[0]
        seg_nii = nib.load(seg_path)
        seg_mask = seg_nii.get_fdata().astype(np.uint8)

        # Apply transforms if specified
        if self.transform:
            images, seg_mask = self.transform(images, seg_mask)

        return torch.tensor(images), torch.tensor(seg_mask)
'''
# Initialize the dataset and dataloader
csv_path = '../data/training_detailed_summary_2020.csv'
dataset = BrainSegmentationDataset(csv_path)

dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)

# Example: Inspect a batch
for batch_idx, (images, masks) in enumerate(dataloader):
    print(f"Batch {batch_idx+1}")
    print(f"Images shape: {images.shape}")
    print(f"Masks shape: {masks.shape}")
    break


Batch 1
Images shape: torch.Size([4, 4, 240, 240])
Masks shape: torch.Size([4, 240, 240])


In [11]:
def train_model(
    model,
    dataset,
    device,
    epochs: int = 5,
    batch_size: int = 1,
    learning_rate: float = 1e-5,
    val_percent: float = 0.1,
    save_checkpoint: bool = True,
    amp: bool = False,
    weight_decay: float = 1e-8,
    momentum: float = 0.999,
    gradient_clipping: float = 1.0,
    pin_memory=False,
    checkpoint_dir: str = "./checkpoints",
    test_mode: bool = False
):
    
    # 1. Split dataset into training and validation sets
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

    # 2. Create DataLoader objects
    loader_args = dict(batch_size=batch_size, num_workers=0, pin_memory=True)
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

    # 3. Initialize optimizer, loss function, and AMP scaler
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)

    criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()

    # 4. Training loop
    logging.info(f"Starting training for {epochs} epochs...")
    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss = 0
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch}/{epochs}", unit="batch") as pbar:
            for images, masks in train_loader:
                images = images.to(device, dtype=torch.float32)
                masks = masks.to(device, dtype=torch.long)

                # Forward pass
                with torch.cuda.amp.autocast(enabled=amp):
                    predictions = model(images)
                    if model.n_classes == 1:
                        loss = criterion(predictions.squeeze(1), masks.float())
                        loss += dice_loss(torch.sigmoid(predictions.squeeze(1)), masks.float(), multiclass=False)
                    else:
                        loss = criterion(predictions, masks)
                        loss += dice_loss(
                            torch.softmax(predictions, dim=1),
                            torch.nn.functional.one_hot(masks, num_classes=model.n_classes)
                                .permute(0, 3, 1, 2)
                                .float(),
                            multiclass=True
                        )

                # Backward pass
                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                grad_scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(1)
                epoch_loss += loss.item()
                pbar.set_postfix(loss=loss.item())

        logging.info(f"Epoch {epoch} - Training loss: {epoch_loss:.4f}")

        # Validation loop
        val_score = evaluate(model, val_loader, device, amp)
        logging.info(f"Epoch {epoch} - Validation Dice Score: {val_score:.4f}")

        # Save checkpoint
        if save_checkpoint:
            checkpoint_path = Path(checkpoint_dir)
            checkpoint_path.mkdir(parents=True, exist_ok=True)
            torch.save(model.state_dict(), checkpoint_path / f"checkpoint_epoch{epoch}.pth")
            logging.info(f"Checkpoint saved at epoch {epoch}")

        if test_mode:
        # Use a small subset of the dataset for testing
            train_set, val_set = random_split(dataset, [10, 10], generator=torch.Generator().manual_seed(0))
        else:
            # Full dataset split
            n_val = int(len(dataset) * val_percent)
            n_train = len(dataset) - n_val
            train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

def evaluate(net, dataloader, device, amp):
    net.eval()
    num_val_batches = len(dataloader)
    dice_score = 0

    # iterate over the validation set
    with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
        for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
            image, mask_true = batch

            # move images and labels to correct device and type
            image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
            mask_true = mask_true.to(device=device, dtype=torch.long)

            # predict the mask
            mask_pred = net(image)

            if net.n_classes == 1:
                assert mask_true.min() >= 0 and mask_true.max() <= 1, 'True mask indices should be in [0, 1]'
                mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
                # compute the Dice score
                dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
            else:
                assert mask_true.min() >= 0 and mask_true.max() < net.n_classes, 'True mask indices should be in [0, n_classes['
                # convert to one-hot format
                mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()
                mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
                # compute the Dice score, ignoring background
                dice_score += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False)

    net.train()
    return dice_score / max(num_val_batches, 1)

def test_dataset(dataset, batch_size=1):
    print("Testing dataset...")
    try:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        for idx, (images, masks) in enumerate(dataloader):
            print(f"Batch {idx+1}:")
            print(f"  Images shape: {images.shape}")
            print(f"  Masks shape: {masks.shape}")
            if idx == 2:  # Stop after a few iterations
                break
        print("Dataset test passed!")
    except Exception as e:
        print("Dataset test failed!")
        print(e)
        return

In [14]:
device = torch.device('cuda')
model = UNet(n_channels=4, n_classes=4)
train_model(
    model=model,
    dataset=dataset,
    device=device,
    epochs=1,  # Run for only 1 epoch
    batch_size=2,  # Smaller batch size
    learning_rate=1e-4,
    test_mode=True  # Enable test mode
)
test_dataset(dataset)

  grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
  with torch.cuda.amp.autocast(enabled=amp):
Epoch 1/1:   0%|                                                                          | 0/25738 [00:00<?, ?batch/s]


RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

In [8]:
# Test the evaluation function
val_loader = DataLoader(dataset, batch_size=2, shuffle=False)
try:
    dice_score = evaluate(model, val_loader, device, amp=False)
    print(f"Validation test passed! Dice Score: {dice_score}")
except Exception as e:
    print("Validation test failed!")
    print(e)


Validation test failed!
name 'model' is not defined


In [7]:
device = torch.device('cuda')
dataset = BrainSegmentationDataset(csv_path="../data/training_detailed_summary_2020.csv")

model = UNet(n_channels=4, n_classes=4)  # Adjust based on your data
model.to(device)

train_model(
    model=model,
    dataset=dataset,
    device=device,
    epochs=50,
    batch_size=8,
    learning_rate=1e-4,
    checkpoint_dir="./checkpoints"
)


  grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
  with torch.cuda.amp.autocast(enabled=amp):
Epoch 1/50:   1%|▌                                                  | 75/6435 [03:12<4:32:37,  2.57s/batch, loss=0.835]


ValueError: Missing segmentation mask for subject 355

In [None]:
if subject_data[subject_data['Scan Type'] == 'seg'].empty:
    raise ValueError(f"Missing segmentation mask for subject {subject_id}")
