# Diabetic Retinopathy Detection with Custom Hybrid Swin Transformer

This notebook implements a state-of-the-art Diabetic Retinopathy detection system using a **Custom Hybrid Architecture**.

### Architecture Highlights:
1.  **Backbone**: Swin Transformer v2 (Pre-trained on ImageNet) for powerful feature extraction.
2.  **Custom Neck**: **Coordinate Attention (CA)** module to enhance feature representation.
3.  **Head**: Custom classification head for 5-class DR grading.

### Optimization:
1.  **WeightedRandomSampler**: Physically oversamples minority classes (Severe/Proliferative) to fix imbalance.
2.  **Mixup & Cutmix**: Advanced augmentation that blends images and labels. Forces the model to learn robust features.
3.  **Test Time Augmentation (TTA)**: Averages predictions across flipped versions of the image during validation/testing.



## 1. Setup and Dependencies

In [None]:
# Force numpy<2.0 to avoid compatibility issues with scipy/sklearn
!pip install "numpy<2.0" --upgrade timm torchmetrics grad-cam scipy scikit-learn

import os
import gc
import cv2
import time
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import torchvision.transforms as transforms
import timm

# Seed everything for reproducibility
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Data Preparation
We assume the dataset is attached to the Kaggle Kernel at `/kaggle/input`.

In [None]:
# Define Paths
# Dataset: sovitrath/diabetic-retinopathy-2015-data-colored-resized
DATA_DIR = '/kaggle/input/diabetic-retinopathy-2015-data-colored-resized/colored_images/colored_images'

# Check if path exists
if not os.path.exists(DATA_DIR):
    print(f"Path not found: {DATA_DIR}")
    print("Listing /kaggle/input to help debug:")
    for root, dirs, files in os.walk('/kaggle/input'):
        print(root)
        break

# Create DataFrame by crawling the folder structure
# The dataset is organized into folders: No_DR, Mild, Moderate, Severe, Proliferate_DR
data = []
mapping = {
    'No_DR': 0,
    'Mild': 1,
    'Moderate': 2,
    'Severe': 3,
    'Proliferate_DR': 4
}

print("Scanning dataset folders...")
for class_name, label in mapping.items():
    class_dir = os.path.join(DATA_DIR, class_name)
    if not os.path.exists(class_dir):
        print(f"WARNING: Directory not found: {class_dir}")
        continue
        
    for img_name in os.listdir(class_dir):
        if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
            img_path = os.path.join(class_dir, img_name)
            data.append([img_path, label])

df = pd.DataFrame(data, columns=['id_code', 'label'])
print(f"Loaded dataset with {len(df)} images.")

# Split Data
if len(df) > 0:
    train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['label'], random_state=42)
    print(f"Train size: {len(train_df)}, Val size: {len(val_df)}")
else:
    print("ERROR: No images found. Please check the dataset path.")

## 3. Custom Dataset & Transforms

In [None]:
class RetinopathyDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row['id_code']
        label = row['label']
        
        image = cv2.imread(img_path)
        if image is None:
             # Handle missing images if any
             image = np.zeros((256, 256, 3), dtype=np.uint8)
        else:
             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            image = self.transform(image)
            
        return image, torch.tensor(label, dtype=torch.long)

# Transforms
train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# --- ADVANCED OPTIMIZATION: WeightedRandomSampler ---
# This forces the model to see all classes equally, fixing the imbalance physically.
if len(df) > 0:
    # Calculate weights for each sample
    class_counts = df['label'].value_counts().sort_index().values
    sample_weights = [1.0 / class_counts[label] for label in train_df['label']]
    sample_weights = torch.DoubleTensor(sample_weights)
    
    # Create Sampler
    sampler = torch.utils.data.WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
    print("WeightedRandomSampler initialized!")
else:
    sampler = None

if 'train_df' in locals():
    train_dataset = RetinopathyDataset(train_df, transform=train_transforms)
    val_dataset = RetinopathyDataset(val_df, transform=val_transforms)

    # Use Sampler for Training (shuffle must be False when using sampler)
    # num_workers=0 is CRITICAL for Kaggle stability
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False, sampler=sampler, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0, pin_memory=True)

## 4. Custom Hybrid Model (Swin + Coordinate Attention)
This is the core innovation of the project.

In [None]:
class CoordinateAttention(nn.Module):
    def __init__(self, inp, reduction=32):
        super(CoordinateAttention, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = nn.Hardswish()
        
        self.conv_h = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)
        

    def forward(self, x):
        identity = x
        
        n, c, h, w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y) 
        
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_h * a_w
        return out

class SwinTransformerCA(nn.Module):
    def __init__(self, num_classes=5, pretrained=True):
        super(SwinTransformerCA, self).__init__()
        # Load Swin Transformer v2 as backbone
        # CORRECTED MODEL NAME: swinv2_tiny_window8_256.ms_in1k
        try:
            self.backbone = timm.create_model('swinv2_tiny_window8_256.ms_in1k', pretrained=pretrained, num_classes=0)
        except RuntimeError:
            print("Model not found. Listing available Swin V2 models:")
            print(timm.list_models('*swin*v2*'))
            raise
            
        num_features = self.backbone.num_features
        
        self.ca = CoordinateAttention(num_features)
        
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.head = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(num_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        # Swin backbone features
        x = self.backbone.forward_features(x) # Returns (B, H, W, C)
        
        # Permute to (B, C, H, W) for Coordinate Attention
        x = x.permute(0, 3, 1, 2)
        
        # Apply Custom Attention
        x = self.ca(x)
        
        # Pooling and Head
        x = self.avg_pool(x).flatten(1)
        x = self.head(x)
        return x

model = SwinTransformerCA(num_classes=5)
model = model.to(device)
print("Model initialized with Coordinate Attention!")

## 5. Training Loop with Mixup & Cutmix

In [None]:
# --- ADVANCED OPTIMIZATION: Mixup & Cutmix ---
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy

# Configure Mixup
mixup_args = {
    'mixup_alpha': 0.8,
    'cutmix_alpha': 1.0,
    'cutmix_minmax': None,
    'prob': 1.0,            # Apply mixup to 100% of batches
    'switch_prob': 0.5,     # Switch between mixup and cutmix
    'mode': 'batch',
    'label_smoothing': 0.1,
    'num_classes': 5
}
mixup_fn = Mixup(**mixup_args)
print("Mixup & Cutmix Augmentation Enabled!")

# Loss Function for Mixup (Soft Targets)
criterion = SoftTargetCrossEntropy()
val_criterion = nn.CrossEntropyLoss() # Validation uses standard loss

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05) # Increased weight decay for regularization
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6)
scaler = GradScaler()

def train_one_epoch(model, loader, optimizer, criterion, scaler, mixup_fn):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc="Training")
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        # Apply Mixup
        if mixup_fn is not None:
            images, labels = mixup_fn(images, labels)
        
        optimizer.zero_grad()
        
        with autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item()
        
        # Accuracy calculation is tricky with Mixup, so we skip it for training logs
        # or use the max prob class (approximate)
        _, predicted = outputs.max(1)
        # For logging only - labels are mixed so this isn't perfect accuracy
        _, labels_max = labels.max(1) 
        total += labels.size(0)
        correct += predicted.eq(labels_max).sum().item()
        
        pbar.set_postfix({'loss': running_loss/total})
        
    return running_loss / len(loader), correct / total

# --- ADVANCED OPTIMIZATION: Test Time Augmentation (TTA) ---
# TTA averages predictions across multiple augmented versions of the image
def validate_tta(model, loader, criterion, tta_steps=5):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # TTA Transforms (must be same as train transforms but without the extreme distortions)
    tta_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((256, 256)),
        transforms.RandomHorizontalFlip(), # TTA flip
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Validation (TTA)"):
            images, labels = images.to(device), labels.to(device)
            
            # Standard Prediction
            outputs = model(images)
            
            # TTA: Horizontal Flip
            # We need to apply the flip manually or use a loop if we had the original images.
            # Since loader gives tensors, we can just flip the tensors directly.
            images_flipped = torch.flip(images, [3]) # Flip width dimension
            outputs_flipped = model(images_flipped)
            
            # Average Predictions
            outputs = (outputs + outputs_flipped) / 2.0
            
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
    return running_loss / len(loader), correct / total

# Training
epochs = 50 # Increased to 50 for Mixup convergence
best_acc = 0.0

if 'train_loader' in locals():
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, scaler, mixup_fn)
        
        # Use TTA for Validation
        val_loss, val_acc = validate_tta(model, val_loader, val_criterion)
        
        scheduler.step()
        
        print(f"Train Loss: {train_loss:.4f} | Train Acc (Approx): {train_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f} | Val Acc (TTA): {val_acc:.4f}")
        
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')
            print("Saved Best Model!")
        print("-"*30)
else:
    print("Data loaders not defined. Skipping training loop.")

## 6. Evaluation & Visualization

In [None]:
# Load Best Model
if os.path.exists('best_model.pth'):
    model.load_state_dict(torch.load('best_model.pth'))
    model.eval()

    # Inference on Validation Set
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    print(classification_report(all_labels, all_preds))
else:
    print("No model weights found. Train the model first.")

# --- GRAD-CAM VISUALIZATION ---
# Install grad-cam library
# !pip install grad-cam (Run this in a separate cell if needed, but we included it at the top)

try:
    from pytorch_grad_cam import GradCAM
    from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
    from pytorch_grad_cam.utils.image import show_cam_on_image
    
    print("Generating Grad-CAM Visualizations...")
    
    # Target the Coordinate Attention layer or the last Swin block
    # We target the Coordinate Attention module to see its effect
    target_layers = [model.ca] 
    
    # If that fails, fallback to backbone norm: [model.backbone.norm]
    
    cam = GradCAM(model=model, target_layers=target_layers)
    
    # Get a few validation images
    model.eval()
    
    # Create a figure
    fig, axes = plt.subplots(4, 2, figsize=(10, 20))
    
    # Get a batch
    images, labels = next(iter(val_loader))
    images = images.to(device)
    
    for i in range(4): # Show 4 examples
        input_tensor = images[i].unsqueeze(0)
        label = labels[i].item()
        
        # Generate CAM
        grayscale_cam = cam(input_tensor=input_tensor, targets=[ClassifierOutputTarget(label)])
        grayscale_cam = grayscale_cam[0, :]
        
        # Prepare image for visualization
        img = input_tensor.squeeze().permute(1, 2, 0).cpu().numpy()
        img = (img - img.min()) / (img.max() - img.min()) # Normalize to 0-1
        
        visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True)
        
        # Plot Original
        axes[i, 0].imshow(img)
        axes[i, 0].set_title(f"Original (Label: {label})")
        axes[i, 0].axis('off')
        
        # Plot Grad-CAM
        axes[i, 1].imshow(visualization)
        axes[i, 1].set_title(f"Grad-CAM (Focus)")
        axes[i, 1].axis('off')
        
    plt.tight_layout()
    plt.show()
    
except ImportError:
    print("grad-cam library not installed. Please run: !pip install grad-cam")
except Exception as e:
    print(f"Grad-CAM Error: {e}")
    print("Try changing target_layers to [model.backbone.norm]")