In [None]:
# =========================================================
# DERM7PT DATA LOADER FOR TRAINING NOTEBOOKS
# =========================================================
# Run this code in your training notebooks to load the preprocessed Derm7pt dataset
# Make sure the preprocessing pipeline above has been executed first!

import os
import json
import joblib
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

# =========================================================
# PATHS TO PREPROCESSED DATA
# =========================================================
PREPROCESSED_DIR = r"augmented"

TRAIN_CSV = os.path.join(PREPROCESSED_DIR, "train_metadata_final.csv")
VAL_CSV   = os.path.join(PREPROCESSED_DIR, "val_metadata_final.csv")
TEST_CSV  = os.path.join(PREPROCESSED_DIR, "test_metadata_final.csv")

INFO_PATH = os.path.join(PREPROCESSED_DIR, "preprocessing_info.json")

# =========================================================
# LOAD PREPROCESSED DATA
# =========================================================
print("Loading preprocessed Derm7pt data...")

# Load CSVs
train_df = pd.read_csv(TRAIN_CSV)
val_df   = pd.read_csv(VAL_CSV)
test_df  = pd.read_csv(TEST_CSV)

# Load preprocessing info
with open(INFO_PATH, "r") as f:
    preprocessing_info = json.load(f)

categorical_cols = preprocessing_info["categorical_cols"]
label_mapping = preprocessing_info["label_mapping"]

print(f"\n‚úÖ Training samples:   {len(train_df)}")
print(f"‚úÖ Validation samples: {len(val_df)}")
print(f"‚úÖ Test samples:       {len(test_df)}")
print(f"\nLabel mapping: {label_mapping}")

# =========================================================
# EXTRACT FEATURES AND LABELS
# =========================================================
def extract_features(df):
    """Extract image paths, metadata features, and labels from dataframe"""
    img_paths = df["ImagePath"].values
    labels = df["label"].values
    
    # Metadata features (all columns except ImagePath and label)
    metadata_cols = [col for col in df.columns if col not in ["ImagePath", "label"]]
    metadata = df[metadata_cols].values
    
    return img_paths, metadata, labels

X_train_img, X_train_meta, y_train = extract_features(train_df)
X_val_img, X_val_meta, y_val       = extract_features(val_df)
X_test_img, X_test_meta, y_test    = extract_features(test_df)

num_classes = len(label_mapping)
print(f"\nNumber of classes: {num_classes}")

# =========================================================
# PYTORCH DATASET CLASS
# =========================================================
class Derm7ptDataset(Dataset):
    """
    Custom Dataset for Derm7pt with images + metadata
    """
    def __init__(self, img_paths, metadata, labels, transform=None):
        self.img_paths = img_paths
        self.metadata = metadata
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.img_paths[idx]
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            # Fallback to black image if loading fails
            print(f"Warning: Failed to load {img_path}, using placeholder")
            image = Image.new("RGB", (224, 224), color="black")
        
        if self.transform:
            image = self.transform(image)
        
        # Get metadata and label
        metadata = self.metadata[idx].astype(np.float32)
        label = int(self.labels[idx])
        
        return image, metadata, label

# =========================================================
# DATA TRANSFORMS
# =========================================================
# Training transforms (with augmentation) - REDUCED for small dataset
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),  # Reduced from 30
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),  # Reduced
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation/Test transforms (no augmentation)
val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# =========================================================
# CREATE DATASETS
# =========================================================
train_dataset = Derm7ptDataset(X_train_img, X_train_meta, y_train, transform=train_transform)
val_dataset   = Derm7ptDataset(X_val_img, X_val_meta, y_val, transform=val_test_transform)
test_dataset  = Derm7ptDataset(X_test_img, X_test_meta, y_test, transform=val_test_transform)

print(f"\n‚úÖ Created PyTorch Datasets")
print(f"   - Train: {len(train_dataset)} samples")
print(f"   - Val:   {len(val_dataset)} samples")
print(f"   - Test:  {len(test_dataset)} samples")

# =========================================================
# CREATE DATALOADERS (EXAMPLE - ADJUST BATCH SIZE AS NEEDED)
# =========================================================
BATCH_SIZE = 16

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,  # Set to 0 for Windows, increase for Linux/Mac
    pin_memory=True)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

print(f"\n‚úÖ Created DataLoaders (batch_size={BATCH_SIZE})")
print(f"   - Train batches: {len(train_loader)}")
print(f"   - Val batches:   {len(val_loader)}")
print(f"   - Test batches:  {len(test_loader)}")

# =========================================================
# EXAMPLE: TEST LOADING A BATCH
# =========================================================
print("\nüîç Testing batch loading...")
for images, metadata, labels in train_loader:
    print(f"   - Image batch shape:    {images.shape}")
    print(f"   - Metadata batch shape: {metadata.shape}")
    print(f"   - Labels batch shape:   {labels.shape}")
    break

print("\n‚úÖ Derm7pt data loading complete! Ready for training.")
print("\nüí° Usage in your model:")
print("   for images, metadata, labels in train_loader:")
print("       # images: torch.Tensor of shape (batch_size, 3, 224, 224)")
print("       # metadata: torch.Tensor of shape (batch_size, num_metadata_features)")
print("       # labels: torch.Tensor of shape (batch_size,)")
print("       # Your training code here...")

In [None]:
# # =========================================================
# # BALANCED BATCH SAMPLER (Paper Implementation)
# # =========================================================
# # "Each mini-batch contains an equal number of samples for each label"
# from torch.utils.data import Sampler
# import random

# class BalancedBatchSampler(Sampler):
#     """
#     Samples k examples per class in each batch to ensure balanced mini-batches.
    
#     Args:
#         labels: Array of labels for the dataset
#         k_per_class: Number of samples per class in each batch (e.g., 8)
#         num_classes: Total number of classes
#         drop_last: Whether to drop the last incomplete batch
#     """
#     def __init__(self, labels, k_per_class=8, num_classes=None, drop_last=True):
#         self.labels = np.array(labels)
#         self.k_per_class = k_per_class
#         self.drop_last = drop_last
        
#         # Automatically determine number of classes if not provided
#         if num_classes is None:
#             self.num_classes = len(np.unique(labels))
#         else:
#             self.num_classes = num_classes
        
#         # Organize indices by class
#         self.class_indices = {}
#         for class_idx in range(self.num_classes):
#             self.class_indices[class_idx] = np.where(self.labels == class_idx)[0].tolist()
        
#         # Verify all classes have enough samples
#         for class_idx, indices in self.class_indices.items():
#             if len(indices) < self.k_per_class:
#                 print(f"Warning: Class {class_idx} has only {len(indices)} samples, less than k={k_per_class}")
    
#     def __iter__(self):
#         # Shuffle indices within each class
#         shuffled_indices = {}
#         for class_idx, indices in self.class_indices.items():
#             shuffled = indices.copy()
#             random.shuffle(shuffled)
#             shuffled_indices[class_idx] = shuffled
        
#         # Create balanced batches
#         batch = []
#         class_positions = {class_idx: 0 for class_idx in range(self.num_classes)}
        
#         while True:
#             # Check if we can form a complete batch
#             can_form_batch = all(
#                 class_positions[class_idx] + self.k_per_class <= len(shuffled_indices[class_idx])
#                 for class_idx in range(self.num_classes)
#             )
            
#             if not can_form_batch:
#                 if not self.drop_last and len(batch) > 0:
#                     yield batch
#                 break
            
#             # Sample k examples from each class
#             for class_idx in range(self.num_classes):
#                 start_pos = class_positions[class_idx]
#                 end_pos = start_pos + self.k_per_class
#                 batch.extend(shuffled_indices[class_idx][start_pos:end_pos])
#                 class_positions[class_idx] = end_pos
            
#             # Shuffle batch to mix classes
#             random.shuffle(batch)
#             yield batch
#             batch = []
    
#     def __len__(self):
#         # Calculate number of complete batches possible
#         min_batches = min(len(indices) // self.k_per_class 
#                          for indices in self.class_indices.values())
#         return min_batches


# # =========================================================
# # OPTION: USE BALANCED BATCH SAMPLER
# # =========================================================
# # Set this to True to use balanced batch sampling (paper approach)
# USE_BALANCED_SAMPLING = True

# if USE_BALANCED_SAMPLING:
#     print("\n‚öñÔ∏è Using Balanced Batch Sampling (k samples per class)")
    
#     # Determine k per class based on smallest class and desired batch size
#     # For Derm7pt with 5 classes: batch_size = num_classes * k_per_class
#     # Example: 5 classes * 8 samples/class = 40 total batch size
#     k_per_class = 8  # Adjust this value
#     effective_batch_size = num_classes * k_per_class
    
#     print(f"   - k per class: {k_per_class}")
#     print(f"   - Effective batch size: {effective_batch_size}")
    
#     # Create balanced sampler for training
#     train_sampler = BalancedBatchSampler(
#         labels=y_train,
#         k_per_class=k_per_class,
#         num_classes=num_classes,
#         drop_last=True
#     )
    
#     # Recreate train loader with balanced sampler
#     train_loader = DataLoader(
#         train_dataset,
#         batch_sampler=train_sampler,  # Use batch_sampler instead of batch_size
#         num_workers=0,
#         pin_memory=True
#     )
    
#     print(f"   - Train batches: {len(train_loader)}")
    
#     # Val/test loaders remain unchanged
#     # (they already exist from previous cell)

# else:
#     print("\nüì¶ Using standard DataLoader (original approach)")
#     # Standard loaders already created above
#     pass

# print("\n‚úÖ DataLoader configuration complete!")

In [None]:
num_classes = train_df['label'].nunique()
print("Number of classes:", num_classes)
non_feature_cols = ["label", "ImagePath"]
X_train_meta = train_df.drop(columns=non_feature_cols)
input_dim_meta = X_train_meta.shape[1]
print(f"Metadata input dimension: {input_dim_meta}")

In [None]:
import torch
import torch.nn as nn
import timm
import torch.nn.functional as F

class EarlyFusionModel(nn.Module):
    def __init__(self, input_dim_meta, num_classes):
        super().__init__()
        
        # Embed metadata to smaller spatial dimensions first
        self.meta_embed = nn.Sequential(
            nn.Linear(input_dim_meta, 56 * 56),  # Smaller initial dimension
            nn.ReLU(),
            nn.BatchNorm1d(56 * 56),
            nn.Dropout(0.3)
        )
        
        # Load PVT v2 model
        self.pvt = timm.create_model("pvt_v2_b1", pretrained=True, num_classes=num_classes)
        
        # Modify the first convolution layer to accept additional channel (4 instead of 3)
        first_conv = self.pvt.patch_embed.proj
        self.pvt.patch_embed.proj = nn.Conv2d(4, first_conv.out_channels, 
                                              kernel_size=first_conv.kernel_size,
                                              stride=first_conv.stride,
                                              padding=first_conv.padding,
                                              bias=first_conv.bias is not None)
        
        # Initialize new channel weights
        with torch.no_grad():
            self.pvt.patch_embed.proj.weight.data[:, :3] = first_conv.weight.data
            self.pvt.patch_embed.proj.weight.data[:, 3:] = first_conv.weight.data.mean(dim=1, keepdim=True) * 0.1

    def forward(self, img, meta):
        # Reshape metadata to image-like format
        batch_size = img.shape[0]
        meta_reshaped = self.meta_embed(meta).view(batch_size, 1, 56, 56)
        
        # Upsample to match image dimensions
        meta_upsampled = F.interpolate(meta_reshaped, 
                                       size=(224, 224), 
                                       mode='bilinear', 
                                       align_corners=False)
        
        # Early fusion
        combined_input = torch.cat([img, meta_upsampled], dim=1)
        
        # Process through modified PVT
        out = self.pvt(combined_input)
        return out

input_dim_meta = X_train_meta.shape[1]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EarlyFusionModel(input_dim_meta, num_classes).to(device)

from torchinfo import summary
summary(model=model, 
        input_size=[(16, 3, 224, 224), (16, input_dim_meta)],  # Updated for MobileViT input size
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

In [None]:
print(input_dim_meta)

In [None]:
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

def test(model, loader, device):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for imgs, metas, labels in loader:
            imgs, metas, labels = imgs.to(device), metas.to(device), labels.to(device)
            outputs = model(imgs, metas)
            _, predicted = torch.max(outputs.data, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return all_labels, all_preds

class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.path = path
        self.counter = 0
        self.best_loss = None  # Monitor loss for early stopping
        self.best_accuracy = -np.inf  # Track best accuracy for checkpoint
        self.early_stop = False
        
    def __call__(self, val_loss, val_acc, model):
        # Monitor validation LOSS for early stopping (lower is better)
        if self.best_loss is None:
            self.best_loss = val_loss
            self.counter = 0
        elif val_loss >= self.best_loss - self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience} (val_loss: {val_loss:.4f}, best: {self.best_loss:.4f})')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            if self.verbose:
                print(f'‚úÖ Validation loss improved ({self.best_loss:.6f} --> {val_loss:.6f})')
            self.best_loss = val_loss
            self.counter = 0
        
        # Save checkpoint based on best ACCURACY (not loss)
        if val_acc > self.best_accuracy:
            self.save_checkpoint(val_acc, model)

    def save_checkpoint(self, val_acc, model):
        if self.verbose:
            print(f'üíæ Validation accuracy improved ({self.best_accuracy:.6f} --> {val_acc:.6f}). Saving model...')
        torch.save(model.state_dict(), self.path)
        self.best_accuracy = val_acc

def train(model, train_loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc='Training')
    for images, meta, labels in pbar:
        images, meta, labels = images.to(device), meta.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images, meta)
        loss = criterion(outputs, labels)
        loss.backward()
        
        # Optional: Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{running_loss/total:.4f}',
            'acc': f'{100.*correct/total:.2f}%'
        })
    
    return running_loss/len(train_loader), correct/total

def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, meta, labels in val_loader:
            images, meta, labels = images.to(device), meta.to(device), labels.to(device)
            
            outputs = model(images, meta)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return running_loss/len(val_loader), correct/total

def train_model_with_scheduler_and_checkpoint(
    model, train_loader, val_loader, optimizer, criterion, device, 
    epochs=20, patience=5, scheduler_patience=5, checkpoint_dir='checkpoints'):
    
    # Create checkpoint directory if it doesn't exist
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, 'mobilevitdermp.pt')
    
    early_stopping = EarlyStopping(
        patience=patience, 
        verbose=True,
        path=checkpoint_path
    )
    scheduler = ReduceLROnPlateau(
        optimizer, 
        mode='min',  # Monitor loss (lower is better)
        patience=scheduler_patience, 
        verbose=True,
        factor=0.1,
        min_lr=1e-6
    )
    
    history = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': [],
        'lr': []
    }
    
    best_model_epoch = None
    
    for epoch in range(epochs):
        print(f'\nEpoch {epoch+1}/{epochs}')
        
        # Training phase
        train_loss, train_acc = train(model, train_loader, optimizer, criterion, device)
        
        # Validation phase
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        # Update history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['lr'].append(optimizer.param_groups[0]['lr'])
        
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
        
        # Update scheduler based on validation loss
        scheduler.step(val_loss)
        
        # Early stopping check (monitors loss, saves best accuracy)
        early_stopping(val_loss, val_acc, model)
        if val_acc > early_stopping.best_accuracy:
            best_model_epoch = epoch + 1
            
        if early_stopping.early_stop:
            print("üõë Early stopping triggered (validation loss not improving)")
            break
    
    # Load best model
    model.load_state_dict(torch.load(checkpoint_path))
    
    # Plot training curves
    plot_training_curves_with_checkpoint(history, best_model_epoch)
    
    return model, history

def plot_training_curves_with_checkpoint(history, best_model_epoch):
    epochs_range = range(1, len(history['train_loss']) + 1)
    
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
    
    # Loss curves
    ax1.plot(epochs_range, history['train_loss'], label='Training Loss')
    ax1.plot(epochs_range, history['val_loss'], label='Validation Loss')
    if best_model_epoch:
        ax1.axvline(best_model_epoch, color='r', linestyle='--', label='Best Model')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    
    # Accuracy curves
    ax2.plot(epochs_range, history['train_acc'], label='Training Accuracy')
    ax2.plot(epochs_range, history['val_acc'], label='Validation Accuracy')
    if best_model_epoch:
        ax2.axvline(best_model_epoch, color='r', linestyle='--', label='Best Model')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Training and Validation Accuracy')
    ax2.legend()
    
    # Learning rate curve
    ax3.plot(epochs_range, history['lr'], label='Learning Rate')
    if best_model_epoch:
        ax3.axvline(best_model_epoch, color='r', linestyle='--', label='Best Model')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Learning Rate')
    ax3.set_title('Learning Rate Schedule')
    ax3.set_yscale('log')
    ax3.legend()
    
    plt.tight_layout()
    plt.show()
    

In [None]:
import random
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np

def set_random_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# Function to evaluate test metrics
def evaluate_test_metrics(model, test_loader, device):
    true_labels, pred_labels = test(model, test_loader, device)
    acc = accuracy_score(true_labels, pred_labels)
    precision = precision_score(true_labels, pred_labels, average='macro')
    recall = recall_score(true_labels, pred_labels, average='macro')
    f1 = f1_score(true_labels, pred_labels, average='macro')
    return acc, precision, recall, f1

# Placeholder for results
results = {
    "accuracy": [],
    "precision": [],
    "recall": [],
    "f1_score": []
}

best_accuracy = 0.0
best_model_state = None
model = None

# Run experiment for 3 random seeds
seeds = [42, 123, 569]  # Example random seeds
for seed in seeds:
    print(f"\nTraining with random seed: {seed}")
    set_random_seed(seed)
    
    # Reinitialize model, optimizer, and criterion
    model = EarlyFusionModel(input_dim_meta, num_classes).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
    criterion = nn.CrossEntropyLoss()
    
    # Train the model
    model, history = train_model_with_scheduler_and_checkpoint(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        criterion=criterion,
        device=device,
        epochs=100,
        patience=5,
        scheduler_patience=3,
        checkpoint_dir='D:\\PAD-UFES\\checkpoints'
    )
    
    # Evaluate on test set
    acc, precision, recall, f1 = evaluate_test_metrics(model, test_loader, device)
    print(f"Seed {seed}: Accuracy={acc:.4f}, Precision={precision:.4f}, Recall={recall:.4f}, F1 Score={f1:.4f}")
    
    # Save metrics
    results["accuracy"].append(acc)
    results["precision"].append(precision)
    results["recall"].append(recall)
    results["f1_score"].append(f1)
    
    # Update the best model
    if acc > best_accuracy:
        best_accuracy = acc
        best_model_state = model.state_dict()

# Compute average and standard deviation
metrics_summary = {}
for metric, values in results.items():
    avg = np.mean(values)
    std_dev = np.std(values)
    metrics_summary[metric] = (avg, std_dev)
    print(f"{metric.capitalize()}: Mean={avg:.4f}, StdDev={std_dev:.4f}")

# Save the best model
print(f"Best model achieved an accuracy of {best_accuracy:.4f}")
torch.save(best_model_state, 'D:\\Dermp7\\best_early_fusion_pvtv2smoteDA.pth')
print("Saved")
model = EarlyFusionModel(input_dim_meta, num_classes).to(device)
model.load_state_dict(torch.load('D:\\Dermp7\\best_early_fusion_pvtv2smoteDA.pth'))
model.eval()
print("Best model loaded successfully!")

In [None]:
from tqdm.auto import tqdm
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

def test(model, loader, device, desc="Testing"):
    model.eval()
    all_preds = []
    all_labels = []
    running_correct = 0
    running_total = 0

    with torch.no_grad():
        for imgs, metas, labels in tqdm(loader, total=len(loader), desc=desc, unit="batch"):
            imgs, metas, labels = imgs.to(device), metas.to(device), labels.to(device)
            outputs = model(imgs, metas)
            _, predicted = torch.max(outputs, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            running_correct += (predicted == labels).sum().item()
            running_total += labels.size(0)
            tqdm.write(f"Batch acc: {running_correct / running_total:.4f}")

    return all_labels, all_preds

# Run test with progress bar
true_labels, pred_labels = test(model, test_loader, device, desc="Evaluating on Test")

# Reports
unique_labels = sorted(train_df["label"].unique())

import matplotlib.pyplot as plt
import seaborn as sns

class_names = [k for k, v in sorted(label_mapping.items(), key=lambda x: x[1])]

conf_matrix = confusion_matrix(true_labels, pred_labels, normalize="true")

plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, cmap="gray", fmt=".2f",
            xticklabels=class_names,
            yticklabels=class_names)
plt.xlabel("Predicted Label", fontweight="bold")
plt.ylabel("True Label", fontweight="bold")
plt.title("Normalized Confusion Matrix", fontweight="bold")
plt.show()

In [None]:
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np

# --- 8 Color-Blind-Safe Colors ---
colors = [
    "#0072B2",  # Blue
    "#E69F00",  # Orange
    "#56B4E9",  # Sky Blue
    "#009E73",  # Bluish Green
    "#F0E442",  # Yellow
    "#D55E00",  # Vermillion
    "#CC79A7",  # Reddish Purple
    "#000000"   # Black
]

# --- Number of classes ---
n_class = len(class_names)   # should be 8

# --- Convert labels to arrays ---
true_labels = np.array(true_labels)
pred_labels = np.array(pred_labels)

# --- Store curves ---
fpr = {}
tpr = {}
roc_auc = {}

for i in range(n_class):
    # One-vs-rest encoding
    fpr[i], tpr[i], _ = roc_curve((true_labels == i).astype(int),
                                  (pred_labels == i).astype(int))
    roc_auc[i] = auc(fpr[i], tpr[i])

# --- Plot ---
plt.figure(figsize=(10, 8))

for i in range(n_class):
    plt.plot(
        fpr[i], tpr[i],
        color=colors[i],
        lw=2,
        label=f"{class_names[i]} (AUC = {roc_auc[i]:.3f})"
    )

plt.plot([0, 1], [0, 1], color="gray", linestyle="--", linewidth=1)

plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])

plt.xlabel("False Positive Rate", fontweight="bold", fontsize=12)
plt.ylabel("True Positive Rate", fontweight="bold", fontsize=12)
plt.title("ROC Curve (One-vs-Rest)", fontweight="bold", fontsize=14)

plt.legend(loc="lower right", fontsize=10)
plt.tight_layout()
plt.show()
