In [None]:
# 1 - Install required libraries
import sys
!{sys.executable} -m pip install --upgrade pip
!{sys.executable} -m pip install timm matplotlib scikit-learn xgboost optuna

In [None]:
# 2 - Import libraries and configure CUDA
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import timm
import os, copy, time, gc
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import LabelEncoder, StandardScaler
from imblearn.over_sampling import SMOTE
import xgboost as xgb
import optuna
import pickle
from collections import Counter

# Configure CUDA
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128,expandable_segments:True'

# Set random seed
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using:', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU')

In [None]:
# 3 - Verify dataset structure
DATA_ROOT = '/workspace/VDMDR'
if os.path.exists(os.path.join(DATA_ROOT, 'RGB')) and os.path.exists(os.path.join(DATA_ROOT, 'Vessel')):
    print("Dataset structure OK")
    print("RGB classes:", os.listdir(os.path.join(DATA_ROOT, 'RGB')))
    print("Vessel classes:", os.listdir(os.path.join(DATA_ROOT, 'Vessel')))
else:
    print("Dataset structure NOT found")

In [None]:
# 4 - SMOTE-based dataset balancing using ResNet features
class SMOTEImageDataset:
    def __init__(self, dataset, target_size=(224, 224)):
        self.dataset = dataset
        self.target_size = target_size

    def apply_smote(self, random_state=42, k_neighbors=5):
        features, labels = [], []
        feature_extractor = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
        feature_extractor.fc = nn.Identity()
        feature_extractor.eval()
        transform = transforms.Compose([
            transforms.Resize(self.target_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        with torch.no_grad():
            for rgb, _, label in self.dataset:
                if isinstance(rgb, torch.Tensor):
                    rgb = transforms.ToPILImage()(rgb)
                img_tensor = transform(rgb).unsqueeze(0)
                feature = feature_extractor(img_tensor).squeeze().numpy()
                features.append(feature)
                labels.append(label)
        features = np.array(features)
        labels = np.array(labels)
        print(f"Original distribution: {Counter(labels)}")
        smote = SMOTE(random_state=random_state, k_neighbors=k_neighbors)
        X_res, y_res = smote.fit_resample(features, labels)
        print(f"Balanced distribution: {Counter(y_res)}")
        return X_res, y_res

    def create_synthetic_images(self, features_balanced, labels_balanced):
        synthetic_dataset = []
        for i, (feature, label) in enumerate(zip(features_balanced, labels_balanced)):
            if i < len(self.dataset):
                synthetic_dataset.append(self.dataset[i])
            else:
                original_features = features_balanced[:len(self.dataset)]
                closest_idx = np.argmin(np.linalg.norm(original_features - feature, axis=1))
                rgb, vessel, _ = self.dataset[closest_idx]
                if isinstance(rgb, torch.Tensor):
                    rgb = transforms.ToPILImage()(rgb)
                if isinstance(vessel, torch.Tensor):
                    vessel = transforms.ToPILImage()(vessel)
                rgb_aug = transforms.Compose([
                    transforms.RandomRotation(15),
                    transforms.ColorJitter(0.3, 0.3, 0.3, 0.1),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomAffine(0, translate=(0.1, 0.1))
                ])(rgb)
                vessel_aug = transforms.Compose([
                    transforms.RandomRotation(15),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomAffine(0, translate=(0.1, 0.1))
                ])(vessel)
                synthetic_dataset.append((rgb_aug, vessel_aug, label))
        return synthetic_dataset

In [None]:
# 4 - Data Loading & Preprocessing (with separate transforms for RGB and Vessel)
DATA_ROOT = '/workspace/VDMDR'

class VDMDRDataset(Dataset):
    def __init__(self, rgb_dir, vessel_dir):
        self.rgb_paths = []
        self.vessel_paths = []
        self.labels = []
        for class_folder in sorted(os.listdir(rgb_dir)):
            rgb_class_dir = os.path.join(rgb_dir, class_folder)
            vessel_class_dir = os.path.join(vessel_dir, class_folder)
            if not os.path.isdir(rgb_class_dir): continue
            for fname in sorted(os.listdir(rgb_class_dir)):
                rgb_path = os.path.join(rgb_class_dir, fname)
                vessel_path = os.path.join(vessel_class_dir, fname)
                if os.path.exists(rgb_path) and os.path.exists(vessel_path):
                    self.rgb_paths.append(rgb_path)
                    self.vessel_paths.append(vessel_path)
                    self.labels.append(int(class_folder))

    def __len__(self):
        return len(self.rgb_paths)

    def __getitem__(self, idx):
        rgb = Image.open(self.rgb_paths[idx]).convert('RGB')
        vessel = Image.open(self.vessel_paths[idx]).convert('L')
        label = self.labels[idx]
        return rgb, vessel, label

# Transforms for better aug RGB (3-channel) and Vessel (1-channel)
rgb_train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.3),
    transforms.RandomRotation(45),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),
    transforms.RandomErasing(p=0.3, scale=(0.02, 0.15))
])

rgb_val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

vessel_train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.3),
    transforms.RandomRotation(45),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),  # For grayscale
    transforms.RandomErasing(p=0.3, scale=(0.02, 0.15))
])
vessel_val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # For grayscale
])

# Paths
RGB_DIR = os.path.join(DATA_ROOT, 'RGB')
VESSEL_DIR = os.path.join(DATA_ROOT, 'Vessel')

# Load full dataset
full_dataset = VDMDRDataset(RGB_DIR, VESSEL_DIR)

# Stratified K-Fold Cross-Validation Split
from sklearn.model_selection import StratifiedKFold
import numpy as np

# full_dataset.labels is already available
n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
folds = list(skf.split(np.arange(len(full_dataset.labels)), full_dataset.labels))
fold_num = 0 
train_idx, val_idx = folds[fold_num]
train_dataset = torch.utils.data.Subset(full_dataset, train_idx)
val_dataset = torch.utils.data.Subset(full_dataset, val_idx)
print(f"Fold {fold_num}: {len(train_dataset)} train, {len(val_dataset)} val")
print(f"Number of classes: {len(set(full_dataset.labels))}")

# Subset wrapper with separate transforms for RGB and Vessel
class SubsetWithTransform(Dataset):
    def __init__(self, dataset, indices, rgb_transform, vessel_transform):
        self.dataset = dataset
        self.indices = indices
        self.rgb_transform = rgb_transform
        self.vessel_transform = vessel_transform
    def __len__(self):
        return len(self.indices)
    def __getitem__(self, idx):
        rgb, vessel, label = self.dataset[self.indices[idx]]
        if self.rgb_transform:
            rgb = self.rgb_transform(rgb)
        if self.vessel_transform:
            vessel = self.vessel_transform(vessel)
        return rgb, vessel, label

train_dataset = SubsetWithTransform(full_dataset, train_idx, rgb_train_transform, vessel_train_transform)
val_dataset = SubsetWithTransform(full_dataset, val_idx, rgb_val_transform, vessel_val_transform)


BATCH_SIZE = 32 
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=True)

print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
print(f"Number of classes: {len(set(full_dataset.labels))}")

In [None]:
# 3.5. Check Class Distribution and Add SMOTE Integration with Local Models
import torch.utils.data as data
import torchvision.models as models

# Check class distribution
class_counts = Counter(full_dataset.labels)
print("Original class distribution:")
for class_id, count in sorted(class_counts.items()):
    print(f"Class {class_id}: {count} samples")

# Modified SMOTE processor with local model loading
class SMOTEImageDatasetLocal:
    def __init__(self, dataset, target_size=(224, 224)):
        self.dataset = dataset
        self.target_size = target_size
        
    def apply_smote(self, random_state=42, k_neighbors=5):
        """Apply SMOTE with local ResNet18 loading"""
        print("Extracting features for SMOTE...")
        
        features = []
        labels = []
        
        # Use direct torchvision model loading (no network required)
        print("Loading ResNet18 model locally...")
        feature_extractor = models.resnet18(pretrained=True)  # Local loading
        feature_extractor.fc = torch.nn.Identity()
        feature_extractor.eval()
        
        transform = transforms.Compose([
            transforms.Resize(self.target_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        with torch.no_grad():
            for i in range(len(self.dataset)):
                rgb, vessel, label = self.dataset[i]
                
                if isinstance(rgb, torch.Tensor):
                    rgb = transforms.ToPILImage()(rgb)
                
                img_tensor = transform(rgb).unsqueeze(0)
                feature = feature_extractor(img_tensor).squeeze().numpy()
                features.append(feature)
                labels.append(label)
                
                if (i + 1) % 100 == 0:
                    print(f"Processed {i + 1}/{len(self.dataset)} samples")
        
        features = np.array(features)
        labels = np.array(labels)
        
        print(f"Original distribution: {Counter(labels)}")
        
        # Apply SMOTE
        from imblearn.over_sampling import SMOTE
        smote = SMOTE(random_state=random_state, k_neighbors=k_neighbors)
        features_balanced, labels_balanced = smote.fit_resample(features, labels)
        
        print(f"Balanced distribution: {Counter(labels_balanced)}")
        
        return features_balanced, labels_balanced
    
    def create_synthetic_images(self, features_balanced, labels_balanced):
        """Create synthetic images from SMOTE features"""
        synthetic_dataset = []
        
        for i, (feature, label) in enumerate(zip(features_balanced, labels_balanced)):
            if i < len(self.dataset):
                synthetic_dataset.append(self.dataset[i])
            else:
                # Find closest original image
                original_features = features_balanced[:len(self.dataset)]
                distances = np.linalg.norm(original_features - feature, axis=1)
                closest_idx = np.argmin(distances)
                
                closest_rgb, closest_vessel, _ = self.dataset[closest_idx]
                
                if isinstance(closest_rgb, torch.Tensor):
                    closest_rgb = transforms.ToPILImage()(closest_rgb)
                if isinstance(closest_vessel, torch.Tensor):
                    closest_vessel = transforms.ToPILImage()(closest_vessel)
                
                # Augment for synthetic samples
                aug_transform = transforms.Compose([
                    transforms.RandomRotation(15),
                    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
                ])
                
                vessel_aug_transform = transforms.Compose([
                    transforms.RandomRotation(15),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
                ])
                
                synthetic_rgb = aug_transform(closest_rgb)
                synthetic_vessel = vessel_aug_transform(closest_vessel)
                synthetic_dataset.append((synthetic_rgb, synthetic_vessel, label))
        
        return synthetic_dataset

# Apply SMOTE preprocessing with local model
print("\n=== Applying SMOTE preprocessing (Local) ===")
smote_processor = SMOTEImageDatasetLocal(full_dataset)  # Use local version
features_balanced, labels_balanced = smote_processor.apply_smote()
balanced_dataset = smote_processor.create_synthetic_images(features_balanced, labels_balanced)

# Check balanced distribution
balanced_labels = [item[2] for item in balanced_dataset]
balanced_counts = Counter(balanced_labels)
print("SMOTE-balanced class distribution:")
for class_id, count in sorted(balanced_counts.items()):
    print(f"Class {class_id}: {count} samples")

# Helper class for SMOTE dataset
class VDMDRDatasetWrapper(Dataset):
    def __init__(self, data_list):
        self.data_list = data_list
    
    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        return self.data_list[idx]

# Create wrapped dataset
balanced_wrapped_dataset = VDMDRDatasetWrapper(balanced_dataset)

# Use SMOTE-balanced data for train/val split
from sklearn.model_selection import train_test_split
train_indices, val_indices = train_test_split(
    range(len(balanced_dataset)), 
    test_size=0.2, 
    stratify=balanced_labels, 
    random_state=42
)

print(f"\nSMOTE dataset size: {len(balanced_dataset)}")
print(f"Train: {len(train_indices)}, Val: {len(val_indices)}")

# Update datasets to use SMOTE-balanced data
train_dataset = SubsetWithTransform(balanced_wrapped_dataset, train_indices, rgb_train_transform, vessel_train_transform)
val_dataset = SubsetWithTransform(balanced_wrapped_dataset, val_indices, rgb_val_transform, vessel_val_transform)

# Update data loaders
BATCH_SIZE = 20  # Reduced for 384x384 training
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=True)

print(f"Updated Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
print("✅ SMOTE integration completed with local model loading!")


In [None]:
# 4. Model Components
import torch.nn.functional as F
import math

# Cross Attention Block
class CrossAttentionBlock(nn.Module):
    def __init__(self, dim, heads=8):
        super().__init__()
        self.heads = heads
        self.scale = dim ** -0.5
        
        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim)
        
    def forward(self, x, context):
        b, n, _, h = *x.shape, self.heads
        
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.view(b, n, h, -1).transpose(1, 2), qkv)
        
        dots = (q @ k.transpose(-1, -2)) * self.scale
        attn = dots.softmax(dim=-1)
        
        out = (attn @ v).transpose(1, 2).reshape(b, n, -1)
        return self.to_out(out)

# Graph Attention Pooling
class GraphAttentionPooling(nn.Module):
    def __init__(self, input_dim, output_dim, heads=4):
        super().__init__()
        self.heads = heads
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        self.attention = nn.MultiheadAttention(input_dim, heads, batch_first=True)
        self.norm = nn.LayerNorm(input_dim)
        self.fc = nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        # x shape: (batch_size, seq_len, input_dim)
        attended, _ = self.attention(x, x, x)
        attended = self.norm(attended + x)
        
        # Global average pooling
        pooled = attended.mean(dim=1)
        
        return self.fc(pooled)

# Contrastive Head - FIXED VERSION
class ContrastiveHead(nn.Module):
    def __init__(self, input_dim, proj_dim=128):
        super().__init__()
        self.projection = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, proj_dim)
        )
    
    def forward(self, x):
        projected = self.projection(x)
        # Apply L2 normalization using F.normalize
        return F.normalize(projected, p=2, dim=1)

# Focal Loss Implementation
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Improved Contrastive Loss
class ImprovedContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.1, margin=0.5):
        super(ImprovedContrastiveLoss, self).__init__()
        self.temperature = temperature
        self.margin = margin
    
    def forward(self, features, labels):
        # Normalize features
        features = F.normalize(features, dim=1)
        
        # Compute similarity matrix
        similarity_matrix = torch.matmul(features, features.T) / self.temperature
        
        # Create positive and negative masks
        batch_size = features.shape[0]
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(features.device)
        
        # Remove diagonal (self-similarity)
        mask = mask - torch.eye(batch_size).to(features.device)
        
        # Compute InfoNCE loss
        exp_sim = torch.exp(similarity_matrix)
        exp_sim = exp_sim * (1 - torch.eye(batch_size).to(features.device))
        
        pos_sim = exp_sim * mask
        neg_sim = exp_sim * (1 - mask)
        
        pos_sum = pos_sim.sum(dim=1, keepdim=True)
        neg_sum = neg_sim.sum(dim=1, keepdim=True)
        
        loss = -torch.log(pos_sum / (pos_sum + neg_sum + 1e-8))
        
        return loss.mean()

print("✅ Model components and loss functions defined successfully (L2Norm fixed)!")


In [None]:
# 5. Fixed 384x384 VDMDR Model with Swin Transformer
import timm
import torch.nn.functional as F

class VDMDRModelEnhanced(nn.Module):
    def __init__(self, num_classes=2, pretrained=True):
        super().__init__()
        
        # Use Swin-Base with window12_384 for native 384x384 support
        self.rgb_backbone = timm.create_model('swin_base_patch4_window12_384', pretrained=pretrained, num_classes=0)
        self.vessel_backbone = timm.create_model('swin_base_patch4_window12_384', pretrained=pretrained, num_classes=0, in_chans=1)
        
        embed_dim = self.rgb_backbone.num_features
        
        # Reduced projections for memory efficiency
        self.rgb_proj = nn.Linear(embed_dim, embed_dim // 2)
        self.vessel_proj = nn.Linear(embed_dim, embed_dim // 2)
        
        # Reduced attention heads for memory efficiency
        self.cross_attn_rv = CrossAttentionBlock(embed_dim // 2, heads=8)
        self.cross_attn_vr = CrossAttentionBlock(embed_dim // 2, heads=8)
        
        # Fusion network
        self.fusion = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.BatchNorm1d(embed_dim // 2),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2)
        )
        
        self.graph_pool = GraphAttentionPooling(embed_dim // 2, embed_dim // 2, heads=4)
        self.contrastive_head = ContrastiveHead(embed_dim // 2, proj_dim=128)
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim // 2, embed_dim // 4),
            nn.BatchNorm1d(embed_dim // 4),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(embed_dim // 4, num_classes)
        )

    def forward(self, rgb, vessel):
        # Force resize all inputs to 384x384 to match model expectations
        rgb = F.interpolate(rgb, size=(384, 384), mode='bilinear', align_corners=False)
        vessel = F.interpolate(vessel, size=(384, 384), mode='bilinear', align_corners=False)
        
        # Extract features - now guaranteed to be 384x384
        rgb_feat = self.rgb_backbone(rgb)
        vessel_feat = self.vessel_backbone(vessel)

        # Project features (reduced dimension)
        rgb_feat = self.rgb_proj(rgb_feat)
        vessel_feat = self.vessel_proj(vessel_feat)

        # Cross-attention (both directions)
        rgb_seq = rgb_feat.unsqueeze(1)
        vessel_seq = vessel_feat.unsqueeze(1)

        rgb_attended = self.cross_attn_rv(rgb_seq, vessel_seq).squeeze(1)
        vessel_attended = self.cross_attn_vr(vessel_seq, rgb_seq).squeeze(1)

        # Fusion
        fused = torch.cat([rgb_attended, vessel_attended], dim=1)
        fused = self.fusion(fused)

        # Graph pooling
        pooled = self.graph_pool(fused.unsqueeze(1))

        # Outputs
        contrastive_vec = self.contrastive_head(pooled)
        logits = self.classifier(pooled)

        return logits, contrastive_vec

print("✅ Fixed 384x384 Swin-based VDMDRModelEnhanced class defined!")


In [None]:
# 6. Model Configuration and Initialization for 384x384
import torch.optim as optim
from torch.optim import lr_scheduler

# Model initialization with Enhanced Model
num_classes = len(set(full_dataset.labels))
model = VDMDRModelEnhanced(num_classes=num_classes).to(device)
print(f"Enhanced model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Number of classes: {num_classes}")

# Optimizer with adjusted learning rate for higher resolution
optimizer = optim.AdamW(
    model.parameters(), 
    lr=1e-6,  # Slightly lower LR for 384x384 stability
    weight_decay=0.01, 
    betas=(0.9, 0.999)
)

# Scheduler with longer cycle for more epochs
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=45, eta_min=1e-7)

# Loss functions
focal_loss = FocalLoss(gamma=2.0)
contrastive_loss = ImprovedContrastiveLoss()

print("✅ Enhanced model and optimizers initialized for 384x384 training!")


In [None]:
# 6.2. MixUp Augmentation Utility
import torch
import numpy as np

def mixup_data(x1, x2, y, alpha=0.4):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x1.size()[0]
    index = torch.randperm(batch_size).to(x1.device)
    mixed_x1 = lam * x1 + (1 - lam) * x1[index, :]
    mixed_x2 = lam * x2 + (1 - lam) * x2[index, :]
    y_a, y_b = y, y[index]
    return mixed_x1, mixed_x2, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [None]:
# 6.3. Label Smoothing Loss
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.log_softmax = nn.LogSoftmax(dim=1)

    def forward(self, x, target):
        logprobs = self.log_softmax(x)
        with torch.no_grad():
            true_dist = torch.zeros_like(logprobs)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * logprobs, dim=1))

In [None]:
# 6.5. Learning Rate Warmup and Advanced Scheduling
class WarmupScheduler:
    def __init__(self, optimizer, warmup_steps, base_lr):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.base_lr = base_lr
        self.current_step = 0

    def step(self):
        self.current_step += 1
        if self.current_step <= self.warmup_steps:
            lr = self.base_lr * (self.current_step / self.warmup_steps)
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
        return self.current_step <= self.warmup_steps

# Initialize warmup
warmup_scheduler = WarmupScheduler(optimizer, warmup_steps=len(train_loader) * 5, base_lr=1e-4)

In [None]:
# Progressive Resizing Training Strategy for RTX 5090
class ProgressiveResizeTrainer:
    def __init__(self, model, optimizer, criterion, device):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device
        
        # Memory-optimized progressive sizes for RTX 5090
        self.resize_schedule = [
            {'size': 224, 'epochs': 15, 'batch_size': 32},   # Reduced from 64
            {'size': 288, 'epochs': 20, 'batch_size': 24},   # Reduced from 48
            {'size': 384, 'epochs': 15, 'batch_size': 16},   # Reduced from 32
        ]
    
    def update_transforms(self, size):
        """Update dataset transforms for new image size"""
        rgb_train_transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.RandomHorizontalFlip(p=0.6),
            transforms.RandomVerticalFlip(p=0.3),
            transforms.RandomRotation(45),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
            transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),
            transforms.RandomErasing(p=0.3, scale=(0.02, 0.15))
        ])
        
        rgb_val_transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
        ])
        
        vessel_train_transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.RandomHorizontalFlip(p=0.6),
            transforms.RandomVerticalFlip(p=0.3),
            transforms.RandomRotation(45),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
            transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
            transforms.RandomErasing(p=0.3, scale=(0.02, 0.15))
        ])
        
        vessel_val_transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
        
        return rgb_train_transform, rgb_val_transform, vessel_train_transform, vessel_val_transform
    
    def create_data_loaders(self, train_dataset, val_dataset, config):
        """Create data loaders with updated batch size and transforms"""
        train_loader = DataLoader(
            train_dataset, 
            batch_size=config['batch_size'], 
            shuffle=True, 
            num_workers=8,  # More workers for RTX 5090 system
            pin_memory=True
        )
        val_loader = DataLoader(
            val_dataset, 
            batch_size=config['batch_size'], 
            shuffle=False, 
            num_workers=8,
            pin_memory=True
        )
        return train_loader, val_loader

print("Progressive Resizing trainer ready!")


In [None]:
# CutMix Implementation for Medical Imaging
import numpy as np
import torch
import torch.nn.functional as F

class CutMixAugmentation:
    def __init__(self, alpha=1.0, prob=0.5):
        self.alpha = alpha
        self.prob = prob
    
    def cutmix_data(self, rgb, vessel, target, alpha=1.0):
        """Apply CutMix to both RGB and vessel images"""
        if alpha > 0:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 1
        
        batch_size = rgb.size(0)
        index = torch.randperm(batch_size).to(rgb.device)
        
        target_a = target
        target_b = target[index]
        
        # Generate bounding box
        W, H = rgb.size(2), rgb.size(3)
        cut_rat = np.sqrt(1. - lam)
        cut_w = int(W * cut_rat)
        cut_h = int(H * cut_rat)
        
        # Uniform sampling for center point
        cx = np.random.randint(W)
        cy = np.random.randint(H)
        
        # Bounding box coordinates
        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)
        
        # Apply CutMix to both RGB and vessel images
        rgb[:, :, bbx1:bbx2, bby1:bby2] = rgb[index, :, bbx1:bbx2, bby1:bby2]
        vessel[:, :, bbx1:bbx2, bby1:bby2] = vessel[index, :, bbx1:bbx2, bby1:bby2]
        
        # Adjust lambda to match pixel ratio
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
        
        return rgb, vessel, target_a, target_b, lam

def cutmix_criterion(criterion, pred, y_a, y_b, lam):
    """CutMix loss calculation"""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

print("CutMix augmentation ready!")


In [None]:
# 6.8. Memory Management for RTX 5090
import gc

def clear_gpu_memory():
    """Clear GPU memory cache"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()
        print(f"GPU Memory cleared. Available: {torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated():.2f} bytes")

def get_gpu_memory_info():
    """Print current GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3  # GB
        reserved = torch.cuda.memory_reserved() / 1024**3    # GB
        total = torch.cuda.get_device_properties(0).total_memory / 1024**3  # GB
        print(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB, Total: {total:.2f}GB")

# Clear memory before starting
clear_gpu_memory()
get_gpu_memory_info()

# Update CUDA memory settings
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512,expandable_segments:True'


In [None]:
# 7. Training Loop: Mixed Precision, Gradient Accumulation

# 7. Training Loop: Mixed Precision, Gradient Accumulation, MixUp
def train_epoch(model, train_loader, optimizer, scheduler, scaler, epoch, use_mixup=True):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc=f'Epoch {epoch}')

    for batch_idx, (rgb, vessel, labels) in enumerate(pbar):
        rgb, vessel, labels = rgb.to(device), vessel.to(device), labels.to(device)

        # MixUp augmentation
        if use_mixup:
            rgb, vessel, y_a, y_b, lam = mixup_data(rgb, vessel, labels)
        else:
            y_a, y_b, lam = labels, labels, 1.0

        with autocast():
            logits, contrastive_vec = model(rgb, vessel)
            # Use only classification loss for MixUp
            loss = mixup_criterion(nn.CrossEntropyLoss(), logits, y_a, y_b, lam)
            # Optionally add contrastive loss if you want
            # loss += 0.1 * ImprovedContrastiveLoss()(contrastive_vec, y_a)
            loss = loss / ACCUMULATION_STEPS

        scaler.scale(loss).backward()

        if (batch_idx + 1) % ACCUMULATION_STEPS == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        total_loss += loss.item() * ACCUMULATION_STEPS
        _, predicted = logits.max(1)
        total += labels.size(0)
        # For accuracy, use y_a (original labels)
        correct += predicted.eq(y_a).sum().item()

        pbar.set_postfix({
            'Loss': f'{total_loss/(batch_idx+1):.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })

    scheduler.step()
    return total_loss / len(train_loader), 100. * correct / total

def validate(model, val_loader):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for rgb, vessel, labels in tqdm(val_loader, desc='Validation'):
            rgb, vessel, labels = rgb.to(device), vessel.to(device), labels.to(device)

            with autocast():
                logits, contrastive_vec = model(rgb, vessel)
                loss = combined_loss(logits, contrastive_vec, labels)

            total_loss += loss.item()
            _, predicted = logits.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return total_loss / len(val_loader), 100. * correct / total, all_preds, all_labels

In [None]:
def train_stable_384():
    """Ultra-stable 384x384 training with all safeguards"""
    
    def stable_loss_fn(logits, contrastive_vec, labels):
        # Clamp logits for stability
        logits = torch.clamp(logits, min=-10, max=10)
        
        # Simple cross-entropy (most stable)
        ce_loss = F.cross_entropy(logits, labels)
        
        # Check for NaN/inf
        if not torch.isfinite(ce_loss):
            print("⚠️ Loss instability detected, using dummy loss")
            return torch.tensor(1.0, requires_grad=True, device=logits.device)
        
        return ce_loss
    
    # Ultra-conservative learning rate
    optimizer = optim.Adam(model.parameters(), lr=1e-6, weight_decay=0.01)  # Even more conservative
    
    best_acc = 0.0
    
    for epoch in range(45):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for batch_idx, (rgb, vessel, labels) in enumerate(train_loader_384):
            rgb, vessel, labels = rgb.to(device), vessel.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass (no mixed precision)
            logits, contrastive_vec = model(rgb, vessel)
            loss = stable_loss_fn(logits, contrastive_vec, labels)
            
            # Check for instability
            if not torch.isfinite(loss):
                print(f"⚠️ Skipping batch {batch_idx} due to unstable loss")
                continue
            
            # Backward pass with gradient clipping
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)  # Very conservative
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = logits.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()
        
        train_acc = 100. * train_correct / train_total if train_total > 0 else 0
        print(f'Epoch {epoch+1}: Train Acc: {train_acc:.2f}%')
        
        if train_acc > best_acc:
            best_acc = train_acc
    
    return best_acc

# Run stable training
final_accuracy = train_stable_384()


In [None]:
# Reload the best model weights (before ensemble training)
model = VDMDRModel(num_classes=num_classes).to(device)
model.load_state_dict(torch.load('best_vdmdr_model.pth', map_location=device))
model.eval()

In [None]:
# 8.5. Model Ensemble for Higher Accuracy (memory-safe, reduced batch size)

import torch.nn as nn
import gc

# Reduce batch size for ensemble training to save memory
ENSEMBLE_BATCH_SIZE = 8
ensemble_train_loader = DataLoader(train_dataset, batch_size=ENSEMBLE_BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
ensemble_val_loader = DataLoader(val_dataset, batch_size=ENSEMBLE_BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

class EnsembleModel(nn.Module):
    def __init__(self, models):
        super().__init__()
        self.models = models

    def forward(self, rgb, vessel):
        logits_list = []
        for model in self.models:
            model.eval()
            logits, _ = model(rgb, vessel)
            logits_list.append(logits)
        avg_logits = torch.stack(logits_list).mean(0)
        return avg_logits

ensemble_models = [model.to('cpu')]  # Move best model to CPU to save GPU memory

for seed in [123, 456, 789]:
    print(f"Training ensemble model with seed {seed}...")
    torch.manual_seed(seed)
    ensemble_model = VDMDRModel(num_classes=num_classes).to(device)
    ensemble_optimizer = optim.AdamW(ensemble_model.parameters(), lr=2e-4, weight_decay=0.01)
    ensemble_scheduler = lr_scheduler.CosineAnnealingLR(ensemble_optimizer, T_max=30, eta_min=1e-6)
    for epoch in range(20):
        train_loss, train_acc = train_epoch(ensemble_model, ensemble_train_loader, ensemble_optimizer, ensemble_scheduler, scaler, epoch)
        val_loss, val_acc, _, _ = validate(ensemble_model, ensemble_val_loader)
        if epoch % 5 == 0:
            print(f'Seed {seed} - Epoch {epoch}: Val Acc: {val_acc:.2f}%')
    # Move model to CPU and force memory release
    ensemble_model_cpu = ensemble_model.to('cpu')
    del ensemble_model
    torch.cuda.empty_cache()
    gc.collect()
    ensemble_models.append(ensemble_model_cpu)

ensemble = EnsembleModel(ensemble_models)
print(f"Created ensemble with {len(ensemble_models)} models")

In [None]:
model = model.to(device)

In [None]:
# 9. Feature Extraction & XGBoost Ensemble

def extract_features(model, dataloader):
    model.eval()
    features = []
    labels = []
    with torch.no_grad():
        for rgb, vessel, batch_labels in tqdm(dataloader, desc='Extracting features'):
            rgb, vessel = rgb.to(device), vessel.to(device)
            # No autocast for feature extraction
            logits, contrastive_vec = model(rgb, vessel)
            combined_features = torch.cat([logits, contrastive_vec], dim=1)
            features.append(combined_features.cpu().numpy())
            labels.extend(batch_labels.numpy())
    return np.vstack(features), np.array(labels)

# Extract features
print("Extracting training features...")
train_features, train_labels = extract_features(model, train_loader)
print("Extracting validation features...")
val_features, val_labels = extract_features(model, val_loader)

# XGBoost ensemble
print("Training XGBoost ensemble...")
xgb_model = xgb.XGBClassifier(
    n_estimators=100,
    max_depth=6,
    learning_rate=0.1,
    random_state=42,
    n_jobs=-1
)

xgb_model.fit(train_features, train_labels)
xgb_preds = xgb_model.predict(val_features)
xgb_acc = 100. * (xgb_preds == val_labels).sum() / len(val_labels)
print(f"XGBoost ensemble accuracy: {xgb_acc:.2f}%")

# Save XGBoost model
import pickle
with open('xgb_ensemble.pkl', 'wb') as f:
    pickle.dump(xgb_model, f)

In [None]:
# 9.5. Optuna Hyperparameter Tuning for XGBoost Ensemble

import optuna
from sklearn.metrics import accuracy_score

def objective(trial):
    params = {
        'n_estimators': trial.suggest_int('n_estimators', 50, 300),
        'max_depth': trial.suggest_int('max_depth', 3, 12),
        'learning_rate': trial.suggest_loguniform('learning_rate', 1e-3, 0.3),
        'subsample': trial.suggest_uniform('subsample', 0.6, 1.0),
        'colsample_bytree': trial.suggest_uniform('colsample_bytree', 0.6, 1.0),
        'random_state': 42,
        'n_jobs': -1,
        'use_label_encoder': False,
        'eval_metric': 'mlogloss'
    }
    model = xgb.XGBClassifier(**params)
    model.fit(train_features, train_labels)
    preds = model.predict(val_features)
    acc = accuracy_score(val_labels, preds)
    return acc

study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=30)

print("Best XGBoost params:", study.best_params)
print("Best XGBoost validation accuracy: {:.2f}%".format(100 * study.best_value))

# Train and save the best XGBoost model
best_xgb = xgb.XGBClassifier(**study.best_params)
best_xgb.fit(train_features, train_labels)
optuna_preds = best_xgb.predict(val_features)
optuna_acc = accuracy_score(val_labels, optuna_preds)
print(f"Optuna-tuned XGBoost accuracy: {optuna_acc*100:.2f}%")

import pickle
with open('xgb_ensemble_optuna.pkl', 'wb') as f:
    pickle.dump(best_xgb, f)

In [None]:
def tta_predict(model, rgb, vessel, n_augments=5):
    """Test Time Augmentation - applies augmentations during inference"""
    model.eval()
    predictions = []

    # Original prediction
    with torch.no_grad():
        logits, _ = model(rgb, vessel)
        predictions.append(torch.softmax(logits, dim=1))

    # Augmented predictions
    tta_transforms = [
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.RandomRotation(10),
        lambda x: transforms.functional.rotate(x, -10),
        transforms.ColorJitter(brightness=0.1, contrast=0.1),
    ]

    for transform in tta_transforms:
        if callable(transform):
            rgb_aug = transform(rgb)
            vessel_aug = transform(vessel)
        else:
            rgb_aug = transform(rgb)
            vessel_aug = transform(vessel)
        with torch.no_grad():
            logits, _ = model(rgb_aug, vessel_aug)
            predictions.append(torch.softmax(logits, dim=1))

    # Average predictions
    avg_pred = torch.stack(predictions).mean(0)
    return avg_pred

In [None]:
import sys
!{sys.executable} -m pip install seaborn

In [None]:
# 10. Evaluation & Reporting

from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

# Final evaluation
model.eval()
final_preds = []
final_labels = []

with torch.no_grad():
    for rgb, vessel, labels in tqdm(val_loader, desc='Final evaluation'):
        rgb, vessel, labels = rgb.to(device), vessel.to(device), labels.to(device)

        with autocast():
            logits, _ = model(rgb, vessel)

        _, predicted = logits.max(1)
        final_preds.extend(predicted.cpu().numpy())
        final_labels.extend(labels.cpu().numpy())

# Classification report
print("Classification Report:")
print(classification_report(final_labels, final_preds))

# Confusion matrix
cm = confusion_matrix(final_labels, final_preds)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

# Training curves
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Acc')
plt.plot(val_accs, label='Val Acc')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.tight_layout()
plt.show()

print(f"Final Results:")
print(f"Best Validation Accuracy: {best_acc:.2f}%")
print(f"XGBoost Ensemble Accuracy: {xgb_acc:.2f}%")