# Diabetic Retinopathy: Fast-Track Improvement (90%+ Goal)

**Objective**: Maximize accuracy in < 4 hours.
**Key Strategy**: 
1. **Physical Oversampling**: Explicitly duplicate minority classes (Severity 1, 3, 4).
2. **OneCycleLR**: "Super-convergence" scheduler.
3. **Model EMA**: Exponential Moving Average for weight stabilization.
4. **Full Visualization**: CM, ROC, Grad-CAM included.

**Instructions**:
- Enable **Internet**.
- Select **GPU T4 x2** or **P100**.

In [None]:
# === IMPORTS & SETUP ===
!pip install -q "numpy<2.0" timm==0.9.16 grad-cam scikit-learn scipy seaborn

import os
import cv2
import time
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import cohen_kappa_score, accuracy_score, confusion_matrix, classification_report

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import torchvision.transforms as transforms

import timm
from timm.utils import ModelEmaV2

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

# Check Internet
import socket
try:
    socket.create_connection(("huggingface.co", 443), timeout=5)
    print("✓ Internet OK")
except:
    print("❌ Internet OFF. Please enable it in Settings.")

# Seed
def seed_everything(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## 1. Aggressive Data Balancing & Loader

In [None]:
DATA_DIR = '/kaggle/input/aptos2019-blindness-detection'
df = pd.read_csv(os.path.join(DATA_DIR, 'train.csv'))
df['path'] = df['id_code'].apply(lambda x: os.path.join(DATA_DIR, 'train_images', x + '.png'))
df = df.rename(columns={'diagnosis': 'label'})

# === OVERSAMPLING STRATEGY ===
df_1 = df[df['label'] == 1]
df_3 = df[df['label'] == 3]
df_4 = df[df['label'] == 4]

df_balanced = pd.concat([
    df,                # Original data
    df_3, df_3, df_3, df_3, # 4x copies of Class 3
    df_4, df_4, df_4,       # 3x copies of Class 4
    df_1, df_1              # 2x copies of Class 1
]).reset_index(drop=True)

print("Balanced Distribution:")
print(df_balanced['label'].value_counts().sort_index())

# Image Loader (Ben Graham)
def load_ben_color(path, sigmaX=10, target_size=256):
    image = cv2.imread(path)
    if image is None: return np.zeros((target_size, target_size, 3), dtype=np.uint8)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    mask = gray > 7
    if mask.sum() > 0:
        rows = np.any(mask, axis=1)
        cols = np.any(mask, axis=0)
        rmin, rmax = np.where(rows)[0][[0, -1]]
        cmin, cmax = np.where(cols)[0][[0, -1]]
        image = image[rmin:rmax+1, cmin:cmax+1]
    image = cv2.resize(image, (target_size, target_size))
    image = cv2.addWeighted(image, 4, cv2.GaussianBlur(image, (0,0), sigmaX), -4, 128)
    return image

class DRDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = load_ben_color(row['path'])
        if self.transform: img = self.transform(img)
        return img, torch.tensor(row['label'], dtype=torch.long)

train_tf = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(30),
    transforms.ColorJitter(0.3, 0.3, 0.3, 0.1),
    transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_tf = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

## 2. Model & Training Logic

In [None]:
class CustomSwin(nn.Module):
    def __init__(self, num_classes=5):
        super().__init__()
        self.backbone = timm.create_model('swinv2_base_window8_256.ms_in1k', pretrained=True, num_classes=0, img_size=256)
        self.head = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(self.backbone.num_features, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    def forward(self, x):
        return self.head(self.backbone(x))

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.9, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss(reduction='none')
    def forward(self, x, target):
        logpt = -self.ce(x, target)
        pt = torch.exp(logpt)
        return -(self.alpha * (1-pt)**self.gamma * logpt).mean()

def train_one_epoch(model, ema, loader, opt, sched, crit, scaler):
    model.train()
    avg_loss = 0
    for img, label in tqdm(loader, leave=False):
        img, label = img.to(device), label.to(device)
        opt.zero_grad()
        with autocast():
            out = model(img)
            loss = crit(out, label)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        sched.step()
        ema.update(model)
        avg_loss += loss.item()
    return avg_loss / len(loader)

def validate(model, loader):
    model.eval()
    preds, targets, probs = [], [], []
    with torch.no_grad():
        for img, label in loader:
            img = img.to(device)
            out1 = model(img)
            out2 = model(torch.flip(img, [3]))
            out = (out1 + out2) / 2
            prob = torch.softmax(out, dim=1)
            probs.extend(prob.cpu().numpy())
            preds.extend(out.argmax(1).cpu().numpy())
            targets.extend(label.numpy())
    return accuracy_score(targets, preds), cohen_kappa_score(targets, preds, weights='quadratic'), np.array(probs), np.array(targets)

In [None]:
# === CONFIG ===
FOLDS = 5
EPOCHS = 15
BATCH = 16
LR = 1e-4

skf = StratifiedKFold(n_splits=FOLDS, shuffle=True, random_state=42)
oof_probs = []
oof_targets = []

print("Starting Improvement Run...\n")

for fold, (train_idx, val_idx) in enumerate(skf.split(df_balanced['path'], df_balanced['label'])):
    print(f"\n{'='*20} FOLD {fold+1}/{FOLDS} {'='*20}")
    
    train_ds = DRDataset(df_balanced.iloc[train_idx], train_tf)
    val_ds = DRDataset(df_balanced.iloc[val_idx], val_tf)
    
    train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH, shuffle=False, num_workers=4)
    
    model = CustomSwin().to(device)
    ema = ModelEmaV2(model, decay=0.999)
    opt = optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
    sched = optim.lr_scheduler.OneCycleLR(opt, max_lr=LR, steps_per_epoch=len(train_loader), epochs=EPOCHS)
    crit = FocalLoss()
    scaler = GradScaler()
    
    best_acc = 0
    best_probs = None
    best_targets = None
    
    for epoch in range(EPOCHS):
        loss = train_one_epoch(model, ema, train_loader, opt, sched, crit, scaler)
        acc, kappa, probs, targets = validate(ema.module, val_loader)
        
        print(f"Epoch {epoch+1} | Loss: {loss:.4f} | EMA Acc: {acc:.4f} | Kappa: {kappa:.4f}", end="")
        
        if acc > best_acc:
            best_acc = acc
            best_probs = probs
            best_targets = targets
            torch.save(ema.module.state_dict(), f'fold_{fold+1}_best_ema.pth')
            print(" [Saved]")
        else:
            print()
            
    # Store OOF predictions
    oof_probs.append(best_probs)
    oof_targets.append(best_targets)
            
    del model, ema, opt, sched, scaler
    torch.cuda.empty_cache()

## 3. Visualization & Ensembling

In [None]:
# === ENSEMBLING RESULTS ===
all_probs = np.concatenate(oof_probs)
all_targets = np.concatenate(oof_targets)
all_preds = np.argmax(all_probs, axis=1)

final_acc = accuracy_score(all_targets, all_preds)
final_kappa = cohen_kappa_score(all_targets, all_preds, weights='quadratic')

print("\n" + "="*40)
print(f"FINAL ENSEMBLE RESULTS")
print(f"Accuracy: {final_acc:.4f}")
print(f"Kappa:    {final_kappa:.4f}")
print("="*40)

# === CONFUSION MATRIX ===
cm = confusion_matrix(all_targets, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.title(f'Confusion Matrix (Kappa: {final_kappa:.4f})', fontsize=14)
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.savefig('confusion_matrix_final.png')
plt.show()

In [None]:
# === GRAD-CAM VISUALIZATION ===
# Visualize the LAST fold model for demonstration
model = CustomSwin(num_classes=5).to(device)
model.load_state_dict(torch.load(f'fold_{FOLDS}_best_ema.pth'))
model.eval()

target_layers = [model.backbone.layers[-1].blocks[-1].norm1]
cam = GradCAM(model=model, target_layers=target_layers)

fig, axes = plt.subplots(2, 4, figsize=(20, 10))
axes = axes.flatten()

# Pick some random samples from class 3 and 4 (severe)
samples = df_balanced[df_balanced['label'].isin([3, 4])].sample(8)

for idx, (_, row) in enumerate(samples.iterrows()):
    img = load_ben_color(row['path'])
    input_tensor = val_tf(img).unsqueeze(0).to(device)
    rgb_img = img.astype(np.float32) / 255.0
    
    target_category = int(row['label'])
    with torch.no_grad():
        out = model(input_tensor)
        pred = out.argmax(1).item()
        
    grayscale_cam = cam(input_tensor=input_tensor, targets=[ClassifierOutputTarget(target_category)])[0]
    visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
    
    axes[idx].imshow(visualization)
    axes[idx].set_title(f"True: {target_category} | Pred: {pred}", 
                        color='green' if target_category==pred else 'red', fontweight='bold')
    axes[idx].axis('off')

plt.tight_layout()
plt.savefig('gradcam_final.png')
plt.show()