# Diabetic Retinopathy Detection: SOTA Transformer Edition (v2)

**Goal**: >90% Accuracy & High Precision across ALL classes.

### The "SOTA" Strategy:
1.  **Model**: **Swin Transformer V2 Small** (Larger capacity than Tiny).
2.  **Preprocessing**: **Ben Graham's Method**. This is the "Secret Sauce" used by Kaggle winners. It removes lighting variations and standardizes the retina.
3.  **Loss Function**: **Focal Loss**. This forces the model to focus on "Hard" examples (like Mild DR) and ignore easy ones.
4.  **Optimization**: WeightedRandomSampler + Mixup + Cutmix + TTA.

### Instructions:
1.  **Add Data**: Search for `diabetic-retinopathy-2015-data-colored-resized`.
2.  **Accelerator**: GPU P100 or T4 x2.

In [None]:
!pip install "numpy<2.0" --upgrade timm torchmetrics grad-cam scipy scikit-learn

import os
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import torchvision.transforms as transforms
import timm

def seed_everything(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Ben Graham's Preprocessing (The Secret Weapon)
This function crops the black borders and applies a Gaussian blur subtraction to normalize lighting.

In [None]:
def load_ben_color(path, sigmaX=10):
    image = cv2.imread(path)
    if image is None:
        return np.zeros((256, 256, 3), dtype=np.uint8)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Crop black borders
    # Convert to gray to find contours
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    mask = gray > 7
    if mask.sum() == 0:
        return cv2.resize(image, (256, 256))
        
    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]
    image = image[rmin:rmax, cmin:cmax]
    
    # Resize
    image = cv2.resize(image, (256, 256))
    
    # Ben Graham's Method
    image = cv2.addWeighted(image, 4, cv2.GaussianBlur(image, (0,0), sigmaX), -4, 128)
    return image

In [None]:
# Load Data
DATA_DIR = '/kaggle/input/diabetic-retinopathy-2015-data-colored-resized/colored_images/colored_images'
data = []
mapping = {'No_DR': 0, 'Mild': 1, 'Moderate': 2, 'Severe': 3, 'Proliferate_DR': 4}

if os.path.exists(DATA_DIR):
    for class_name, label in mapping.items():
        class_dir = os.path.join(DATA_DIR, class_name)
        if os.path.exists(class_dir):
            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    data.append([os.path.join(class_dir, img_name), label])

df = pd.DataFrame(data, columns=['id_code', 'label'])
print(f"Loaded {len(df)} images.")

if len(df) > 0:
    train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['label'], random_state=42)
    
    # Weighted Sampler
    class_counts = df['label'].value_counts().sort_index().values
    sample_weights = [1.0 / class_counts[label] for label in train_df['label']]
    sampler = torch.utils.data.WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
else:
    print("Dataset not found!")

In [None]:
class RetinopathyDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # Use Ben Graham Preprocessing here
        image = load_ben_color(row['id_code'])
        label = row['label']
        
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(label, dtype=torch.long)

train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(360),
    transforms.ToTensor(),
    # No normalization needed for Ben Graham? Actually standard ImageNet norm is still good practice
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

if 'train_df' in locals():
    train_dataset = RetinopathyDataset(train_df, transform=train_transforms)
    val_dataset = RetinopathyDataset(val_df, transform=val_transforms)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False, sampler=sampler, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0, pin_memory=True)

## 2. Focal Loss
Standard Cross Entropy is overwhelmed by easy examples. Focal Loss focuses training on hard negatives.

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.ce = nn.CrossEntropyLoss(reduction='none')

    def forward(self, inputs, targets):
        ce_loss = self.ce(inputs, targets)
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

## 3. Swin Transformer V2 (Small)
Upgrading from Tiny to Small for better feature extraction.

In [None]:
class SwinTransformerSOTA(nn.Module):
    def __init__(self, num_classes=5, pretrained=True):
        super(SwinTransformerSOTA, self).__init__()
        # Upgrade to SMALL version: swinv2_small_window8_256.ms_in1k
        self.backbone = timm.create_model('swinv2_small_window8_256.ms_in1k', pretrained=pretrained, num_classes=0)
        num_features = self.backbone.num_features
        
        self.head = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, 1024),
            nn.BatchNorm1d(1024),
            nn.Hardswish(),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        x = self.backbone(x)
        x = self.head(x)
        return x

model = SwinTransformerSOTA(num_classes=5)
model = model.to(device)

In [None]:
# Training Setup
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy

# Mixup
mixup_fn = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0, prob=1.0, switch_prob=0.5, mode='batch', label_smoothing=0.1, num_classes=5)

# Loss: SoftTargetCrossEntropy for Mixup, FocalLoss for Validation/Clean
criterion_train = SoftTargetCrossEntropy()
criterion_val = FocalLoss(gamma=2)

optimizer = optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.05) # Lower LR for larger model
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=40, eta_min=1e-7)
scaler = GradScaler()

def train_one_epoch(model, loader):
    model.train()
    running_loss = 0.0
    pbar = tqdm(loader, desc="Training")
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        images, labels = mixup_fn(images, labels)
        
        optimizer.zero_grad()
        with autocast():
            outputs = model(images)
            loss = criterion_train(outputs, labels)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item()
        pbar.set_postfix({'loss': running_loss/len(loader)})
    return running_loss / len(loader)

def validate_tta(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Validation"):
            images, labels = images.to(device), labels.to(device)
            # TTA: Original + Flip
            out1 = model(images)
            out2 = model(torch.flip(images, [3]))
            outputs = (out1 + out2) / 2.0
            
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    return correct / total

# Run Training
epochs = 40
best_acc = 0.0

if 'train_loader' in locals():
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        loss = train_one_epoch(model, train_loader)
        acc = validate_tta(model, val_loader)
        scheduler.step()
        print(f"Loss: {loss:.4f} | Val Acc: {acc:.4f}")
        
        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), 'best_model_sota.pth')
            print("Saved Best Model!")
else:
    print("No data loaded.")