Import Libraries & Setup Device

In [13]:
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 seaborn nibabel tqdm matplotlib

Looking in indexes: https://download.pytorch.org/whl/cu118
Note: you may need to restart the kernel to use updated packages.


In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("="*70)
print("WEEK-3: SWIN TRANSFORMER 3D - MODEL SETUP")
print("="*70)
print(f"PyTorch Version: {torch.__version__}")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
print("="*70)

WEEK-3: SWIN TRANSFORMER 3D - MODEL SETUP
PyTorch Version: 2.7.1+cu118
Device: cuda
GPU: NVIDIA GeForce RTX 4060 Laptop GPU
GPU Memory: 8.00 GB


Configuration

In [15]:
class Config:
    # Paths
    DATA_DIR = os.path.join('..', 'processed_data')
    
    # Data parameters
    NUM_CLASSES = 4  # Background (0), Necrotic (1), Edema (2), Enhancing (4)
    IN_CHANNELS = 4  # FLAIR, T1, T1CE, T2
    PATCH_SIZE = (64, 64, 64)  # Adjust based on your preprocessed patches
    
    # Training parameters
    BATCH_SIZE = 1  # Start small for sanity check (increase to 2 if memory allows)
    NUM_WORKERS = 0  # 0 for Windows, increase for Linux
    
    # Model parameters
    EMBED_DIM = 48  # Base embedding dimension
    DEPTHS = [2, 2, 2, 2]  # Number of Swin blocks in each stage
    NUM_HEADS = [3, 6, 12, 24]  # Number of attention heads per stage
    WINDOW_SIZE = (4, 4, 4)  # Window size for local attention
    
    # Optimization parameters
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-5
    MAX_EPOCHS = 100  # For future full training
    
    # Loss weights
    DICE_WEIGHT = 0.5
    CE_WEIGHT = 0.5
    
    # Sanity check
    SANITY_CHECK_BATCHES = 2

config = Config()

print("Configuration:")
print(f"  Data directory: {config.DATA_DIR}")
print(f"  Patch size: {config.PATCH_SIZE}")
print(f"  Input channels: {config.IN_CHANNELS}")
print(f"  Number of classes: {config.NUM_CLASSES}")
print(f"  Batch size: {config.BATCH_SIZE}")
print(f"  Learning rate: {config.LEARNING_RATE}")

Configuration:
  Data directory: ..\processed_data
  Patch size: (64, 64, 64)
  Input channels: 4
  Number of classes: 4
  Batch size: 1
  Learning rate: 0.0001


BRATS DATASET CLASS

In [16]:
class BraTSDataset(Dataset):
    """
    PyTorch Dataset for BraTS 3D brain tumor segmentation.
    
    Loads preprocessed .npy patches (images + masks) from disk.
    
    Returns:
        image: Tensor of shape (4, D, H, W) - 4 MRI modalities
        mask: Tensor of shape (D, H, W) - segmentation labels
    """
    
    def __init__(self, data_dir, split='train', transform=None):
        """
        Args:
            data_dir: Root directory containing train/val folders
            split: 'train' or 'val'
            transform: Optional transforms to apply
        """
        self.data_dir = data_dir
        self.split = split
        self.transform = transform
        
        # Paths to images and masks
        self.image_dir = os.path.join(data_dir, split, 'images')
        self.mask_dir = os.path.join(data_dir, split, 'masks')
        
        # Get all image files
        self.image_files = sorted([f for f in os.listdir(self.image_dir) if f.endswith('.npy')])
        self.mask_files = sorted([f for f in os.listdir(self.mask_dir) if f.endswith('.npy')])
        
        assert len(self.image_files) == len(self.mask_files), \
            f"Mismatch: {len(self.image_files)} images vs {len(self.mask_files)} masks"
        
        print(f"✓ {split.upper()} Dataset initialized: {len(self.image_files)} samples")
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # Load image and mask
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])
        
        image = np.load(image_path)  # Shape: (4, D, H, W)
        mask = np.load(mask_path)    # Shape: (D, H, W)
        
        # Convert to PyTorch tensors
        image = torch.from_numpy(image).float()
        mask = torch.from_numpy(mask).long()
        
        # Remap labels: BraTS uses {0, 1, 2, 4} → Map to {0, 1, 2, 3}
        # This is necessary for CrossEntropyLoss which expects contiguous labels
        mask[mask == 4] = 3
        
        # Apply transforms if any
        if self.transform:
            image = self.transform(image)
        
        return image, mask

# Test dataset loading
print("\n" + "="*70)
print("LOADING DATASETS")
print("="*70)

train_dataset = BraTSDataset(config.DATA_DIR, split='train')
val_dataset = BraTSDataset(config.DATA_DIR, split='val')

print(f"\nDataset Statistics:")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")

# Test loading one sample
sample_image, sample_mask = train_dataset[0]
print(f"\nSample Shapes:")
print(f"  Image: {sample_image.shape} (4 modalities × D × H × W)")
print(f"  Mask: {sample_mask.shape} (D × H × W)")
print(f"  Image dtype: {sample_image.dtype}")
print(f"  Mask dtype: {sample_mask.dtype}")
print(f"  Mask unique values: {torch.unique(sample_mask)}")
print("="*70)


LOADING DATASETS
✓ TRAIN Dataset initialized: 21 samples
✓ VAL Dataset initialized: 5 samples

Dataset Statistics:
  Training samples: 21
  Validation samples: 5

Sample Shapes:
  Image: torch.Size([4, 64, 64, 64]) (4 modalities × D × H × W)
  Mask: torch.Size([64, 64, 64]) (D × H × W)
  Image dtype: torch.float32
  Mask dtype: torch.int64
  Mask unique values: tensor([0, 1, 2, 3])


Create DataLoaders

In [17]:
train_loader = DataLoader(
    train_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=True,
    num_workers=config.NUM_WORKERS,
    pin_memory=True if torch.cuda.is_available() else False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=False,
    num_workers=config.NUM_WORKERS,
    pin_memory=True if torch.cuda.is_available() else False
)

print("✓ DataLoaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Batch size: {config.BATCH_SIZE}")

# Test dataloader
print("\nTesting DataLoader...")
for batch_idx, (images, masks) in enumerate(train_loader):
    print(f"  Batch {batch_idx+1}:")
    print(f"    Images shape: {images.shape}")
    print(f"    Masks shape: {masks.shape}")
    if batch_idx == 0:
        break

✓ DataLoaders created:
  Train batches: 21
  Val batches: 5
  Batch size: 1

Testing DataLoader...
  Batch 1:
    Images shape: torch.Size([1, 4, 64, 64, 64])
    Masks shape: torch.Size([1, 64, 64, 64])


Swin Transformer 3D Building Blocks

In [24]:
class PatchEmbed3D(nn.Module):
    """
    3D Patch Embedding layer.
    Splits input volume into non-overlapping patches and projects to embedding dimension.
    """
    def __init__(self, patch_size=4, in_chans=4, embed_dim=96):
        super().__init__()
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim
        
        # Use 3D convolution for patch projection
        self.proj = nn.Conv3d(
            in_chans, 
            embed_dim, 
            kernel_size=patch_size, 
            stride=patch_size
        )
        self.norm = nn.LayerNorm(embed_dim)
    
    def forward(self, x):
        # x: (B, C, D, H, W)
        x = self.proj(x)  # (B, embed_dim, D/p, H/p, W/p)
        
        # Reshape for transformer: (B, D/p, H/p, W/p, embed_dim)
        B, C, D, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # (B, D*H*W, embed_dim)
        x = self.norm(x)
        
        # Reshape back to 3D grid
        x = x.transpose(1, 2).view(B, C, D, H, W)
        
        return x


def window_partition(x, window_size):
    """
    Partition input into windows with padding if necessary.
    
    Args:
        x: (B, D, H, W, C)
        window_size: (Wd, Wh, Ww)
    
    Returns:
        windows: (B*num_windows, Wd*Wh*Ww, C)
    """
    B, D, H, W, C = x.shape
    Wd, Wh, Ww = window_size
    
    # Pad if necessary
    pad_d = (Wd - D % Wd) % Wd
    pad_h = (Wh - H % Wh) % Wh
    pad_w = (Ww - W % Ww) % Ww
    
    if pad_d > 0 or pad_h > 0 or pad_w > 0:
        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h, 0, pad_d))
    
    B, D, H, W, C = x.shape
    
    # Partition into windows
    x = x.view(B, D // Wd, Wd, H // Wh, Wh, W // Ww, Ww, C)
    windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()
    windows = windows.view(-1, Wd * Wh * Ww, C)
    
    return windows, (D, H, W)


def window_reverse(windows, window_size, original_size):
    """
    Reverse window partition.
    
    Args:
        windows: (B*num_windows, Wd*Wh*Ww, C)
        window_size: (Wd, Wh, Ww)
        original_size: (D, H, W) - padded size
    
    Returns:
        x: (B, D, H, W, C)
    """
    Wd, Wh, Ww = window_size
    D, H, W = original_size
    C = windows.shape[-1]
    
    B = int(windows.shape[0] / (D * H * W / Wd / Wh / Ww))
    
    x = windows.view(B, D // Wd, H // Wh, W // Ww, Wd, Wh, Ww, C)
    x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous()
    x = x.view(B, D, H, W, C)
    
    return x


class WindowAttention3D(nn.Module):
    """
    Window-based multi-head self-attention for 3D data.
    """
    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # (Wd, Wh, Ww)
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        
        # Q, K, V projections
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        
    def forward(self, x):
        # x: (B*num_windows, window_size^3, dim)
        B_, N, C = x.shape
        
        # Generate Q, K, V
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B_, num_heads, N, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        
        # Apply attention to values
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        
        return x


class SwinTransformerBlock3D(nn.Module):
    """
    Swin Transformer Block with Window Attention + MLP.
    """
    def __init__(self, dim, num_heads, window_size, mlp_ratio=4.0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.mlp_ratio = mlp_ratio
        
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention3D(dim, window_size, num_heads)
        
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, dim)
        )
    
    def forward(self, x):
        # x: (B, D, H, W, C)
        B, D, H, W, C = x.shape
        
        shortcut = x
        x = self.norm1(x)
        
        # Window partition with padding
        x_windows, padded_size = window_partition(x, self.window_size)
        
        # Window attention
        attn_windows = self.attn(x_windows)
        
        # Reverse window partition
        x = window_reverse(attn_windows, self.window_size, padded_size)
        
        # Crop back to original size if padding was added
        x = x[:, :D, :H, :W, :].contiguous()
        
        # Skip connection
        x = shortcut + x
        
        # MLP
        x = x + self.mlp(self.norm2(x))
        
        return x


print("✓ Swin Transformer 3D building blocks defined")
print("  • PatchEmbed3D")
print("  • WindowAttention3D")
print("  • SwinTransformerBlock3D")
print("  • window_partition (with automatic padding)")
print("  • window_reverse")

✓ Swin Transformer 3D building blocks defined
  • PatchEmbed3D
  • WindowAttention3D
  • SwinTransformerBlock3D
  • window_partition (with automatic padding)
  • window_reverse


Swin Transformer 3D Model

In [19]:
class SwinTransformer3D(nn.Module):
    """
    Swin Transformer 3D for medical image segmentation.
    
    Architecture:
        1. Patch Embedding
        2. Multiple Swin Transformer stages
        3. Upsampling decoder
        4. Final segmentation head
    """
    def __init__(
        self,
        in_chans=4,
        num_classes=4,
        embed_dim=48,
        depths=[2, 2, 2, 2],
        num_heads=[3, 6, 12, 24],
        window_size=(4, 4, 4),
        patch_size=4
    ):
        super().__init__()
        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_size = patch_size
        
        # Patch embedding
        self.patch_embed = PatchEmbed3D(
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim
        )
        
        # Swin Transformer stages
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = nn.ModuleList([
                SwinTransformerBlock3D(
                    dim=int(embed_dim * 2 ** i_layer),
                    num_heads=num_heads[i_layer],
                    window_size=window_size
                )
                for _ in range(depths[i_layer])
            ])
            self.layers.append(layer)
        
        # Patch merging (downsampling) between stages
        self.downsample_layers = nn.ModuleList()
        for i_layer in range(self.num_layers - 1):
            downsample_layer = nn.Conv3d(
                int(embed_dim * 2 ** i_layer),
                int(embed_dim * 2 ** (i_layer + 1)),
                kernel_size=2,
                stride=2
            )
            self.downsample_layers.append(downsample_layer)
        
        # Decoder (upsampling path)
        self.upsample_layers = nn.ModuleList()
        for i_layer in range(self.num_layers - 1, 0, -1):
            upsample_layer = nn.Sequential(
                nn.ConvTranspose3d(
                    int(embed_dim * 2 ** i_layer),
                    int(embed_dim * 2 ** (i_layer - 1)),
                    kernel_size=2,
                    stride=2
                ),
                nn.BatchNorm3d(int(embed_dim * 2 ** (i_layer - 1))),
                nn.ReLU(inplace=True)
            )
            self.upsample_layers.append(upsample_layer)
        
        # Final upsampling to original resolution
        self.final_upsample = nn.Sequential(
            nn.ConvTranspose3d(embed_dim, embed_dim, kernel_size=patch_size, stride=patch_size),
            nn.BatchNorm3d(embed_dim),
            nn.ReLU(inplace=True)
        )
        
        # Segmentation head
        self.segmentation_head = nn.Conv3d(embed_dim, num_classes, kernel_size=1)
    
    def forward(self, x):
        # x: (B, C, D, H, W)
        B, C, D, H, W = x.shape
        
        # Patch embedding
        x = self.patch_embed(x)  # (B, embed_dim, D/p, H/p, W/p)
        
        # Store features for skip connections (optional)
        encoder_features = []
        
        # Encoder path
        for i, layer_blocks in enumerate(self.layers):
            # Prepare for transformer blocks (need channel-last format)
            B, C, D, H, W = x.shape
            x = x.permute(0, 2, 3, 4, 1).contiguous()  # (B, D, H, W, C)
            
            # Apply Swin blocks
            for block in layer_blocks:
                x = block(x)
            
            # Back to channel-first
            x = x.permute(0, 4, 1, 2, 3).contiguous()  # (B, C, D, H, W)
            
            encoder_features.append(x)
            
            # Downsample (except last layer)
            if i < self.num_layers - 1:
                x = self.downsample_layers[i](x)
        
        # Decoder path (simple upsampling without skip connections for now)
        for i, upsample_layer in enumerate(self.upsample_layers):
            x = upsample_layer(x)
        
        # Final upsampling to original size
        x = self.final_upsample(x)
        
        # Segmentation head
        x = self.segmentation_head(x)
        
        return x


# Initialize model
print("\n" + "="*70)
print("INITIALIZING MODEL")
print("="*70)

model = SwinTransformer3D(
    in_chans=config.IN_CHANNELS,
    num_classes=config.NUM_CLASSES,
    embed_dim=config.EMBED_DIM,
    depths=config.DEPTHS,
    num_heads=config.NUM_HEADS,
    window_size=config.WINDOW_SIZE,
    patch_size=4
).to(device)

print("✓ Model initialized")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Statistics:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: ~{total_params * 4 / 1024**2:.2f} MB (float32)")
print("="*70)


INITIALIZING MODEL
✓ Model initialized

Model Statistics:
  Total parameters: 6,429,076
  Trainable parameters: 6,429,076
  Model size: ~24.52 MB (float32)


Loss Functions

In [20]:
class DiceLoss(nn.Module):
    """
    Dice Loss for multi-class segmentation.
    
    Dice = 2 * |X ∩ Y| / (|X| + |Y|)
    Loss = 1 - Dice
    """
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth
    
    def forward(self, predictions, targets, num_classes=4):
        """
        Args:
            predictions: (B, C, D, H, W) - raw logits
            targets: (B, D, H, W) - class indices
        """
        # Convert predictions to probabilities
        predictions = F.softmax(predictions, dim=1)
        
        # One-hot encode targets
        targets_one_hot = F.one_hot(targets, num_classes=num_classes)  # (B, D, H, W, C)
        targets_one_hot = targets_one_hot.permute(0, 4, 1, 2, 3).float()  # (B, C, D, H, W)
        
        # Calculate Dice score for each class
        dice_scores = []
        for class_idx in range(num_classes):
            pred_class = predictions[:, class_idx]
            target_class = targets_one_hot[:, class_idx]
            
            intersection = (pred_class * target_class).sum()
            union = pred_class.sum() + target_class.sum()
            
            dice = (2. * intersection + self.smooth) / (union + self.smooth)
            dice_scores.append(dice)
        
        # Average Dice across classes
        dice_score = torch.stack(dice_scores).mean()
        
        return 1 - dice_score


class CombinedLoss(nn.Module):
    """
    Combined Dice + Cross-Entropy Loss.
    """
    def __init__(self, dice_weight=0.5, ce_weight=0.5):
        super().__init__()
        self.dice_weight = dice_weight
        self.ce_weight = ce_weight
        
        self.dice_loss = DiceLoss()
        self.ce_loss = nn.CrossEntropyLoss()
    
    def forward(self, predictions, targets):
        """
        Args:
            predictions: (B, C, D, H, W) - raw logits
            targets: (B, D, H, W) - class indices
        """
        dice = self.dice_loss(predictions, targets)
        ce = self.ce_loss(predictions, targets)
        
        combined = self.dice_weight * dice + self.ce_weight * ce
        
        return combined, dice, ce


# Initialize loss
criterion = CombinedLoss(
    dice_weight=config.DICE_WEIGHT,
    ce_weight=config.CE_WEIGHT
).to(device)

print("✓ Loss functions initialized:")
print(f"  Combined Loss = {config.DICE_WEIGHT} × Dice + {config.CE_WEIGHT} × CrossEntropy")

✓ Loss functions initialized:
  Combined Loss = 0.5 × Dice + 0.5 × CrossEntropy


OPTIMIZER & SCHEDULER

In [21]:
# AdamW optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config.LEARNING_RATE,
    weight_decay=config.WEIGHT_DECAY
)

# Cosine annealing scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=config.MAX_EPOCHS,
    eta_min=1e-6
)

print("✓ Optimizer & Scheduler initialized:")
print(f"  Optimizer: AdamW")
print(f"  Learning rate: {config.LEARNING_RATE}")
print(f"  Weight decay: {config.WEIGHT_DECAY}")
print(f"  Scheduler: CosineAnnealingLR")
print(f"  T_max: {config.MAX_EPOCHS} epochs")

✓ Optimizer & Scheduler initialized:
  Optimizer: AdamW
  Learning rate: 0.0001
  Weight decay: 1e-05
  Scheduler: CosineAnnealingLR
  T_max: 100 epochs


TRAINING & VALIDATION FUNCTIONS

In [22]:
def train_one_epoch(model, dataloader, criterion, optimizer, device, epoch):
    """
    Train for one epoch.
    
    Returns:
        Average loss for the epoch
    """
    model.train()
    total_loss = 0
    total_dice = 0
    total_ce = 0
    
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch} [Train]")
    
    for batch_idx, (images, masks) in enumerate(progress_bar):
        # Move to device
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward pass
        outputs = model(images)
        
        # Calculate loss
        loss, dice_loss, ce_loss = criterion(outputs, masks)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item()
        total_dice += dice_loss.item()
        total_ce += ce_loss.item()
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'dice': f'{dice_loss.item():.4f}',
            'ce': f'{ce_loss.item():.4f}'
        })
    
    avg_loss = total_loss / len(dataloader)
    avg_dice = total_dice / len(dataloader)
    avg_ce = total_ce / len(dataloader)
    
    return avg_loss, avg_dice, avg_ce


def validate(model, dataloader, criterion, device, epoch):
    """
    Validate the model.
    
    Returns:
        Average validation loss
    """
    model.eval()
    total_loss = 0
    total_dice = 0
    total_ce = 0
    
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch} [Val]")
    
    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(progress_bar):
            # Move to device
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs = model(images)
            
            # Calculate loss
            loss, dice_loss, ce_loss = criterion(outputs, masks)
            
            # Track metrics
            total_loss += loss.item()
            total_dice += dice_loss.item()
            total_ce += ce_loss.item()
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'dice': f'{dice_loss.item():.4f}',
                'ce': f'{ce_loss.item():.4f}'
            })
    
    avg_loss = total_loss / len(dataloader)
    avg_dice = total_dice / len(dataloader)
    avg_ce = total_ce / len(dataloader)
    
    return avg_loss, avg_dice, avg_ce


print("✓ Training & validation functions defined")

✓ Training & validation functions defined


SANITY CHECK - RUN 2 BATCHES ONLY

In [23]:
print("\n" + "="*70)
print("RUNNING SANITY CHECK (2 BATCHES)")
print("="*70)

model.train()

sanity_results = []

print("\nTraining Pass:")
print("-"*70)

for batch_idx, (images, masks) in enumerate(train_loader):
    if batch_idx >= config.SANITY_CHECK_BATCHES:
        break
    
    print(f"\nBatch {batch_idx + 1}/{config.SANITY_CHECK_BATCHES}")
    print(f"  Input shape: {images.shape}")
    print(f"  Mask shape: {masks.shape}")
    
    # Move to device
    images = images.to(device)
    masks = masks.to(device)
    
    # Forward pass
    print("  → Forward pass...")
    outputs = model(images)
    print(f"  Output shape: {outputs.shape}")
    
    # Calculate loss
    print("  → Calculating loss...")
    loss, dice_loss, ce_loss = criterion(outputs, masks)
    
    print(f"  ✓ Combined Loss: {loss.item():.4f}")
    print(f"    • Dice Loss: {dice_loss.item():.4f}")
    print(f"    • CE Loss: {ce_loss.item():.4f}")
    
    # Backward pass
    print("  → Backward pass...")
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print("  ✓ Gradients computed and weights updated")
    
    sanity_results.append({
        'batch': batch_idx + 1,
        'loss': loss.item(),
        'dice': dice_loss.item(),
        'ce': ce_loss.item()
    })

print("\n" + "="*70)
print("SANITY CHECK RESULTS")
print("="*70)
for result in sanity_results:
    print(f"Batch {result['batch']}:")
    print(f"  Loss: {result['loss']:.4f} | Dice: {result['dice']:.4f} | CE: {result['ce']:.4f}")

print("\n" + "="*70)
print("✅ SANITY CHECK PASSED!")
print("="*70)
print("\nVerifications Complete:")
print("  ✓ Model can process input batches")
print("  ✓ Output shapes are correct")
print("  ✓ Loss functions work properly")
print("  ✓ Forward pass successful")
print("  ✓ Backward pass successful")
print("  ✓ Gradient computation working")
print("  ✓ Optimizer updates weights")
print("\n🚀 Ready for Week-4: Full Training!")
print("="*70)


RUNNING SANITY CHECK (2 BATCHES)

Training Pass:
----------------------------------------------------------------------

Batch 1/2
  Input shape: torch.Size([1, 4, 64, 64, 64])
  Mask shape: torch.Size([1, 64, 64, 64])
  → Forward pass...
  Output shape: torch.Size([1, 4, 64, 64, 64])
  → Calculating loss...
  ✓ Combined Loss: 1.1421
    • Dice Loss: 0.8207
    • CE Loss: 1.4634
  → Backward pass...
  ✓ Gradients computed and weights updated

Batch 2/2
  Input shape: torch.Size([1, 4, 64, 64, 64])
  Mask shape: torch.Size([1, 64, 64, 64])
  → Forward pass...
  Output shape: torch.Size([1, 4, 64, 64, 64])
  → Calculating loss...
  ✓ Combined Loss: 1.1580
    • Dice Loss: 0.8284
    • CE Loss: 1.4876
  → Backward pass...
  ✓ Gradients computed and weights updated

SANITY CHECK RESULTS
Batch 1:
  Loss: 1.1421 | Dice: 0.8207 | CE: 1.4634
Batch 2:
  Loss: 1.1580 | Dice: 0.8284 | CE: 1.4876

✅ SANITY CHECK PASSED!

Verifications Complete:
  ✓ Model can process input batches
  ✓ Output shape