# HRF Segmentation U-Net Model Evaluation

This notebook evaluates the trained U-Net and Attention U-Net models on HRF (Hyperreflective Foci) segmentation in retinal OCT images.

**Dataset**: 435 OCT images with expert-annotated HRF masks

**Models**: U-Net and Attention U-Net (PyTorch)

**Authors**: Pavithra Kodiyalbail Chakrapani, Preetham Kumar, Sulatha V Bhandary, Geetha Maiya, Shailaja S, Steven Fernandes, Prakhar Choudhary

## 1. Setup and Installation

In [None]:
# Install required packages
!pip install torch torchvision tqdm scikit-learn matplotlib seaborn opencv-python-headless pillow

In [None]:
# Import libraries
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from glob import glob
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import (
    confusion_matrix, accuracy_score, precision_score, recall_score,
    f1_score, jaccard_score, roc_curve, auc, roc_auc_score
)
from PIL import Image

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 2. Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 3. Configuration

**Update these paths according to your Google Drive structure**

In [None]:
# IMPORTANT: Update these paths to match your Google Drive structure
BASE_DIR = '/content/drive/MyDrive/HRF-DATASET'
IMAGES_DIR = os.path.join(BASE_DIR, 'HRF_IMAGES')
MASKS_DIR = os.path.join(BASE_DIR, 'HRF_MASKS')
UNET_MODEL_PATH = os.path.join(BASE_DIR, 'unet_best_model.pth')
AUNET_MODEL_PATH = os.path.join(BASE_DIR, 'aunet_best_model.pth')

# Output directory for results
OUTPUT_DIR = '/content/evaluation_results'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Device configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

## 4. Model Definitions

In [None]:
# U-Net Building Blocks
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels_deeper, in_channels_skip, out_channels, bilinear=False):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels_skip + in_channels_deeper, out_channels)
        else:
            self.up = nn.ConvTranspose2d(in_channels_deeper, in_channels_deeper // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels_skip + in_channels_deeper // 2, out_channels)

    def forward(self, x_deeper, x_skip):
        x_deeper = self.up(x_deeper)
        diffY = x_skip.size()[2] - x_deeper.size()[2]
        diffX = x_skip.size()[3] - x_deeper.size()[3]
        x_deeper = F.pad(x_deeper, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x_skip, x_deeper], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [None]:
# U-Net Model
class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1, bilinear=False, base_filters=64):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.base_filters = base_filters

        self.inc = DoubleConv(n_channels, base_filters)
        self.down1 = Down(base_filters, base_filters * 2)
        self.down2 = Down(base_filters * 2, base_filters * 4)
        self.down3 = Down(base_filters * 4, base_filters * 8)
        self.down4 = Down(base_filters * 8, base_filters * 16)

        self.up4 = Up(base_filters * 16, base_filters * 8, base_filters * 8, bilinear)
        self.up3 = Up(base_filters * 8, base_filters * 4, base_filters * 4, bilinear)
        self.up2 = Up(base_filters * 4, base_filters * 2, base_filters * 2, bilinear)
        self.up1 = Up(base_filters * 2, base_filters, base_filters, bilinear)
        
        self.outc = OutConv(base_filters, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up4(x5, x4)
        x = self.up3(x, x3)
        x = self.up2(x, x2)
        x = self.up1(x, x1)
        
        logits = self.outc(x)
        return logits

In [None]:
# Attention U-Net Components
class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        if g1.size()[2:] != x1.size()[2:]:
            g1 = F.interpolate(g1, size=x1.size()[2:], mode='bilinear', align_corners=True)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class UpAttention(nn.Module):
    def __init__(self, in_channels_deeper, in_channels_skip, out_channels, bilinear=False):
        super().__init__()
        self.attention = AttentionBlock(F_g=in_channels_deeper, F_l=in_channels_skip, F_int=in_channels_skip // 2)
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels_skip + in_channels_deeper, out_channels)
        else:
            self.up = nn.ConvTranspose2d(in_channels_deeper, in_channels_deeper // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels_skip + in_channels_deeper // 2, out_channels)
            
    def forward(self, x_deeper, x_skip):
        x_upsampled = self.up(x_deeper)
        diffY = x_skip.size()[2] - x_upsampled.size()[2]
        diffX = x_skip.size()[3] - x_upsampled.size()[3]
        x_upsampled = F.pad(x_upsampled, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x_skip_attended = self.attention(g=x_deeper, x=x_skip)
        x = torch.cat([x_skip_attended, x_upsampled], dim=1)
        return self.conv(x)

class AttentionUNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1, bilinear=False, base_filters=64):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.base_filters = base_filters

        self.inc = DoubleConv(n_channels, base_filters)
        self.down1 = Down(base_filters, base_filters * 2)
        self.down2 = Down(base_filters * 2, base_filters * 4)
        self.down3 = Down(base_filters * 4, base_filters * 8)
        self.down4 = Down(base_filters * 8, base_filters * 16)

        self.up4 = UpAttention(base_filters * 16, base_filters * 8, base_filters * 8, bilinear)
        self.up3 = UpAttention(base_filters * 8, base_filters * 4, base_filters * 4, bilinear)
        self.up2 = UpAttention(base_filters * 4, base_filters * 2, base_filters * 2, bilinear)
        self.up1 = UpAttention(base_filters * 2, base_filters, base_filters, bilinear)
        
        self.outc = OutConv(base_filters, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up4(x5, x4)
        x = self.up3(x, x3)
        x = self.up2(x, x2)
        x = self.up1(x, x1)
        
        logits = self.outc(x)
        return logits

## 5. Load Models

In [None]:
def load_model(model_type, checkpoint_path, device):
    """Load a trained model from checkpoint"""
    if model_type == 'unet':
        model = UNet(n_channels=3, n_classes=1, bilinear=False, base_filters=64)
    elif model_type == 'attention_unet':
        model = AttentionUNet(n_channels=3, n_classes=1, bilinear=False, base_filters=64)
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    print(f"Loading {model_type} from {checkpoint_path}...")
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    print(f"{model_type} loaded successfully!")
    
    return model

# Load both models
print("="*60)
unet = load_model('unet', UNET_MODEL_PATH, DEVICE)
print("="*60)
aunet = load_model('attention_unet', AUNET_MODEL_PATH, DEVICE)
print("="*60)

## 6. Data Loading Functions

In [None]:
def load_image(image_path):
    """Load and preprocess an image"""
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def load_mask(mask_path):
    """Load and preprocess a mask"""
    from PIL import Image
    # Load .ome.tiff mask
    mask = Image.open(mask_path)
    mask = np.array(mask)
    if len(mask.shape) == 3:
        mask = mask[:, :, 0]  # Take first channel if RGB
    return mask

def apply_clahe(image: np.ndarray) -> np.ndarray:
    """Apply CLAHE preprocessing (for Attention U-Net)"""
    lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
    l, a, b = cv2.split(lab)
    
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    cl = clahe.apply(l)
    
    limg = cv2.merge((cl, a, b))
    final = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB)
    return final

def preprocess_for_model(image, mask, use_clahe=False):
    """
    Convert image and mask to tensors
    
    Args:
        image: Input image (RGB numpy array)
        mask: Ground truth mask
        use_clahe: If True, apply CLAHE + Z-score normalization (for Attention U-Net)
                  If False, apply simple 0-255 normalization (for U-Net)
    """
    # Apply CLAHE for Attention U-Net
    if use_clahe:
        image = apply_clahe(image)
    
    # Normalize to [0, 1]
    image = image.astype(np.float32) / 255.0
    
    # Apply Z-Score normalization for Attention U-Net (ImageNet stats)
    if use_clahe:
        mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
        std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
        image = (image - mean) / std
    
    # Create tensor and ensure float32
    image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float()
    
    # Normalize mask to binary
    mask_tensor = torch.from_numpy(mask).float()
    mask_tensor = (mask_tensor > 0).float()  # Binary mask
    
    return image_tensor, mask_tensor

# Get all image and mask paths
image_paths = sorted(glob(os.path.join(IMAGES_DIR, '*.jpeg')))
mask_paths = []
for img_path in image_paths:
    img_name = os.path.splitext(os.path.basename(img_path))[0]
    mask_path = os.path.join(MASKS_DIR, f"{img_name}_HRF.ome.tiff")
    if os.path.exists(mask_path):
        mask_paths.append(mask_path)
    else:
        print(f"Warning: Mask not found for {img_name}")

print(f"Found {len(image_paths)} images")
print(f"Found {len(mask_paths)} masks")

# Use 15% as test set (matching the training split)
test_split = int(0.15 * len(image_paths))
test_image_paths = image_paths[-test_split:]
test_mask_paths = mask_paths[-test_split:]

print(f"\nUsing {len(test_image_paths)} images for testing")

## 7. Run Inference

In [None]:
def run_inference(model, image_tensor, device):
    """Run inference on a single image"""
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        output = model(image_tensor)
        prob = torch.sigmoid(output).cpu().squeeze().numpy()
    return prob

# Store predictions and ground truth
unet_predictions = []
aunet_predictions = []
ground_truth = []
test_images = []

print("Running inference on test set...")
for img_path, mask_path in tqdm(zip(test_image_paths, test_mask_paths), total=len(test_image_paths)):
    # Load data
    image = load_image(img_path)
    mask = load_mask(mask_path)
    
    # Preprocess
    image_tensor, mask_tensor = preprocess_for_model(image, mask)
    
    # Run inference
    unet_pred = run_inference(unet, image_tensor, DEVICE)
    aunet_pred = run_inference(aunet, image_tensor, DEVICE)
    
    # Store results
    unet_predictions.append(unet_pred)
    aunet_predictions.append(aunet_pred)
    ground_truth.append(mask_tensor.numpy())
    test_images.append(image)

print("\nInference completed!")

## 8. Calculate Metrics

In [None]:
def calculate_metrics(y_true_list, y_pred_list, threshold=0.5):
    """Calculate segmentation metrics"""
    # Flatten all predictions and ground truth
    y_true_flat = np.concatenate([y.flatten() for y in y_true_list])
    y_pred_flat = np.concatenate([y.flatten() for y in y_pred_list])
    y_pred_binary_flat = (y_pred_flat > threshold).astype(np.float32)
    
    # Calculate metrics
    accuracy = accuracy_score(y_true_flat, y_pred_binary_flat)
    precision = precision_score(y_true_flat, y_pred_binary_flat, zero_division=0)
    recall = recall_score(y_true_flat, y_pred_binary_flat, zero_division=0)
    f1 = f1_score(y_true_flat, y_pred_binary_flat, zero_division=0)
    iou = jaccard_score(y_true_flat, y_pred_binary_flat, zero_division=0)
    
    # Dice coefficient
    dice = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    # Specificity
    cm = confusion_matrix(y_true_flat, y_pred_binary_flat)
    tn, fp, fn, tp = cm.ravel()
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    
    # AUC-ROC
    try:
        auc_score = roc_auc_score(y_true_flat, y_pred_flat)
    except:
        auc_score = 0.0
    
    metrics = {
        'Dice': dice,
        'IoU': iou,
        'Precision': precision,
        'Recall': recall,
        'F1': f1,
        'Specificity': specificity,
        'Jaccard': iou,
        'AUC': auc_score,
        'TP': int(tp),
        'TN': int(tn),
        'FP': int(fp),
        'FN': int(fn)
    }
    
    return metrics, cm

# Calculate metrics for both models
print("Calculating metrics for U-Net...")
unet_metrics, unet_cm = calculate_metrics(ground_truth, unet_predictions)

print("Calculating metrics for Attention U-Net...")
aunet_metrics, aunet_cm = calculate_metrics(ground_truth, aunet_predictions)

# Display results
print("\n" + "="*70)
print("EVALUATION RESULTS")
print("="*70)
print(f"{'Metric':<20} {'U-Net':<20} {'Attention U-Net':<20}")
print("-"*70)
for key in ['Dice', 'IoU', 'Precision', 'Recall', 'F1', 'Specificity', 'AUC']:
    print(f"{key:<20} {unet_metrics[key]:<20.4f} {aunet_metrics[key]:<20.4f}")
print("-"*70)
print(f"\nU-Net Confusion Matrix: TP={unet_metrics['TP']}, TN={unet_metrics['TN']}, FP={unet_metrics['FP']}, FN={unet_metrics['FN']}")
print(f"Attention U-Net Confusion Matrix: TP={aunet_metrics['TP']}, TN={aunet_metrics['TN']}, FP={aunet_metrics['FP']}, FN={aunet_metrics['FN']}")
print("="*70)

## 9. Confusion Matrix Visualization

In [None]:
# Plot confusion matrices
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# U-Net
sns.heatmap(unet_cm, annot=True, fmt='d', cmap='Blues', ax=axes[0], cbar=True)
axes[0].set_title(f'U-Net Confusion Matrix\n(AUC = {unet_metrics["AUC"]:.4f})', 
                  fontsize=14, fontweight='bold')
axes[0].set_ylabel('True Label', fontsize=12)
axes[0].set_xlabel('Predicted Label', fontsize=12)
axes[0].set_xticklabels(['Background', 'HRF'])
axes[0].set_yticklabels(['Background', 'HRF'])

# Attention U-Net
sns.heatmap(aunet_cm, annot=True, fmt='d', cmap='Greens', ax=axes[1], cbar=True)
axes[1].set_title(f'Attention U-Net Confusion Matrix\n(AUC = {aunet_metrics["AUC"]:.4f})', 
                  fontsize=14, fontweight='bold')
axes[1].set_ylabel('True Label', fontsize=12)
axes[1].set_xlabel('Predicted Label', fontsize=12)
axes[1].set_xticklabels(['Background', 'HRF'])
axes[1].set_yticklabels(['Background', 'HRF'])

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'confusion_matrices.png'), dpi=300, bbox_inches='tight')
plt.show()

## 10. ROC Curve

In [None]:
# Calculate ROC curves
y_true_flat = np.concatenate([y.flatten() for y in ground_truth])
unet_pred_flat = np.concatenate([y.flatten() for y in unet_predictions])
aunet_pred_flat = np.concatenate([y.flatten() for y in aunet_predictions])

unet_fpr, unet_tpr, _ = roc_curve(y_true_flat, unet_pred_flat)
aunet_fpr, aunet_tpr, _ = roc_curve(y_true_flat, aunet_pred_flat)

unet_auc = auc(unet_fpr, unet_tpr)
aunet_auc = auc(aunet_fpr, aunet_tpr)

# Plot ROC curves
plt.figure(figsize=(12, 8))
plt.plot(unet_fpr, unet_tpr, color='blue', lw=2, 
         label=f'U-Net (AUC = {unet_auc:.4f})')
plt.plot(aunet_fpr, aunet_tpr, color='green', lw=2, 
         label=f'Attention U-Net (AUC = {aunet_auc:.4f})')
plt.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--', label='Random')

plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=14)
plt.ylabel('True Positive Rate', fontsize=14)
plt.title('ROC Curves - HRF Segmentation', fontsize=16, fontweight='bold')
plt.legend(loc="lower right", fontsize=12)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'roc_curves.png'), dpi=300, bbox_inches='tight')
plt.show()

## 11. Metrics Comparison Bar Plot

In [None]:
# Metrics comparison
metrics_to_plot = ['Dice', 'IoU', 'Precision', 'Recall', 'F1', 'Specificity', 'AUC']
unet_values = [unet_metrics[m] for m in metrics_to_plot]
aunet_values = [aunet_metrics[m] for m in metrics_to_plot]

x = np.arange(len(metrics_to_plot))
width = 0.35

fig, ax = plt.subplots(figsize=(14, 6))
bars1 = ax.bar(x - width/2, unet_values, width, label='U-Net', color='steelblue', edgecolor='black')
bars2 = ax.bar(x + width/2, aunet_values, width, label='Attention U-Net', color='mediumseagreen', edgecolor='black')

# Add value labels
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.3f}', ha='center', va='bottom', fontsize=9)

ax.set_xlabel('Metrics', fontsize=12, fontweight='bold')
ax.set_ylabel('Score', fontsize=12, fontweight='bold')
ax.set_title('Model Performance Comparison', fontsize=16, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(metrics_to_plot)
ax.set_ylim([0, 1.1])
ax.legend(fontsize=12)
ax.grid(True, axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'metrics_comparison.png'), dpi=300, bbox_inches='tight')
plt.show()

## 12. Prediction Visualizations

In [None]:
# Visualize predictions on sample images
num_samples = min(5, len(test_images))
fig, axes = plt.subplots(num_samples, 5, figsize=(20, 4*num_samples))

for i in range(num_samples):
    # Original image
    axes[i, 0].imshow(test_images[i])
    axes[i, 0].set_title('Original Image', fontsize=12, fontweight='bold')
    axes[i, 0].axis('off')
    
    # Ground truth
    axes[i, 1].imshow(ground_truth[i], cmap='gray')
    axes[i, 1].set_title('Ground Truth', fontsize=12, fontweight='bold')
    axes[i, 1].axis('off')
    
    # U-Net prediction
    axes[i, 2].imshow(unet_predictions[i], cmap='jet', vmin=0, vmax=1)
    axes[i, 2].set_title('U-Net Prediction', fontsize=12, fontweight='bold')
    axes[i, 2].axis('off')
    
    # Attention U-Net prediction
    axes[i, 3].imshow(aunet_predictions[i], cmap='jet', vmin=0, vmax=1)
    axes[i, 3].set_title('Attention U-Net Prediction', fontsize=12, fontweight='bold')
    axes[i, 3].axis('off')
    
    # Overlay (U-Net)
    overlay = test_images[i].copy()
    mask_binary = (unet_predictions[i] > 0.5).astype(np.uint8)
    overlay[mask_binary > 0, 0] = np.minimum(overlay[mask_binary > 0, 0] + 100, 255)
    axes[i, 4].imshow(overlay)
    axes[i, 4].set_title('U-Net Overlay', fontsize=12, fontweight='bold')
    axes[i, 4].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'predictions_visualization.png'), dpi=300, bbox_inches='tight')
plt.show()

## 13. Summary Report

In [None]:
# Generate and save summary report
summary = f"""
{'='*80}
HRF SEGMENTATION U-NET EVALUATION SUMMARY
{'='*80}

Dataset: {len(test_image_paths)} test images
Models: U-Net and Attention U-Net

U-NET METRICS:
{'-'*80}
Dice Coefficient:  {unet_metrics['Dice']:.4f}
IoU (Jaccard):     {unet_metrics['IoU']:.4f}
Precision:         {unet_metrics['Precision']:.4f}
Recall:            {unet_metrics['Recall']:.4f}
F1-Score:          {unet_metrics['F1']:.4f}
Specificity:       {unet_metrics['Specificity']:.4f}
AUC:               {unet_metrics['AUC']:.4f}

Confusion Matrix:
  True Positives:  {unet_metrics['TP']:,}
  True Negatives:  {unet_metrics['TN']:,}
  False Positives: {unet_metrics['FP']:,}
  False Negatives: {unet_metrics['FN']:,}

ATTENTION U-NET METRICS:
{'-'*80}
Dice Coefficient:  {aunet_metrics['Dice']:.4f}
IoU (Jaccard):     {aunet_metrics['IoU']:.4f}
Precision:         {aunet_metrics['Precision']:.4f}
Recall:            {aunet_metrics['Recall']:.4f}
F1-Score:          {aunet_metrics['F1']:.4f}
Specificity:       {aunet_metrics['Specificity']:.4f}
AUC:               {aunet_metrics['AUC']:.4f}

Confusion Matrix:
  True Positives:  {aunet_metrics['TP']:,}
  True Negatives:  {aunet_metrics['TN']:,}
  False Positives: {aunet_metrics['FP']:,}
  False Negatives: {aunet_metrics['FN']:,}

{'='*80}
Generated Files:
  - confusion_matrices.png
  - roc_curves.png
  - metrics_comparison.png
  - predictions_visualization.png
  - evaluation_summary.txt
{'='*80}
"""

print(summary)

# Save to file
with open(os.path.join(OUTPUT_DIR, 'evaluation_summary.txt'), 'w') as f:
    f.write(summary)

print(f"\nAll results saved to: {OUTPUT_DIR}")

## 14. Download Results

In [None]:
# Zip results for download
import shutil
from google.colab import files

zip_path = '/content/hrf_evaluation_results'
shutil.make_archive(zip_path, 'zip', OUTPUT_DIR)
print(f"Results packaged as: {zip_path}.zip")

# Download
files.download(f'{zip_path}.zip')
print("Download started!")