In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from torchvision.models import MobileNet_V3_Small_Weights
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from skimage import io, exposure
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import torch.nn.functional as F
from collections import OrderedDict
import torchvision.transforms as transforms

In [None]:
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=4):
        super().__init__()
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // reduction, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return x * self.fc(x)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        # Compute channel-wise max and mean
        max_pool = torch.max(x, dim=1, keepdim=True)[0]
        mean_pool = torch.mean(x, dim=1, keepdim=True)
        attn = torch.cat([max_pool, mean_pool], dim=1)
        attn = self.sigmoid(self.conv(attn))
        return x * attn

class DecoderBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch, use_se=True, use_sa=True):
        super().__init__()
       
        self.conv1 = nn.Conv2d(in_ch + skip_ch, out_ch, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)
        self.se = SEBlock(out_ch) if use_se else nn.Identity()
        self.sa = SpatialAttention() if use_sa else nn.Identity()
    def forward(self, x, skip):
        x = F.interpolate(x, size=skip.shape[2:], mode="bilinear", align_corners=False)
        x = torch.cat([x, skip], dim=1)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.se(x) # channel attention
        x = self.sa(x) # spatial attention
        return x

class LMA_UNet(nn.Module):
    def __init__(self, n_classes=2, in_channels=4):  # Changed to n_classes=2 for binary
        super().__init__()
        self.n_classes = n_classes
        mobilenet = models.mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT)
        if in_channels !=3:
            mobilenet.features[0][0] = nn.Conv2d(in_channels,16,3,2,1,bias=False)
        self.encoder = mobilenet.features
        self.skip_idx = [0,1,2,8]
        self.dec4 = DecoderBlock(576, 48, 256, use_se=True, use_sa=True)
        self.dec3 = DecoderBlock(256, 24, 128, use_se=True, use_sa=True)
        self.dec2 = DecoderBlock(128, 16, 64, use_se=True, use_sa=True)
        self.dec1 = DecoderBlock(64, 16, 32, use_se=True, use_sa=True)
        self.head = nn.Conv2d(32,n_classes,1)
    def forward(self,x):
        skips = []
        for i,l in enumerate(self.encoder):
            x = l(x)
            if i in self.skip_idx: skips.append(x)
        bottleneck = x
        skips = skips[::-1]
        x = self.dec4(bottleneck, skips[0])
        x = self.dec3(x, skips[1])
        x = self.dec2(x, skips[2])
        x = self.dec1(x, skips[3])
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = self.head(x)
        return x

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
# Binary Dice Coefficient
def dice_coefficient(pred, target, smooth=1e-6):
    pred = torch.softmax(pred, dim=1)[:, 1]  # Probability for tumor class (1)
    pred = (pred > 0.5).float()  # Threshold to binary
    target = (target == 1).float()  # Binary target: 1 for tumor, 0 for background
    
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    dice = (2 * intersection + smooth) / (union + smooth)
    return dice

# Binary IoU
def iou_score(pred, target, smooth=1e-6):
    pred = torch.softmax(pred, dim=1)[:, 1]  # Probability for tumor class
    pred = (pred > 0.5).float()
    target = (target == 1).float()
    
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    iou = (intersection + smooth) / (union + smooth)
    return iou

# Binary Sensitivity (Recall)
def sensitivity_score(pred, target, smooth=1e-6):
    pred = torch.softmax(pred, dim=1)[:, 1]
    pred = (pred > 0.5).float()
    target = (target == 1).float()
    
    tp = (pred * target).sum()
    fn = target.sum() - tp
    sens = tp / (tp + fn + smooth)
    return sens

# Binary loss: Dice + BCE
class DiceBCELoss(nn.Module):
    def __init__(self, dice_weight=1.0, bce_weight=1.0, smooth=1e-6):
        super().__init__()
        self.dice_weight = dice_weight
        self.bce_weight = bce_weight
        self.smooth = smooth
        self.bce = nn.CrossEntropyLoss()

    def forward(self, pred, target):
        bce_loss = self.bce(pred, target)
        
        pred_soft = torch.softmax(pred, dim=1)[:, 1]
        target_bin = (target == 1).float()
        
        intersection = (pred_soft * target_bin).sum()
        union = pred_soft.sum() + target_bin.sum()
        dice_loss = 1 - (2 * intersection + self.smooth) / (union + self.smooth)
        
        return self.dice_weight * dice_loss + self.bce_weight * bce_loss

In [None]:
class BRISCDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = sorted(os.listdir(image_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        # Extract base name from image (.jpg) and append .png for mask
        base_name = os.path.splitext(self.images[idx])[0]
        mask_filename = base_name + '.png'
        mask_path = os.path.join(self.mask_dir, mask_filename)
        
        # Safety check
        if not os.path.exists(mask_path):
            raise FileNotFoundError(f"Mask not found: {mask_path} (for image: {img_path})")
        
        image = io.imread(img_path)  # Shape: (H, W) for grayscale MRI
        mask = io.imread(mask_path, as_gray=True).astype(np.uint8)  # Shape: (H, W), original labels 0-3
        
        # Remap to binary: 0 -> 0 (background), 1/2/3 -> 1 (tumor)
        mask = np.where(mask > 0, 1, 0).astype(np.uint8)
        
        # CHANGE: Replicate grayscale to 3 channels for CLAHE compatibility (was 4)
        if len(image.shape) == 2:
            image = np.stack([image] * 3, axis=-1)
        
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask'].long()  # Convert to long after ToTensorV2
        else:
            # Manual conversion if no transform
            image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255.0
            mask = torch.from_numpy(mask).long()
        
        return image, mask

In [None]:
IMG_SIZE = 256  # Or 512, whatever you want

train_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=1.0),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ToTensorV2()
])


In [None]:
train_image_dir = '/kaggle/input/brisc-cleaned/train/images'
train_mask_dir = '/kaggle/input/brisc-cleaned/train/masks'
val_image_dir = '/kaggle/input/brisc-cleaned/test/images'
val_mask_dir = '/kaggle/input/brisc-cleaned/test/masks'

In [None]:
train_dataset = BRISCDataset(train_image_dir, train_mask_dir, transform=train_transform)
val_dataset = BRISCDataset(val_image_dir, val_mask_dir, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
print(f"Train dataset: {len(train_dataset)}")
print(f"Test dataset: {len(val_dataset)}")

In [None]:
model = LMA_UNet(n_classes=2, in_channels=3).to(device)
criterion = DiceBCELoss(dice_weight=1.5, bce_weight=1.0)  # Heavier on Dice for imbalance
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True)

In [None]:
num_epochs = 100
best_dice = 0.0
patience = 15
early_stop_counter = 0

In [None]:
train_losses = []
train_dices = []
val_losses = []
val_dices = []
val_ious = []
val_sens_list = []

best_dice = 0.0
early_stop_counter = 0

In [None]:
# for epoch in range(num_epochs):
#     # Train
#     model.train()
#     train_loss = 0.0
#     train_dice = 0.0
#     pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
#     for images, masks in pbar:
#         images, masks = images.to(device), masks.to(device)
        
#         optimizer.zero_grad()
#         outputs = model(images)
#         loss = criterion(outputs, masks)
#         loss.backward()
#         optimizer.step()
        
#         train_loss += loss.item()
#         train_dice += dice_coefficient(outputs, masks).item()
        
#         pbar.set_postfix({'Loss': f'{loss.item():.4f}', 'Dice': f'{dice_coefficient(outputs, masks).item():.4f}'})
    
#     train_loss /= len(train_loader)
#     train_dice /= len(train_loader)
    
#     # Validate
#     model.eval()
#     val_loss = 0.0
#     val_dice = 0.0
#     val_iou = 0.0
#     val_sens = 0.0
#     with torch.no_grad():
#         for images, masks in val_loader:
#             images, masks = images.to(device), masks.to(device)
#             outputs = model(images)
#             loss = criterion(outputs, masks)
            
#             val_loss += loss.item()
#             val_dice += dice_coefficient(outputs, masks).item()
#             val_iou += iou_score(outputs, masks).item()
#             val_sens += sensitivity_score(outputs, masks).item()
    
#     val_loss /= len(val_loader)
#     val_dice /= len(val_loader)
#     val_iou /= len(val_loader)
#     val_sens /= len(val_loader)
    
#     scheduler.step(val_dice)
    
#     # Store metrics
#     train_losses.append(train_loss)
#     train_dices.append(train_dice)
#     val_losses.append(val_loss)
#     val_dices.append(val_dice)
#     val_ious.append(val_iou)
#     val_sens.append(val_sens)
    
#     print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Dice: {train_dice:.4f}')
#     print(f'Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}, Val IoU: {val_iou:.4f}, Val Sens: {val_sens:.4f}')
    
#     # Save best model
#     if val_dice > best_dice:
#         best_dice = val_dice
#         torch.save({
#             'epoch': epoch,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'best_dice': best_dice,
#         }, 'best_lma_unet_binary.pth')
#         early_stop_counter = 0
#         print(f'New best Dice: {best_dice:.4f} - Model saved!')
#     else:
#         early_stop_counter += 1
    
#     # Early stopping
#     if early_stop_counter >= patience:
#         print('Early stopping triggered!')
#         break

# print(f'Training completed. Best validation Dice: {best_dice:.4f}')

for epoch in range(num_epochs):
    # ----------------- TRAIN ----------------- #
    model.train()
    train_loss_epoch = 0.0
    train_dice_epoch = 0.0
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
    
    for images, masks in pbar:
        images, masks = images.to(device), masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        train_loss_epoch += loss.item()
        train_dice_epoch += dice_coefficient(outputs, masks).item()
        
        pbar.set_postfix({'Loss': f'{loss.item():.4f}', 
                          'Dice': f'{dice_coefficient(outputs, masks).item():.4f}'})
    
    train_loss_epoch /= len(train_loader)
    train_dice_epoch /= len(train_loader)
    
    # ----------------- VALIDATE ----------------- #
    model.eval()
    val_loss_epoch = 0.0
    val_dice_epoch = 0.0
    val_iou_epoch = 0.0
    val_sens_epoch = 0.0
    
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            val_loss_epoch += loss.item()
            val_dice_epoch += dice_coefficient(outputs, masks).item()
            val_iou_epoch += iou_score(outputs, masks).item()
            val_sens_epoch += sensitivity_score(outputs, masks).item()
    
    val_loss_epoch /= len(val_loader)
    val_dice_epoch /= len(val_loader)
    val_iou_epoch /= len(val_loader)
    val_sens_epoch /= len(val_loader)
    
    scheduler.step(val_dice_epoch)
    
    # ----------------- STORE METRICS ----------------- #
    train_losses.append(train_loss_epoch)
    train_dices.append(train_dice_epoch)
    val_losses.append(val_loss_epoch)
    val_dices.append(val_dice_epoch)
    val_ious.append(val_iou_epoch)
    val_sens_list.append(val_sens_epoch)
    
    # ----------------- PRINT ----------------- #
    print(f'Epoch {epoch+1}: Train Loss: {train_loss_epoch:.4f}, Train Dice: {train_dice_epoch:.4f}')
    print(f'Val Loss: {val_loss_epoch:.4f}, Val Dice: {val_dice_epoch:.4f}, Val IoU: {val_iou_epoch:.4f}, Val Sens: {val_sens_epoch:.4f}')
    
    # ----------------- SAVE BEST MODEL ----------------- #
    if val_dice_epoch > best_dice:
        best_dice = val_dice_epoch
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_dice': best_dice,
        }, 'best_lma_unet_binary.pth')
        early_stop_counter = 0
        print(f'New best Dice: {best_dice:.4f} - Model saved!')
    else:
        early_stop_counter += 1
    
    # ----------------- EARLY STOPPING ----------------- #
    if early_stop_counter >= patience:
        print('Early stopping triggered!')
        break

print(f'Training completed. Best validation Dice: {best_dice:.4f}')

In [None]:
epochs = range(1, len(train_losses) + 1)

fig, axs = plt.subplots(2, 3, figsize=(18, 12))

# Train and Val Loss
axs[0, 0].plot(epochs, train_losses, 'b-', label='Train Loss')
axs[0, 0].plot(epochs, val_losses, 'r-', label='Val Loss')
axs[0, 0].set_title('Loss Curves')
axs[0, 0].set_xlabel('Epochs')
axs[0, 0].set_ylabel('Loss')
axs[0, 0].legend()
axs[0, 0].grid(True)

# Train and Val Dice
axs[0, 1].plot(epochs, train_dices, 'b-', label='Train Dice')
axs[0, 1].plot(epochs, val_dices, 'r-', label='Val Dice')
axs[0, 1].set_title('Dice Coefficient Curves')
axs[0, 1].set_xlabel('Epochs')
axs[0, 1].set_ylabel('Dice')
axs[0, 1].legend()
axs[0, 1].grid(True)

# Val IoU
axs[0, 2].plot(epochs, val_ious, 'g-', label='Val IoU')
axs[0, 2].set_title('IoU Curve')
axs[0, 2].set_xlabel('Epochs')
axs[0, 2].set_ylabel('IoU')
axs[0, 2].legend()
axs[0, 2].grid(True)

# Val Sensitivity
axs[1, 0].plot(epochs, val_sens_list, 'm-', label='Val Sensitivity')
axs[1, 0].set_title('Sensitivity Curve')
axs[1, 0].set_xlabel('Epochs')
axs[1, 0].set_ylabel('Sensitivity')
axs[1, 0].legend()
axs[1, 0].grid(True)

# Combined Metrics (Val Dice, IoU, Sens)
axs[1, 1].plot(epochs, val_dices, 'r-', label='Val Dice')
axs[1, 1].plot(epochs, val_ious, 'g-', label='Val IoU')
axs[1, 1].plot(epochs, val_sens_list, 'm-', label='Val Sensitivity')
axs[1, 1].set_title('Validation Metrics Comparison')
axs[1, 1].set_xlabel('Epochs')
axs[1, 1].set_ylabel('Score')
axs[1, 1].legend()
axs[1, 1].grid(True)

plt.tight_layout()
plt.savefig('training_metrics_binary.png')
plt.show()

In [None]:
import random

num_samples = 4
indices = random.sample(range(len(val_dataset)), num_samples)

with torch.no_grad():
    for k, idx in enumerate(indices):
        image, mask = val_dataset[idx]

        image = image.unsqueeze(0).to(device)  # add batch dimension
        mask = mask.to(device)

        output = model(image)
        pred = torch.argmax(output, dim=1)[0].cpu().numpy()

        # denormalize single image
        orig = denormalize_image(image[0])[:3].permute(1,2,0).numpy()

        # plot
        fig, axs = plt.subplots(1,3, figsize=(15,5))
        axs[0].imshow(orig)
        axs[0].set_title(f"Original Image #{k+1}")
        axs[0].axis('off')

        axs[1].imshow(mask.cpu().numpy(), cmap='gray')
        axs[1].set_title("Ground Truth")
        axs[1].axis('off')

        axs[2].imshow(pred, cmap='gray')
        axs[2].set_title("Prediction")
        axs[2].axis('off')

        plt.show()

In [None]:
# -----------------------
# DICE FUNCTION
# -----------------------
def dice_score(pred, target, eps=1e-6):
    # pred, target: numpy arrays of shape [H, W] with 0/1 values
    pred = torch.tensor(pred).float()
    target = torch.tensor(target).float()

    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()

    dice = (2 * intersection + eps) / (union + eps)
    return dice.item()


# -----------------------
# LOAD MODEL
# -----------------------
checkpoint = torch.load('best_lma_unet_binary.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Inverse normalization
mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
std  = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)

def denormalize_image(img_tensor):
    img = img_tensor.detach().cpu().clone()
    img[:3] = img[:3] * std + mean
    return torch.clamp(img, 0, 1)


# -----------------------
# VISUALIZATION + DICE
# -----------------------
num_samples = 4
indices = random.sample(range(len(val_dataset)), num_samples)

with torch.no_grad():
    for k, idx in enumerate(indices):
        image, mask = val_dataset[idx]   # no batch dimension

        image_batch = image.unsqueeze(0).to(device)
        mask = mask.to(device)

        output = model(image_batch)
        pred = torch.argmax(output, dim=1)[0].cpu().numpy()

        # Convert GT mask to numpy
        gt = mask.cpu().numpy()

        # Compute Dice
        dice = dice_score(pred, gt)

        # Denormalized original
        orig = denormalize_image(image)[:3].permute(1,2,0).numpy()

        # ---- PLOT ----
        fig, axs = plt.subplots(1, 3, figsize=(15,5))

        axs[0].imshow(orig)
        axs[0].set_title(f"Original Image #{k+1}")
        axs[0].axis("off")

        axs[1].imshow(gt, cmap="gray")
        axs[1].set_title("Ground Truth")
        axs[1].axis("off")

        axs[2].imshow(pred, cmap="gray")
        axs[2].set_title(f"Prediction\nDice: {dice:.4f}")
        axs[2].axis("off")

        plt.tight_layout()
        plt.show()


In [None]:
with torch.no_grad():
    sample_images = next(iter(val_loader))[0].to(device)
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    _ = model(sample_images)
    end.record()
    torch.cuda.synchronize()
    inference_time_ms = start.elapsed_time(end)
    print(f'Average inference time per image (batch=8): {inference_time_ms / 8:.2f} ms')
    print('Model is lightweight with MobileNetV3 backbone, ensuring fast inference (~10-20 FPS on GPU).')