In [None]:
pip install ipywidgets
jupyter nbextension enable --py widgetsnbextension


In [1]:
# ============================================
# BLOCK 1 ‚Äî Setup, Imports, Device
# ============================================

import os
import random
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms as T, models

import timm  # if missing: pip install timm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    accuracy_score,
    balanced_accuracy_score,
)

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
if device.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))


Using device: cuda
GPU: NVIDIA GeForce MX450


In [2]:
# ============================================
# BLOCK 2 ‚Äî Load meta.csv + Image paths
# ============================================

ROOT_DIR = r"C:\Users\anama\Documents\Group_8"
DATASET_DIR = os.path.join(ROOT_DIR, "Dataset", "DERM7PT")

META_CSV = os.path.join(DATASET_DIR, "meta", "meta.csv")
IMAGES_DIR = os.path.join(DATASET_DIR, "images")

df = pd.read_csv(META_CSV)
df = df.drop(columns=["case_num", "case_id", "notes"], errors="ignore")

df["derm_fullpath"] = df["derm"].apply(lambda x: os.path.join(IMAGES_DIR, x))
df["exists"] = df["derm_fullpath"].apply(os.path.exists)

missing = (~df["exists"]).sum()
print("Missing images:", missing)
df = df[df["exists"]].reset_index(drop=True)


Missing images: 0


In [3]:
# ============================================
# BLOCK 3 ‚Äî Label Encoding
# ============================================

counts = df["diagnosis"].value_counts()

# Remove "melanoma" class if it has only 1 sample
if "melanoma" in counts.index and counts["melanoma"] == 1:
    df = df[df["diagnosis"] != "melanoma"].reset_index(drop=True)

le = LabelEncoder()
df["label"] = le.fit_transform(df["diagnosis"])
class_names = list(le.classes_)

print("Classes:", class_names)


Classes: ['basal cell carcinoma', 'blue nevus', 'clark nevus', 'combined nevus', 'congenital nevus', 'dermal nevus', 'dermatofibroma', 'lentigo', 'melanoma (0.76 to 1.5 mm)', 'melanoma (in situ)', 'melanoma (less than 0.76 mm)', 'melanoma (more than 1.5 mm)', 'melanoma metastasis', 'melanosis', 'miscellaneous', 'recurrent nevus', 'reed or spitz nevus', 'seborrheic keratosis', 'vascular lesion']


In [1]:
#!/usr/bin/env python
# coding: utf-8

# ============================================
# MEDICAL VISION TRANSFORMER - Advanced Dermatology Classifier
# ============================================

import os
import random
import numpy as np
import pandas as pd
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import torchvision
from torchvision import transforms, models
import timm

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, balanced_accuracy_score

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# ============================================
# BLOCK 1 ‚Äî Setup & Diagnostics
# ============================================

# Paths
ROOT_DIR = r"C:\Users\anama\Documents\Group_8"
DATASET_DIR = os.path.join(ROOT_DIR, "Dataset", "DERM7PT")
META_CSV = os.path.join(DATASET_DIR, "meta", "meta.csv")
IMAGES_FOLDER = os.path.join(DATASET_DIR, "images")

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# ============================================
# BLOCK 2 ‚Äî Data Loading with Robust Validation
# ============================================

def load_and_validate_data():
    """Robust data loading with comprehensive validation"""
    print("=== DATA LOADING & VALIDATION ===")
    
    # Load metadata
    df = pd.read_csv(META_CSV)
    print(f"Loaded meta.csv ‚Äî shape: {df.shape}")
    
    # Clean data
    df = df.drop(columns=["case_num", "case_id", "notes"], errors="ignore")
    df["derm_fullpath"] = df["derm"].apply(lambda x: os.path.join(IMAGES_FOLDER, x))
    
    # Validate paths
    df["derm_exists"] = df["derm_fullpath"].apply(os.path.exists)
    missing_count = (~df["derm_exists"]).sum()
    print(f"Missing images: {missing_count}")
    
    df = df[df["derm_exists"]].reset_index(drop=True)
    
    # Filter classes with >= 2 samples
    class_counts = df["diagnosis"].value_counts()
    valid_classes = class_counts[class_counts >= 2].index
    df = df[df["diagnosis"].isin(valid_classes)].reset_index(drop=True)
    
    print(f"Final dataset size: {len(df)}")
    print(f"Classes: {len(valid_classes)}")
    print("Class distribution:")
    print(class_counts[valid_classes])
    
    return df, valid_classes

df, valid_classes = load_and_validate_data()

# Label encoding
le = LabelEncoder()
df["label"] = le.fit_transform(df["diagnosis"])
NUM_CLASSES = len(le.classes_)
class_names = le.classes_

print(f"\nEncoded {NUM_CLASSES} classes:")
for i, cls in enumerate(class_names):
    print(f"  {i}: {cls}")

# ============================================
# BLOCK 3 ‚Äî Robust Data Split
# ============================================

def create_robust_splits(df):
    """Create splits with special handling for rare classes"""
    # Handle melanoma metastasis separately
    meta_mask = df["diagnosis"] == "melanoma metastasis"
    df_meta = df[meta_mask].copy()
    df_main = df[~meta_mask].copy()
    
    print(f"Main samples: {len(df_main)}, Metastasis: {len(df_meta)}")
    
    # Main stratified split
    train_main, temp_main = train_test_split(
        df_main, test_size=0.30, stratify=df_main["label"], random_state=42
    )
    val_main, test_main = train_test_split(
        temp_main, test_size=0.50, stratify=temp_main["label"], random_state=42
    )
    
    # Distribute metastasis samples
    if len(df_meta) > 0:
        df_meta_shuffled = df_meta.sample(frac=1, random_state=42).reset_index(drop=True)
        train_meta = df_meta_shuffled.iloc[:2] if len(df_meta) >= 2 else df_meta_shuffled.iloc[:1]
        val_meta = df_meta_shuffled.iloc[2:3] if len(df_meta) >= 3 else pd.DataFrame()
        test_meta = df_meta_shuffled.iloc[3:4] if len(df_meta) >= 4 else pd.DataFrame()
    else:
        train_meta, val_meta, test_meta = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
    
    # Final splits
    train_df = pd.concat([train_main, train_meta]).reset_index(drop=True) if len(train_meta) > 0 else train_main
    val_df = pd.concat([val_main, val_meta]).reset_index(drop=True) if len(val_meta) > 0 else val_main
    test_df = pd.concat([test_main, test_meta]).reset_index(drop=True) if len(test_meta) > 0 else test_main
    
    print(f"\nFinal splits:")
    print(f"Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")
    
    return train_df, val_df, test_df

train_df, val_df, test_df = create_robust_splits(df)

# ============================================
# BLOCK 4 ‚Äî Medical-Optimized Augmentations
# ============================================

class MedicalAugmentation:
    """Medical-specific augmentations that preserve lesion characteristics"""
    
    @staticmethod
    def get_train_transform():
        return transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),  # Less aggressive cropping
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.3),
            transforms.RandomRotation(10),  # Conservative rotation
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05),  # Subtle color changes
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    @staticmethod
    def get_vit_transform():
        return transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.3),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.1, hue=0.02),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
    
    @staticmethod
    def get_test_transform():
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    @staticmethod
    def get_tta_transforms():
        """Test Time Augmentation transforms"""
        return [
            transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()]),
            transforms.Compose([transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(p=1.0), transforms.ToTensor()]),
            transforms.Compose([transforms.Resize((224, 224)), transforms.RandomVerticalFlip(p=1.0), transforms.ToTensor()]),
        ]

# Initialize transforms
train_transform = MedicalAugmentation.get_train_transform()
vit_transform = MedicalAugmentation.get_vit_transform()
test_transform = MedicalAugmentation.get_test_transform()
tta_transforms = MedicalAugmentation.get_tta_transforms()

# ============================================
# BLOCK 5 ‚Äî Robust Dataset with Debugging
# ============================================

class RobustDermDataset(Dataset):
    def __init__(self, df, transform=None, debug=False):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.debug = debug
        
        if debug:
            self._validate_dataset()
    
    def _validate_dataset(self):
        """Comprehensive dataset validation"""
        print("=== DATASET VALIDATION ===")
        print(f"Total samples: {len(self.df)}")
        print(f"Label range: {self.df['label'].min()} to {self.df['label'].max()}")
        print(f"Num classes: {NUM_CLASSES}")
        
        # Check for invalid labels
        invalid_labels = self.df[self.df['label'] >= NUM_CLASSES]
        if len(invalid_labels) > 0:
            print(f"üö® ERROR: {len(invalid_labels)} invalid labels found!")
            print(invalid_labels[['diagnosis', 'label']].head())
            raise ValueError("Invalid labels detected")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        try:
            row = self.df.iloc[idx]
            img_path = str(row["derm_fullpath"]).replace("/", "\\")
            label = int(row["label"])
            
            # Validate label range
            if label >= NUM_CLASSES or label < 0:
                raise ValueError(f"Invalid label {label} for sample {idx}")
            
            img = Image.open(img_path).convert("RGB")
            
            if self.transform:
                img = self.transform(img)
            
            return img, label
            
        except Exception as e:
            print(f"Error loading sample {idx}: {e}")
            print(f"Path: {img_path}, Label: {label}")
            raise e

# Create datasets with validation
print("\nCreating datasets...")
train_dataset = RobustDermDataset(train_df, transform=train_transform, debug=True)
val_dataset = RobustDermDataset(val_df, transform=test_transform, debug=True)
test_dataset = RobustDermDataset(test_df, transform=test_transform, debug=True)

vit_train_dataset = RobustDermDataset(train_df, transform=vit_transform)
vit_val_dataset = RobustDermDataset(val_df, transform=test_transform)

# ============================================
# BLOCK 6 ‚Äî Advanced DataLoaders
# ============================================

def create_data_loaders():
    """Create optimized data loaders"""
    # Weighted sampling for class imbalance
    class_counts = train_df["label"].value_counts().sort_index()
    class_weights = 1.0 / class_counts
    sample_weights = train_df["label"].map(class_weights).values
    sample_weights = torch.tensor(sample_weights, dtype=torch.float32)
    
    train_sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    
    batch_size = 16  # Reduced for stability
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
    vit_train_loader = DataLoader(vit_train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=0)
    vit_val_loader = DataLoader(vit_val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
    print(f"DataLoader batch size: {batch_size}")
    print(f"Train batches: {len(train_loader)}")
    
    return train_loader, val_loader, test_loader, vit_train_loader, vit_val_loader

train_loader, val_loader, test_loader, vit_train_loader, vit_val_loader = create_data_loaders()

# ============================================
# BLOCK 7 ‚Äî Hybrid CNN-Transformer Model
# ============================================

class HybridModel(nn.Module):
    """Combines CNN and Transformer features"""
    def __init__(self, num_classes, cnn_dropout=0.2, transformer_dropout=0.1):
        super().__init__()
        
        # CNN backbone (ResNet50)
        self.cnn_backbone = models.resnet50(weights="IMAGENET1K_V2")
        self.cnn_backbone.fc = nn.Identity()
        cnn_features = 2048
        
        # Transformer backbone (ViT)
        self.transformer_backbone = timm.create_model(
            "vit_base_patch16_224", 
            pretrained=True, 
            num_classes=0  # Remove classification head
        )
        transformer_features = 768
        
        # Feature fusion
        total_features = cnn_features + transformer_features
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(total_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )
        
        # Initialize weights
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Proper weight initialization"""
        for m in self.classifier.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # CNN features
        cnn_features = self.cnn_backbone(x)
        
        # Transformer features
        transformer_features = self.transformer_backbone(x)
        
        # Feature fusion
        combined_features = torch.cat([cnn_features, transformer_features], dim=1)
        
        # Classification
        output = self.classifier(combined_features)
        
        return output

# ============================================
# BLOCK 8 ‚Äî Advanced Loss Functions
# ============================================

class LabelSmoothingCrossEntropy(nn.Module):
    """Label smoothing for better calibration"""
    def __init__(self, smoothing=0.1, reduction='mean'):
        super().__init__()
        self.smoothing = smoothing
        self.reduction = reduction
    
    def forward(self, logits, targets):
        log_probs = F.log_softmax(logits, dim=-1)
        nll_loss = -log_probs.gather(dim=-1, index=targets.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -log_probs.mean(dim=-1)
        loss = (1 - self.smoothing) * nll_loss + self.smoothing * smooth_loss
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss

class FocalLoss(nn.Module):
    """Focal loss for imbalanced datasets"""
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, logits, targets):
        ce_loss = F.cross_entropy(logits, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        return focal_loss

# ============================================
# BLOCK 9 ‚Äî Advanced Training Components
# ============================================

class StochasticWeightAveraging:
    """SWA for better generalization"""
    def __init__(self, model, swa_start=10, swa_lr=1e-5):
        self.model = model
        self.swa_start = swa_start
        self.swa_lr = swa_lr
        self.swa_model = torch.optim.swa_utils.AveragedModel(model)
        self.swa_scheduler = torch.optim.swa_utils.SWALR(
            optimizers, swa_lr=swa_lr
        )
        self.swa_enabled = False
    
    def update(self, epoch):
        if epoch >= self.swa_start and not self.swa_enabled:
            self.swa_enabled = True
            print(f"üöÄ SWA activated at epoch {epoch}")
        
        if self.swa_enabled:
            self.swa_model.update_parameters(self.model)
    
    def swap_weights(self):
        if self.swa_enabled:
            self.swa_model.swap_swa_sgd()

class GradientAccumulator:
    """Gradient accumulation for effective larger batch sizes"""
    def __init__(self, accumulation_steps=4):
        self.accumulation_steps = accumulation_steps
        self.counter = 0
    
    def should_step(self):
        self.counter += 1
        return self.counter % self.accumulation_steps == 0
    
    def reset(self):
        self.counter = 0

# ============================================
# BLOCK 10 ‚Äî Robust Training Loop
# ============================================

def compute_accuracy(logits, labels):
    """Safe accuracy computation"""
    preds = torch.argmax(logits, dim=1)
    correct = (preds == labels).sum().item()
    return correct / len(labels)

def train_epoch_safe(model, loader, optimizer, criterion, device, gradient_accumulator):
    """Safe training with gradient checking"""
    model.train()
    total_loss, total_acc = 0, 0
    gradient_accumulator.reset()
    
    for batch_idx, (imgs, labels) in enumerate(tqdm(loader, desc="Training", leave=False)):
        imgs, labels = imgs.to(device), labels.to(device)
        
        # Forward pass
        outputs = model(imgs)
        loss = criterion(outputs, labels) / gradient_accumulator.accumulation_steps
        
        # Backward pass
        loss.backward()
        
        # Gradient accumulation step
        if gradient_accumulator.should_step():
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()
        
        total_loss += loss.item() * gradient_accumulator.accumulation_steps
        total_acc += compute_accuracy(outputs, labels)
    
    return total_loss / len(loader), total_acc / len(loader)

def validate_epoch_safe(model, loader, criterion, device):
    """Safe validation"""
    model.eval()
    total_loss, total_acc = 0, 0
    
    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc="Validation", leave=False):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            total_acc += compute_accuracy(outputs, labels)
    
    return total_loss / len(loader), total_acc / len(loader)

# ============================================
# BLOCK 11 ‚Äî Comprehensive Training Manager
# ============================================

def train_comprehensive(model, train_loader, val_loader, model_name, num_classes, 
                       epochs=50, lr=1e-4, use_swa=True):
    """Comprehensive training with all advanced techniques"""
    
    # Loss function with label smoothing
    criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
    
    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=lr, 
        weight_decay=1e-4,
        betas=(0.9, 0.999)
    )
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=10, T_mult=2, eta_min=1e-7
    )
    
    # Advanced components
    gradient_accumulator = GradientAccumulator(accumulation_steps=4)
    swa = StochasticWeightAveraging(model, swa_start=epochs//2) if use_swa else None
    
    history = {
        "train_loss": [], "val_loss": [], 
        "train_acc": [], "val_acc": [],
        "learning_rates": []
    }
    
    best_val_acc = 0.0
    patience, patience_counter = 10, 0
    
    print(f"\n{'='*60}")
    print(f"TRAINING {model_name}")
    print(f"{'='*60}")
    print(f"Epochs: {epochs}, LR: {lr}, Classes: {num_classes}")
    
    for epoch in range(1, epochs + 1):
        print(f"\nEpoch {epoch}/{epochs}")
        
        # Training
        train_loss, train_acc = train_epoch_safe(
            model, train_loader, optimizer, criterion, device, gradient_accumulator
        )
        
        # Validation
        val_loss, val_acc = validate_epoch_safe(model, val_loader, criterion, device)
        
        # Learning rate scheduling
        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]
        
        # SWA update
        if swa:
            swa.update(epoch)
        
        # History tracking
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(train_acc)
        history["val_acc"].append(val_acc)
        history["learning_rates"].append(current_lr)
        
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc:.4f}")
        print(f"LR: {current_lr:.2e}")
        
        # Early stopping and model saving
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), f"{model_name}_best.pth")
            print(f"  ‚úÖ New best model saved (val_acc: {best_val_acc:.4f})")
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= patience:
            print(f"üõë Early stopping at epoch {epoch}")
            break
    
    # Apply SWA at the end
    if swa and swa.swa_enabled:
        print("Applying SWA...")
        swa.swap_weights()
        torch.save(model.state_dict(), f"{model_name}_swa_final.pth")
    
    return history, best_val_acc

# ============================================
# BLOCK 12 ‚Äî Test Time Augmentation (TTA)
# ============================================

def predict_with_tta(model, test_loader, device, num_transforms=5):
    """Test Time Augmentation for robust predictions"""
    model.eval()
    all_preds, all_probs, all_labels = [], [], []
    
    with torch.no_grad():
        for imgs, labels in tqdm(test_loader, desc="TTA Prediction"):
            imgs, labels = imgs.to(device), labels.to(device)
            batch_probs = []
            
            # Original image
            outputs = model(imgs)
            batch_probs.append(F.softmax(outputs, dim=1))
            
            # Augmented predictions
            for i in range(num_transforms - 1):
                if i < len(tta_transforms):
                    augmented_imgs = tta_transforms[i](imgs.cpu()).to(device)
                else:
                    # Random augmentation for remaining transforms
                    aug_transform = transforms.Compose([
                        transforms.RandomHorizontalFlip(p=0.5),
                        transforms.RandomVerticalFlip(p=0.3),
                        transforms.ToTensor()
                    ])
                    augmented_imgs = torch.stack([aug_transform(img.cpu()) for img in imgs]).to(device)
                
                outputs = model(augmented_imgs)
                batch_probs.append(F.softmax(outputs, dim=1))
            
            # Average probabilities
            avg_probs = torch.stack(batch_probs).mean(0)
            preds = torch.argmax(avg_probs, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(avg_probs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return np.array(all_preds), np.array(all_probs), np.array(all_labels)

# ============================================
# BLOCK 13 ‚Äî Smart Ensemble
# ============================================

class SmartEnsemble:
    def __init__(self, models, val_loader, device):
        self.models = models
        self.device = device
        self.weights = self._compute_optimal_weights(val_loader)
        print(f"Ensemble weights: {self.weights}")
    
    def _compute_optimal_weights(self, val_loader):
        """Compute optimal weights based on validation performance"""
        model_accuracies = []
        
        for model in self.models:
            model.eval()
            correct, total = 0, 0
            
            with torch.no_grad():
                for imgs, labels in val_loader:
                    imgs, labels = imgs.to(self.device), labels.to(self.device)
                    outputs = model(imgs)
                    preds = torch.argmax(outputs, dim=1)
                    correct += (preds == labels).sum().item()
                    total += len(labels)
            
            accuracy = correct / total
            model_accuracies.append(accuracy)
        
        # Convert accuracies to weights using softmax
        acc_tensor = torch.tensor(model_accuracies)
        weights = F.softmax(acc_tensor * 5, dim=0)  # Temperature scaling
        return weights.numpy()
    
    def predict(self, test_loader, use_tta=True):
        """Ensemble prediction with optional TTA"""
        all_preds, all_probs, all_labels = [], [], []
        
        with torch.no_grad():
            for imgs, labels in test_loader:
                imgs, labels = imgs.to(self.device), labels.to(self.device)
                batch_probs = None
                
                for i, model in enumerate(self.models):
                    if use_tta:
                        # TTA for each model
                        model_probs = []
                        for transform in tta_transforms:
                            augmented_imgs = transform(imgs.cpu()).to(self.device)
                            outputs = model(augmented_imgs)
                            model_probs.append(F.softmax(outputs, dim=1))
                        avg_model_probs = torch.stack(model_probs).mean(0)
                    else:
                        outputs = model(imgs)
                        avg_model_probs = F.softmax(outputs, dim=1)
                    
                    weighted_probs = avg_model_probs * self.weights[i]
                    
                    if batch_probs is None:
                        batch_probs = weighted_probs
                    else:
                        batch_probs += weighted_probs
                
                preds = torch.argmax(batch_probs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_probs.extend(batch_probs.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        return np.array(all_preds), np.array(all_probs), np.array(all_labels)

# ============================================
# BLOCK 14 ‚Äî Model Training Execution
# ============================================

print("\n" + "="*80)
print("INITIALIZING ADVANCED TRAINING PIPELINE")
print("="*80)

# Initialize Hybrid Model
print("\nCreating Hybrid CNN-Transformer Model...")
hybrid_model = HybridModel(num_classes=NUM_CLASSES).to(device)
print(f"Hybrid model parameters: {sum(p.numel() for p in hybrid_model.parameters()):,}")

# Train Hybrid Model
print("\n" + "="*60)
print("PHASE 1: TRAINING HYBRID MODEL")
print("="*60)

hybrid_history, hybrid_best_acc = train_comprehensive(
    hybrid_model, 
    train_loader, 
    val_loader,
    model_name="Hybrid_CNN_Transformer",
    num_classes=NUM_CLASSES,
    epochs=60,  # More epochs for convergence
    lr=5e-5,    # Conservative learning rate
    use_swa=True
)

print(f"\nüéØ Hybrid Model Best Validation Accuracy: {hybrid_best_acc:.4f}")

# ============================================
# BLOCK 15 ‚Äî Comprehensive Evaluation
# ============================================

def comprehensive_evaluation(model, test_loader, model_name, use_tta=True):
    """Comprehensive model evaluation with metrics"""
    print(f"\n{'='*60}")
    print(f"COMPREHENSIVE EVALUATION: {model_name}")
    print(f"{'='*60}")
    
    if use_tta:
        print("Using Test Time Augmentation...")
        preds, probs, true_labels = predict_with_tta(model, test_loader, device)
    else:
        model.eval()
        preds, probs, true_labels = [], [], []
        
        with torch.no_grad():
            for imgs, labels in tqdm(test_loader, desc="Evaluation"):
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                batch_probs = F.softmax(outputs, dim=1)
                batch_preds = torch.argmax(outputs, dim=1)
                
                preds.extend(batch_preds.cpu().numpy())
                probs.extend(batch_probs.cpu().numpy())
                true_labels.extend(labels.cpu().numpy())
        
        preds, probs, true_labels = np.array(preds), np.array(probs), np.array(true_labels)
    
    # Calculate metrics
    accuracy = accuracy_score(true_labels, preds)
    balanced_accuracy = balanced_accuracy_score(true_labels, preds)
    
    print(f"Accuracy:           {accuracy:.4f}")
    print(f"Balanced Accuracy:  {balanced_accuracy:.4f}")
    
    # Detailed classification report
    print("\nüìä Classification Report:")
    print(classification_report(true_labels, preds, target_names=class_names))
    
    # Per-class accuracy
    print("\nüéØ Per-Class Accuracy:")
    for i, cls in enumerate(class_names):
        class_mask = (true_labels == i)
        if class_mask.sum() > 0:
            class_acc = (preds[class_mask] == i).mean()
            print(f"  {cls:25s}: {class_acc:.4f} ({class_mask.sum()} samples)")
        else:
            print(f"  {cls:25s}: No test samples")
    
    # Confusion Matrix
    plt.figure(figsize=(12, 10))
    cm = confusion_matrix(true_labels, preds)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f'Confusion Matrix - {model_name}')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()
    
    return accuracy, balanced_accuracy, probs

# Evaluate Hybrid Model
print("\n" + "="*60)
print("EVALUATING HYBRID MODEL")
print("="*60)

hybrid_accuracy, hybrid_balanced, hybrid_probs = comprehensive_evaluation(
    hybrid_model, test_loader, "Hybrid CNN-Transformer", use_tta=True
)

# ============================================
# BLOCK 16 ‚Äî Training Curves Visualization
# ============================================

def plot_comprehensive_history(history, title):
    """Plot training history with multiple subplots"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
    
    # Loss curves
    ax1.plot(history['train_loss'], label='Train Loss', linewidth=2)
    ax1.plot(history['val_loss'], label='Val Loss', linewidth=2)
    ax1.set_title(f'{title} - Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Accuracy curves
    ax2.plot(history['train_acc'], label='Train Acc', linewidth=2)
    ax2.plot(history['val_acc'], label='Val Acc', linewidth=2)
    ax2.set_title(f'{title} - Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Learning rate
    ax3.plot(history['learning_rates'], color='purple', linewidth=2)
    ax3.set_title('Learning Rate Schedule')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Learning Rate')
    ax3.set_yscale('log')
    ax3.grid(True, alpha=0.3)
    
    # Loss vs Accuracy
    ax4.scatter(history['train_loss'], history['train_acc'], alpha=0.6, label='Train')
    ax4.scatter(history['val_loss'], history['val_acc'], alpha=0.6, label='Val')
    ax4.set_title('Loss vs Accuracy')
    ax4.set_xlabel('Loss')
    ax4.set_ylabel('Accuracy')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Plot training history
print("\nüìà Plotting Training History...")
plot_comprehensive_history(hybrid_history, "Hybrid CNN-Transformer")

# ============================================
# BLOCK 17 ‚Äî Final Results and Summary
# ============================================

print("\n" + "="*80)
print("FINAL RESULTS SUMMARY")
print("="*80)

print(f"\nüéØ HYBRID CNN-TRANSFORMER PERFORMANCE:")
print(f"   Test Accuracy:        {hybrid_accuracy:.4f}")
print(f"   Balanced Accuracy:    {hybrid_balanced:.4f}")

# Compare with baseline (0.389 from previous)
baseline_accuracy = 0.389
improvement = ((hybrid_accuracy - baseline_accuracy) / baseline_accuracy) * 100

print(f"\nüìà IMPROVEMENT OVER BASELINE:")
print(f"   Baseline Accuracy:    {baseline_accuracy:.4f}")
print(f"   Current Accuracy:     {hybrid_accuracy:.4f}")
print(f"   Improvement:          +{improvement:+.1f}%")

if hybrid_accuracy > 0.50:
    print(f"\nüéâ EXCELLENT! Significant improvement achieved!")
elif hybrid_accuracy > 0.45:
    print(f"\n‚úÖ GOOD! Moderate improvement achieved!")
else:
    print(f"\n‚ö†Ô∏è  Needs further optimization")

# Save final results
results_summary = {
    'Model': ['Hybrid_CNN_Transformer'],
    'Accuracy': [hybrid_accuracy],
    'Balanced_Accuracy': [hybrid_balanced],
    'Improvement_Over_Baseline': [improvement]
}

results_df = pd.DataFrame(results_summary)
results_df.to_csv('advanced_model_results.csv', index=False)
print(f"\nüíæ Results saved to 'advanced_model_results.csv'")

print("\n" + "="*80)
print("TRAINING COMPLETED SUCCESSFULLY! üéâ")
print("="*80)

Using device: cuda
=== DATA LOADING & VALIDATION ===
Loaded meta.csv ‚Äî shape: (1011, 19)
Missing images: 0
Final dataset size: 1010
Classes: 19
Class distribution:
diagnosis
clark nevus                     399
melanoma (less than 0.76 mm)    102
reed or spitz nevus              79
melanoma (in situ)               64
melanoma (0.76 to 1.5 mm)        53
seborrheic keratosis             45
basal cell carcinoma             42
dermal nevus                     33
vascular lesion                  29
blue nevus                       28
melanoma (more than 1.5 mm)      28
lentigo                          24
dermatofibroma                   20
congenital nevus                 17
melanosis                        16
combined nevus                   13
miscellaneous                     8
recurrent nevus                   6
melanoma metastasis               4
Name: count, dtype: int64

Encoded 19 classes:
  0: basal cell carcinoma
  1: blue nevus
  2: clark nevus
  3: combined nevus
  4: congenita

NameError: name 'optimizers' is not defined

In [None]:
#!/usr/bin/env python
# coding: utf-8

# ============================================
# MEDICAL VISION TRANSFORMER - Advanced Dermatology Classifier
# ============================================

import os
import random
import numpy as np
import pandas as pd
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.optim.swa_utils import AveragedModel, SWALR

import torchvision
from torchvision import transforms, models
import timm

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, balanced_accuracy_score

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# ============================================
# BLOCK 1 ‚Äî Setup & Diagnostics
# ============================================

# Paths
ROOT_DIR = r"C:\Users\anama\Documents\Group_8"
DATASET_DIR = os.path.join(ROOT_DIR, "Dataset", "DERM7PT")
META_CSV = os.path.join(DATASET_DIR, "meta", "meta.csv")
IMAGES_FOLDER = os.path.join(DATASET_DIR, "images")

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# ============================================
# BLOCK 2 ‚Äî Data Loading with Robust Validation
# ============================================

def load_and_validate_data():
    """Robust data loading with comprehensive validation"""
    print("=== DATA LOADING & VALIDATION ===")
    
    # Load metadata
    df = pd.read_csv(META_CSV)
    print(f"Loaded meta.csv ‚Äî shape: {df.shape}")
    
    # Clean data
    df = df.drop(columns=["case_num", "case_id", "notes"], errors="ignore")
    df["derm_fullpath"] = df["derm"].apply(lambda x: os.path.join(IMAGES_FOLDER, x))
    
    # Validate paths
    df["derm_exists"] = df["derm_fullpath"].apply(os.path.exists)
    missing_count = (~df["derm_exists"]).sum()
    print(f"Missing images: {missing_count}")
    
    df = df[df["derm_exists"]].reset_index(drop=True)
    
    # Filter classes with >= 2 samples
    class_counts = df["diagnosis"].value_counts()
    valid_classes = class_counts[class_counts >= 2].index
    df = df[df["diagnosis"].isin(valid_classes)].reset_index(drop=True)
    
    print(f"Final dataset size: {len(df)}")
    print(f"Classes: {len(valid_classes)}")
    print("Class distribution:")
    print(class_counts[valid_classes])
    
    return df, valid_classes

df, valid_classes = load_and_validate_data()

# Label encoding
le = LabelEncoder()
df["label"] = le.fit_transform(df["diagnosis"])
NUM_CLASSES = len(le.classes_)
class_names = le.classes_

print(f"\nEncoded {NUM_CLASSES} classes:")
for i, cls in enumerate(class_names):
    print(f"  {i}: {cls}")

# ============================================
# BLOCK 3 ‚Äî Robust Data Split
# ============================================

def create_robust_splits(df):
    """Create splits with special handling for rare classes"""
    # Handle melanoma metastasis separately
    meta_mask = df["diagnosis"] == "melanoma metastasis"
    df_meta = df[meta_mask].copy()
    df_main = df[~meta_mask].copy()
    
    print(f"Main samples: {len(df_main)}, Metastasis: {len(df_meta)}")
    
    # Main stratified split
    train_main, temp_main = train_test_split(
        df_main, test_size=0.30, stratify=df_main["label"], random_state=42
    )
    val_main, test_main = train_test_split(
        temp_main, test_size=0.50, stratify=temp_main["label"], random_state=42
    )
    
    # Distribute metastasis samples
    if len(df_meta) > 0:
        df_meta_shuffled = df_meta.sample(frac=1, random_state=42).reset_index(drop=True)
        train_meta = df_meta_shuffled.iloc[:2] if len(df_meta) >= 2 else df_meta_shuffled.iloc[:1]
        val_meta = df_meta_shuffled.iloc[2:3] if len(df_meta) >= 3 else pd.DataFrame()
        test_meta = df_meta_shuffled.iloc[3:4] if len(df_meta) >= 4 else pd.DataFrame()
    else:
        train_meta, val_meta, test_meta = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
    
    # Final splits
    train_df = pd.concat([train_main, train_meta]).reset_index(drop=True) if len(train_meta) > 0 else train_main
    val_df = pd.concat([val_main, val_meta]).reset_index(drop=True) if len(val_meta) > 0 else val_main
    test_df = pd.concat([test_main, test_meta]).reset_index(drop=True) if len(test_meta) > 0 else test_main
    
    print(f"\nFinal splits:")
    print(f"Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")
    
    return train_df, val_df, test_df

train_df, val_df, test_df = create_robust_splits(df)

# ============================================
# BLOCK 4 ‚Äî Medical-Optimized Augmentations
# ============================================

class MedicalAugmentation:
    """Medical-specific augmentations that preserve lesion characteristics"""
    
    @staticmethod
    def get_train_transform():
        return transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.3),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    @staticmethod
    def get_vit_transform():
        return transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.3),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.1, hue=0.02),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
    
    @staticmethod
    def get_test_transform():
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    @staticmethod
    def get_tta_transforms():
        """Test Time Augmentation transforms"""
        return [
            transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]),
            transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.RandomHorizontalFlip(p=1.0),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]),
            transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.RandomVerticalFlip(p=1.0),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]),
        ]

# Initialize transforms
train_transform = MedicalAugmentation.get_train_transform()
vit_transform = MedicalAugmentation.get_vit_transform()
test_transform = MedicalAugmentation.get_test_transform()
tta_transforms = MedicalAugmentation.get_tta_transforms()

# ============================================
# BLOCK 5 ‚Äî Robust Dataset with Debugging
# ============================================

class RobustDermDataset(Dataset):
    def __init__(self, df, transform=None, debug=False):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.debug = debug
        
        if debug:
            self._validate_dataset()
    
    def _validate_dataset(self):
        """Comprehensive dataset validation"""
        print("=== DATASET VALIDATION ===")
        print(f"Total samples: {len(self.df)}")
        print(f"Label range: {self.df['label'].min()} to {self.df['label'].max()}")
        print(f"Num classes: {NUM_CLASSES}")
        
        # Check for invalid labels
        invalid_labels = self.df[self.df['label'] >= NUM_CLASSES]
        if len(invalid_labels) > 0:
            print(f"üö® ERROR: {len(invalid_labels)} invalid labels found!")
            print(invalid_labels[['diagnosis', 'label']].head())
            raise ValueError("Invalid labels detected")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        try:
            row = self.df.iloc[idx]
            img_path = str(row["derm_fullpath"]).replace("/", "\\")
            label = int(row["label"])
            
            # Validate label range
            if label >= NUM_CLASSES or label < 0:
                raise ValueError(f"Invalid label {label} for sample {idx}")
            
            img = Image.open(img_path).convert("RGB")
            
            if self.transform:
                img = self.transform(img)
            
            return img, label
            
        except Exception as e:
            print(f"Error loading sample {idx}: {e}")
            print(f"Path: {img_path}, Label: {label}")
            raise e

# Create datasets with validation
print("\nCreating datasets...")
train_dataset = RobustDermDataset(train_df, transform=train_transform, debug=True)
val_dataset = RobustDermDataset(val_df, transform=test_transform, debug=True)
test_dataset = RobustDermDataset(test_df, transform=test_transform, debug=True)

# ============================================
# BLOCK 6 ‚Äî Advanced DataLoaders
# ============================================

def create_data_loaders():
    """Create optimized data loaders"""
    # Weighted sampling for class imbalance
    class_counts = train_df["label"].value_counts().sort_index()
    class_weights = 1.0 / class_counts
    sample_weights = train_df["label"].map(class_weights).values
    sample_weights = torch.tensor(sample_weights, dtype=torch.float32)
    
    train_sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    
    batch_size = 16  # Reduced for stability
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
    print(f"DataLoader batch size: {batch_size}")
    print(f"Train batches: {len(train_loader)}")
    
    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = create_data_loaders()

# ============================================
# BLOCK 7 ‚Äî Enhanced Vision Transformer Model
# ============================================

class EnhancedViT(nn.Module):
    """Enhanced Vision Transformer with better regularization"""
    def __init__(self, num_classes, model_name='vit_base_patch16_224', dropout=0.1):
        super().__init__()
        
        # Load pre-trained ViT
        self.vit = timm.create_model(
            model_name,
            pretrained=True,
            num_classes=0,  # Remove classification head
            drop_rate=dropout
        )
        
        # Get feature dimension
        feature_dim = self.vit.num_features
        
        # Enhanced classification head
        self.classifier = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )
        
        # Initialize classifier weights
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Proper weight initialization"""
        for m in self.classifier.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        features = self.vit(x)
        return self.classifier(features)

# ============================================
# BLOCK 8 ‚Äî Advanced Loss Functions
# ============================================

class LabelSmoothingCrossEntropy(nn.Module):
    """Label smoothing for better calibration"""
    def __init__(self, smoothing=0.1, reduction='mean'):
        super().__init__()
        self.smoothing = smoothing
        self.reduction = reduction
    
    def forward(self, logits, targets):
        log_probs = F.log_softmax(logits, dim=-1)
        nll_loss = -log_probs.gather(dim=-1, index=targets.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -log_probs.mean(dim=-1)
        loss = (1 - self.smoothing) * nll_loss + self.smoothing * smooth_loss
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss

# ============================================
# BLOCK 9 ‚Äî Advanced Training Components
# ============================================

class GradientAccumulator:
    """Gradient accumulation for effective larger batch sizes"""
    def __init__(self, accumulation_steps=4):
        self.accumulation_steps = accumulation_steps
        self.counter = 0
    
    def should_step(self):
        self.counter += 1
        return self.counter % self.accumulation_steps == 0
    
    def reset(self):
        self.counter = 0

# ============================================
# BLOCK 10 ‚Äî Robust Training Loop
# ============================================

def compute_accuracy(logits, labels):
    """Safe accuracy computation"""
    preds = torch.argmax(logits, dim=1)
    correct = (preds == labels).sum().item()
    return correct / len(labels)

def train_epoch_safe(model, loader, optimizer, criterion, device, gradient_accumulator):
    """Safe training with gradient checking"""
    model.train()
    total_loss, total_acc = 0, 0
    gradient_accumulator.reset()
    
    for batch_idx, (imgs, labels) in enumerate(tqdm(loader, desc="Training", leave=False)):
        imgs, labels = imgs.to(device), labels.to(device)
        
        # Forward pass
        outputs = model(imgs)
        loss = criterion(outputs, labels) / gradient_accumulator.accumulation_steps
        
        # Backward pass
        loss.backward()
        
        # Gradient accumulation step
        if gradient_accumulator.should_step():
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()
        
        total_loss += loss.item() * gradient_accumulator.accumulation_steps
        total_acc += compute_accuracy(outputs, labels)
    
    return total_loss / len(loader), total_acc / len(loader)

def validate_epoch_safe(model, loader, criterion, device):
    """Safe validation"""
    model.eval()
    total_loss, total_acc = 0, 0
    
    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc="Validation", leave=False):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            total_acc += compute_accuracy(outputs, labels)
    
    return total_loss / len(loader), total_acc / len(loader)

# ============================================
# BLOCK 11 ‚Äî Comprehensive Training Manager
# ============================================

def train_comprehensive(model, train_loader, val_loader, model_name, num_classes, 
                       epochs=50, lr=1e-4, use_swa=True):
    """Comprehensive training with all advanced techniques"""
    
    # Loss function with label smoothing
    criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
    
    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=lr, 
        weight_decay=1e-4,
        betas=(0.9, 0.999)
    )
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=10, T_mult=2, eta_min=1e-7
    )
    
    # SWA setup
    if use_swa:
        swa_model = AveragedModel(model)
        swa_scheduler = SWALR(optimizer, swa_lr=1e-5)
        swa_start = epochs // 2
    else:
        swa_model = None
    
    # Gradient accumulation
    gradient_accumulator = GradientAccumulator(accumulation_steps=4)
    
    history = {
        "train_loss": [], "val_loss": [], 
        "train_acc": [], "val_acc": [],
        "learning_rates": []
    }
    
    best_val_acc = 0.0
    patience, patience_counter = 10, 0
    
    print(f"\n{'='*60}")
    print(f"TRAINING {model_name}")
    print(f"{'='*60}")
    print(f"Epochs: {epochs}, LR: {lr}, Classes: {num_classes}")
    
    for epoch in range(1, epochs + 1):
        print(f"\nEpoch {epoch}/{epochs}")
        
        # Training
        train_loss, train_acc = train_epoch_safe(
            model, train_loader, optimizer, criterion, device, gradient_accumulator
        )
        
        # Validation
        val_loss, val_acc = validate_epoch_safe(model, val_loader, criterion, device)
        
        # Learning rate scheduling
        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]
        
        # SWA update
        if use_swa and epoch >= swa_start:
            swa_model.update_parameters(model)
            swa_scheduler.step()
        
        # History tracking
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(train_acc)
        history["val_acc"].append(val_acc)
        history["learning_rates"].append(current_lr)
        
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc:.4f}")
        print(f"LR: {current_lr:.2e}")
        
        # Early stopping and model saving
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), f"{model_name}_best.pth")
            print(f"  ‚úÖ New best model saved (val_acc: {best_val_acc:.4f})")
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= patience:
            print(f"üõë Early stopping at epoch {epoch}")
            break
    
    # Apply SWA at the end
    if use_swa and swa_model is not None:
        print("Applying SWA...")
        torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
        torch.save(swa_model.state_dict(), f"{model_name}_swa_final.pth")
    
    return history, best_val_acc

# ============================================
# BLOCK 12 ‚Äî Test Time Augmentation (TTA)
# ============================================

def predict_with_tta(model, test_loader, device, num_transforms=3):
    """Test Time Augmentation for robust predictions"""
    model.eval()
    all_preds, all_probs, all_labels = [], [], []
    
    with torch.no_grad():
        for imgs, labels in tqdm(test_loader, desc="TTA Prediction"):
            imgs, labels = imgs.to(device), labels.to(device)
            batch_probs = []
            
            # Original image + augmented predictions
            for i in range(num_transforms):
                if i == 0:
                    # Original image
                    augmented_imgs = imgs
                elif i == 1:
                    # Horizontal flip
                    augmented_imgs = torch.flip(imgs, [3])
                else:
                    # Vertical flip
                    augmented_imgs = torch.flip(imgs, [2])
                
                outputs = model(augmented_imgs)
                batch_probs.append(F.softmax(outputs, dim=1))
            
            # Average probabilities
            avg_probs = torch.stack(batch_probs).mean(0)
            preds = torch.argmax(avg_probs, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(avg_probs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return np.array(all_preds), np.array(all_probs), np.array(all_labels)

# ============================================
# BLOCK 13 ‚Äî Model Training Execution
# ============================================

print("\n" + "="*80)
print("INITIALIZING ADVANCED TRAINING PIPELINE")
print("="*80)

# Initialize Enhanced ViT Model
print("\nCreating Enhanced Vision Transformer Model...")
vit_model = EnhancedViT(num_classes=NUM_CLASSES).to(device)
print(f"Enhanced ViT parameters: {sum(p.numel() for p in vit_model.parameters()):,}")

# Train Enhanced ViT Model
print("\n" + "="*60)
print("PHASE 1: TRAINING ENHANCED VISION TRANSFORMER")
print("="*60)

vit_history, vit_best_acc = train_comprehensive(
    vit_model, 
    train_loader, 
    val_loader,
    model_name="Enhanced_ViT",
    num_classes=NUM_CLASSES,
    epochs=60,
    lr=2e-5,    # Lower LR for ViT
    use_swa=True
)

print(f"\nüéØ Enhanced ViT Best Validation Accuracy: {vit_best_acc:.4f}")

# ============================================
# BLOCK 14 ‚Äî Comprehensive Evaluation
# ============================================

def comprehensive_evaluation(model, test_loader, model_name, use_tta=True):
    """Comprehensive model evaluation with metrics"""
    print(f"\n{'='*60}")
    print(f"COMPREHENSIVE EVALUATION: {model_name}")
    print(f"{'='*60}")
    
    if use_tta:
        print("Using Test Time Augmentation...")
        preds, probs, true_labels = predict_with_tta(model, test_loader, device)
    else:
        model.eval()
        preds, probs, true_labels = [], [], []
        
        with torch.no_grad():
            for imgs, labels in tqdm(test_loader, desc="Evaluation"):
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                batch_probs = F.softmax(outputs, dim=1)
                batch_preds = torch.argmax(outputs, dim=1)
                
                preds.extend(batch_preds.cpu().numpy())
                probs.extend(batch_probs.cpu().numpy())
                true_labels.extend(labels.cpu().numpy())
        
        preds, probs, true_labels = np.array(preds), np.array(probs), np.array(true_labels)
    
    # Calculate metrics
    accuracy = accuracy_score(true_labels, preds)
    balanced_accuracy = balanced_accuracy_score(true_labels, preds)
    
    print(f"Accuracy:           {accuracy:.4f}")
    print(f"Balanced Accuracy:  {balanced_accuracy:.4f}")
    
    # Detailed classification report
    print("\nüìä Classification Report:")
    print(classification_report(true_labels, preds, target_names=class_names))
    
    # Per-class accuracy
    print("\nüéØ Per-Class Accuracy:")
    for i, cls in enumerate(class_names):
        class_mask = (true_labels == i)
        if class_mask.sum() > 0:
            class_acc = (preds[class_mask] == i).mean()
            print(f"  {cls:25s}: {class_acc:.4f} ({class_mask.sum()} samples)")
        else:
            print(f"  {cls:25s}: No test samples")
    
    # Confusion Matrix
    plt.figure(figsize=(12, 10))
    cm = confusion_matrix(true_labels, preds)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f'Confusion Matrix - {model_name}')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()
    
    return accuracy, balanced_accuracy, probs

# Evaluate Enhanced ViT Model
print("\n" + "="*60)
print("EVALUATING ENHANCED VISION TRANSFORMER")
print("="*60)

vit_accuracy, vit_balanced, vit_probs = comprehensive_evaluation(
    vit_model, test_loader, "Enhanced Vision Transformer", use_tta=True
)

# ============================================
# BLOCK 15 ‚Äî Training Curves Visualization
# ============================================

def plot_comprehensive_history(history, title):
    """Plot training history with multiple subplots"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
    
    # Loss curves
    ax1.plot(history['train_loss'], label='Train Loss', linewidth=2)
    ax1.plot(history['val_loss'], label='Val Loss', linewidth=2)
    ax1.set_title(f'{title} - Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Accuracy curves
    ax2.plot(history['train_acc'], label='Train Acc', linewidth=2)
    ax2.plot(history['val_acc'], label='Val Acc', linewidth=2)
    ax2.set_title(f'{title} - Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Learning rate
    ax3.plot(history['learning_rates'], color='purple', linewidth=2)
    ax3.set_title('Learning Rate Schedule')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Learning Rate')
    ax3.set_yscale('log')
    ax3.grid(True, alpha=0.3)
    
    # Loss vs Accuracy
    ax4.scatter(history['train_loss'], history['train_acc'], alpha=0.6, label='Train')
    ax4.scatter(history['val_loss'], history['val_acc'], alpha=0.6, label='Val')
    ax4.set_title('Loss vs Accuracy')
    ax4.set_xlabel('Loss')
    ax4.set_ylabel('Accuracy')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Plot training history
print("\nüìà Plotting Training History...")
plot_comprehensive_history(vit_history, "Enhanced Vision Transformer")

# ============================================
# BLOCK 16 ‚Äî Final Results and Summary
# ============================================

print("\n" + "="*80)
print("FINAL RESULTS SUMMARY")
print("="*80)

print(f"\nüéØ ENHANCED VISION TRANSFORMER PERFORMANCE:")
print(f"   Test Accuracy:        {vit_accuracy:.4f}")
print(f"   Balanced Accuracy:    {vit_balanced:.4f}")

# Compare with baseline (0.389 from previous)
baseline_accuracy = 0.389
improvement = ((vit_accuracy - baseline_accuracy) / baseline_accuracy) * 100

print(f"\nüìà IMPROVEMENT OVER BASELINE:")
print(f"   Baseline Accuracy:    {baseline_accuracy:.4f}")
print(f"   Current Accuracy:     {vit_accuracy:.4f}")
print(f"   Improvement:          +{improvement:+.1f}%")

if vit_accuracy > 0.50:
    print(f"\nüéâ EXCELLENT! Significant improvement achieved!")
elif vit_accuracy > 0.45:
    print(f"\n‚úÖ GOOD! Moderate improvement achieved!")
else:
    print(f"\n‚ö†Ô∏è  Needs further optimization")

# Save final results
results_summary = {
    'Model': ['Enhanced_Vision_Transformer'],
    'Accuracy': [vit_accuracy],
    'Balanced_Accuracy': [vit_balanced],
    'Improvement_Over_Baseline': [improvement]
}

results_df = pd.DataFrame(results_summary)
results_df.to_csv('enhanced_vit_results.csv', index=False)
print(f"\nüíæ Results saved to 'enhanced_vit_results.csv'")

print("\n" + "="*80)
print("TRAINING COMPLETED SUCCESSFULLY! üéâ")
print("="*80)

Using device: cuda
=== DATA LOADING & VALIDATION ===
Loaded meta.csv ‚Äî shape: (1011, 19)
Missing images: 0
Final dataset size: 1010
Classes: 19
Class distribution:
diagnosis
clark nevus                     399
melanoma (less than 0.76 mm)    102
reed or spitz nevus              79
melanoma (in situ)               64
melanoma (0.76 to 1.5 mm)        53
seborrheic keratosis             45
basal cell carcinoma             42
dermal nevus                     33
vascular lesion                  29
blue nevus                       28
melanoma (more than 1.5 mm)      28
lentigo                          24
dermatofibroma                   20
congenital nevus                 17
melanosis                        16
combined nevus                   13
miscellaneous                     8
recurrent nevus                   6
melanoma metastasis               4
Name: count, dtype: int64

Encoded 19 classes:
  0: basal cell carcinoma
  1: blue nevus
  2: clark nevus
  3: combined nevus
  4: congenita

Training:   0%|          | 0/45 [00:00<?, ?it/s]