In [None]:
# Import Necessary Libraries

import os
import h5py
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras import layers, Model, backend as K

In [None]:
# Data Loading

DATA_PATH = "/kaggle/input/brats2020-training-data/BraTS2020_training_data/content/data"
MAX_FILES = 2000

def safe_load_h5(path):
    try:
        with h5py.File(path, 'r') as f:
            img = np.array(f['image']).astype(np.float32)
            msk = np.array(f['mask']).astype(np.float32)
            return img, msk
    except Exception as e:
        return None, None

print("Loading dataset...")
files = sorted([f for f in os.listdir(DATA_PATH) if f.endswith('.h5')])[:MAX_FILES]

images = []
masks = []

for fname in tqdm(files, desc="Loading files"):
    p = os.path.join(DATA_PATH, fname)
    img, msk = safe_load_h5(p)
    if img is None:
        continue
        
    if img.ndim == 2:
        img = np.expand_dims(img, -1)
    if msk.ndim == 2:
        msk = np.expand_dims(msk, -1)

    # Normalize image
    p99 = np.percentile(img, 99)
    img_norm = img / p99 if p99 > 0 else img
    img_norm = np.clip(img_norm, 0.0, 1.0)

    # Mask processing
    msk = np.clip(msk, 0.0, 1.0)

    images.append(img_norm)
    masks.append(msk)

X = np.array(images, dtype=np.float32)
Y = np.array(masks, dtype=np.float32)
print(f"Loaded shapes -> X: {X.shape}, Y: {Y.shape}")

In [None]:
# Loss Functions

def dice_coef(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

def combined_loss(y_true, y_pred):
    bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    bce = K.mean(bce)
    d_loss = dice_loss(y_true, y_pred)
    return bce + d_loss

In [None]:
# 3-Channel U-Net

def build_simple_multi_unet(input_shape=(240, 240, 4), num_classes=3):
    """U-Net that works well"""
    inputs = tf.keras.Input(shape=input_shape)
    
    # Downsample
    c1 = layers.Conv2D(32, 3, activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(32, 3, activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D(2)(c1)
    
    c2 = layers.Conv2D(64, 3, activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(64, 3, activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D(2)(c2)
    
    # Bottleneck
    c3 = layers.Conv2D(128, 3, activation='relu', padding='same')(p2)
    c3 = layers.Conv2D(128, 3, activation='relu', padding='same')(c3)
    
    # Upsample
    u1 = layers.UpSampling2D(2)(c3)
    u1 = layers.concatenate([u1, c2])
    c4 = layers.Conv2D(64, 3, activation='relu', padding='same')(u1)
    c4 = layers.Conv2D(64, 3, activation='relu', padding='same')(c4)
    
    u2 = layers.UpSampling2D(2)(c4)
    u2 = layers.concatenate([u2, c1])
    c5 = layers.Conv2D(32, 3, activation='relu', padding='same')(u2)
    c5 = layers.Conv2D(32, 3, activation='relu', padding='same')(c5)
    
    # 3 output channels with sigmoid (independent segmentation)
    outputs = layers.Conv2D(num_classes, 1, activation='sigmoid')(c5)
    
    model = tf.keras.Model(inputs, outputs)
    return model

In [None]:
# Train/Val Split

X_train, X_val, y_train, y_val = train_test_split(X, Y, test_size=0.18, random_state=42)
print(f"Train: {X_train.shape}, {y_train.shape}")
print(f"Val: {X_val.shape}, {y_val.shape}")

In [None]:
# Build & Compile Model

print("Building 3-channel U-Net...")
model = build_simple_multi_unet(input_shape=X_train.shape[1:], num_classes=3)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=['accuracy', dice_coef]
)

model.summary()

In [None]:
# Callbacks

callbacks = [
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_dice_coef', factor=0.5, patience=10, min_lr=1e-7, mode='max'),
    tf.keras.callbacks.EarlyStopping(monitor='val_dice_coef', patience=20, restore_best_weights=True, mode='max'),
    tf.keras.callbacks.ModelCheckpoint('best_multi_model.h5', monitor='val_dice_coef', save_best_only=True, mode='max')
]

In [None]:
# Training

print("Starting training with 3-channel model...")
history = model.fit(
    X_train, y_train,
    batch_size=8,
    epochs=100,  
    validation_data=(X_val, y_val),
    callbacks=callbacks,
    shuffle=True,
    verbose=1
)

In [None]:
# Evaluation Matrices

from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# Enhanced metrics functions
def iou_coef(y_true, y_pred, smooth=1e-6):
    """Intersection over Union (Jaccard Index)"""
    intersection = tf.reduce_sum(tf.abs(y_true * y_pred), axis=[1,2,3])
    union = tf.reduce_sum(y_true, axis=[1,2,3]) + tf.reduce_sum(y_pred, axis=[1,2,3]) - intersection
    iou = tf.reduce_mean((intersection + smooth) / (union + smooth))
    return iou

def precision_metric(y_true, y_pred):
    """Precision metric"""
    true_positives = tf.reduce_sum(tf.round(tf.clip_by_value(y_true * y_pred, 0, 1)))
    predicted_positives = tf.reduce_sum(tf.round(tf.clip_by_value(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + tf.keras.backend.epsilon())
    return precision

def recall_metric(y_true, y_pred):
    """Recall metric (Sensitivity)"""
    true_positives = tf.reduce_sum(tf.round(tf.clip_by_value(y_true * y_pred, 0, 1)))
    possible_positives = tf.reduce_sum(tf.round(tf.clip_by_value(y_true, 0, 1)))
    recall = true_positives / (possible_positives + tf.keras.backend.epsilon())
    return recall

def specificity_metric(y_true, y_pred):
    """Specificity metric"""
    true_negatives = tf.reduce_sum(tf.round(tf.clip_by_value((1-y_true) * (1-y_pred), 0, 1)))
    possible_negatives = tf.reduce_sum(tf.round(tf.clip_by_value(1-y_true, 0, 1)))
    specificity = true_negatives / (possible_negatives + tf.keras.backend.epsilon())
    return specificity

def f1_score(y_true, y_pred):
    """F1-Score metric"""
    precision = precision_metric(y_true, y_pred)
    recall = recall_metric(y_true, y_pred)
    return 2 * ((precision * recall) / (precision + recall + tf.keras.backend.epsilon()))

# Enhanced evaluation function
def comprehensive_evaluation(model, X_val, y_val, threshold=0.5):
    """Comprehensive evaluation with medical segmentation metrics"""
    
    print("=" * 70)
    print("COMPREHENSIVE MODEL EVALUATION - MEDICAL SEGMENTATION")
    print("=" * 70)
    
    # Get predictions
    y_pred = model.predict(X_val, verbose=0)
    y_pred_binary = (y_pred > threshold).astype(np.float32)
    
    # Flatten for overall metrics
    y_true_flat = y_val.flatten()
    y_pred_flat = y_pred_binary.flatten()
    
    # Calculate metrics
    val_loss = model.evaluate(X_val, y_val, verbose=0)[0]
    
    # Overall metrics
    overall_dice = dice_coef(
        tf.convert_to_tensor(y_val), 
        tf.convert_to_tensor(y_pred)
    ).numpy()
    
    overall_iou = iou_coef(
        tf.convert_to_tensor(y_val), 
        tf.convert_to_tensor(y_pred_binary)
    ).numpy()
    
    # Confusion matrix
    cm = confusion_matrix(y_true_flat, y_pred_flat)
    tn, fp, fn, tp = cm.ravel()
    
    # Calculate rates
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    print(f"\nOVERALL PERFORMANCE METRICS:")
    print(f"   Loss: {val_loss:.4f}")
    print(f"   Dice Coefficient: {overall_dice:.4f}")
    print(f"   IoU (Jaccard): {overall_iou:.4f}")
    print(f"   F1-Score: {f1:.4f}")
    
    print(f"\nCLASSIFICATION METRICS:")
    print(f"   Accuracy: {accuracy:.4f}")
    print(f"   Precision: {precision:.4f}")
    print(f"   Recall (Sensitivity): {recall:.4f}") 
    print(f"   Specificity: {specificity:.4f}")
    
    print(f"\nCONFUSION MATRIX (Pixel-wise):")
    print(f"   True Negatives: {tn:,}")
    print(f"   False Positives: {fp:,}")
    print(f"   False Negatives: {fn:,}")
    print(f"   True Positives: {tp:,}")

    return {
        'loss': val_loss,
        'dice': overall_dice,
        'iou': overall_iou,
        'f1_score': f1,
        'precision': precision,
        'recall': recall,
        'specificity': specificity
    }


print("PERFORMING COMPREHENSIVE EVALUATION...")

# Load best model with enhanced metrics
print("Loading best model with enhanced metrics...")
best_model = tf.keras.models.load_model('best_multi_model.h5', 
    custom_objects={
        'combined_loss': combined_loss,
        'dice_coef': dice_coef,
        'dice_loss': dice_loss,
        'iou_coef': iou_coef,
        'precision_metric': precision_metric,
        'recall_metric': recall_metric,
        'specificity_metric': specificity_metric,
        'f1_score': f1_score
    }
)

# Run comprehensive evaluation
results = comprehensive_evaluation(best_model, X_val, y_val)

In [None]:
# ==================== SIMPLE SIDE-BY-SIDE VISUALIZATION ====================

def visualize_comparison(model, X_data, y_true, num_samples=3):
    """Simple side-by-side: MRI vs Ground Truth vs Prediction"""
    
    indices = np.random.choice(len(X_data), num_samples, replace=False)
    
    for i, idx in enumerate(indices):
        # Get data
        X_sample = X_data[idx]
        y_true_sample = y_true[idx]
        y_pred = model.predict(X_sample[np.newaxis, ...], verbose=0)[0]
        y_pred_binary = (y_pred > 0.5).astype(np.float32)
        
        # Use FLAIR channel as background
        background = X_sample[:, :, 0]
        
        # Create figure
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # 1. Original MRI
        axes[0].imshow(background, cmap='gray')
        axes[0].set_title('Input MRI')
        axes[0].axis('off')
        
        # 2. Ground Truth (combine all tumor regions)
        axes[1].imshow(background, cmap='gray')
        # Create binary mask of all tumor regions
        gt_combined = np.any(y_true_sample > 0, axis=-1)
        axes[1].imshow(gt_combined, cmap='Reds', alpha=0.5)
        axes[1].set_title('Ground Truth')
        axes[1].axis('off')
        
        # 3. Prediction (combine all tumor regions)
        axes[2].imshow(background, cmap='gray')
        pred_combined = np.any(y_pred_binary > 0, axis=-1)
        axes[2].imshow(pred_combined, cmap='Reds', alpha=0.5)
        
        # Calculate Dice score
        dice = dice_coef(
            tf.convert_to_tensor(y_true_sample[np.newaxis, ...]),
            tf.convert_to_tensor(y_pred[np.newaxis, ...])
        ).numpy()
        
        axes[2].set_title(f'Prediction\nDice: {dice:.3f}')
        axes[2].axis('off')
        
        plt.suptitle(f'Sample {i+1}', fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.show()

# Execute
print("Generating side-by-side visualizations...")
visualize_comparison(best_model, X_val, y_val, num_samples=3)