In [1]:
# Cell: Dataset for BraTS 2D slice segmentation
import os
import numpy as np
import nibabel as nib
import torch
from torch.utils.data import Dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


brats_dir = "/home/hiranmoy/Downloads/Sameer/Brats/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData"

class BraTSSliceDataset(Dataset):
    """
    Dataset for 2D slice segmentation from BraTS. Binary segmentation: tumor vs background.
    Modalities: flair, t1, t1ce, t2. 
    modality_mode: 
       'all'      -> use all 4 modalities stacked as 4-channel input,
       or a list like ['flair'], ['t1ce','flair'] for simultaneous subset.
    filter_empty: if True, exclude slices where mask is all-zero.
    transforms: a function that takes (image: np.ndarray shape [C,H,W], mask: np.ndarray [H,W]) and returns augmented versions.
    """
    def __init__(self, brats_dir, patient_list=None, modality_mode='all', filter_empty=True, transforms=None):
        """
        brats_dir: root directory containing BraTS20_Training_xxx folders.
        patient_list: list of folder names or full paths to include; if None, list all.
        modality_mode: 'all' or list of modality strings among ['flair','t1','t1ce','t2'].
        filter_empty: whether to drop slices where mask has zero tumor pixels.
        transforms: callable(image, mask) -> (image, mask), for data augmentation.
        """
        self.brats_dir = brats_dir
        # Determine modalities to load
        all_mods = ['flair','t1','t1ce','t2']
        if modality_mode == 'all':
            self.modalities = all_mods
        else:
            # ensure list
            assert isinstance(modality_mode, (list,tuple)), "modality_mode must be 'all' or list"
            for m in modality_mode:
                assert m in all_mods, f"Unknown modality {m}"
            self.modalities = modality_mode
        
        # List patients
        if patient_list is None:
            # list all directories starting with 'BraTS20_Training'
            dirs = sorted([d for d in os.listdir(brats_dir) 
                           if os.path.isdir(os.path.join(brats_dir,d)) and d.startswith("BraTS20_Training")])
            self.patients = [os.path.join(brats_dir, d) for d in dirs]
        else:
            # user-provided list of full paths or folder names
            tmp = []
            for p in patient_list:
                # if just folder name:
                if os.path.basename(p).startswith("BraTS20_Training") and not os.path.isabs(p):
                    tmp.append(os.path.join(brats_dir, p))
                else:
                    tmp.append(p)
            self.patients = tmp
        
        self.filter_empty = filter_empty
        self.transforms = transforms
        
        # Build index: list of (patient_path, slice_idx)
        self.index = []
        for p in self.patients:
            # Load mask once to know shape
            mask_nii = nib.load(os.path.join(p, os.path.basename(p) + '_seg.nii'))
            mask_vol = mask_nii.get_fdata()  # shape [H, W, D]
            _, _, D = mask_vol.shape
            for z in range(D):
                if filter_empty:
                    sl = mask_vol[..., z]
                    if np.all(sl == 0):
                        continue
                self.index.append((p, z))
        print(f"BraTSSliceDataset: {len(self.index)} slices (from {len(self.patients)} patients), modalities={self.modalities}")
    
    def __len__(self):
        return len(self.index)
    
    def __getitem__(self, idx):
        p, z = self.index[idx]
        base = os.path.basename(p)
        # Load modalities for slice z
        imgs = []
        for m in self.modalities:
            nii_path = os.path.join(p, f"{base}_{m}.nii")
            arr = nib.load(nii_path).get_fdata()  # [H, W, D]
            sl = arr[..., z]  # [H, W]
            # Normalize per-slice: z-score
            mu, sd = sl.mean(), sl.std()
            if sd > 0:
                sl = (sl - mu) / sd
            else:
                sl = sl - mu
            imgs.append(sl.astype(np.float32))
        image = np.stack(imgs, axis=0)  # [C, H, W]
        # Load mask slice, binary
        mask_nii = nib.load(os.path.join(p, f"{base}_seg.nii"))
        mask_vol = mask_nii.get_fdata()
        msl = mask_vol[..., z]
        # Binary: tumor if label>0
        mask = (msl > 0).astype(np.float32)  # [H, W]
        
        # Apply transforms if any (e.g. random flip/rotate). They should handle image [C,H,W] and mask [H,W]
        if self.transforms is not None:
            image, mask = self.transforms(image, mask)
        
        # To tensor
        image_t = torch.from_numpy(image)  # [C,H,W]
        mask_t = torch.from_numpy(mask).unsqueeze(0)  # [1,H,W], as channel for BCE or similar
        return image_t, mask_t


Using device: cuda


In [2]:
import random
import torch.nn.functional as F

def random_flip_rotate(image, mask):
    # image: np [C,H,W], mask: np [H,W]
    # Random horizontal flip
    if random.random() < 0.5:
        image = image[:, :, ::-1]
        mask = mask[:, ::-1]
    # Random vertical flip
    if random.random() < 0.5:
        image = image[:, ::-1, :]
        mask = mask[:, ::-1]
    # Random rotation by 90 deg multiples
    k = random.choice([0,1,2,3])
    if k>0:
        image = np.rot90(image, k, axes=(1,2))
        mask = np.rot90(mask, k, axes=(0,1))
    return image.copy(), mask.copy()


In [15]:
from torch.utils.data import DataLoader

# Cell: Subject-wise split for BraTS 2D experiments
import os, random

# Root directory containing BraTS patient folders
brats_dir = "/home/hiranmoy/Downloads/Sameer/Brats/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData"

# List all patient directories (folder names start with "BraTS20_Training")
all_dirs = sorted([d for d in os.listdir(brats_dir)
                   if os.path.isdir(os.path.join(brats_dir, d)) and d.startswith("BraTS20_Training")])
all_patients = [os.path.join(brats_dir, d) for d in all_dirs]
print("Total subjects found:", len(all_patients))  # e.g. ~369

# Shuffle and split 70% train / 30% val
random.seed(42)
random.shuffle(all_patients)
split_idx = int(0.7 * len(all_patients))
train_subjects = all_patients[:split_idx]
val_subjects   = all_patients[split_idx:]
print(f"Train subjects: {len(train_subjects)}, Val subjects: {len(val_subjects)}")



Total subjects found: 369
Train subjects: 258, Val subjects: 111


In [4]:
# Cell: U-Net definition (2D) - from previous
import torch.nn as nn
import torch

class DoubleConv(nn.Sequential):
    """(Conv -> ReLU) * 2"""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

class Down(nn.Sequential):
    """Downscale: MaxPool then DoubleConv."""
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels),
        )

class Up(nn.Module):
    """Upsample then DoubleConv, concatenating with skip feature (or zeros)."""
    def __init__(self, in_channels, out_channels):
        super(Up, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(out_channels, out_channels)

    def forward(self, x, skip_feat=None):
        x = self.up(x)
        if skip_feat is not None:
            # center-crop skip if needed
            if x.size(2) != skip_feat.size(2) or x.size(3) != skip_feat.size(3):
                diffY = skip_feat.size(2) - x.size(2)
                diffX = skip_feat.size(3) - x.size(3)
                skip_feat = skip_feat[:, :,
                                      diffY//2 : diffY//2 + x.size(2),
                                      diffX//2 : diffX//2 + x.size(3)]
            x = torch.cat([x, skip_feat], dim=1)
        else:
            zeros = torch.zeros_like(x)
            x = torch.cat([x, zeros], dim=1)
        x = self.conv(x)
        return x

class UNet2D(nn.Module):
    def __init__(self, in_channels=4, out_classes=1):
        super(UNet2D, self).__init__()
        self.conv1 = DoubleConv(in_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.bottleneck = DoubleConv(512, 1024)
        self.up3 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up1 = Up(256, 128)
        self.conv_out = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, out_classes, kernel_size=1)
        )

    def forward(self, x):
        # x: [B, in_channels, H, W]
        c1 = self.conv1(x)        # [B,64,H,W]
        d1 = self.down1(c1)       # [B,128,H/2,W/2]
        d2 = self.down2(d1)       # [B,256,H/4,W/4]
        d3 = self.down3(d2)       # [B,512,H/8,W/8]
        b  = self.bottleneck(d3)  # [B,1024,H/8,W/8]
        u3 = self.up3(b, d2)      # [B,512,H/4,W/4]
        u2 = self.up2(u3, d1)     # [B,256,H/2,W/2]
        u1 = self.up1(u2, c1)     # [B,128,H,W]
        out = self.conv_out(u1)   # [B,1,H,W]
        return out


In [5]:
# Cell: Simplified SegNet for 2D binary segmentation
# We implement an encoder-decoder with pooling indices.

class SegNet2D(nn.Module):
    def __init__(self, in_channels=4, out_classes=1):
        super(SegNet2D, self).__init__()
        # Encoder layers: list of conv blocks; store indices
        # For brevity, define 3 encoder stages (like U-Net depth 3)
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.pool1 = nn.MaxPool2d(2, stride=2, return_indices=True)
        
        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.pool2 = nn.MaxPool2d(2, stride=2, return_indices=True)
        
        self.enc3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.pool3 = nn.MaxPool2d(2, stride=2, return_indices=True)
        
        # Bottleneck conv
        self.bottleneck = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        
        # Decoder: unpool + conv
        self.unpool3 = nn.MaxUnpool2d(2, stride=2)
        self.dec3 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        
        self.unpool2 = nn.MaxUnpool2d(2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        
        self.unpool1 = nn.MaxUnpool2d(2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        
        # Final conv
        self.final_conv = nn.Conv2d(64, out_classes, kernel_size=1)
        
    def forward(self, x):
        # Encoder stage 1
        e1 = self.enc1(x)                  # [B,64,H,W]
        p1, idx1 = self.pool1(e1)         # [B,64,H/2,W/2]
        # Encoder stage 2
        e2 = self.enc2(p1)                # [B,128,H/2,W/2]
        p2, idx2 = self.pool2(e2)         # [B,128,H/4,W/4]
        # Encoder stage 3
        e3 = self.enc3(p2)                # [B,256,H/4,W/4]
        p3, idx3 = self.pool3(e3)         # [B,256,H/8,W/8]
        # Bottleneck
        b = self.bottleneck(p3)           # [B,512,H/8,W/8]
        # Decoder stage 3
        up3 = self.unpool3(b, idx3, output_size=e3.size())  # [B,512,H/4,W/4]
        d3 = self.dec3(up3)               # [B,256,H/4,W/4]
        # Decoder stage 2
        up2 = self.unpool2(d3, idx2, output_size=e2.size()) # [B,256,H/2,W/2]
        d2 = self.dec2(up2)               # [B,128,H/2,W/2]
        # Decoder stage 1
        up1 = self.unpool1(d2, idx1, output_size=e1.size()) # [B,128,H,W]
        d1 = self.dec1(up1)               # [B,64,H,W]
        out = self.final_conv(d1)         # [B,1,H,W]
        return out


In [6]:
# BCE Loss
bce_loss = nn.BCEWithLogitsLoss()

# Dice Loss
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-5):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    def forward(self, logits, target):
        # logits: [B,1,H,W], target: [B,1,H,W], values {0,1}
        probs = torch.sigmoid(logits)
        B = target.size(0)
        probs_flat = probs.view(B, -1)
        target_flat = target.view(B, -1)
        intersection = (probs_flat * target_flat).sum(dim=1)
        union = probs_flat.sum(dim=1) + target_flat.sum(dim=1)
        dice = (2 * intersection + self.smooth) / (union + self.smooth)
        loss = 1 - dice
        return loss.mean()

# Generalized Dice Loss (accounting for class imbalance via weights):
class GeneralizedDiceLoss(nn.Module):
    def __init__(self, smooth=1e-5):
        super(GeneralizedDiceLoss, self).__init__()
        self.smooth = smooth
    def forward(self, logits, target):
        # Here binary: treat background vs tumor; weight inversely by volume
        probs = torch.sigmoid(logits)
        B = target.size(0)
        losses = []
        for i in range(B):
            p = probs[i].view(-1)
            g = target[i].view(-1)
            # weights: w_c = 1 / (sum(g==c)^2) but binary: two classes
            # compute for class 1 (tumor) and class 0 (bg)
            # sum over target
            # ground truth volumes
            vol_pos = g.sum()
            vol_neg = (1 - g).sum()
            w_pos = 1.0 / (vol_pos * vol_pos + self.smooth)
            w_neg = 1.0 / (vol_neg * vol_neg + self.smooth)
            # dice for each class
            # for tumor class:
            inter_pos = (p * g).sum()
            dice_pos = (2 * inter_pos + self.smooth) / (p.sum() + g.sum() + self.smooth)
            # for background:
            p_neg = 1 - p
            g_neg = 1 - g
            inter_neg = (p_neg * g_neg).sum()
            dice_neg = (2 * inter_neg + self.smooth) / (p_neg.sum() + g_neg.sum() + self.smooth)
            # weighted sum
            loss_i = 1 - (w_pos * dice_pos + w_neg * dice_neg) / (w_pos + w_neg)
            losses.append(loss_i)
        return torch.stack(losses).mean()

# Focal Loss

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    def forward(self, logits, target):
        # logits: [B,1,H,W], target: [B,1,H,W]
        probs = torch.sigmoid(logits)
        pt = torch.where(target == 1, probs, 1 - probs)  # [B,1,H,W]
        w = self.alpha * (1 - pt) ** self.gamma
        bce = F.binary_cross_entropy_with_logits(logits, target, reduction='none')
        loss = w * bce
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss



# Tversky Loss 


class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.5, smooth=1e-5):
        super(TverskyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.smooth = smooth
    def forward(self, logits, target):
        probs = torch.sigmoid(logits)
        B = target.size(0)
        losses = []
        for i in range(B):
            p = probs[i].view(-1)
            g = target[i].view(-1)
            tp = (p * g).sum()
            fn = ((1 - p) * g).sum()
            fp = (p * (1 - g)).sum()
            tversky = (tp + self.smooth) / (tp + self.alpha * fn + self.beta * fp + self.smooth)
            losses.append(1 - tversky)
        return torch.stack(losses).mean()

# Focal Traversky Loss 

class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.5, gamma=1.0, smooth=1e-5):
        super(FocalTverskyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.smooth = smooth
    def forward(self, logits, target):
        probs = torch.sigmoid(logits)
        B = target.size(0)
        losses = []
        for i in range(B):
            p = probs[i].view(-1)
            g = target[i].view(-1)
            tp = (p * g).sum()
            fn = ((1 - p) * g).sum()
            fp = (p * (1 - g)).sum()
            tversky = (tp + self.smooth) / (tp + self.alpha * fn + self.beta * fp + self.smooth)
            loss_i = (1 - tversky) ** self.gamma
            losses.append(loss_i)
        return torch.stack(losses).mean()

# Robust Dice Loss 

class RobustDiceLoss(nn.Module):
    def __init__(self, lam=2.0, smooth=1e-5):
        super(RobustDiceLoss, self).__init__()
        self.lam = lam
        self.smooth = smooth
    def forward(self, logits, target):
        # p^lam in numerator and denom
        probs = torch.sigmoid(logits)
        B = target.size(0)
        losses = []
        for i in range(B):
            p = probs[i].view(-1)
            g = target[i].view(-1)
            p_l = p ** self.lam
            inter = (p_l * g).sum()
            denom = (p_l * p_l).sum() + (g * g).sum()  # note: original formula uses p^lam squared? or p^(2*lam)? In paper: 2∑ p^λ g / (∑ p^{2λ} + ∑g^2). We'll use p_l^2 = p^(2λ).
            # Actually the formula given: Robust Dice = 1 - (2 ∑ p^λ g) / (∑ p^{2λ} + ∑ g^2)
            denom = (p ** (2*self.lam)).sum() + (g * g).sum()
            dice = (2 * inter + self.smooth) / (denom + self.smooth)
            losses.append(1 - dice)
        return torch.stack(losses).mean()

# Adaptive Robust Loss 

class AdaptiveRobustLoss(nn.Module):
    def __init__(self, alpha=2.0, c=1.0, reduction='mean'):
        super(AdaptiveRobustLoss, self).__init__()
        self.alpha = alpha
        self.c = c
        self.reduction = reduction
    def forward(self, logits, target):
        # Here, input x = difference between prediction probability and target? 
        # Alternatively, apply robust loss on residual p - g.
        probs = torch.sigmoid(logits)
        # residual
        x = probs - target  # [B,1,H,W]
        # compute AR loss elementwise
        # AR Loss(x, alpha, c) = |alpha-2|/alpha * (( (x/c)^2 / |alpha-2| + 1)^(alpha/2) - 1)
        # but careful with alpha=2: return 0? Might fallback to L2.
        eps = 1e-6
        a = self.alpha
        c = self.c
        term = (x / c) ** 2
        if abs(a - 2.0) < eps:
            # L2 loss scaled
            loss_elem = 0.5 * term
        else:
            loss_elem = torch.abs(a - 2.0)/a * ( (term / torch.abs(a - 2.0) + 1.0).pow(a/2.0) - 1.0 )
        if self.reduction == 'mean':
            return loss_elem.mean()
        elif self.reduction == 'sum':
            return loss_elem.sum()
        else:
            return loss_elem
# Boundary Loss 

import scipy.ndimage as ndi

class BoundaryLoss(nn.Module):
    def __init__(self):
        super(BoundaryLoss, self).__init__()
    def forward(self, logits, target):
        # Compute distance map of ground truth boundary
        # This is expensive per batch; precompute distance transforms for each mask in dataset?
        # For demonstration: assume single batch size 1.
        probs = torch.sigmoid(logits)  # [B,1,H,W]
        total_loss = 0.0
        B = target.size(0)
        for i in range(B):
            g = target[i,0].cpu().numpy().astype(np.uint8)  # [H,W]
            # Compute distance transform: distance to nearest boundary pixel
            # boundary: edges of mask
            boundary = np.logical_xor(g, ndi.binary_erosion(g)).astype(np.uint8)
            # distance: for each pixel, distance to nearest boundary pixel
            # distance transform of inverted boundary: zero at boundary, increases away
            dist_map = ndi.distance_transform_edt(1 - boundary)
            dist_map = torch.from_numpy(dist_map).to(logits.device).float()  # [H,W]
            p = probs[i,0]
            # Boundary loss: sum |p - g| * dist_map
            total_loss += torch.mean((p - torch.from_numpy(g).to(logits.device).float()).abs() * dist_map)
        return total_loss / B

# Hausdorff Loss

class HausdorffLoss(nn.Module):
    def __init__(self):
        super(HausdorffLoss, self).__init__()
        # Possibly precompute distance transforms of gt for entire dataset if memory allows.
    def forward(self, logits, target):
        # Similar to boundary loss but penalize max distance. Use a surrogate: 
        # For each pixel in predicted boundary, measure distance to GT boundary via precomputed map.
        # Here we define a simple surrogate: mean distance over predicted edges.
        probs = torch.sigmoid(logits)
        total = 0.0
        B = target.size(0)
        for i in range(B):
            p_np = (probs[i,0].cpu().detach().numpy() > 0.5).astype(np.uint8)
            g_np = target[i,0].cpu().numpy().astype(np.uint8)
            # boundaries
            p_b = np.logical_xor(p_np, ndi.binary_erosion(p_np)).astype(np.uint8)
            g_b = np.logical_xor(g_np, ndi.binary_erosion(g_np)).astype(np.uint8)
            # distance transform of g boundary
            dist_map = ndi.distance_transform_edt(1 - g_b)
            # For each predicted boundary pixel, get its distance
            if p_b.sum() > 0:
                d = dist_map[p_b==1]
                total += d.mean()
            else:
                # if no predicted boundary, penalize heavily
                total += dist_map.max()
        return torch.tensor(total / B, device=logits.device)



# Perceptual Loss 

class PerceptualLoss(nn.Module):
    def __init__(self, feature_extractor):
        super(PerceptualLoss, self).__init__()
        self.fe = feature_extractor.eval()  # freeze
        for p in self.fe.parameters():
            p.requires_grad = False
        self.l1 = nn.L1Loss()
    def forward(self, logits, target, input_image):
        """
        logits: [B,1,H,W], target: [B,1,H,W], input_image: [B,1,H,W] or multi-channel
        We compute perceptual loss between predicted mask (or masked image) and ground truth mask?
        One approach: treat predicted mask and GT mask as single-channel images, feed through fe.
        """
        pred = torch.sigmoid(logits)
        # expand to 3 channels if fe expects 3?
        # For simplicity, tile to 3 channels or adapt feature_extractor to 1-channel.
        feat_p = self.fe(pred)    # feature maps
        feat_g = self.fe(target)  # feature maps
        return self.l1(feat_p, feat_g)



In [7]:
# Novel Loss - Bottleneck Diversity

class BottleneckDiversityLoss(nn.Module):
    def __init__(self, weight=1e-3):
        super(BottleneckDiversityLoss, self).__init__()
        self.weight = weight
    def forward(self, bottleneck_feat):
        """
        bottleneck_feat: [B, C, h, w]
        We compute channel-wise covariance and encourage off-diagonal elements to be small.
        For efficiency, we approximate by sampling spatial positions or batch.
        """
        B, C, h, w = bottleneck_feat.shape
        # Flatten spatial dims: [B, C, h*w]
        x = bottleneck_feat.view(B, C, -1)  # [B, C, N]
        # Compute per-sample covariance; average over batch
        loss = 0.0
        for i in range(B):
            xi = x[i]  # [C, N]
            # zero-mean per channel
            xi = xi - xi.mean(dim=1, keepdim=True)
            # compute covariance matrix approx: (C x C) but expensive if C large (1024).
            # Instead, sample M spatial positions randomly:
            N = xi.shape[1]
            M = min(100, N)  # sample 100 positions
            idx = torch.randperm(N)[:M]
            xs = xi[:, idx]  # [C, M]
            # compute covariance: cov = xs @ xs.T / (M-1); shape [C,C]
            cov = xs @ xs.T / (M - 1.0)
            # zero diagonal
            cov_off = cov.clone()
            cov_off.fill_diagonal_(0)
            loss += cov_off.abs().mean()
        return self.weight * (loss / B)


In [8]:
def dice_coeff(pred, target, smooth=1e-5):
    """
    pred: tensor [B,1,H,W], after sigmoid thresholded or use probabilities?
    For metric, use binarized prediction at 0.5 threshold or use soft dice?
    Here implement soft Dice for evaluation.
    """
    probs = torch.sigmoid(pred)
    B = target.size(0)
    dices = []
    for i in range(B):
        p = probs[i].view(-1)
        g = target[i].view(-1)
        inter = (p * g).sum()
        union = p.sum() + g.sum()
        dice = (2 * inter + smooth) / (union + smooth)
        dices.append(dice.item())
    return np.mean(dices)


In [9]:
# Cell: generic train/eval function (modified for stability and OOM handling)
import torch.optim as optim
from tqdm import tqdm
import torch

def train_model(model, train_loader, val_loader, loss_fn, 
                num_epochs=50, lr=1e-3, weight_decay=1e-5, 
                scheduler=None, device=torch.device('cuda'), 
                save_path=None, 
                hook_bottleneck=False):
    """
    model: nn.Module
    train_loader, val_loader: DataLoader
    loss_fn: loss function instance; if CombinedLoss with PerceptualLoss, need input_image arg
    num_epochs, lr, etc.
    hook_bottleneck: if True, register a forward hook on bottleneck to capture features during validation.
    """
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    if scheduler is None:
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    best_val_dice = 0.0

    # Hook for bottleneck activations if needed
    bottleneck_feats = []
    if hook_bottleneck:
        def hook_fn(module, inp, outp):
            # outp shape [B, C, h, w]; detach and move to CPU immediately
            # OPTION A: Store full outp.detach().cpu().clone() — might be large
            #    feats = outp.detach().cpu().clone()
            #    bottleneck_feats.append(feats)
            #
            # OPTION B (recommended): store only spatial/channel summary, e.g., global mean per channel:
            #    feats_mean = outp.detach().cpu().mean(dim=(0,2,3), keepdim=False).clone()  # shape [C]
            #    bottleneck_feats.append(feats_mean)
            #
            # OPTION C: store small random subset of spatial locations per channel
            #    B, C, h, w = outp.shape
            #    arr = outp.detach().cpu()
            #    # Randomly sample e.g. 100 pixels per channel over the batch:
            #    num_samples = min(100, h*w*B)
            #    flat = arr.view(-1, C)  # [B*h*w, C]
            #    idx = torch.randperm(flat.size(0))[:num_samples]
            #    sample = flat[idx].clone()  # [num_samples, C]
            #    bottleneck_feats.append(sample)
            #
            # Choose one option. Here we demonstrate OPTION B to greatly reduce memory:
            feats_mean = outp.detach().cpu().mean(dim=(0,2,3), keepdim=False).clone()  # [C]
            # if len(bottleneck_feats<20):
                # bottleneck_feats.append(feats_mean)

        # Register hook on the bottleneck layer of model
        # For UNet2D: assume attribute is model.bottleneck; for other: adjust accordingly
        handle = model.bottleneck.register_forward_hook(hook_fn)
    else:
        handle = None

    for epoch in range(1, num_epochs+1):
        model.train()
        train_losses = []
        train_dices = []
        for images, masks in tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs} [Train]"):
            images = images.to(device, non_blocking=True)  # [B, C, H, W]
            masks  = masks.to(device, non_blocking=True)   # [B, 1, H, W]
            optimizer.zero_grad()
            try:
                logits = model(images)
                # Some loss_fns need input_image, e.g. PerceptualLoss
                if isinstance(loss_fn, PerceptualLoss):
                    # modify input_image slice selection as needed
                    loss = loss_fn(logits, masks, input_image=images[:,0:1,...])
                else:
                    loss = loss_fn(logits, masks)
                loss.backward()
                optimizer.step()
                train_losses.append(loss.item())
                with torch.no_grad():
                    dice = dice_coeff(logits, masks)
                    train_dices.append(dice)
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print(f"WARNING: OOM on batch, skipping batch. Clearing cache. Error: {e}")
                    optimizer.zero_grad()
                    torch.cuda.empty_cache()
                    # optionally: break the loop or continue
                    continue
                else:
                    # re-raise if other error
                    raise

        avg_train_loss = np.mean(train_losses) if train_losses else float('nan')
        avg_train_dice = np.mean(train_dices) if train_dices else float('nan')

        # Validation
        model.eval()
        val_losses = []
        val_dices = []
        # Clear bottleneck_feats before validation
        if hook_bottleneck:
            bottleneck_feats.clear()
        for images, masks in tqdm(val_loader, desc=f"Epoch {epoch}/{num_epochs} [Val]"):
            images = images.to(device, non_blocking=True)
            masks  = masks.to(device, non_blocking=True)
            with torch.no_grad():
                try:
                    logits = model(images)
                    if isinstance(loss_fn, PerceptualLoss):
                        loss = loss_fn(logits, masks, input_image=images[:,0:1,...])
                    else:
                        loss = loss_fn(logits, masks)
                    val_losses.append(loss.item())
                    dice = dice_coeff(logits, masks)
                    val_dices.append(dice)
                except RuntimeError as e:
                    if 'out of memory' in str(e):
                        print(f"WARNING: OOM during validation batch, skipping. Error: {e}")
                        torch.cuda.empty_cache()
                        continue
                    else:
                        raise
        avg_val_loss = np.mean(val_losses) if val_losses else float('nan')
        avg_val_dice = np.mean(val_dices) if val_dices else float('nan')
        print(f"Epoch {epoch}: Train Loss={avg_train_loss:.4f}, Train Dice={avg_train_dice:.4f} | Val Loss={avg_val_loss:.4f}, Val Dice={avg_val_dice:.4f}")

        # Scheduler step
        # Some schedulers expect a metric; ReduceLROnPlateau uses avg_val_loss
        scheduler.step(avg_val_loss)

        # Save best
        if save_path is not None and not np.isnan(avg_val_dice) and avg_val_dice > best_val_dice:
            best_val_dice = avg_val_dice
            torch.save(model.state_dict(), save_path)
            print(f"Saved best model at epoch {epoch}, Val Dice={avg_val_dice:.4f}")
        
        # After each epoch, free any cached GPU memory
        torch.cuda.empty_cache()

        # Optionally: after each epoch or at checkpoints, analyze bottleneck_feats captured for val set
        if hook_bottleneck:
            # bottleneck_feats: list of tensors, each shape [C] if OPTION B used
            # You can store these per-epoch or process them here:
            # For example, stack means: epoch_bott_means = torch.stack(bottleneck_feats)  # [num_batches, C]
            pass

    # Remove hook handle if registered
    if handle is not None:
        handle.remove()

    return model, best_val_dice, bottleneck_feats

In [10]:
import numpy as np
from numpy.linalg import eigh
import matplotlib.pyplot as plt
import torch

def analyze_bottleneck_spectrum(bottleneck_feats, num_components=50):
    """
    bottleneck_feats: list of 1D tensors [C] collected over validation batches.
    We stack them into shape [N, C], compute covariance eigenvalues.
    """
    if len(bottleneck_feats) == 0:
        print("No bottleneck features collected.")
        return None

    # Stack: shape [num_batches, C]
    data = torch.stack(bottleneck_feats, dim=0).numpy()  # [N, C]
    # Zero-mean across batches
    data_mean = data.mean(axis=0, keepdims=True)
    data_centered = data - data_mean  # [N, C]
    # Compute covariance: shape [C,C] might be large. But since N ~ num_batches (e.g. <100), covariance unstable.
    # Instead, we can compute covariance of the batch-means: gives insight into channel variation across val batches.
    cov = np.cov(data_centered, rowvar=False)  # [C, C]
    # Eigenvalues
    vals, _ = eigh(cov)
    vals = np.sort(vals)[::-1]
    vals_norm = vals / (vals.sum() + 1e-12)
    # Plot top components
    plt.figure(figsize=(5,3))
    plt.plot(vals_norm[:num_components], marker='o')
    plt.title("Normalized eigenvalues of bottleneck covariance (channel-means)")
    plt.xlabel("Component index")
    plt.ylabel("Normalized eigenvalue")
    plt.show()
    return vals_norm


In [11]:
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

def analyze_pixel_separation(bottleneck_feats, val_dataset, num_samples=10000):
    """
    Evaluate whether bottleneck embeddings separate tumor vs background.
    Approach: randomly sample pixel embeddings from bottleneck_feats and corresponding downsampled mask.
    val_dataset: dataset to get masks for downsampling; ensure index alignment with bottleneck_feats collection.
    NOTE: This requires storing which slices correspond to which features; simpler: run on a small subset manually.
    For demonstration, do for a single slice:
    """
    # For simplicity, assume we'll analyze one slice: manually pick patient, slice index:
    # E.g., from validation: patient_dir, z; run model forward to get bottleneck_feat [1,C,h,w] and mask_slice.
    # Then run:
    pass


In [12]:
def channel_importance(bottleneck_feat, mask_down):
    """
    bottleneck_feat: [1,C,h,w] tensor on CPU
    mask_down: np array [h,w] binary 0/1
    """
    C, h, w = bottleneck_feat.shape[1:]
    feat = bottleneck_feat[0].view(C, -1).numpy()  # [C, h*w]
    labels = mask_down.reshape(-1)
    mean_in = feat[:, labels==1].mean(axis=1)
    mean_out = feat[:, labels==0].mean(axis=1)
    diff = mean_in - mean_out
    idx_sorted = np.argsort(-np.abs(diff))
    # Return top channels
    return idx_sorted, diff


In [13]:
#-------------------------------Experiments Begin Here On Out-----------------------------------#

In [17]:
# Cell: Experiment 1 setup   
# Baseline Single-Modality U-Net with Standard Losses
modalities = ['flair']
train_ds = BraTSSliceDataset(
    brats_dir,
    patient_list=train_subjects,
    modality_mode=modalities,
    filter_empty=True,
    transforms=random_flip_rotate
)
val_ds   = BraTSSliceDataset(
    brats_dir,
    patient_list=val_subjects,
    modality_mode=modalities,
    filter_empty=False,
    transforms=None
)
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=8, shuffle=False, num_workers=4, pin_memory=True)

loss_fns = {
    'BCE': nn.BCEWithLogitsLoss(),
    'Dice': DiceLoss(),
}

results_exp1 = {}
for name, loss_fn in loss_fns.items():
    print(f"\n=== Experiment 1: U-Net single-modality FLAIR with loss {name} ===")
    model = UNet2D(in_channels=1, out_classes=1)
    save_path = f"unet_flair_{name}.pth"
    trained_model, best_dice, bott_feats = train_model(
        model, train_loader, val_loader, loss_fn,
        num_epochs=30, lr=1e-3, device=device,
        save_path=save_path, hook_bottleneck=True
    )
    print(f"Loss {name}: Best Val Dice = {best_dice:.4f}")
    # Analyze bottleneck representation
    vals_norm = analyze_bottleneck_spectrum(bott_feats, num_components=30)
    results_exp1[name] = {'best_dice': best_dice, 'eigvals': vals_norm}


BraTSSliceDataset: 17178 slices (from 258 patients), modalities=['flair']
BraTSSliceDataset: 17205 slices (from 111 patients), modalities=['flair']

=== Experiment 1: U-Net single-modality FLAIR with loss BCE ===


Epoch 1/30 [Train]: 100%|███████████████████| 2148/2148 [12:10<00:00,  2.94it/s]
Epoch 1/30 [Val]: 100%|█████████████████████| 2151/2151 [06:05<00:00,  5.88it/s]


Epoch 1: Train Loss=0.0857, Train Dice=0.1350 | Val Loss=0.0397, Val Dice=0.0794
Saved best model at epoch 1, Val Dice=0.0794


Epoch 2/30 [Train]: 100%|███████████████████| 2148/2148 [12:04<00:00,  2.96it/s]
Epoch 2/30 [Val]: 100%|█████████████████████| 2151/2151 [05:58<00:00,  6.00it/s]


Epoch 2: Train Loss=0.0817, Train Dice=0.1475 | Val Loss=0.0362, Val Dice=0.0972
Saved best model at epoch 2, Val Dice=0.0972


Epoch 3/30 [Train]: 100%|███████████████████| 2148/2148 [11:56<00:00,  3.00it/s]
Epoch 3/30 [Val]: 100%|█████████████████████| 2151/2151 [06:05<00:00,  5.88it/s]


Epoch 3: Train Loss=0.0809, Train Dice=0.1518 | Val Loss=0.0382, Val Dice=0.1078
Saved best model at epoch 3, Val Dice=0.1078


Epoch 4/30 [Train]: 100%|███████████████████| 2148/2148 [11:57<00:00,  2.99it/s]
Epoch 4/30 [Val]: 100%|█████████████████████| 2151/2151 [06:25<00:00,  5.58it/s]


Epoch 4: Train Loss=0.0805, Train Dice=0.1534 | Val Loss=0.0353, Val Dice=0.0733


Epoch 5/30 [Train]: 100%|███████████████████| 2148/2148 [11:54<00:00,  3.01it/s]
Epoch 5/30 [Val]: 100%|█████████████████████| 2151/2151 [06:04<00:00,  5.91it/s]


Epoch 5: Train Loss=0.0803, Train Dice=0.1539 | Val Loss=0.0325, Val Dice=0.1014


Epoch 6/30 [Train]: 100%|███████████████████| 2148/2148 [11:50<00:00,  3.02it/s]
Epoch 6/30 [Val]: 100%|█████████████████████| 2151/2151 [05:58<00:00,  6.00it/s]


Epoch 6: Train Loss=0.0800, Train Dice=0.1555 | Val Loss=0.0330, Val Dice=0.1073


Epoch 7/30 [Train]: 100%|███████████████████| 2148/2148 [11:55<00:00,  3.00it/s]
Epoch 7/30 [Val]: 100%|█████████████████████| 2151/2151 [06:03<00:00,  5.92it/s]


Epoch 7: Train Loss=0.0790, Train Dice=0.1613 | Val Loss=0.0314, Val Dice=0.1050


Epoch 8/30 [Train]: 100%|███████████████████| 2148/2148 [11:54<00:00,  3.01it/s]
Epoch 8/30 [Val]: 100%|█████████████████████| 2151/2151 [05:54<00:00,  6.07it/s]


Epoch 8: Train Loss=0.0802, Train Dice=0.1550 | Val Loss=0.0340, Val Dice=0.1029


Epoch 9/30 [Train]: 100%|███████████████████| 2148/2148 [11:48<00:00,  3.03it/s]
Epoch 9/30 [Val]: 100%|█████████████████████| 2151/2151 [05:52<00:00,  6.10it/s]


Epoch 9: Train Loss=0.0798, Train Dice=0.1555 | Val Loss=0.0323, Val Dice=0.1015


Epoch 10/30 [Train]: 100%|██████████████████| 2148/2148 [11:47<00:00,  3.04it/s]
Epoch 10/30 [Val]: 100%|████████████████████| 2151/2151 [05:52<00:00,  6.10it/s]


Epoch 10: Train Loss=0.0797, Train Dice=0.1569 | Val Loss=0.0313, Val Dice=0.1116
Saved best model at epoch 10, Val Dice=0.1116


Epoch 11/30 [Train]: 100%|██████████████████| 2148/2148 [11:39<00:00,  3.07it/s]
Epoch 11/30 [Val]: 100%|████████████████████| 2151/2151 [05:46<00:00,  6.22it/s]


Epoch 11: Train Loss=0.0797, Train Dice=0.1576 | Val Loss=0.0374, Val Dice=0.1097


Epoch 12/30 [Train]: 100%|██████████████████| 2148/2148 [11:40<00:00,  3.07it/s]
Epoch 12/30 [Val]: 100%|████████████████████| 2151/2151 [05:46<00:00,  6.22it/s]


Epoch 12: Train Loss=0.0792, Train Dice=0.1600 | Val Loss=0.0342, Val Dice=0.1082


Epoch 13/30 [Train]: 100%|██████████████████| 2148/2148 [11:39<00:00,  3.07it/s]
Epoch 13/30 [Val]: 100%|████████████████████| 2151/2151 [05:46<00:00,  6.21it/s]


Epoch 13: Train Loss=0.0793, Train Dice=0.1594 | Val Loss=0.0317, Val Dice=0.1000


Epoch 14/30 [Train]: 100%|██████████████████| 2148/2148 [11:39<00:00,  3.07it/s]
Epoch 14/30 [Val]: 100%|████████████████████| 2151/2151 [05:45<00:00,  6.23it/s]


Epoch 14: Train Loss=0.0793, Train Dice=0.1594 | Val Loss=0.0321, Val Dice=0.1083


Epoch 15/30 [Train]: 100%|██████████████████| 2148/2148 [11:39<00:00,  3.07it/s]
Epoch 15/30 [Val]: 100%|████████████████████| 2151/2151 [05:45<00:00,  6.22it/s]


Epoch 15: Train Loss=0.0794, Train Dice=0.1596 | Val Loss=0.0342, Val Dice=0.0973


Epoch 16/30 [Train]: 100%|██████████████████| 2148/2148 [11:40<00:00,  3.07it/s]
Epoch 16/30 [Val]: 100%|████████████████████| 2151/2151 [05:45<00:00,  6.23it/s]


Epoch 16: Train Loss=0.0797, Train Dice=0.1564 | Val Loss=0.0358, Val Dice=0.1123
Saved best model at epoch 16, Val Dice=0.1123


Epoch 17/30 [Train]: 100%|██████████████████| 2148/2148 [11:40<00:00,  3.07it/s]
Epoch 17/30 [Val]: 100%|████████████████████| 2151/2151 [05:46<00:00,  6.21it/s]


Epoch 17: Train Loss=0.0790, Train Dice=0.1596 | Val Loss=0.0334, Val Dice=0.1142
Saved best model at epoch 17, Val Dice=0.1142


Epoch 18/30 [Train]: 100%|██████████████████| 2148/2148 [11:41<00:00,  3.06it/s]
Epoch 18/30 [Val]: 100%|████████████████████| 2151/2151 [05:45<00:00,  6.22it/s]


Epoch 18: Train Loss=0.0792, Train Dice=0.1592 | Val Loss=0.0328, Val Dice=0.1076


Epoch 19/30 [Train]: 100%|██████████████████| 2148/2148 [11:40<00:00,  3.07it/s]
Epoch 19/30 [Val]: 100%|████████████████████| 2151/2151 [05:44<00:00,  6.25it/s]


Epoch 19: Train Loss=0.0789, Train Dice=0.1615 | Val Loss=0.0326, Val Dice=0.1135


Epoch 20/30 [Train]: 100%|██████████████████| 2148/2148 [11:41<00:00,  3.06it/s]
Epoch 20/30 [Val]: 100%|████████████████████| 2151/2151 [05:43<00:00,  6.26it/s]


Epoch 20: Train Loss=0.0791, Train Dice=0.1609 | Val Loss=0.0303, Val Dice=0.1177
Saved best model at epoch 20, Val Dice=0.1177


Epoch 21/30 [Train]: 100%|██████████████████| 2148/2148 [11:42<00:00,  3.06it/s]
Epoch 21/30 [Val]: 100%|████████████████████| 2151/2151 [05:44<00:00,  6.24it/s]


Epoch 21: Train Loss=0.0789, Train Dice=0.1618 | Val Loss=0.0314, Val Dice=0.1032


Epoch 22/30 [Train]: 100%|██████████████████| 2148/2148 [11:39<00:00,  3.07it/s]
Epoch 22/30 [Val]: 100%|████████████████████| 2151/2151 [05:45<00:00,  6.23it/s]


Epoch 22: Train Loss=0.0790, Train Dice=0.1600 | Val Loss=0.0326, Val Dice=0.1040


Epoch 23/30 [Train]: 100%|██████████████████| 2148/2148 [11:40<00:00,  3.07it/s]
Epoch 23/30 [Val]: 100%|████████████████████| 2151/2151 [05:47<00:00,  6.20it/s]


Epoch 23: Train Loss=0.0794, Train Dice=0.1587 | Val Loss=0.0355, Val Dice=0.0936


Epoch 24/30 [Train]: 100%|██████████████████| 2148/2148 [11:40<00:00,  3.07it/s]
Epoch 24/30 [Val]: 100%|████████████████████| 2151/2151 [05:46<00:00,  6.20it/s]


Epoch 24: Train Loss=0.0791, Train Dice=0.1604 | Val Loss=0.0333, Val Dice=0.1025


Epoch 25/30 [Train]: 100%|██████████████████| 2148/2148 [11:41<00:00,  3.06it/s]
Epoch 25/30 [Val]: 100%|████████████████████| 2151/2151 [05:45<00:00,  6.22it/s]


Epoch 25: Train Loss=0.0785, Train Dice=0.1631 | Val Loss=0.0332, Val Dice=0.1065


Epoch 26/30 [Train]: 100%|██████████████████| 2148/2148 [11:40<00:00,  3.06it/s]
Epoch 26/30 [Val]: 100%|████████████████████| 2151/2151 [05:44<00:00,  6.24it/s]


Epoch 26: Train Loss=0.0785, Train Dice=0.1637 | Val Loss=0.0314, Val Dice=0.1092


Epoch 27/30 [Train]: 100%|██████████████████| 2148/2148 [11:39<00:00,  3.07it/s]
Epoch 27/30 [Val]: 100%|████████████████████| 2151/2151 [05:45<00:00,  6.22it/s]


Epoch 27: Train Loss=0.0782, Train Dice=0.1650 | Val Loss=0.0336, Val Dice=0.1149


Epoch 28/30 [Train]: 100%|██████████████████| 2148/2148 [11:39<00:00,  3.07it/s]
Epoch 28/30 [Val]: 100%|████████████████████| 2151/2151 [05:46<00:00,  6.20it/s]


Epoch 28: Train Loss=0.0784, Train Dice=0.1656 | Val Loss=0.0319, Val Dice=0.1080


Epoch 29/30 [Train]: 100%|██████████████████| 2148/2148 [11:39<00:00,  3.07it/s]
Epoch 29/30 [Val]: 100%|████████████████████| 2151/2151 [05:45<00:00,  6.22it/s]


Epoch 29: Train Loss=0.0780, Train Dice=0.1671 | Val Loss=0.0314, Val Dice=0.1149


Epoch 30/30 [Train]: 100%|██████████████████| 2148/2148 [11:39<00:00,  3.07it/s]
Epoch 30/30 [Val]: 100%|████████████████████| 2151/2151 [05:44<00:00,  6.24it/s]


Epoch 30: Train Loss=0.0783, Train Dice=0.1654 | Val Loss=0.0340, Val Dice=0.1119
Loss BCE: Best Val Dice = 0.1177
No bottleneck features collected.

=== Experiment 1: U-Net single-modality FLAIR with loss Dice ===


Epoch 1/30 [Train]: 100%|███████████████████| 2148/2148 [12:00<00:00,  2.98it/s]
Epoch 1/30 [Val]: 100%|█████████████████████| 2151/2151 [05:48<00:00,  6.17it/s]


Epoch 1: Train Loss=0.9991, Train Dice=0.0009 | Val Loss=0.4229, Val Dice=0.5771
Saved best model at epoch 1, Val Dice=0.5771


Epoch 2/30 [Train]: 100%|███████████████████| 2148/2148 [11:51<00:00,  3.02it/s]
Epoch 2/30 [Val]: 100%|█████████████████████| 2151/2151 [05:47<00:00,  6.19it/s]


Epoch 2: Train Loss=0.9989, Train Dice=0.0011 | Val Loss=0.4210, Val Dice=0.5790
Saved best model at epoch 2, Val Dice=0.5790


Epoch 3/30 [Train]: 100%|███████████████████| 2148/2148 [11:48<00:00,  3.03it/s]
Epoch 3/30 [Val]: 100%|█████████████████████| 2151/2151 [05:46<00:00,  6.20it/s]


Epoch 3: Train Loss=1.0000, Train Dice=0.0000 | Val Loss=0.4210, Val Dice=0.5790


Epoch 4/30 [Train]: 100%|███████████████████| 2148/2148 [11:47<00:00,  3.04it/s]
Epoch 4/30 [Val]: 100%|█████████████████████| 2151/2151 [05:47<00:00,  6.19it/s]


Epoch 4: Train Loss=0.9991, Train Dice=0.0009 | Val Loss=0.4210, Val Dice=0.5790


Epoch 5/30 [Train]: 100%|███████████████████| 2148/2148 [11:50<00:00,  3.03it/s]
Epoch 5/30 [Val]: 100%|█████████████████████| 2151/2151 [05:47<00:00,  6.20it/s]


Epoch 5: Train Loss=1.0000, Train Dice=0.0000 | Val Loss=0.8999, Val Dice=0.1001


Epoch 6/30 [Train]: 100%|███████████████████| 2148/2148 [11:47<00:00,  3.03it/s]
Epoch 6/30 [Val]: 100%|█████████████████████| 2151/2151 [05:48<00:00,  6.17it/s]


Epoch 6: Train Loss=0.9986, Train Dice=0.0014 | Val Loss=0.4210, Val Dice=0.5790


Epoch 7/30 [Train]: 100%|███████████████████| 2148/2148 [11:49<00:00,  3.03it/s]
Epoch 7/30 [Val]: 100%|█████████████████████| 2151/2151 [05:47<00:00,  6.19it/s]


Epoch 7: Train Loss=0.9950, Train Dice=0.0050 | Val Loss=0.9787, Val Dice=0.0213


Epoch 8/30 [Train]: 100%|███████████████████| 2148/2148 [11:50<00:00,  3.02it/s]
Epoch 8/30 [Val]: 100%|█████████████████████| 2151/2151 [05:48<00:00,  6.17it/s]


Epoch 8: Train Loss=0.9955, Train Dice=0.0045 | Val Loss=0.4210, Val Dice=0.5790


Epoch 9/30 [Train]: 100%|███████████████████| 2148/2148 [11:50<00:00,  3.02it/s]
Epoch 9/30 [Val]: 100%|█████████████████████| 2151/2151 [05:48<00:00,  6.17it/s]


Epoch 9: Train Loss=1.0000, Train Dice=0.0000 | Val Loss=0.9133, Val Dice=0.0867


Epoch 10/30 [Train]: 100%|██████████████████| 2148/2148 [11:50<00:00,  3.02it/s]
Epoch 10/30 [Val]: 100%|████████████████████| 2151/2151 [05:47<00:00,  6.19it/s]


Epoch 10: Train Loss=0.9962, Train Dice=0.0038 | Val Loss=0.9998, Val Dice=0.0002


Epoch 11/30 [Train]: 100%|██████████████████| 2148/2148 [11:51<00:00,  3.02it/s]
Epoch 11/30 [Val]: 100%|████████████████████| 2151/2151 [05:48<00:00,  6.18it/s]


Epoch 11: Train Loss=0.9855, Train Dice=0.0145 | Val Loss=0.9406, Val Dice=0.0594


Epoch 12/30 [Train]: 100%|██████████████████| 2148/2148 [11:52<00:00,  3.02it/s]
Epoch 12/30 [Val]: 100%|████████████████████| 2151/2151 [05:48<00:00,  6.17it/s]


Epoch 12: Train Loss=1.0000, Train Dice=0.0000 | Val Loss=1.0000, Val Dice=0.0000


Epoch 13/30 [Train]: 100%|██████████████████| 2148/2148 [11:52<00:00,  3.02it/s]
Epoch 13/30 [Val]: 100%|████████████████████| 2151/2151 [05:51<00:00,  6.13it/s]


Epoch 13: Train Loss=0.9998, Train Dice=0.0002 | Val Loss=1.0000, Val Dice=0.0000


Epoch 14/30 [Train]: 100%|██████████████████| 2148/2148 [11:51<00:00,  3.02it/s]
Epoch 14/30 [Val]: 100%|████████████████████| 2151/2151 [05:44<00:00,  6.25it/s]


Epoch 14: Train Loss=0.9989, Train Dice=0.0011 | Val Loss=0.7450, Val Dice=0.2550


Epoch 15/30 [Train]: 100%|██████████████████| 2148/2148 [11:50<00:00,  3.02it/s]
Epoch 15/30 [Val]: 100%|████████████████████| 2151/2151 [05:50<00:00,  6.15it/s]


Epoch 15: Train Loss=1.0000, Train Dice=0.0000 | Val Loss=1.0000, Val Dice=0.0000


Epoch 16/30 [Train]: 100%|██████████████████| 2148/2148 [11:49<00:00,  3.03it/s]
Epoch 16/30 [Val]: 100%|████████████████████| 2151/2151 [05:48<00:00,  6.17it/s]


Epoch 16: Train Loss=0.9982, Train Dice=0.0018 | Val Loss=0.9794, Val Dice=0.0206


Epoch 17/30 [Train]: 100%|██████████████████| 2148/2148 [11:51<00:00,  3.02it/s]
Epoch 17/30 [Val]: 100%|████████████████████| 2151/2151 [05:47<00:00,  6.20it/s]


Epoch 17: Train Loss=1.0000, Train Dice=0.0000 | Val Loss=0.9998, Val Dice=0.0002


Epoch 18/30 [Train]: 100%|██████████████████| 2148/2148 [11:50<00:00,  3.02it/s]
Epoch 18/30 [Val]: 100%|████████████████████| 2151/2151 [05:49<00:00,  6.16it/s]


Epoch 18: Train Loss=0.7899, Train Dice=0.2101 | Val Loss=0.6962, Val Dice=0.3038


Epoch 19/30 [Train]: 100%|██████████████████| 2148/2148 [11:52<00:00,  3.01it/s]
Epoch 19/30 [Val]: 100%|████████████████████| 2151/2151 [05:47<00:00,  6.19it/s]


Epoch 19: Train Loss=0.6937, Train Dice=0.3063 | Val Loss=0.7527, Val Dice=0.2473


Epoch 20/30 [Train]: 100%|██████████████████| 2148/2148 [11:52<00:00,  3.01it/s]
Epoch 20/30 [Val]: 100%|████████████████████| 2151/2151 [05:49<00:00,  6.16it/s]


Epoch 20: Train Loss=0.6877, Train Dice=0.3123 | Val Loss=0.7496, Val Dice=0.2504


Epoch 21/30 [Train]: 100%|██████████████████| 2148/2148 [11:50<00:00,  3.02it/s]
Epoch 21/30 [Val]: 100%|████████████████████| 2151/2151 [05:50<00:00,  6.14it/s]


Epoch 21: Train Loss=0.6850, Train Dice=0.3150 | Val Loss=0.7489, Val Dice=0.2511


Epoch 22/30 [Train]: 100%|██████████████████| 2148/2148 [11:52<00:00,  3.02it/s]
Epoch 22/30 [Val]: 100%|████████████████████| 2151/2151 [05:48<00:00,  6.16it/s]


Epoch 22: Train Loss=0.6827, Train Dice=0.3173 | Val Loss=0.7444, Val Dice=0.2556


Epoch 23/30 [Train]: 100%|██████████████████| 2148/2148 [11:51<00:00,  3.02it/s]
Epoch 23/30 [Val]: 100%|████████████████████| 2151/2151 [05:48<00:00,  6.16it/s]


Epoch 23: Train Loss=0.6834, Train Dice=0.3166 | Val Loss=0.7464, Val Dice=0.2536


Epoch 24/30 [Train]: 100%|██████████████████| 2148/2148 [11:51<00:00,  3.02it/s]
Epoch 24/30 [Val]: 100%|████████████████████| 2151/2151 [05:45<00:00,  6.22it/s]


Epoch 24: Train Loss=0.6776, Train Dice=0.3224 | Val Loss=0.7471, Val Dice=0.2529


Epoch 25/30 [Train]: 100%|██████████████████| 2148/2148 [11:51<00:00,  3.02it/s]
Epoch 25/30 [Val]: 100%|████████████████████| 2151/2151 [05:48<00:00,  6.17it/s]


Epoch 25: Train Loss=0.6741, Train Dice=0.3259 | Val Loss=0.7421, Val Dice=0.2579


Epoch 26/30 [Train]: 100%|██████████████████| 2148/2148 [11:52<00:00,  3.01it/s]
Epoch 26/30 [Val]: 100%|████████████████████| 2151/2151 [05:48<00:00,  6.18it/s]


Epoch 26: Train Loss=0.6795, Train Dice=0.3205 | Val Loss=0.7412, Val Dice=0.2588


Epoch 27/30 [Train]: 100%|██████████████████| 2148/2148 [11:51<00:00,  3.02it/s]
Epoch 27/30 [Val]: 100%|████████████████████| 2151/2151 [05:49<00:00,  6.15it/s]


Epoch 27: Train Loss=0.6825, Train Dice=0.3175 | Val Loss=0.7411, Val Dice=0.2589


Epoch 28/30 [Train]: 100%|██████████████████| 2148/2148 [11:52<00:00,  3.01it/s]
Epoch 28/30 [Val]: 100%|████████████████████| 2151/2151 [05:48<00:00,  6.17it/s]


Epoch 28: Train Loss=0.6743, Train Dice=0.3257 | Val Loss=0.7416, Val Dice=0.2584


Epoch 29/30 [Train]: 100%|██████████████████| 2148/2148 [11:52<00:00,  3.01it/s]
Epoch 29/30 [Val]: 100%|████████████████████| 2151/2151 [05:49<00:00,  6.16it/s]


Epoch 29: Train Loss=0.6778, Train Dice=0.3222 | Val Loss=0.7414, Val Dice=0.2586


Epoch 30/30 [Train]: 100%|██████████████████| 2148/2148 [11:50<00:00,  3.02it/s]
Epoch 30/30 [Val]: 100%|████████████████████| 2151/2151 [05:48<00:00,  6.17it/s]


Epoch 30: Train Loss=0.6820, Train Dice=0.3180 | Val Loss=0.7400, Val Dice=0.2600
Loss Dice: Best Val Dice = 0.5790
No bottleneck features collected.
