We trained the final versions with the best parameters we observed

In [None]:
#CODICE UNET CHE VA BENE
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
from skimage import io
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
import wandb  # For logging
from sklearn.metrics import jaccard_score  # For IoU in evaluation

# WandB setup (replace with your key or use env var; init once)
#wandb.login(key='')  # Or export WANDB_API_KEY
wandb.init(project='Zuliani1_Marchetto2', name='efficientnet-b2_U-Net_MRI_Segmentation')  # Customize name

# 1. Data Preparation
data_dir = '/home/stud/fmarchetto/SegmentationTests/Data/kaggle_3m/'  
image_paths = []
mask_paths = []

for case_folder in os.listdir(data_dir):
    case_path = os.path.join(data_dir, case_folder)
    if not os.path.isdir(case_path):
        continue
    for file in os.listdir(case_path):
        if file.endswith('.tif') and not file.endswith('_mask.tif'):
            image_paths.append(os.path.join(case_path, file))
            mask_file = file.replace('.tif', '_mask.tif')
            mask_paths.append(os.path.join(case_path, mask_file))

print(f"Total Images: {len(image_paths)}")
print(f"Total Masks: {len(mask_paths)}")

# Create DataFrame
data = pd.DataFrame({
    "image_path": image_paths,
    "mask_path": mask_paths
})

# Add tumor status
tumor_status = []
for mask_path in tqdm(data["mask_path"], desc="Checking masks"):
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    if mask is None or np.max(mask) == 0:
        tumor_status.append("Non-Tumor")
    else:
        tumor_status.append("Tumor")

data["status"] = tumor_status
print(data["status"].value_counts())

# Plot tumor vs non-tumor count
counts = data["status"].value_counts()
plt.figure(figsize=(6,4))
bars = plt.bar(counts.index, counts.values, color=["red", "gray"])
plt.title("Tumor vs Non-Tumor Image Count", fontsize=14)
plt.xlabel("Category")
plt.ylabel("Number of Images")
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.bar_label(bars, padding=3)
plt.show()

# Preprocessing function (resize to 128x128, normalize)
IMG_SIZE = 128
def process_image(image_path, mask_path, img_size=IMG_SIZE):
    image = cv2.imread(image_path)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    
    image = cv2.resize(image, (img_size, img_size))
    mask = cv2.resize(mask, (img_size, img_size))
    
    image = image / 255.0  # Normalize
    mask = mask / 255.0
    mask = np.where(mask > 0.5, 1, 0)  # Binary
    
    return image.astype(np.float32), mask.astype(np.float32).reshape(img_size, img_size, 1)

# Process all data
images = []
masks = []
for img_path, msk_path in zip(data["image_path"], data["mask_path"]):
    img, msk = process_image(img_path, msk_path)
    images.append(img)
    masks.append(msk)

images = np.array(images)
masks = np.array(masks)

print("Images shape:", images.shape)
print("Masks shape:", masks.shape)

# Split data (80/10/10)
X_train, X_test, y_train, y_test = train_test_split(images, masks, test_size=0.2, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=42)

print("Train:", X_train.shape, y_train.shape)
print("Val:", X_val.shape, y_val.shape)
print("Test:", X_test.shape, y_test.shape)

# Visualize samples (adapted from Kaggle)
def visualize_samples(images, masks, n_samples=4):
    tumor_indices = [i for i in range(len(masks)) if np.any(masks[i] > 0)]
    non_tumor_indices = [i for i in range(len(masks)) if not np.any(masks[i] > 0)]

    tumor_samples = np.random.choice(tumor_indices, min(n_samples, len(tumor_indices)), replace=False)
    non_tumor_samples = np.random.choice(non_tumor_indices, min(n_samples, len(non_tumor_indices)), replace=False)

    fig, axs = plt.subplots(2, n_samples, figsize=(12, 8))
    for col, idx in enumerate(tumor_samples):
        axs[0, col].imshow(images[idx])
        axs[0, col].imshow(masks[idx].squeeze(), cmap="Reds", alpha=0.4)
        axs[0, col].title.set_text("Tumor")
        axs[0, col].axis("off")

    for col, idx in enumerate(non_tumor_samples):
        axs[1, col].imshow(images[idx])
        axs[1, col].imshow(masks[idx].squeeze(), cmap="Reds", alpha=0.3)
        axs[1, col].title.set_text("No Tumor")
        axs[1, col].axis("off")

    plt.tight_layout()
    plt.show()

visualize_samples(X_test, y_test, n_samples=5)

# Augmentation (using Albumentations, PyTorch equivalent of ImageDataGenerator)
data_gen_args = A.Compose([
    A.Rotate(limit=10, p=0.5),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=0, p=0.5),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    ToTensorV2()
])

# DataLoader for PyTorch
class MRI_Dataset(Dataset):
    def __init__(self, images, masks, transform=None):
        self.images = images
        self.masks = masks
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        mask = self.masks[idx]
        
        if self.transform:
            augmented = self.transform(image=img, mask=mask)
            img = augmented['image']  # tensor [C, H, W]
            mask = augmented['mask']  # tensor [H, W] or [C, H, W] for grayscale
        else:
            # FIXED: Manual conversion to tensor if no transform (numpy H W C -> tensor C H W)
            img = torch.from_numpy(img.transpose(2, 0, 1)).float()  # [C, H, W]
            mask = torch.from_numpy(mask.transpose(2, 0, 1)).float()  # [1, H, W]
        
        print(f"Mask shape after processing: {mask.shape}")  # Debug (remove after testing)
        
        return img, mask

# Create loaders
batch_size = 16
train_dataset = MRI_Dataset(X_train, y_train, transform=data_gen_args)
val_dataset = MRI_Dataset(X_val, y_val)  # No transform for val
test_dataset = MRI_Dataset(X_test, y_test)  # No transform for test

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 3. Model Definition (U-Net with ResNet34 backbone from smp, similar to Kaggle's custom U-Net)
#model = smp.Unet(
#    encoder_name="resnet34",  # Pretrained backbone
#    encoder_weights="imagenet",  # Use pretrained weights
#    in_channels=3,  # RGB images
#    classes=1,  # Binary mask
#    activation=None  # For binary output
#)

model = smp.Unet(
    encoder_name="efficientnet-b2",  # Change to this for a different model
    encoder_weights="imagenet",
    in_channels=3,
    classes=1,
    activation=None
)
# Move to GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Loss (BCE + Dice, like Kaggle)
class BCE_Dice_Loss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, pred, target):
        
        if target.shape[1] != 1:  # If target is [batch, H, W, 1], permute to [batch, 1, H, W]
            target = target.permute(0, 3, 1, 2)
        elif target.ndim == 3:  # If [batch, H, W], unsqueeze to [batch, 1, H, W]
            target = target.unsqueeze(1)
        
        bce = self.bce(pred, target)
        pred_sig = torch.sigmoid(pred)
        intersection = (pred_sig * target).sum(dim=(2,3))
        dice = (2 * intersection + 1e-6) / (pred_sig.sum(dim=(2,3)) + target.sum(dim=(2,3)) + 1e-6)
        dice_loss = 1 - dice.mean()
        return bce + dice_loss

# Metrics (Dice and IoU, like Kaggle)
def dice_coef(y_true, y_pred, threshold=0.5, smooth=1):
    y_true = y_true.view(-1)
    y_pred = torch.sigmoid(y_pred)   
    y_pred = (y_pred > threshold).float().view(-1)
    intersection = (y_true * y_pred).sum()
    return (2. * intersection + smooth) / (y_true.sum() + y_pred.sum() + smooth)

def iou_coef(y_true, y_pred, threshold=0.5, smooth=1):
    y_true = y_true.view(-1)
    y_pred = torch.sigmoid(y_pred)   
    y_pred = (y_pred > threshold).float().view(-1)
    intersection = (y_true * y_pred).sum()
    union = y_true.sum() + y_pred.sum() - intersection
    return (intersection + smooth) / (union + smooth)

# ADDED: Accuracy metric (pixel-wise)
def accuracy(y_true, y_pred, threshold=0.5):
    y_true = y_true.view(-1)
    y_pred = torch.sigmoid(y_pred)  # Convert logits to probabilities
    y_pred = (y_pred > threshold).float().view(-1)
    correct = (y_pred == y_true).float()
    return correct.mean()

# Compile (optimizer, loss)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = BCE_Dice_Loss()

# Callbacks (PyTorch equivalents)
class EarlyStopping:
    def __init__(self, patience=8, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = np.inf

    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

early_stop = EarlyStopping(patience=8)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-7)

checkpoint_path = "best_unetEfficentNet2_model.pth"  # Save best model

# Training Loop with WandB Logging
epochs = 50
best_val_loss = np.inf
history = {'train_loss': [], 'val_loss': [], 'val_dice': [], 'val_iou': [], 'train_acc': [], 'val_acc': []}  # ADDED train/val acc to history

for epoch in range(epochs):
    model.train()
    train_loss = 0
    train_acc = 0  # ADDED
    for img, mask in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} - Train"):
        img, mask = img.to(device), mask.to(device)
        optimizer.zero_grad()
        pred = model(img)
        loss = loss_fn(pred, mask)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        train_acc += accuracy(mask, pred).item()  
        
    train_loss /= len(train_loader)
    train_acc /= len(train_loader) 
    
    # Validation
    model.eval()
    val_loss = 0
    val_dice = 0
    val_iou = 0
    val_acc = 0  
    with torch.no_grad():
        for img, mask in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} - Val"):
            img, mask = img.to(device), mask.to(device)
            pred = model(img)
            val_loss += loss_fn(pred, mask).item()
            val_dice += dice_coef(mask, pred).item()
            val_iou += iou_coef(mask, pred).item()
            val_acc += accuracy(mask, pred).item()  
    
    val_loss /= len(val_loader)
    val_dice /= len(val_loader)
    val_iou /= len(val_loader)
    val_acc /= len(val_loader)  
    
    # Log to WandB 
    wandb.log({
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_acc': train_acc,   
        'val_acc': val_acc,       
        'val_dice': val_dice,
        'val_iou': val_iou
    })
    
    print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}, Val Dice = {val_dice:.4f}, Val IoU = {val_iou:.4f}, Train Acc = {train_acc:.4f}, Val Acc = {val_acc:.4f}")
    
    # Scheduler and early stopping
    scheduler.step(val_loss)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), checkpoint_path)
        wandb.save(checkpoint_path)  # Log artifact to WandB
        print("Saved best model!")
    
    if early_stop(val_loss):
        print("Early stopping!")
        break

# Finish WandB run
wandb.finish()

# Load best model for evaluation
model.load_state_dict(torch.load(checkpoint_path))
model.eval()

# Test Evaluation 
test_loss = 0
test_dice = 0
test_iou = 0
test_acc = 0  # ADDED
with torch.no_grad():
    for img, mask in tqdm(test_loader, desc="Test"):
        img, mask = img.to(device), mask.to(device)
        pred = model(img)
        test_loss += loss_fn(pred, mask).item()
        test_dice += dice_coef(mask, pred).item()
        test_iou += iou_coef(mask, pred).item()
        test_acc += accuracy(mask, pred).item()  

test_loss /= len(test_loader)
test_dice /= len(test_loader)
test_iou /= len(test_loader)
test_acc /= len(test_loader)  

print(f"Test Loss: {test_loss:.4f}")
print(f"Test Dice Coefficient: {test_dice:.4f}")
print(f"Test IoU Coefficient: {test_iou:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")  

# Visualization (adapted from Kaggle)
def visualize_predictions(model, images, masks, num_samples=5):
    model.eval()
    indices = np.random.choice(len(images), num_samples)
    fig, axs = plt.subplots(num_samples, 3, figsize=(12, num_samples * 3))
    
    with torch.no_grad():
        for row, idx in enumerate(indices):
            img = torch.from_numpy(images[idx]).permute(2, 0, 1).unsqueeze(0).to(device)  # [1, 3, H, W]
            true_mask = masks[idx].squeeze()
            pred = model(img)
            pred_mask = (pred.squeeze().cpu().numpy() > 0.5).astype(np.uint8)
            
            axs[row, 0].imshow(images[idx])
            axs[row, 0].set_title("MRI")
            axs[row, 0].axis("off")
            
            axs[row, 1].imshow(true_mask, cmap='gray')
            axs[row, 1].set_title("True Mask")
            axs[row, 1].axis("off")
            
            axs[row, 2].imshow(pred_mask, cmap='gray')
            axs[row, 2].set_title("Predicted Mask")
            axs[row, 2].axis("off")
    
    plt.tight_layout()
    plt.show()

visualize_predictions(model, X_test, y_test, num_samples=6)

# Optional
def tumor_detection_accuracy(model, images, masks, threshold=0.5):
    correct = 0
    total = len(images)
    model.eval()
    with torch.no_grad():
        for i in range(total):
            img = torch.from_numpy(images[i]).permute(2, 0, 1).unsqueeze(0).to(device)
            pred = model(img)
            pred_has_tumor = np.any(pred.cpu().numpy() > threshold)
            true_has_tumor = np.any(masks[i] > 0)
            if pred_has_tumor == true_has_tumor:
                correct += 1
    return correct / total * 100

detection_acc = tumor_detection_accuracy(model, X_test, y_test)
print(f"Tumor Detection Accuracy: {detection_acc:.2f}%")

# Plot Training History
ef = pd.DataFrame(history)
ef[['train_loss', 'val_loss']].plot(title="Loss")
ef[['val_dice', 'val_iou', 'val_acc']].plot(title="Metrics")  
plt.show()