# Diabetic Retinopathy Detection Pipeline
## IET Codefest 2025 - Complete ML Pipeline

This notebook implements a comprehensive diabetic retinopathy detection system using PyTorch and ResNet50.

**Pipeline Overview:**
1. Dataset Understanding & Label Cleaning
2. Exploratory Data Analysis (EDA)
3. Preprocessing & Augmentation
4. Model Training (Two-phase ResNet50)
5. Explainability (Grad-CAM)
6. Model Export (ONNX for Next.js)
7. Comprehensive Evaluation

## 📦 Setup and Imports

In [None]:
# Install required packages
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install timm albumentations opencv-python-headless
!pip install pytorch-grad-cam onnx onnxruntime scikit-learn
!pip install matplotlib seaborn plotly pandas numpy
!pip install efficientnet-pytorch

In [None]:
import os
import json
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from collections import Counter
import re

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms
from torchvision import models
import timm

# Image processing
import cv2
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

# ML utilities
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    classification_report, confusion_matrix, roc_curve, auc,
    precision_recall_curve, f1_score, accuracy_score
)
from sklearn.utils.class_weight import compute_class_weight

# Explainability
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

# ONNX export
import onnx
import onnxruntime as ort

warnings.filterwarnings('ignore')
plt.style.use('default')
sns.set_palette("husl")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## ⚙️ Configuration

In [None]:
# Configuration
CONFIG = {
    'DATA_PATH': '/kaggle/input',  # Update this path to your dataset location
    'LABELS_FILE': 'labels.csv',  # Update this to your labels file name
    'IMAGE_SIZE': 224,
    'BATCH_SIZE': 32,
    'NUM_EPOCHS': 50,
    'LEARNING_RATE': 1e-4,
    'WEIGHT_DECAY': 1e-5,
    'NUM_CLASSES': 5,
    'MODEL_NAME': 'resnet50',
    'PATIENCE': 10,
    'MIN_DELTA': 0.001,
    'SAVE_PATH': './models/',
    'RANDOM_SEED': 42
}

# Create directories
os.makedirs(CONFIG['SAVE_PATH'], exist_ok=True)
os.makedirs('./outputs', exist_ok=True)

print("Configuration:")
for key, value in CONFIG.items():
    print(f"{key}: {value}")

## 1️⃣ Dataset Understanding & Label Cleaning

First, we'll load the labels.csv file and clean the inconsistent labels into 5 standardized classes:
- 0 = No_DR (No Diabetic Retinopathy)
- 1 = Mild
- 2 = Moderate
- 3 = Severe
- 4 = Proliferative_DR

In [None]:
def clean_labels(df):
    """
    Clean and normalize inconsistent labels into 5 standardized classes
    """
    # Create a copy to avoid modifying original
    df_clean = df.copy()
    
    # Convert to string and strip whitespace
    df_clean['label'] = df_clean['label'].astype(str).str.strip()
    
    # Define mapping for various label formats
    label_mapping = {
        # Numeric labels
        '0': 0, '00': 0, '0.0': 0,
        '1': 1, '01': 1, '1.0': 1,
        '2': 2, '02': 2, '2.0': 2,
        '3': 3, '03': 3, '3.0': 3,
        '4': 4, '04': 4, '4.0': 4,
        
        # Text labels (case insensitive)
        'NO_DR': 0, 'No_DR': 0, 'no_dr': 0, 'No DR': 0, 'no dr': 0,
        'MILD': 1, 'Mild': 1, 'mild': 1,
        'MODERATE': 2, 'Moderate': 2, 'moderate': 2,
        'SEVERE': 3, 'Severe': 3, 'severe': 3,
        'PROLIFERATIVE_DR': 4, 'Proliferative_DR': 4, 'proliferative_dr': 4,
        'PROLIFERATIVE DR': 4, 'Proliferative DR': 4, 'proliferative dr': 4
    }
    
    # Apply mapping
    df_clean['label_clean'] = df_clean['label'].map(label_mapping)
    
    # Identify invalid labels
    invalid_mask = df_clean['label_clean'].isna()
    invalid_labels = df_clean[invalid_mask]['label'].unique()
    
    print(f"Found {invalid_mask.sum()} invalid labels: {invalid_labels}")
    
    # Remove invalid labels
    df_clean = df_clean[~invalid_mask].copy()
    
    # Convert to int
    df_clean['label_clean'] = df_clean['label_clean'].astype(int)
    
    return df_clean, invalid_labels

# Load and clean labels
print("Loading labels.csv...")
try:
    # Try to find labels file in various locations
    possible_paths = [
        os.path.join(CONFIG['DATA_PATH'], CONFIG['LABELS_FILE']),
        CONFIG['LABELS_FILE'],
        './labels.csv',
        '../input/labels.csv'
    ]
    
    labels_df = None
    for path in possible_paths:
        if os.path.exists(path):
            labels_df = pd.read_csv(path)
            print(f"Found labels file at: {path}")
            break
    
    if labels_df is None:
        raise FileNotFoundError("labels.csv not found. Please update CONFIG['DATA_PATH'] and CONFIG['LABELS_FILE']")
    
    print(f"Original dataset shape: {labels_df.shape}")
    print(f"Columns: {list(labels_df.columns)}")
    
    # Display first few rows
    print("\nFirst 10 rows:")
    print(labels_df.head(10))
    
    # Show unique labels before cleaning
    print(f"\nUnique labels before cleaning: {sorted(labels_df['label'].unique())}")
    
    # Clean labels
    labels_clean, invalid_labels = clean_labels(labels_df)
    
    print(f"\nCleaned dataset shape: {labels_clean.shape}")
    print(f"Removed {len(labels_df) - len(labels_clean)} invalid entries")
    
except FileNotFoundError as e:
    print(f"Error: {e}")
    print("Creating sample dataset for demonstration...")
    
    # Create sample dataset for demo
    sample_data = {
        'image_id': [f'img_{i:04d}.jpg' for i in range(1000)],
        'label': np.random.choice(['0', '1', '2', '3', '4', 'No_DR', 'Mild', 'unknown'], 1000)
    }
    labels_df = pd.DataFrame(sample_data)
    labels_clean, invalid_labels = clean_labels(labels_df)
    print("Sample dataset created for demonstration.")

In [None]:
# Display class distribution
class_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR']
class_counts = labels_clean['label_clean'].value_counts().sort_index()

print("\nClass Distribution:")
print("=" * 40)
for i, (class_id, count) in enumerate(class_counts.items()):
    percentage = count / len(labels_clean) * 100
    print(f"{class_id}: {class_names[class_id]:<15} {count:>6} ({percentage:>5.1f}%)")

print(f"\nTotal valid samples: {len(labels_clean)}")

# Visualize class distribution
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
class_counts.plot(kind='bar', color='skyblue', edgecolor='black')
plt.title('Class Distribution (Count)')
plt.xlabel('Class')
plt.ylabel('Count')
plt.xticks(range(len(class_names)), class_names, rotation=45)
plt.grid(axis='y', alpha=0.3)

plt.subplot(1, 2, 2)
percentages = class_counts / len(labels_clean) * 100
plt.pie(percentages, labels=class_names, autopct='%1.1f%%', startangle=90)
plt.title('Class Distribution (Percentage)')

plt.tight_layout()
plt.show()

# Check for class imbalance
imbalance_ratio = class_counts.max() / class_counts.min()
print(f"\nClass imbalance ratio: {imbalance_ratio:.2f}")
if imbalance_ratio > 5:
    print("⚠️  Significant class imbalance detected. Will use weighted sampling/loss.")
else:
    print("✅ Class distribution is relatively balanced.")

### 🔍 Image File Verification

Let's check if all labeled images exist and identify any missing files.

In [None]:
def find_images(labels_df, data_path):
    """
    Find image files and check for missing images
    """
    # Common image extensions
    extensions = ['.jpg', '.jpeg', '.png', '.tiff', '.tif']
    
    # Look for images in various subdirectories
    search_paths = [
        data_path,
        os.path.join(data_path, 'images'),
        os.path.join(data_path, 'train'),
        os.path.join(data_path, 'test'),
        './images',
        '../input/images'
    ]
    
    image_files = {}
    image_dir = None
    
    for search_path in search_paths:
        if os.path.exists(search_path):
            for ext in extensions:
                pattern = f"*{ext}"
                files = list(Path(search_path).glob(pattern))
                if files:
                    image_dir = search_path
                    for file in files:
                        image_files[file.name] = str(file)
                    print(f"Found {len(files)} {ext} files in {search_path}")
    
    if not image_files:
        print("No image files found. Please check your data path.")
        return None, None, None
    
    # Check for missing images
    missing_images = []
    existing_images = []
    
    for img_id in labels_df['image_id']:
        if img_id in image_files:
            existing_images.append(img_id)
        else:
            # Try with different extensions
            base_name = os.path.splitext(img_id)[0]
            found = False
            for ext in extensions:
                if f"{base_name}{ext}" in image_files:
                    existing_images.append(img_id)
                    found = True
                    break
            if not found:
                missing_images.append(img_id)
    
    return image_files, existing_images, missing_images, image_dir

# Find images
print("Searching for image files...")
image_files, existing_images, missing_images, image_dir = find_images(labels_clean, CONFIG['DATA_PATH'])

if image_files:
    print(f"\nImage File Summary:")
    print(f"Total image files found: {len(image_files)}")
    print(f"Images with labels: {len(existing_images)}")
    print(f"Missing images: {len(missing_images)}")
    
    if missing_images:
        print(f"\nFirst 10 missing images: {missing_images[:10]}")
        
        # Filter out missing images
        labels_final = labels_clean[labels_clean['image_id'].isin(existing_images)].copy()
        print(f"Final dataset size after removing missing images: {len(labels_final)}")
    else:
        labels_final = labels_clean.copy()
        print("✅ All labeled images found!")
    
    # Update config with image directory
    CONFIG['IMAGE_DIR'] = image_dir
    print(f"Image directory: {image_dir}")
    
else:
    print("⚠️  No images found. Using labels only for demonstration.")
    labels_final = labels_clean.copy()
    CONFIG['IMAGE_DIR'] = None

### 📸 Sample Images Visualization

Let's display random samples from each class to understand the data better.

In [None]:
def load_and_preprocess_image(image_path, size=224):
    """
    Load and preprocess image for display
    """
    try:
        # Load image
        img = cv2.imread(image_path)
        if img is None:
            return None
        
        # Convert BGR to RGB
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Resize
        img = cv2.resize(img, (size, size))
        
        return img
    except Exception as e:
        print(f"Error loading image {image_path}: {e}")
        return None

def crop_black_borders(image, threshold=10):
    """
    Crop black borders from retinal images
    """
    # Convert to grayscale for border detection
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    
    # Find non-black pixels
    coords = cv2.findNonZero((gray > threshold).astype(np.uint8))
    
    if coords is not None:
        # Get bounding box
        x, y, w, h = cv2.boundingRect(coords)
        
        # Add small padding
        pad = 5
        x = max(0, x - pad)
        y = max(0, y - pad)
        w = min(image.shape[1] - x, w + 2*pad)
        h = min(image.shape[0] - y, h + 2*pad)
        
        # Crop image
        cropped = image[y:y+h, x:x+w]
        return cropped
    
    return image

# Display sample images if available
if CONFIG['IMAGE_DIR'] and len(labels_final) > 0:
    print("Displaying sample images from each class...")
    
    fig, axes = plt.subplots(5, 3, figsize=(12, 20))
    fig.suptitle('Sample Images by Class (Original, Cropped, Resized)', fontsize=16)
    
    for class_id in range(5):
        # Get samples from this class
        class_samples = labels_final[labels_final['label_clean'] == class_id].sample(
            min(1, len(labels_final[labels_final['label_clean'] == class_id])), 
            random_state=42
        )
        
        if len(class_samples) > 0:
            img_id = class_samples.iloc[0]['image_id']
            
            # Find image path
            img_path = None
            if img_id in image_files:
                img_path = image_files[img_id]
            else:
                # Try different extensions
                base_name = os.path.splitext(img_id)[0]
                for ext in ['.jpg', '.jpeg', '.png']:
                    if f"{base_name}{ext}" in image_files:
                        img_path = image_files[f"{base_name}{ext}"]
                        break
            
            if img_path and os.path.exists(img_path):
                # Load original image
                original_img = load_and_preprocess_image(img_path, size=300)
                
                if original_img is not None:
                    # Crop borders
                    cropped_img = crop_black_borders(original_img)
                    
                    # Resize to final size
                    final_img = cv2.resize(cropped_img, (224, 224))
                    
                    # Display images
                    axes[class_id, 0].imshow(original_img)
                    axes[class_id, 0].set_title(f'{class_names[class_id]}\nOriginal')
                    axes[class_id, 0].axis('off')
                    
                    axes[class_id, 1].imshow(cropped_img)
                    axes[class_id, 1].set_title('Cropped')
                    axes[class_id, 1].axis('off')
                    
                    axes[class_id, 2].imshow(final_img)
                    axes[class_id, 2].set_title('Resized (224x224)')
                    axes[class_id, 2].axis('off')
                    
                    continue
        
        # If no image found, show placeholder
        for j in range(3):
            axes[class_id, j].text(0.5, 0.5, f'{class_names[class_id]}\nNo image', 
                                 ha='center', va='center', transform=axes[class_id, j].transAxes)
            axes[class_id, j].axis('off')
    
    plt.tight_layout()
    plt.show()
    
else:
    print("No images available for visualization.")
    print("The pipeline will continue with data loading and model training structure.")

## 2️⃣ Preprocessing & Augmentation

We'll create a custom dataset class with:
- Black border cropping
- Image resizing to 224×224
- ImageNet normalization
- Data augmentation for training

In [None]:
class DiabeticRetinopathyDataset(Dataset):
    def __init__(self, dataframe, image_dir, transform=None, is_training=True):
        self.dataframe = dataframe.reset_index(drop=True)
        self.image_dir = image_dir
        self.transform = transform
        self.is_training = is_training
        
        # ImageNet statistics
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]
    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        img_id = row['image_id']
        label = row['label_clean']
        
        # Load image
        img_path = self._find_image_path(img_id)
        
        if img_path is None:
            # Return dummy image if not found
            image = np.zeros((224, 224, 3), dtype=np.uint8)
        else:
            image = self._load_image(img_path)
        
        # Apply transformations
        if self.transform:
            transformed = self.transform(image=image)
            image = transformed['image']
        
        return image, torch.tensor(label, dtype=torch.long)
    
    def _find_image_path(self, img_id):
        """Find the full path to an image"""
        if self.image_dir is None:
            return None
        
        # Try exact match first
        exact_path = os.path.join(self.image_dir, img_id)
        if os.path.exists(exact_path):
            return exact_path
        
        # Try different extensions
        base_name = os.path.splitext(img_id)[0]
        for ext in ['.jpg', '.jpeg', '.png', '.tiff', '.tif']:
            path = os.path.join(self.image_dir, f"{base_name}{ext}")
            if os.path.exists(path):
                return path
        
        return None
    
    def _load_image(self, img_path):
        """Load and preprocess image"""
        try:
            # Load image
            img = cv2.imread(img_path)
            if img is None:
                raise ValueError(f"Could not load image: {img_path}")
            
            # Convert BGR to RGB
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            # Crop black borders
            img = self._crop_black_borders(img)
            
            # Resize
            img = cv2.resize(img, (CONFIG['IMAGE_SIZE'], CONFIG['IMAGE_SIZE']))
            
            return img
            
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return black image as fallback
            return np.zeros((CONFIG['IMAGE_SIZE'], CONFIG['IMAGE_SIZE'], 3), dtype=np.uint8)
    
    def _crop_black_borders(self, image, threshold=10):
        """Crop black borders from retinal images"""
        # Convert to grayscale for border detection
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Find non-black pixels
        coords = cv2.findNonZero((gray > threshold).astype(np.uint8))
        
        if coords is not None:
            # Get bounding box
            x, y, w, h = cv2.boundingRect(coords)
            
            # Add small padding
            pad = 5
            x = max(0, x - pad)
            y = max(0, y - pad)
            w = min(image.shape[1] - x, w + 2*pad)
            h = min(image.shape[0] - y, h + 2*pad)
            
            # Crop image
            cropped = image[y:y+h, x:x+w]
            return cropped
        
        return image

# Define transforms
def get_transforms(is_training=True):
    """Get image transforms for training/validation"""
    
    if is_training:
        transform = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.Rotate(limit=15, p=0.5),
            A.RandomBrightnessContrast(
                brightness_limit=0.2, 
                contrast_limit=0.2, 
                p=0.5
            ),
            A.HueSaturationValue(
                hue_shift_limit=10,
                sat_shift_limit=20,
                val_shift_limit=10,
                p=0.3
            ),
            A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
                max_pixel_value=255.0
            ),
            ToTensorV2()
        ])
    else:
        transform = A.Compose([
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
                max_pixel_value=255.0
            ),
            ToTensorV2()
        ])
    
    return transform

print("Dataset class and transforms defined successfully!")

In [None]:
# Split data into train, validation, and test sets
train_df, temp_df = train_test_split(
    labels_final, 
    test_size=0.3, 
    random_state=CONFIG['RANDOM_SEED'],
    stratify=labels_final['label_clean']
)

val_df, test_df = train_test_split(
    temp_df, 
    test_size=0.5, 
    random_state=CONFIG['RANDOM_SEED'],
    stratify=temp_df['label_clean']
)

print(f"Dataset splits:")
print(f"Train: {len(train_df)} samples")
print(f"Validation: {len(val_df)} samples")
print(f"Test: {len(test_df)} samples")

# Display class distribution in each split
splits_info = {
    'Train': train_df['label_clean'].value_counts().sort_index(),
    'Validation': val_df['label_clean'].value_counts().sort_index(),
    'Test': test_df['label_clean'].value_counts().sort_index()
}

print("\nClass distribution by split:")
for split_name, counts in splits_info.items():
    print(f"\n{split_name}:")
    for class_id, count in counts.items():
        print(f"  {class_names[class_id]}: {count}")

# Calculate class weights for handling imbalance
class_weights = compute_class_weight(
    'balanced',
    classes=np.unique(train_df['label_clean']),
    y=train_df['label_clean']
)
class_weights_tensor = torch.FloatTensor(class_weights).to(device)

print(f"\nClass weights for balanced loss:")
for i, weight in enumerate(class_weights):
    print(f"{class_names[i]}: {weight:.3f}")

# Create datasets
train_dataset = DiabeticRetinopathyDataset(
    train_df, 
    CONFIG['IMAGE_DIR'], 
    transform=get_transforms(is_training=True),
    is_training=True
)

val_dataset = DiabeticRetinopathyDataset(
    val_df, 
    CONFIG['IMAGE_DIR'], 
    transform=get_transforms(is_training=False),
    is_training=False
)

test_dataset = DiabeticRetinopathyDataset(
    test_df, 
    CONFIG['IMAGE_DIR'], 
    transform=get_transforms(is_training=False),
    is_training=False
)

print(f"\nDatasets created successfully!")
print(f"Train dataset: {len(train_dataset)} samples")
print(f"Validation dataset: {len(val_dataset)} samples")
print(f"Test dataset: {len(test_dataset)} samples")

In [None]:
# Create weighted sampler for handling class imbalance
def create_weighted_sampler(dataset, labels):
    """Create weighted sampler for imbalanced dataset"""
    class_counts = np.bincount(labels)
    class_weights = 1.0 / class_counts
    sample_weights = [class_weights[label] for label in labels]
    
    return WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )

# Create weighted sampler for training
train_sampler = create_weighted_sampler(train_dataset, train_df['label_clean'].values)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['BATCH_SIZE'],
    sampler=train_sampler,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['BATCH_SIZE'],
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG['BATCH_SIZE'],
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print(f"Data loaders created:")
print(f"Train batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

# Test data loading
print("\nTesting data loading...")
try:
    sample_batch = next(iter(train_loader))
    images, labels = sample_batch
    print(f"Batch shape: {images.shape}")
    print(f"Labels shape: {labels.shape}")
    print(f"Image range: [{images.min():.3f}, {images.max():.3f}]")
    print(f"Sample labels: {labels[:10].tolist()}")
    print("✅ Data loading successful!")
except Exception as e:
    print(f"❌ Data loading error: {e}")
    print("Note: This is expected if no images are available.")

## 3️⃣ Model Architecture (ResNet50)

We'll use a two-phase training approach:
1. **Phase 1**: Freeze backbone, train classifier head only
2. **Phase 2**: Unfreeze backbone, fine-tune with smaller learning rate

In [None]:
class DiabeticRetinopathyModel(nn.Module):
    def __init__(self, num_classes=5, model_name='resnet50', pretrained=True):
        super(DiabeticRetinopathyModel, self).__init__()
        
        # Load pretrained model
        if model_name == 'resnet50':
            self.backbone = models.resnet50(pretrained=pretrained)
            num_features = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()  # Remove final layer
        else:
            raise ValueError(f"Unsupported model: {model_name}")
        
        # Custom classifier head
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        
        self.num_classes = num_classes
        self.model_name = model_name
    
    def forward(self, x):
        # Extract features
        features = self.backbone(x)
        
        # Classify
        output = self.classifier(features)
        
        return output
    
    def freeze_backbone(self):
        """Freeze backbone parameters"""
        for param in self.backbone.parameters():
            param.requires_grad = False
    
    def unfreeze_backbone(self):
        """Unfreeze backbone parameters"""
        for param in self.backbone.parameters():
            param.requires_grad = True
    
    def get_trainable_params(self):
        """Get number of trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

# Create model
model = DiabeticRetinopathyModel(
    num_classes=CONFIG['NUM_CLASSES'],
    model_name=CONFIG['MODEL_NAME'],
    pretrained=True
).to(device)

print(f"Model created: {CONFIG['MODEL_NAME']}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {model.get_trainable_params():,}")

# Test model forward pass
model.eval()
with torch.no_grad():
    dummy_input = torch.randn(2, 3, CONFIG['IMAGE_SIZE'], CONFIG['IMAGE_SIZE']).to(device)
    dummy_output = model(dummy_input)
    print(f"\nModel output shape: {dummy_output.shape}")
    print(f"Output range: [{dummy_output.min():.3f}, {dummy_output.max():.3f}]")
    print("✅ Model forward pass successful!")

In [None]:
# Training utilities
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.001, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = None
        self.counter = 0
        self.best_weights = None
    
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
        else:
            self.counter += 1
        
        if self.counter >= self.patience:
            if self.restore_best_weights:
                model.load_state_dict(self.best_weights)
            return True
        return False
    
    def save_checkpoint(self, model):
        self.best_weights = model.state_dict().copy()

def calculate_metrics(y_true, y_pred, y_prob=None):
    """Calculate comprehensive metrics"""
    from sklearn.metrics import (
        accuracy_score, precision_recall_fscore_support,
        roc_auc_score, confusion_matrix
    )
    
    accuracy = accuracy_score(y_true, y_pred)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
    
    metrics = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }
    
    # Calculate AUC if probabilities provided
    if y_prob is not None:
        try:
            auc_score = roc_auc_score(y_true, y_prob, multi_class='ovr', average='weighted')
            metrics['auc'] = auc_score
        except:
            metrics['auc'] = 0.0
    
    return metrics

def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    all_probs = []
    
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # Get predictions
        probs = torch.softmax(outputs, dim=1)
        preds = torch.argmax(outputs, dim=1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())
    
    epoch_loss = running_loss / len(loader)
    metrics = calculate_metrics(all_labels, all_preds, all_probs)
    
    return epoch_loss, metrics

def validate_epoch(model, loader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            
            # Get predictions
            probs = torch.softmax(outputs, dim=1)
            preds = torch.argmax(outputs, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    epoch_loss = running_loss / len(loader)
    metrics = calculate_metrics(all_labels, all_preds, all_probs)
    
    return epoch_loss, metrics, all_labels, all_preds, all_probs

print("Training utilities defined successfully!")

### 🎯 Phase 1: Train Classifier Head (Frozen Backbone)

First, we'll freeze the ResNet50 backbone and train only the classifier head.

In [None]:
# Phase 1: Freeze backbone and train classifier
print("=" * 60)
print("PHASE 1: Training Classifier Head (Frozen Backbone)")
print("=" * 60)

# Freeze backbone
model.freeze_backbone()
print(f"Trainable parameters: {model.get_trainable_params():,}")

# Setup training
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
optimizer = optim.AdamW(model.parameters(), lr=CONFIG['LEARNING_RATE'], weight_decay=CONFIG['WEIGHT_DECAY'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
early_stopping = EarlyStopping(patience=CONFIG['PATIENCE'], min_delta=CONFIG['MIN_DELTA'])

# Training history
history = {
    'train_loss': [], 'val_loss': [],
    'train_acc': [], 'val_acc': [],
    'train_f1': [], 'val_f1': [],
    'val_auc': []
}

print(f"\nStarting Phase 1 training for up to {CONFIG['NUM_EPOCHS']} epochs...")
print(f"Learning rate: {CONFIG['LEARNING_RATE']}")
print(f"Batch size: {CONFIG['BATCH_SIZE']}")

best_val_loss = float('inf')
phase1_epochs = 0

for epoch in range(CONFIG['NUM_EPOCHS']):
    print(f"\nEpoch {epoch+1}/{CONFIG['NUM_EPOCHS']}")
    print("-" * 30)
    
    # Train
    train_loss, train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_metrics, _, _, _ = validate_epoch(model, val_loader, criterion, device)
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Store history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_metrics['accuracy'])
    history['val_acc'].append(val_metrics['accuracy'])
    history['train_f1'].append(train_metrics['f1'])
    history['val_f1'].append(val_metrics['f1'])
    history['val_auc'].append(val_metrics.get('auc', 0.0))
    
    # Print metrics
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_metrics['accuracy']:.4f} | Train F1: {train_metrics['f1']:.4f}")
    print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_metrics['accuracy']:.4f} | Val F1:   {val_metrics['f1']:.4f} | Val AUC: {val_metrics.get('auc', 0.0):.4f}")
    
    # Early stopping
    if early_stopping(val_loss, model):
        print(f"\nEarly stopping triggered after {epoch+1} epochs")
        break
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_metrics': val_metrics
        }, os.path.join(CONFIG['SAVE_PATH'], 'best_phase1_model.pth'))
        print(f"✅ Best model saved (Val Loss: {val_loss:.4f})")
    
    phase1_epochs = epoch + 1

print(f"\nPhase 1 completed after {phase1_epochs} epochs")
print(f"Best validation loss: {best_val_loss:.4f}")

### 🔥 Phase 2: Fine-tune Entire Model (Unfrozen Backbone)

Now we'll unfreeze the backbone and fine-tune the entire model with a smaller learning rate.

In [None]:
# Phase 2: Unfreeze and fine-tune entire model
print("\n" + "=" * 60)
print("PHASE 2: Fine-tuning Entire Model (Unfrozen Backbone)")
print("=" * 60)

# Unfreeze backbone
model.unfreeze_backbone()
print(f"Trainable parameters: {model.get_trainable_params():,}")

# Setup training with smaller learning rate
fine_tune_lr = CONFIG['LEARNING_RATE'] / 10  # 10x smaller LR for fine-tuning
optimizer = optim.AdamW(model.parameters(), lr=fine_tune_lr, weight_decay=CONFIG['WEIGHT_DECAY'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
early_stopping = EarlyStopping(patience=CONFIG['PATIENCE']//2, min_delta=CONFIG['MIN_DELTA']/2)  # More sensitive

print(f"\nStarting Phase 2 training for up to {CONFIG['NUM_EPOCHS']//2} epochs...")
print(f"Fine-tuning learning rate: {fine_tune_lr}")

phase2_epochs = 0
best_val_loss_phase2 = float('inf')

for epoch in range(CONFIG['NUM_EPOCHS']//2):  # Fewer epochs for fine-tuning
    print(f"\nPhase 2 - Epoch {epoch+1}/{CONFIG['NUM_EPOCHS']//2}")
    print("-" * 30)
    
    # Train
    train_loss, train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_metrics, _, _, _ = validate_epoch(model, val_loader, criterion, device)
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Store history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_metrics['accuracy'])
    history['val_acc'].append(val_metrics['accuracy'])
    history['train_f1'].append(train_metrics['f1'])
    history['val_f1'].append(val_metrics['f1'])
    history['val_auc'].append(val_metrics.get('auc', 0.0))
    
    # Print metrics
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_metrics['accuracy']:.4f} | Train F1: {train_metrics['f1']:.4f}")
    print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_metrics['accuracy']:.4f} | Val F1:   {val_metrics['f1']:.4f} | Val AUC: {val_metrics.get('auc', 0.0):.4f}")
    
    # Early stopping
    if early_stopping(val_loss, model):
        print(f"\nEarly stopping triggered after {epoch+1} epochs")
        break
    
    # Save best model
    if val_loss < best_val_loss_phase2:
        best_val_loss_phase2 = val_loss
        torch.save({
            'epoch': phase1_epochs + epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_metrics': val_metrics,
            'history': history
        }, os.path.join(CONFIG['SAVE_PATH'], 'best_final_model.pth'))
        print(f"✅ Best final model saved (Val Loss: {val_loss:.4f})")
    
    phase2_epochs = epoch + 1

print(f"\nPhase 2 completed after {phase2_epochs} epochs")
print(f"Best validation loss: {best_val_loss_phase2:.4f}")
print(f"\nTotal training epochs: {phase1_epochs + phase2_epochs}")

In [None]:
# Plot training history
def plot_training_history(history):
    """Plot training and validation metrics"""
    epochs = range(1, len(history['train_loss']) + 1)
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Training History', fontsize=16)
    
    # Loss
    axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
    axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
    axes[0, 0].set_title('Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Accuracy
    axes[0, 1].plot(epochs, history['train_acc'], 'b-', label='Train Acc', linewidth=2)
    axes[0, 1].plot(epochs, history['val_acc'], 'r-', label='Val Acc', linewidth=2)
    axes[0, 1].set_title('Accuracy')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # F1 Score
    axes[1, 0].plot(epochs, history['train_f1'], 'b-', label='Train F1', linewidth=2)
    axes[1, 0].plot(epochs, history['val_f1'], 'r-', label='Val F1', linewidth=2)
    axes[1, 0].set_title('F1 Score')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('F1 Score')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # AUC
    axes[1, 1].plot(epochs, history['val_auc'], 'g-', label='Val AUC', linewidth=2)
    axes[1, 1].set_title('AUC Score')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('AUC')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    # Add phase separation line
    if phase1_epochs > 0:
        for ax in axes.flat:
            ax.axvline(x=phase1_epochs, color='orange', linestyle='--', alpha=0.7, label='Phase 2 Start')
            if ax == axes[0, 0]:  # Only add legend to first subplot
                ax.legend()
    
    plt.tight_layout()
    plt.savefig('./outputs/training_history.png', dpi=300, bbox_inches='tight')
    plt.show()

# Plot training history if we have data
if len(history['train_loss']) > 0:
    plot_training_history(history)
    
    # Print final metrics
    print("\nFinal Training Metrics:")
    print(f"Best Validation Loss: {min(history['val_loss']):.4f}")
    print(f"Best Validation Accuracy: {max(history['val_acc']):.4f}")
    print(f"Best Validation F1: {max(history['val_f1']):.4f}")
    print(f"Best Validation AUC: {max(history['val_auc']):.4f}")
else:
    print("No training history available. Skipping visualization.")

## 6️⃣ Model Evaluation

Let's evaluate our trained model on the test set and generate comprehensive metrics.

In [None]:
# Load best model for evaluation
try:
    checkpoint = torch.load(os.path.join(CONFIG['SAVE_PATH'], 'best_final_model.pth'), map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print("✅ Loaded best final model for evaluation")
    print(f"Model was saved at epoch {checkpoint['epoch']} with val loss {checkpoint['val_loss']:.4f}")
except:
    print("⚠️  No saved model found. Using current model state.")

# Evaluate on test set
print("\nEvaluating on test set...")
test_loss, test_metrics, test_labels, test_preds, test_probs = validate_epoch(model, test_loader, criterion, device)

print(f"\nTest Results:")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_metrics['accuracy']:.4f}")
print(f"Test Precision: {test_metrics['precision']:.4f}")
print(f"Test Recall: {test_metrics['recall']:.4f}")
print(f"Test F1: {test_metrics['f1']:.4f}")
print(f"Test AUC: {test_metrics.get('auc', 0.0):.4f}")

# Detailed classification report
from sklearn.metrics import classification_report
print("\nDetailed Classification Report:")
print(classification_report(test_labels, test_preds, target_names=class_names))

In [None]:
# Plot confusion matrix and ROC curves
def plot_confusion_matrix(y_true, y_pred, class_names):
    """Plot confusion matrix"""
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.tight_layout()
    plt.savefig('./outputs/confusion_matrix.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print normalized confusion matrix
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_normalized, annot=True, fmt='.3f', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Normalized Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.tight_layout()
    plt.savefig('./outputs/confusion_matrix_normalized.png', dpi=300, bbox_inches='tight')
    plt.show()

def plot_roc_curves(y_true, y_prob, class_names):
    """Plot ROC curves for each class"""
    from sklearn.preprocessing import label_binarize
    from sklearn.metrics import roc_curve, auc
    
    # Binarize labels for multiclass ROC
    y_true_bin = label_binarize(y_true, classes=range(len(class_names)))
    
    plt.figure(figsize=(12, 8))
    
    colors = ['blue', 'red', 'green', 'orange', 'purple']
    
    for i, (class_name, color) in enumerate(zip(class_names, colors)):
        if i < y_true_bin.shape[1] and i < len(y_prob[0]):
            fpr, tpr, _ = roc_curve(y_true_bin[:, i], [prob[i] for prob in y_prob])
            roc_auc = auc(fpr, tpr)
            
            plt.plot(fpr, tpr, color=color, lw=2,
                    label=f'{class_name} (AUC = {roc_auc:.3f})')
    
    plt.plot([0, 1], [0, 1], 'k--', lw=2, label='Random Classifier')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curves for Each Class')
    plt.legend(loc="lower right")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('./outputs/roc_curves.png', dpi=300, bbox_inches='tight')
    plt.show()

# Generate plots if we have test results
if len(test_labels) > 0:
    plot_confusion_matrix(test_labels, test_preds, class_names)
    
    if len(test_probs) > 0 and len(test_probs[0]) == len(class_names):
        plot_roc_curves(test_labels, test_probs, class_names)
    else:
        print("Skipping ROC curves due to insufficient probability data")
else:
    print("No test results available for plotting.")

## 4️⃣ Explainability with Grad-CAM

Let's use Grad-CAM (Gradient-weighted Class Activation Mapping) to visualize which parts of the retinal images our model focuses on when making predictions.

In [None]:
# Grad-CAM implementation
def get_gradcam_visualization(model, image, target_class, device):
    """Generate Grad-CAM visualization for a given image and target class"""
    
    # Define target layer (last convolutional layer of ResNet50)
    target_layers = [model.backbone.layer4[-1]]
    
    # Initialize Grad-CAM
    cam = GradCAM(model=model, target_layers=target_layers)
    
    # Generate CAM
    targets = [ClassifierOutputTarget(target_class)]
    
    # Get gradcam output
    grayscale_cam = cam(input_tensor=image.unsqueeze(0), targets=targets)
    grayscale_cam = grayscale_cam[0, :]  # Remove batch dimension
    
    return grayscale_cam

def visualize_gradcam_samples(model, dataset, device, num_samples=10):
    """Visualize Grad-CAM for sample images from each class"""
    model.eval()
    
    # Get samples from each class
    samples_per_class = max(1, num_samples // len(class_names))
    
    fig, axes = plt.subplots(len(class_names), 3, figsize=(15, 3*len(class_names)))
    fig.suptitle('Grad-CAM Visualizations by Class', fontsize=16)
    
    for class_id in range(len(class_names)):
        # Find samples from this class
        class_indices = [i for i, (_, label) in enumerate(dataset) if label == class_id]
        
        if len(class_indices) == 0:
            # No samples for this class
            for j in range(3):
                axes[class_id, j].text(0.5, 0.5, f'No {class_names[class_id]} samples', 
                                     ha='center', va='center', transform=axes[class_id, j].transAxes)
                axes[class_id, j].axis('off')
            continue
        
        # Get a random sample
        sample_idx = np.random.choice(class_indices)
        image, label = dataset[sample_idx]
        
        # Move to device
        image_tensor = image.to(device)
        
        # Get model prediction
        with torch.no_grad():
            output = model(image_tensor.unsqueeze(0))
            probabilities = torch.softmax(output, dim=1)
            predicted_class = torch.argmax(output, dim=1).item()
            confidence = probabilities[0, predicted_class].item()
        
        try:
            # Generate Grad-CAM
            gradcam = get_gradcam_visualization(model, image_tensor, predicted_class, device)
            
            # Convert image to numpy for visualization
            # Denormalize image
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            
            img_np = image.cpu().numpy().transpose(1, 2, 0)
            img_np = std * img_np + mean
            img_np = np.clip(img_np, 0, 1)
            
            # Create overlay
            visualization = show_cam_on_image(img_np, gradcam, use_rgb=True)
            
            # Plot original image
            axes[class_id, 0].imshow(img_np)
            axes[class_id, 0].set_title(f'{class_names[class_id]}\nOriginal')
            axes[class_id, 0].axis('off')
            
            # Plot Grad-CAM heatmap
            axes[class_id, 1].imshow(gradcam, cmap='jet')
            axes[class_id, 1].set_title('Grad-CAM Heatmap')
            axes[class_id, 1].axis('off')
            
            # Plot overlay
            axes[class_id, 2].imshow(visualization)
            axes[class_id, 2].set_title(f'Overlay\nPred: {class_names[predicted_class]} ({confidence:.2f})')
            axes[class_id, 2].axis('off')
            
        except Exception as e:
            print(f"Error generating Grad-CAM for class {class_names[class_id]}: {e}")
            # Show error message
            for j in range(3):
                axes[class_id, j].text(0.5, 0.5, f'Grad-CAM Error\n{class_names[class_id]}', 
                                     ha='center', va='center', transform=axes[class_id, j].transAxes)
                axes[class_id, j].axis('off')
    
    plt.tight_layout()
    plt.savefig('./outputs/gradcam_visualizations.png', dpi=300, bbox_inches='tight')
    plt.show()

# Generate Grad-CAM visualizations
print("Generating Grad-CAM visualizations...")
try:
    visualize_gradcam_samples(model, test_dataset, device, num_samples=len(class_names))
    print("✅ Grad-CAM visualizations generated successfully!")
except Exception as e:
    print(f"❌ Error generating Grad-CAM visualizations: {e}")
    print("This might be due to missing images or model issues.")

## 5️⃣ Model Export for Next.js Integration

We'll export our trained model to ONNX format for use in a Next.js application with `onnxruntime-node`.

In [None]:
# Export model to ONNX format
def export_to_onnx(model, save_path, input_size=(1, 3, 224, 224)):
    """Export PyTorch model to ONNX format"""
    model.eval()
    
    # Create dummy input
    dummy_input = torch.randn(input_size).to(device)
    
    # Export to ONNX
    torch.onnx.export(
        model,
        dummy_input,
        save_path,
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={
            'input': {0: 'batch_size'},
            'output': {0: 'batch_size'}
        }
    )

def verify_onnx_model(onnx_path, pytorch_model, device):
    """Verify ONNX model produces same results as PyTorch model"""
    # Load ONNX model
    ort_session = ort.InferenceSession(onnx_path)
    
    # Create test input
    test_input = torch.randn(1, 3, 224, 224).to(device)
    
    # PyTorch prediction
    pytorch_model.eval()
    with torch.no_grad():
        pytorch_output = pytorch_model(test_input).cpu().numpy()
    
    # ONNX prediction
    ort_inputs = {ort_session.get_inputs()[0].name: test_input.cpu().numpy()}
    ort_output = ort_session.run(None, ort_inputs)[0]
    
    # Compare outputs
    max_diff = np.max(np.abs(pytorch_output - ort_output))
    
    return max_diff < 1e-5, max_diff

# Export model
print("Exporting model to ONNX format...")
onnx_path = './outputs/diabetic_retinopathy_model.onnx'

try:
    export_to_onnx(model, onnx_path)
    print(f"✅ Model exported to: {onnx_path}")
    
    # Verify ONNX model
    is_valid, max_diff = verify_onnx_model(onnx_path, model, device)
    
    if is_valid:
        print(f"✅ ONNX model verification successful (max diff: {max_diff:.2e})")
    else:
        print(f"⚠️  ONNX model verification failed (max diff: {max_diff:.2e})")
    
    # Get model size
    import os
    model_size_mb = os.path.getsize(onnx_path) / (1024 * 1024)
    print(f"Model size: {model_size_mb:.2f} MB")
    
except Exception as e:
    print(f"❌ Error exporting to ONNX: {e}")

In [None]:
# Export preprocessing parameters for frontend
preprocessing_config = {
    'image_size': CONFIG['IMAGE_SIZE'],
    'mean': [0.485, 0.456, 0.406],
    'std': [0.229, 0.224, 0.225],
    'class_names': class_names,
    'num_classes': CONFIG['NUM_CLASSES'],
    'model_architecture': CONFIG['MODEL_NAME'],
    'input_shape': [1, 3, CONFIG['IMAGE_SIZE'], CONFIG['IMAGE_SIZE']],
    'output_shape': [1, CONFIG['NUM_CLASSES']]
}

# Save preprocessing config
with open('./outputs/preprocessing_config.json', 'w') as f:
    json.dump(preprocessing_config, f, indent=2)

print("Preprocessing configuration saved to: ./outputs/preprocessing_config.json")
print("\nPreprocessing Config:")
for key, value in preprocessing_config.items():
    print(f"{key}: {value}")

In [None]:
# Optional: Quantize ONNX model to reduce size
def quantize_onnx_model(input_path, output_path):
    """Quantize ONNX model to reduce size"""
    try:
        from onnxruntime.quantization import quantize_dynamic, QuantType
        
        quantize_dynamic(
            input_path,
            output_path,
            weight_type=QuantType.QUInt8
        )
        
        return True
    except ImportError:
        print("ONNX quantization not available. Install with: pip install onnxruntime-tools")
        return False
    except Exception as e:
        print(f"Quantization error: {e}")
        return False

# Quantize model
print("\nAttempting to quantize ONNX model...")
quantized_path = './outputs/diabetic_retinopathy_model_quantized.onnx'

if os.path.exists(onnx_path):
    success = quantize_onnx_model(onnx_path, quantized_path)
    
    if success and os.path.exists(quantized_path):
        original_size = os.path.getsize(onnx_path) / (1024 * 1024)
        quantized_size = os.path.getsize(quantized_path) / (1024 * 1024)
        compression_ratio = original_size / quantized_size
        
        print(f"✅ Quantized model saved to: {quantized_path}")
        print(f"Original size: {original_size:.2f} MB")
        print(f"Quantized size: {quantized_size:.2f} MB")
        print(f"Compression ratio: {compression_ratio:.2f}x")
    else:
        print("❌ Quantization failed or not available")
else:
    print("No ONNX model found for quantization")

## 7️⃣ Next.js Integration Guide

Here's how to integrate the exported ONNX model into your Next.js application:

In [None]:
# Generate Next.js integration code
nextjs_integration_code = '''
// Next.js API Route Example: /pages/api/predict.js or /app/api/predict/route.js

import * as ort from 'onnxruntime-node';
import sharp from 'sharp';
import path from 'path';
import fs from 'fs';

// Load preprocessing config
const configPath = path.join(process.cwd(), 'models', 'preprocessing_config.json');
const config = JSON.parse(fs.readFileSync(configPath, 'utf8'));

// Load ONNX model
let session = null;

async function loadModel() {
  if (!session) {
    const modelPath = path.join(process.cwd(), 'models', 'diabetic_retinopathy_model.onnx');
    session = await ort.InferenceSession.create(modelPath);
  }
  return session;
}

// Preprocess image
async function preprocessImage(imageBuffer) {
  // Resize and normalize image
  const { data, info } = await sharp(imageBuffer)
    .resize(config.image_size, config.image_size)
    .raw()
    .toBuffer({ resolveWithObject: true });
  
  // Convert to float32 and normalize
  const pixels = new Float32Array(data.length);
  
  for (let i = 0; i < data.length; i += 3) {
    // Normalize RGB channels
    pixels[i] = (data[i] / 255.0 - config.mean[0]) / config.std[0];     // R
    pixels[i + 1] = (data[i + 1] / 255.0 - config.mean[1]) / config.std[1]; // G
    pixels[i + 2] = (data[i + 2] / 255.0 - config.mean[2]) / config.std[2]; // B
  }
  
  // Reshape to [1, 3, 224, 224] format
  const tensor = new Float32Array(1 * 3 * config.image_size * config.image_size);
  
  for (let c = 0; c < 3; c++) {
    for (let h = 0; h < config.image_size; h++) {
      for (let w = 0; w < config.image_size; w++) {
        const pixelIndex = (h * config.image_size + w) * 3 + c;
        const tensorIndex = c * config.image_size * config.image_size + h * config.image_size + w;
        tensor[tensorIndex] = pixels[pixelIndex];
      }
    }
  }
  
  return tensor;
}

// API handler
export default async function handler(req, res) {
  if (req.method !== 'POST') {
    return res.status(405).json({ error: 'Method not allowed' });
  }
  
  try {
    // Get image from request
    const { image } = req.body; // Base64 encoded image
    const imageBuffer = Buffer.from(image, 'base64');
    
    // Preprocess image
    const inputTensor = await preprocessImage(imageBuffer);
    
    // Load model and run inference
    const session = await loadModel();
    const feeds = { input: new ort.Tensor('float32', inputTensor, [1, 3, 224, 224]) };
    const results = await session.run(feeds);
    
    // Get predictions
    const output = results.output.data;
    
    // Apply softmax to get probabilities
    const maxLogit = Math.max(...output);
    const expLogits = output.map(x => Math.exp(x - maxLogit));
    const sumExp = expLogits.reduce((a, b) => a + b, 0);
    const probabilities = expLogits.map(x => x / sumExp);
    
    // Get predicted class
    const predictedClass = probabilities.indexOf(Math.max(...probabilities));
    const confidence = probabilities[predictedClass];
    
    // Prepare response
    const response = {
      predicted_class: predictedClass,
      predicted_label: config.class_names[predictedClass],
      confidence: confidence,
      probabilities: probabilities.reduce((acc, prob, idx) => {
        acc[config.class_names[idx]] = prob;
        return acc;
      }, {})
    };
    
    res.status(200).json(response);
    
  } catch (error) {
    console.error('Prediction error:', error);
    res.status(500).json({ error: 'Prediction failed', details: error.message });
  }
}
'''

# Save Next.js integration code
with open('./outputs/nextjs_integration.js', 'w') as f:
    f.write(nextjs_integration_code.strip())

print("Next.js integration code saved to: ./outputs/nextjs_integration.js")
print("\nIntegration Steps:")
print("1. Copy the ONNX model and preprocessing config to your Next.js project")
print("2. Install dependencies: npm install onnxruntime-node sharp")
print("3. Create the API route using the provided code")
print("4. Test the API with retinal images")

## 📊 Final Model Performance Summary

Let's summarize the final model performance and provide recommendations.

In [None]:
# Create comprehensive model report
def create_model_report():
    """Create a comprehensive model performance report"""
    
    report = {
        'model_info': {
            'architecture': CONFIG['MODEL_NAME'],
            'num_classes': CONFIG['NUM_CLASSES'],
            'input_size': f"{CONFIG['IMAGE_SIZE']}x{CONFIG['IMAGE_SIZE']}",
            'total_parameters': sum(p.numel() for p in model.parameters()),
            'trainable_parameters': model.get_trainable_params()
        },
        'training_info': {
            'total_epochs': len(history['train_loss']) if history['train_loss'] else 0,
            'phase1_epochs': phase1_epochs if 'phase1_epochs' in locals() else 0,
            'phase2_epochs': phase2_epochs if 'phase2_epochs' in locals() else 0,
            'batch_size': CONFIG['BATCH_SIZE'],
            'learning_rate': CONFIG['LEARNING_RATE']
        },
        'dataset_info': {
            'total_samples': len(labels_final) if 'labels_final' in locals() else 0,
            'train_samples': len(train_df) if 'train_df' in locals() else 0,
            'val_samples': len(val_df) if 'val_df' in locals() else 0,
            'test_samples': len(test_df) if 'test_df' in locals() else 0,
            'class_distribution': dict(labels_final['label_clean'].value_counts().sort_index()) if 'labels_final' in locals() else {}
        }
    }
    
    # Add performance metrics if available
    if 'test_metrics' in locals() and test_metrics:
        report['performance'] = {
            'test_accuracy': test_metrics['accuracy'],
            'test_precision': test_metrics['precision'],
            'test_recall': test_metrics['recall'],
            'test_f1': test_metrics['f1'],
            'test_auc': test_metrics.get('auc', 0.0)
        }
    
    # Add training history if available
    if history['train_loss']:
        report['training_history'] = {
            'best_val_loss': min(history['val_loss']),
            'best_val_accuracy': max(history['val_acc']),
            'best_val_f1': max(history['val_f1']),
            'best_val_auc': max(history['val_auc'])
        }
    
    return report

# Generate report
model_report = create_model_report()

# Save report
with open('./outputs/model_report.json', 'w') as f:
    json.dump(model_report, f, indent=2)

print("=" * 60)
print("DIABETIC RETINOPATHY DETECTION MODEL REPORT")
print("=" * 60)

print(f"\n📋 Model Information:")
for key, value in model_report['model_info'].items():
    print(f"  {key.replace('_', ' ').title()}: {value:,}" if isinstance(value, int) else f"  {key.replace('_', ' ').title()}: {value}")

print(f"\n🎯 Training Information:")
for key, value in model_report['training_info'].items():
    print(f"  {key.replace('_', ' ').title()}: {value}")

print(f"\n📊 Dataset Information:")
for key, value in model_report['dataset_info'].items():
    if key != 'class_distribution':
        print(f"  {key.replace('_', ' ').title()}: {value}")

if 'class_distribution' in model_report['dataset_info'] and model_report['dataset_info']['class_distribution']:
    print(f"\n  Class Distribution:")
    for class_id, count in model_report['dataset_info']['class_distribution'].items():
        if class_id < len(class_names):
            print(f"    {class_names[class_id]}: {count}")

if 'performance' in model_report:
    print(f"\n🏆 Test Performance:")
    for key, value in model_report['performance'].items():
        print(f"  {key.replace('_', ' ').title()}: {value:.4f}")

if 'training_history' in model_report:
    print(f"\n📈 Best Training Metrics:")
    for key, value in model_report['training_history'].items():
        print(f"  {key.replace('_', ' ').title()}: {value:.4f}")

print(f"\n💾 Exported Files:")
exported_files = [
    './outputs/diabetic_retinopathy_model.onnx',
    './outputs/preprocessing_config.json',
    './outputs/nextjs_integration.js',
    './outputs/model_report.json'
]

for file_path in exported_files:
    if os.path.exists(file_path):
        file_size = os.path.getsize(file_path) / 1024  # KB
        print(f"  ✅ {file_path} ({file_size:.1f} KB)")
    else:
        print(f"  ❌ {file_path} (not found)")

print(f"\n🚀 Ready for deployment!")
print("Model report saved to: ./outputs/model_report.json")

## 🚀 Deployment Recommendations

### Model Performance Considerations

1. **Sensitivity vs Specificity**: For medical applications, consider the trade-off between false positives and false negatives
2. **Confidence Thresholds**: Implement confidence-based decision making
3. **Edge Cases**: Handle low-quality images and edge cases gracefully

### Production Deployment

1. **Model Serving**: Use the exported ONNX model with onnxruntime-node
2. **Image Processing**: Ensure consistent preprocessing pipeline
3. **Error Handling**: Implement robust error handling for various image formats
4. **Monitoring**: Track prediction confidence and model performance over time

### Next Steps

1. **Clinical Validation**: Validate model performance with medical professionals
2. **Regulatory Compliance**: Ensure compliance with medical device regulations
3. **Continuous Learning**: Implement feedback loops for model improvement
4. **A/B Testing**: Compare model versions in production

### Files Generated

- `diabetic_retinopathy_model.onnx`: Main model for inference
- `preprocessing_config.json`: Preprocessing parameters
- `nextjs_integration.js`: Next.js API route example
- `model_report.json`: Comprehensive model report
- Various visualization images (training history, confusion matrix, ROC curves, Grad-CAM)

**🎉 Congratulations! Your diabetic retinopathy detection pipeline is complete and ready for integration!**