<a href="https://colab.research.google.com/github/SoudeepGhoshal/TResNet/blob/main/TResNet_PAD-UFES-20.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PAD-UFES-20

## ResNet-18

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import pandas as pd
from PIL import Image
import os
import requests
import zipfile
import shutil
from tqdm import tqdm
import time
import sys
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

# --- Configuration Parameters ---
CONFIG = {
    "data_path": "./pad-ufes-20/",
    "dataset_url": "https://www.kaggle.com/api/v1/datasets/download/maxjen/pad-ufes-20",  # Kaggle API URL
    "batch_size": 64,  # Reduced due to smaller dataset
    "num_epochs": 100,
    "learning_rate": 0.001,
    "weight_decay": 1e-4,
    "split_seed": 42,
    "train_split": 0.70,
    "val_split": 0.15,
    "test_split": 0.15,
    "num_classes": 6,  # PAD-UFES-20 has 6 diagnostic classes
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "checkpoint_path": "./best_model_pad_ufes.pth",
}

# ==========================================
# Callback Classes
# ==========================================

class Callback:
    """Base callback class"""
    def on_epoch_end(self, epoch, logs=None):
        pass

    def on_training_end(self):
        pass

class ReduceLROnPlateau(Callback):
    """Reduce learning rate when a metric has stopped improving"""
    def __init__(self, optimizer, monitor='val_accuracy', factor=0.2, patience=3, min_lr=1e-7, verbose=1):
        self.optimizer = optimizer
        self.monitor = monitor
        self.factor = factor
        self.patience = patience
        self.min_lr = min_lr
        self.verbose = verbose
        self.wait = 0
        self.best = None
        self.mode = 'min' if 'loss' in monitor else 'max'

    def on_epoch_end(self, epoch, logs=None):
        if logs is None:
            return

        current = logs.get(self.monitor)
        if current is None:
            return

        if self.best is None:
            self.best = current
            return

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                old_lr = self.optimizer.param_groups[0]['lr']
                new_lr = max(old_lr * self.factor, self.min_lr)
                if new_lr != old_lr:
                    self.optimizer.param_groups[0]['lr'] = new_lr
                    if self.verbose:
                        print(f"\nReducing learning rate from {old_lr:.2e} to {new_lr:.2e}")
                    self.wait = 0

class EarlyStopping(Callback):
    """Stop training when a monitored metric has stopped improving"""
    def __init__(self, monitor='val_accuracy', patience=7, restore_best_weights=True, verbose=1):
        self.monitor = monitor
        self.patience = patience
        self.restore_best_weights = restore_best_weights
        self.verbose = verbose
        self.wait = 0
        self.best = None
        self.best_weights = None
        self.mode = 'min' if 'loss' in monitor else 'max'
        self.stopped_epoch = 0

    def on_epoch_end(self, epoch, logs=None, model=None):
        if logs is None:
            return False

        current = logs.get(self.monitor)
        if current is None:
            return False

        if self.best is None:
            self.best = current
            if model is not None and self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
            return False

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            self.wait = 0
            if model is not None and self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                if self.verbose:
                    print(f"\nEarly stopping at epoch {epoch + 1}")
                return True
        return False

    def on_training_end(self, model=None):
        if self.stopped_epoch > 0 and self.verbose:
            print(f"Restored model weights from the end of the best epoch: {self.stopped_epoch + 1 - self.patience}")
        if model is not None and self.restore_best_weights and self.best_weights is not None:
            model.load_state_dict(self.best_weights)

class ModelCheckpoint(Callback):
    """Save the model after every epoch"""
    def __init__(self, filepath, save_best_only=True, monitor='val_accuracy', verbose=1):
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.monitor = monitor
        self.verbose = verbose
        self.best = None
        self.mode = 'min' if 'loss' in monitor else 'max'

    def on_epoch_end(self, epoch, logs=None, model=None):
        if logs is None or model is None:
            return

        current = logs.get(self.monitor)
        if current is None:
            return

        if not self.save_best_only:
            filepath = self.filepath.replace('.pth', f'_epoch_{epoch + 1}.pth')
            torch.save(model.state_dict(), filepath)
            if self.verbose:
                print(f"\nModel saved to {filepath}")
            return

        if self.best is None:
            self.best = current
            torch.save(model.state_dict(), self.filepath)
            if self.verbose:
                print(f"\nModel saved to {self.filepath}")
            return

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            torch.save(model.state_dict(), self.filepath)
            if self.verbose:
                print(f"\nModel saved to {self.filepath}")

# ==========================================
# Data Preprocessing
# ==========================================

def download_and_extract_pad_ufes(url, dest_path):
    """Downloads and extracts PAD-UFES-20 dataset from Kaggle with progress and robust error handling."""
    if os.path.exists(dest_path):
        print("Dataset directory already exists.")
        return True

    zip_path = dest_path.rstrip('/') + ".zip"
    extract_to_dir = os.path.abspath(os.path.join(dest_path, os.pardir))

    print(f"Downloading PAD-UFES-20 dataset from Kaggle...")
    try:
        # Add headers to mimic browser request
        headers = {
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
        }

        with requests.get(url, stream=True, timeout=60, headers=headers) as r:
            r.raise_for_status()
            total_size = int(r.headers.get('content-length', 0))
            with open(zip_path, 'wb') as f, tqdm(
                total=total_size, unit='iB', unit_scale=True, desc="PAD-UFES-20"
            ) as progress_bar:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
                    progress_bar.update(len(chunk))

        print("Extracting...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to_dir)

        # Check if extraction created nested folders and fix if needed
        extracted_folders = [f for f in os.listdir(extract_to_dir) if os.path.isdir(os.path.join(extract_to_dir, f))]
        if 'pad-ufes-20' in extracted_folders:
            # Dataset extracted correctly
            pass
        else:
            # Look for the actual dataset folder
            for folder in extracted_folders:
                folder_path = os.path.join(extract_to_dir, folder)
                if os.path.exists(os.path.join(folder_path, 'metadata.csv')):
                    # This is our dataset folder, rename it
                    os.rename(folder_path, dest_path)
                    break

        return True

    except requests.exceptions.RequestException as e:
        print(f"\nError downloading file: {e}", file=sys.stderr)
        print("Please download manually from: https://www.kaggle.com/datasets/maxjen/pad-ufes-20", file=sys.stderr)
        return False

    finally:
        if os.path.exists(zip_path):
            os.remove(zip_path)

class PADUFESDataset(Dataset):
    """Custom dataset class for PAD-UFES-20."""
    def __init__(self, data_dir, metadata_df, transform=None):
        self.data_dir = data_dir
        self.metadata_df = metadata_df.reset_index(drop=True)
        self.transform = transform

        # Create label mapping
        self.classes = sorted(metadata_df['diagnostic'].unique())
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

        print(f"Found {len(self.classes)} classes: {self.classes}")

    def __len__(self):
        return len(self.metadata_df)

    def __getitem__(self, idx):
        row = self.metadata_df.iloc[idx]

        # Handle different possible image path formats
        img_name = row['img_id']
        if not img_name.endswith('.png'):
            img_name += '.png'

        # Try different possible folder structures
        possible_paths = [
            os.path.join(self.data_dir, 'images', img_name),
            os.path.join(self.data_dir, 'imgs', img_name),
            os.path.join(self.data_dir, img_name),
        ]

        img_path = None
        for path in possible_paths:
            if os.path.exists(path):
                img_path = path
                break

        if img_path is None:
            # Last resort: search for the file
            for root, dirs, files in os.walk(self.data_dir):
                if img_name in files:
                    img_path = os.path.join(root, img_name)
                    break

        if img_path is None:
            print(f"Warning: Image {img_name} not found, using black placeholder")
            image = Image.new('RGB', (224, 224), color='black')
        else:
            try:
                image = Image.open(img_path).convert('RGB')
            except Exception as e:
                print(f"Error loading image {img_path}: {e}")
                image = Image.new('RGB', (224, 224), color='black')

        label = self.class_to_idx[row['diagnostic']]

        if self.transform:
            image = self.transform(image)

        return image, label

def prepare_pad_ufes_data(data_path):
    """Prepare PAD-UFES-20 data by reading metadata and organizing structure."""

    # Try different possible metadata file locations
    possible_metadata_paths = [
        os.path.join(data_path, 'metadata.csv'),
        os.path.join(data_path, 'PAD-UFES-20_metadata.csv'),
    ]

    metadata_path = None
    for path in possible_metadata_paths:
        if os.path.exists(path):
            metadata_path = path
            break

    if metadata_path is None:
        # Search for any CSV file that might be the metadata
        for root, dirs, files in os.walk(data_path):
            for file in files:
                if file.endswith('.csv') and 'metadata' in file.lower():
                    metadata_path = os.path.join(root, file)
                    break
            if metadata_path:
                break

    if metadata_path is None:
        print(f"Metadata file not found in {data_path}")
        return None

    print(f"Loading metadata from: {metadata_path}")
    try:
        metadata_df = pd.read_csv(metadata_path)
    except Exception as e:
        print(f"Error reading metadata file: {e}")
        return None

    # Print dataset info
    print(f"Total samples: {len(metadata_df)}")
    print(f"Diagnostic classes: {metadata_df['diagnostic'].value_counts()}")

    # Verify dataset integrity
    expected_samples = 2298
    expected_classes = 6

    actual_classes = len(metadata_df['diagnostic'].unique())
    if len(metadata_df) != expected_samples:
        print(f"⚠️  Warning: Expected {expected_samples} samples, got {len(metadata_df)}")
    if actual_classes != expected_classes:
        print(f"⚠️  Warning: Expected {expected_classes} classes, got {actual_classes}")

    print("✅ Dataset verification complete")
    return metadata_df

def get_dataloaders(config):
    """Downloads, prepares, and splits the PAD-UFES-20 data, returning DataLoaders."""

    # Try to download if URL is provided and dataset doesn't exist
    if config.get("dataset_url") and not os.path.exists(config["data_path"]):
        if not download_and_extract_pad_ufes(config["dataset_url"], config["data_path"]):
            print("\nAutomatic download failed. Please download manually:")
            print("1. Visit: https://www.kaggle.com/datasets/maxjen/pad-ufes-20")
            print("2. Download the dataset")
            print("3. Extract to:", config["data_path"])
            return None

    # Check if dataset exists
    if not os.path.exists(config["data_path"]):
        print(f"\nDataset not found at {config['data_path']}")
        print("Please download the PAD-UFES-20 dataset:")
        print("1. Visit: https://www.kaggle.com/datasets/maxjen/pad-ufes-20")
        print("2. Download the dataset")
        print("3. Extract to:", config["data_path"])
        print("4. Ensure the structure includes metadata.csv and image files")
        return None

    # Prepare metadata
    metadata_df = prepare_pad_ufes_data(config["data_path"])
    if metadata_df is None:
        return None

    # Define transforms
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    val_test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Create full dataset
    full_dataset = PADUFESDataset(config["data_path"], metadata_df, transform=None)

    # Ensure reproducible splits
    torch.manual_seed(config["split_seed"])
    np.random.seed(config["split_seed"])

    total_size = len(full_dataset)
    train_size = int(config["train_split"] * total_size)
    val_size = int(config["val_split"] * total_size)
    test_size = int(config["test_split"] * total_size)

    # Adjust sizes to ensure they sum to total_size
    actual_total = train_size + val_size + test_size
    if actual_total != total_size:
        test_size += (total_size - actual_total)

    print(f"Dataset splits - Train: {train_size}, Val: {val_size}, Test: {test_size}, Total: {total_size}")

    # Create indices for splitting
    indices = list(range(total_size))
    np.random.shuffle(indices)

    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size + val_size]
    test_indices = indices[train_size + val_size:]

    # Create subset datasets with appropriate transforms
    train_dataset = PADUFESDataset(
        config["data_path"],
        metadata_df.iloc[train_indices],
        transform=train_transform
    )
    val_dataset = PADUFESDataset(
        config["data_path"],
        metadata_df.iloc[val_indices],
        transform=val_test_transform
    )
    test_dataset = PADUFESDataset(
        config["data_path"],
        metadata_df.iloc[test_indices],
        transform=val_test_transform
    )

    # Set worker init function for reproducible DataLoader behavior
    def worker_init_fn(worker_id):
        np.random.seed(config["split_seed"] + worker_id)

    train_loader = DataLoader(
        train_dataset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn,
        generator=torch.Generator().manual_seed(config["split_seed"])
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )

    return train_loader, val_loader, test_loader

# ==========================================
# Model Architecture and Training
# ==========================================

class ResidualBlock(nn.Module):
    """A residual block, the fundamental building block of ResNet."""
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        # Main path
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Shortcut path (for matching dimensions)
        self.downsample = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        return out

class ResNet(nn.Module):
    """A modular ResNet implementation."""
    def __init__(self, block, layers, num_classes=200):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for s in strides:
            layers.append(block(self.in_channels, out_channels, s))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

def ResNet18(num_classes=200):
    return ResNet(ResidualBlock, [2, 2, 2, 2], num_classes)

def evaluate_model(model, data_loader, criterion, device):
    """Evaluates the model on a given dataset."""
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_loss = total_loss / total
    accuracy = 100 * correct / total
    return avg_loss, accuracy

def train_and_validate(model, train_loader, val_loader, config):
    """Main training loop with callbacks."""
    device = config["device"]
    model.to(device)

    # Set seeds for reproducible training
    torch.manual_seed(config["split_seed"])
    np.random.seed(config["split_seed"])
    if torch.cuda.is_available():
        torch.cuda.manual_seed(config["split_seed"])
        torch.cuda.manual_seed_all(config["split_seed"])

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])

    # Initialize callbacks
    callbacks = [
        ReduceLROnPlateau(
            optimizer=optimizer,
            monitor='val_accuracy',
            factor=0.2,
            patience=3,
            min_lr=1e-7,
            verbose=1
        ),
        EarlyStopping(
            monitor='val_accuracy',
            patience=7,
            restore_best_weights=True,
            verbose=1
        ),
        ModelCheckpoint(
            filepath=config["checkpoint_path"],
            save_best_only=True,
            monitor='val_accuracy',
            verbose=1
        )
    ]

    print("\n--- Starting Training ---")
    early_stop = False

    for epoch in range(config["num_epochs"]):
        if early_stop:
            break

        model.train()
        start_time = time.time()
        running_loss, train_correct, train_total = 0.0, 0, 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Train]", leave=False)
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            progress_bar.set_postfix(loss=loss.item())

        train_loss = running_loss / train_total
        train_acc = 100 * train_correct / train_total
        val_loss, val_acc = evaluate_model(model, val_loader, criterion, device)

        print(f"Epoch {epoch+1}/{config['num_epochs']} | Time: {time.time() - start_time:.2f}s | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Prepare logs for callbacks
        logs = {
            'train_loss': train_loss,
            'train_accuracy': train_acc,
            'val_loss': val_loss,
            'val_accuracy': val_acc
        }

        # Execute callbacks
        for callback in callbacks:
            if isinstance(callback, EarlyStopping):
                if callback.on_epoch_end(epoch, logs, model):
                    early_stop = True
                    break
            elif isinstance(callback, ModelCheckpoint):
                callback.on_epoch_end(epoch, logs, model)
            else:
                callback.on_epoch_end(epoch, logs)

    # Execute callback cleanup
    for callback in callbacks:
        if isinstance(callback, EarlyStopping):
            callback.on_training_end(model)
        else:
            callback.on_training_end()

    print("--- Training Finished ---\n")
    return model

# ==========================================
# Evaluation
# ==========================================

def calculate_top_k_accuracy(outputs, labels, k_values=[1, 3, 5]):
    """Calculate top-k accuracy for given k values."""
    batch_size = labels.size(0)
    _, pred = outputs.topk(max(k_values), 1, True, True)
    pred = pred.t()
    correct = pred.eq(labels.view(1, -1).expand_as(pred))

    top_k_accuracies = {}
    for k in k_values:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        top_k_accuracies[k] = correct_k.item() / batch_size

    return top_k_accuracies

def calculate_entropy(probs):
    """Calculate entropy of probability distributions."""
    # Add small epsilon to avoid log(0)
    epsilon = 1e-8
    probs = probs + epsilon
    entropy = -torch.sum(probs * torch.log(probs), dim=1)
    return entropy

def calculate_ece(confidences, accuracies, n_bins=10):
    """Calculate Expected Calibration Error."""
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = 0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (confidences > bin_lower.item()) & (confidences <= bin_upper.item())
        prop_in_bin = in_bin.float().mean()

        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

    return ece.item()

def comprehensive_evaluation(model, test_loader, criterion, device):
    """Comprehensive evaluation with all requested metrics."""
    model.eval()

    all_predictions = []
    all_labels = []
    all_confidences = []
    all_entropies = []
    all_top_k_results = {1: [], 3: [], 5: []}
    total_loss = 0.0
    total_samples = 0

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Convert outputs to probabilities
            probs = torch.softmax(outputs, dim=1)

            # Get predictions and confidences
            confidences, predictions = torch.max(probs, dim=1)

            # Calculate entropy
            entropies = calculate_entropy(probs)

            # Calculate top-k accuracies
            top_k_accs = calculate_top_k_accuracy(outputs, labels, [1, 3, 5])

            # Store results
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_confidences.extend(confidences.cpu().numpy())
            all_entropies.extend(entropies.cpu().numpy())

            for k in [1, 3, 5]:
                all_top_k_results[k].extend([top_k_accs[k]] * labels.size(0))

            total_loss += loss.item() * images.size(0)
            total_samples += images.size(0)

    # Convert to numpy arrays
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    all_confidences = np.array(all_confidences)
    all_entropies = np.array(all_entropies)

    # Calculate basic metrics
    avg_loss = total_loss / total_samples

    # Calculate top-k accuracies
    top1 = np.mean([pred == label for pred, label in zip(all_predictions, all_labels)])
    top3 = np.mean(all_top_k_results[3])
    top5 = np.mean(all_top_k_results[5])

    # Calculate precision, recall, f1-score
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='macro', zero_division=0
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='weighted', zero_division=0
    )

    # Calculate confidence and calibration metrics
    avg_confidence = np.mean(all_confidences)
    std_confidence = np.std(all_confidences)
    avg_entropy = np.mean(all_entropies)

    # Calculate ECE
    accuracies = (all_predictions == all_labels).astype(float)
    ece = calculate_ece(torch.tensor(all_confidences), torch.tensor(accuracies))

    return {
        'test_loss': avg_loss,
        'top1': top1,
        'top3': top3,
        'top5': top5,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'f1_macro': f1_macro,
        'precision_weighted': precision_weighted,
        'recall_weighted': recall_weighted,
        'f1_weighted': f1_weighted,
        'avg_confidence': avg_confidence,
        'std_confidence': std_confidence,
        'avg_entropy': avg_entropy,
        'ece': ece
    }

# ==========================================
# Main Execution
# ==========================================

if __name__ == '__main__':
    # Set global seeds for full reproducibility
    torch.manual_seed(CONFIG["split_seed"])
    np.random.seed(CONFIG["split_seed"])
    if torch.cuda.is_available():
        torch.cuda.manual_seed(CONFIG["split_seed"])
        torch.cuda.manual_seed_all(CONFIG["split_seed"])
        # Ensure deterministic behavior on CUDA
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    print(f"Using device: {CONFIG['device']}")
    print(f"Reproducibility seed: {CONFIG['split_seed']}")

    # Get data loaders
    dataloaders = get_dataloaders(CONFIG)
    if dataloaders is None:
        print("Could not prepare data. Halting execution.", file=sys.stderr)
        sys.exit(1) # Exit if data preparation failed

    train_loader, val_loader, test_loader = dataloaders

    # Initialize and train model
    model = ResNet18(num_classes=CONFIG["num_classes"])
    trained_model = train_and_validate(model, train_loader, val_loader, CONFIG)

    # Comprehensive evaluation on the test set
    print("--- Starting Comprehensive Evaluation on Test Set ---")
    results = comprehensive_evaluation(trained_model, test_loader, nn.CrossEntropyLoss(), CONFIG["device"])

    # Print all requested metrics
    print("\n--- Top-K Accuracy Results ---")
    print(f"Top-1 Accuracy: {results['top1'] * 100:.2f}%")
    print(f"Top-3 Accuracy: {results['top3'] * 100:.2f}%")
    print(f"Top-5 Accuracy: {results['top5'] * 100:.2f}%")

    print("\n--- Additional Performance Metrics ---")
    print(f"Macro Average Precision: {results['precision_macro']:.4f}")
    print(f"Macro Average Recall: {results['recall_macro']:.4f}")
    print(f"Macro Average F1-Score: {results['f1_macro']:.4f}")
    print(f"Weighted Average Precision: {results['precision_weighted']:.4f}")
    print(f"Weighted Average Recall: {results['recall_weighted']:.4f}")
    print(f"Weighted Average F1-Score: {results['f1_weighted']:.4f}")

    print("\n--- Model Confidence & Calibration Metrics ---")
    print(f"Average Prediction Confidence: {results['avg_confidence']:.4f}")
    print(f"Confidence Standard Deviation: {results['std_confidence']:.4f}")
    print(f"Average Prediction Entropy: {results['avg_entropy']:.4f}")
    print(f"Expected Calibration Error: {results['ece']:.4f}")

    print(f"\nFinal Test Loss: {results['test_loss']:.4f}")
    print("===============================================")

Using device: cuda
Reproducibility seed: 42
Downloading PAD-UFES-20 dataset from Kaggle...


PAD-UFES-20: 100%|██████████| 3.60G/3.60G [00:44<00:00, 80.7MiB/s]


Extracting...
Loading metadata from: ./pad-ufes-20/metadata.csv
Total samples: 2298
Diagnostic classes: diagnostic
BCC    845
ACK    730
NEV    244
SEK    235
SCC    192
MEL     52
Name: count, dtype: int64
✅ Dataset verification complete
Found 6 classes: ['ACK', 'BCC', 'MEL', 'NEV', 'SCC', 'SEK']
Dataset splits - Train: 1608, Val: 344, Test: 346, Total: 2298
Found 6 classes: ['ACK', 'BCC', 'MEL', 'NEV', 'SCC', 'SEK']
Found 6 classes: ['ACK', 'BCC', 'MEL', 'NEV', 'SCC', 'SEK']
Found 6 classes: ['ACK', 'BCC', 'MEL', 'NEV', 'SCC', 'SEK']

--- Starting Training ---




Epoch 1/100 | Time: 78.61s | Train Loss: 1.7392, Train Acc: 32.77% | Val Loss: 2.4749, Val Acc: 35.47%

Model saved to ./best_model_pad_ufes.pth




Epoch 2/100 | Time: 77.79s | Train Loss: 1.5052, Train Acc: 37.81% | Val Loss: 1.4946, Val Acc: 35.47%




Epoch 3/100 | Time: 78.41s | Train Loss: 1.5331, Train Acc: 34.95% | Val Loss: 1.5019, Val Acc: 37.21%

Model saved to ./best_model_pad_ufes.pth




Epoch 4/100 | Time: 76.71s | Train Loss: 1.4889, Train Acc: 37.06% | Val Loss: 1.5013, Val Acc: 40.99%

Model saved to ./best_model_pad_ufes.pth




Epoch 5/100 | Time: 77.71s | Train Loss: 1.4685, Train Acc: 38.81% | Val Loss: 1.4709, Val Acc: 37.79%




Epoch 6/100 | Time: 76.86s | Train Loss: 1.4701, Train Acc: 38.56% | Val Loss: 1.5206, Val Acc: 41.28%

Model saved to ./best_model_pad_ufes.pth




Epoch 7/100 | Time: 77.19s | Train Loss: 1.4269, Train Acc: 39.80% | Val Loss: 1.4787, Val Acc: 38.37%




Epoch 8/100 | Time: 77.21s | Train Loss: 1.4134, Train Acc: 41.11% | Val Loss: 1.4918, Val Acc: 39.24%




Epoch 9/100 | Time: 75.88s | Train Loss: 1.4073, Train Acc: 40.92% | Val Loss: 1.4598, Val Acc: 42.44%

Model saved to ./best_model_pad_ufes.pth




Epoch 10/100 | Time: 76.66s | Train Loss: 1.3778, Train Acc: 42.60% | Val Loss: 1.3838, Val Acc: 40.41%




Epoch 11/100 | Time: 76.41s | Train Loss: 1.3562, Train Acc: 43.16% | Val Loss: 1.3721, Val Acc: 44.48%

Model saved to ./best_model_pad_ufes.pth




Epoch 12/100 | Time: 75.95s | Train Loss: 1.3746, Train Acc: 42.35% | Val Loss: 1.3887, Val Acc: 44.19%




Epoch 13/100 | Time: 76.50s | Train Loss: 1.3521, Train Acc: 41.98% | Val Loss: 1.5303, Val Acc: 38.66%




Epoch 14/100 | Time: 76.29s | Train Loss: 1.3398, Train Acc: 43.28% | Val Loss: 1.3895, Val Acc: 41.86%

Reducing learning rate from 1.00e-03 to 2.00e-04




Epoch 15/100 | Time: 76.41s | Train Loss: 1.2834, Train Acc: 48.94% | Val Loss: 1.2824, Val Acc: 48.84%

Model saved to ./best_model_pad_ufes.pth




Epoch 16/100 | Time: 77.15s | Train Loss: 1.2588, Train Acc: 48.32% | Val Loss: 1.2722, Val Acc: 51.45%

Model saved to ./best_model_pad_ufes.pth




Epoch 17/100 | Time: 77.08s | Train Loss: 1.2217, Train Acc: 49.50% | Val Loss: 1.2180, Val Acc: 53.20%

Model saved to ./best_model_pad_ufes.pth




Epoch 18/100 | Time: 76.85s | Train Loss: 1.2270, Train Acc: 50.87% | Val Loss: 1.2189, Val Acc: 54.94%

Model saved to ./best_model_pad_ufes.pth




Epoch 19/100 | Time: 76.84s | Train Loss: 1.2208, Train Acc: 50.19% | Val Loss: 1.2153, Val Acc: 52.33%




Epoch 20/100 | Time: 77.12s | Train Loss: 1.2163, Train Acc: 49.69% | Val Loss: 1.2395, Val Acc: 51.74%




Epoch 21/100 | Time: 77.38s | Train Loss: 1.2013, Train Acc: 51.74% | Val Loss: 1.2139, Val Acc: 53.49%

Reducing learning rate from 2.00e-04 to 4.00e-05




Epoch 22/100 | Time: 76.97s | Train Loss: 1.1879, Train Acc: 53.11% | Val Loss: 1.1553, Val Acc: 56.40%

Model saved to ./best_model_pad_ufes.pth




Epoch 23/100 | Time: 76.44s | Train Loss: 1.1599, Train Acc: 54.04% | Val Loss: 1.1562, Val Acc: 56.98%

Model saved to ./best_model_pad_ufes.pth




Epoch 24/100 | Time: 77.15s | Train Loss: 1.1596, Train Acc: 54.04% | Val Loss: 1.1472, Val Acc: 54.65%




Epoch 25/100 | Time: 77.47s | Train Loss: 1.1661, Train Acc: 53.54% | Val Loss: 1.1380, Val Acc: 56.10%




Epoch 26/100 | Time: 76.53s | Train Loss: 1.1705, Train Acc: 52.18% | Val Loss: 1.1537, Val Acc: 56.98%

Reducing learning rate from 4.00e-05 to 8.00e-06




Epoch 27/100 | Time: 77.98s | Train Loss: 1.1396, Train Acc: 54.10% | Val Loss: 1.1452, Val Acc: 57.56%

Model saved to ./best_model_pad_ufes.pth




Epoch 28/100 | Time: 76.96s | Train Loss: 1.1350, Train Acc: 54.91% | Val Loss: 1.1503, Val Acc: 56.98%




Epoch 29/100 | Time: 76.75s | Train Loss: 1.1458, Train Acc: 54.42% | Val Loss: 1.1473, Val Acc: 55.81%




Epoch 30/100 | Time: 76.25s | Train Loss: 1.1393, Train Acc: 54.29% | Val Loss: 1.1449, Val Acc: 56.10%

Reducing learning rate from 8.00e-06 to 1.60e-06




Epoch 31/100 | Time: 75.88s | Train Loss: 1.1460, Train Acc: 54.42% | Val Loss: 1.1409, Val Acc: 57.27%




Epoch 32/100 | Time: 75.63s | Train Loss: 1.1388, Train Acc: 54.91% | Val Loss: 1.1435, Val Acc: 56.40%




Epoch 33/100 | Time: 74.95s | Train Loss: 1.1477, Train Acc: 54.73% | Val Loss: 1.1453, Val Acc: 56.10%

Reducing learning rate from 1.60e-06 to 3.20e-07




Epoch 34/100 | Time: 75.64s | Train Loss: 1.1326, Train Acc: 54.29% | Val Loss: 1.1407, Val Acc: 55.81%

Early stopping at epoch 34
Restored model weights from the end of the best epoch: 27
--- Training Finished ---

--- Starting Comprehensive Evaluation on Test Set ---


Evaluating: 100%|██████████| 6/6 [00:12<00:00,  2.09s/it]


--- Top-K Accuracy Results ---
Top-1 Accuracy: 55.49%
Top-3 Accuracy: 89.60%
Top-5 Accuracy: 99.13%

--- Additional Performance Metrics ---
Macro Average Precision: 0.4061
Macro Average Recall: 0.3559
Macro Average F1-Score: 0.3600
Weighted Average Precision: 0.4882
Weighted Average Recall: 0.5549
Weighted Average F1-Score: 0.5086

--- Model Confidence & Calibration Metrics ---
Average Prediction Confidence: 0.5558
Confidence Standard Deviation: 0.1389
Average Prediction Entropy: 1.1394
Expected Calibration Error: 0.0239

Final Test Loss: 1.2002





## Hybrid Transformer on ResNet-18 (No PE)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import pandas as pd
from PIL import Image
import os
import requests
import zipfile
import shutil
from tqdm import tqdm
import time
import sys
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

# --- Configuration Parameters ---
CONFIG = {
    "data_path": "./pad-ufes-20/",
    "dataset_url": "https://www.kaggle.com/api/v1/datasets/download/maxjen/pad-ufes-20",  # Kaggle API URL
    "batch_size": 64,  # Reduced due to smaller dataset
    "num_epochs": 100,
    "learning_rate": 0.001,
    "weight_decay": 1e-4,
    "split_seed": 42,
    "train_split": 0.70,
    "val_split": 0.15,
    "test_split": 0.15,
    "num_classes": 6,  # PAD-UFES-20 has 6 diagnostic classes
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "checkpoint_path": "./best_model_pad_ufes.pth",
}

# --- Transformer Configuration ---
TRANSFORMER_CONFIG = {
    "embedding_dim": 32,
    "nhead": 16,
    "num_encoder_layers": 3,
    "dim_feedforward": 2,
    "dropout": 0.1,
}

# ==========================================
# Callback Classes
# ==========================================

class Callback:
    """Base callback class"""
    def on_epoch_end(self, epoch, logs=None):
        pass

    def on_training_end(self):
        pass

class ReduceLROnPlateau(Callback):
    """Reduce learning rate when a metric has stopped improving"""
    def __init__(self, optimizer, monitor='val_loss', factor=0.2, patience=3, min_lr=1e-7, verbose=1):
        self.optimizer = optimizer
        self.monitor = monitor
        self.factor = factor
        self.patience = patience
        self.min_lr = min_lr
        self.verbose = verbose
        self.wait = 0
        self.best = None
        self.mode = 'min' if 'loss' in monitor else 'max'

    def on_epoch_end(self, epoch, logs=None):
        if logs is None:
            return

        current = logs.get(self.monitor)
        if current is None:
            return

        if self.best is None:
            self.best = current
            return

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                old_lr = self.optimizer.param_groups[0]['lr']
                new_lr = max(old_lr * self.factor, self.min_lr)
                if new_lr != old_lr:
                    self.optimizer.param_groups[0]['lr'] = new_lr
                    if self.verbose:
                        print(f"\nReducing learning rate from {old_lr:.2e} to {new_lr:.2e}")
                    self.wait = 0

class EarlyStopping(Callback):
    """Stop training when a monitored metric has stopped improving"""
    def __init__(self, monitor='val_loss', patience=7, restore_best_weights=True, verbose=1):
        self.monitor = monitor
        self.patience = patience
        self.restore_best_weights = restore_best_weights
        self.verbose = verbose
        self.wait = 0
        self.best = None
        self.best_weights = None
        self.mode = 'min' if 'loss' in monitor else 'max'
        self.stopped_epoch = 0

    def on_epoch_end(self, epoch, logs=None, model=None):
        if logs is None:
            return False

        current = logs.get(self.monitor)
        if current is None:
            return False

        if self.best is None:
            self.best = current
            if model is not None and self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
            return False

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            self.wait = 0
            if model is not None and self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                if self.verbose:
                    print(f"\nEarly stopping at epoch {epoch + 1}")
                return True
        return False

    def on_training_end(self, model=None):
        if self.stopped_epoch > 0 and self.verbose:
            print(f"Restored model weights from the end of the best epoch: {self.stopped_epoch + 1 - self.patience}")
        if model is not None and self.restore_best_weights and self.best_weights is not None:
            model.load_state_dict(self.best_weights)

class ModelCheckpoint(Callback):
    """Save the model after every epoch"""
    def __init__(self, filepath, save_best_only=True, monitor='val_accuracy', verbose=1):
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.monitor = monitor
        self.verbose = verbose
        self.best = None
        self.mode = 'min' if 'loss' in monitor else 'max'

    def on_epoch_end(self, epoch, logs=None, model=None):
        if logs is None or model is None:
            return

        current = logs.get(self.monitor)
        if current is None:
            return

        if not self.save_best_only:
            filepath = self.filepath.replace('.pth', f'_epoch_{epoch + 1}.pth')
            torch.save(model.state_dict(), filepath)
            if self.verbose:
                print(f"\nModel saved to {filepath}")
            return

        if self.best is None:
            self.best = current
            torch.save(model.state_dict(), self.filepath)
            if self.verbose:
                print(f"\nModel saved to {self.filepath}")
            return

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            torch.save(model.state_dict(), self.filepath)
            if self.verbose:
                print(f"\nModel saved to {self.filepath}")

# ==========================================
# Data Preprocessing
# ==========================================

def download_and_extract_pad_ufes(url, dest_path):
    """Downloads and extracts PAD-UFES-20 dataset from Kaggle with progress and robust error handling."""
    if os.path.exists(dest_path):
        print("Dataset directory already exists.")
        return True

    zip_path = dest_path.rstrip('/') + ".zip"
    extract_to_dir = os.path.abspath(os.path.join(dest_path, os.pardir))

    print(f"Downloading PAD-UFES-20 dataset from Kaggle...")
    try:
        # Add headers to mimic browser request
        headers = {
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
        }

        with requests.get(url, stream=True, timeout=60, headers=headers) as r:
            r.raise_for_status()
            total_size = int(r.headers.get('content-length', 0))
            with open(zip_path, 'wb') as f, tqdm(
                total=total_size, unit='iB', unit_scale=True, desc="PAD-UFES-20"
            ) as progress_bar:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
                    progress_bar.update(len(chunk))

        print("Extracting...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to_dir)

        # Check if extraction created nested folders and fix if needed
        extracted_folders = [f for f in os.listdir(extract_to_dir) if os.path.isdir(os.path.join(extract_to_dir, f))]
        if 'pad-ufes-20' in extracted_folders:
            # Dataset extracted correctly
            pass
        else:
            # Look for the actual dataset folder
            for folder in extracted_folders:
                folder_path = os.path.join(extract_to_dir, folder)
                if os.path.exists(os.path.join(folder_path, 'metadata.csv')):
                    # This is our dataset folder, rename it
                    os.rename(folder_path, dest_path)
                    break

        return True

    except requests.exceptions.RequestException as e:
        print(f"\nError downloading file: {e}", file=sys.stderr)
        print("Please download manually from: https://www.kaggle.com/datasets/maxjen/pad-ufes-20", file=sys.stderr)
        return False

    finally:
        if os.path.exists(zip_path):
            os.remove(zip_path)

class PADUFESDataset(Dataset):
    """Custom dataset class for PAD-UFES-20."""
    def __init__(self, data_dir, metadata_df, transform=None):
        self.data_dir = data_dir
        self.metadata_df = metadata_df.reset_index(drop=True)
        self.transform = transform

        # Create label mapping
        self.classes = sorted(metadata_df['diagnostic'].unique())
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

        print(f"Found {len(self.classes)} classes: {self.classes}")

    def __len__(self):
        return len(self.metadata_df)

    def __getitem__(self, idx):
        row = self.metadata_df.iloc[idx]

        # Handle different possible image path formats
        img_name = row['img_id']
        if not img_name.endswith('.png'):
            img_name += '.png'

        # Try different possible folder structures
        possible_paths = [
            os.path.join(self.data_dir, 'images', img_name),
            os.path.join(self.data_dir, 'imgs', img_name),
            os.path.join(self.data_dir, img_name),
        ]

        img_path = None
        for path in possible_paths:
            if os.path.exists(path):
                img_path = path
                break

        if img_path is None:
            # Last resort: search for the file
            for root, dirs, files in os.walk(self.data_dir):
                if img_name in files:
                    img_path = os.path.join(root, img_name)
                    break

        if img_path is None:
            print(f"Warning: Image {img_name} not found, using black placeholder")
            image = Image.new('RGB', (224, 224), color='black')
        else:
            try:
                image = Image.open(img_path).convert('RGB')
            except Exception as e:
                print(f"Error loading image {img_path}: {e}")
                image = Image.new('RGB', (224, 224), color='black')

        label = self.class_to_idx[row['diagnostic']]

        if self.transform:
            image = self.transform(image)

        return image, label

def prepare_pad_ufes_data(data_path):
    """Prepare PAD-UFES-20 data by reading metadata and organizing structure."""

    # Try different possible metadata file locations
    possible_metadata_paths = [
        os.path.join(data_path, 'metadata.csv'),
        os.path.join(data_path, 'PAD-UFES-20_metadata.csv'),
    ]

    metadata_path = None
    for path in possible_metadata_paths:
        if os.path.exists(path):
            metadata_path = path
            break

    if metadata_path is None:
        # Search for any CSV file that might be the metadata
        for root, dirs, files in os.walk(data_path):
            for file in files:
                if file.endswith('.csv') and 'metadata' in file.lower():
                    metadata_path = os.path.join(root, file)
                    break
            if metadata_path:
                break

    if metadata_path is None:
        print(f"Metadata file not found in {data_path}")
        return None

    print(f"Loading metadata from: {metadata_path}")
    try:
        metadata_df = pd.read_csv(metadata_path)
    except Exception as e:
        print(f"Error reading metadata file: {e}")
        return None

    # Print dataset info
    print(f"Total samples: {len(metadata_df)}")
    print(f"Diagnostic classes: {metadata_df['diagnostic'].value_counts()}")

    # Verify dataset integrity
    expected_samples = 2298
    expected_classes = 6

    actual_classes = len(metadata_df['diagnostic'].unique())
    if len(metadata_df) != expected_samples:
        print(f"⚠️  Warning: Expected {expected_samples} samples, got {len(metadata_df)}")
    if actual_classes != expected_classes:
        print(f"⚠️  Warning: Expected {expected_classes} classes, got {actual_classes}")

    print("✅ Dataset verification complete")
    return metadata_df

def get_dataloaders(config):
    """Downloads, prepares, and splits the PAD-UFES-20 data, returning DataLoaders."""

    # Try to download if URL is provided and dataset doesn't exist
    if config.get("dataset_url") and not os.path.exists(config["data_path"]):
        if not download_and_extract_pad_ufes(config["dataset_url"], config["data_path"]):
            print("\nAutomatic download failed. Please download manually:")
            print("1. Visit: https://www.kaggle.com/datasets/maxjen/pad-ufes-20")
            print("2. Download the dataset")
            print("3. Extract to:", config["data_path"])
            return None

    # Check if dataset exists
    if not os.path.exists(config["data_path"]):
        print(f"\nDataset not found at {config['data_path']}")
        print("Please download the PAD-UFES-20 dataset:")
        print("1. Visit: https://www.kaggle.com/datasets/maxjen/pad-ufes-20")
        print("2. Download the dataset")
        print("3. Extract to:", config["data_path"])
        print("4. Ensure the structure includes metadata.csv and image files")
        return None

    # Prepare metadata
    metadata_df = prepare_pad_ufes_data(config["data_path"])
    if metadata_df is None:
        return None

    # Define transforms
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    val_test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Create full dataset
    full_dataset = PADUFESDataset(config["data_path"], metadata_df, transform=None)

    # Ensure reproducible splits
    torch.manual_seed(config["split_seed"])
    np.random.seed(config["split_seed"])

    total_size = len(full_dataset)
    train_size = int(config["train_split"] * total_size)
    val_size = int(config["val_split"] * total_size)
    test_size = int(config["test_split"] * total_size)

    # Adjust sizes to ensure they sum to total_size
    actual_total = train_size + val_size + test_size
    if actual_total != total_size:
        test_size += (total_size - actual_total)

    print(f"Dataset splits - Train: {train_size}, Val: {val_size}, Test: {test_size}, Total: {total_size}")

    # Create indices for splitting
    indices = list(range(total_size))
    np.random.shuffle(indices)

    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size + val_size]
    test_indices = indices[train_size + val_size:]

    # Create subset datasets with appropriate transforms
    train_dataset = PADUFESDataset(
        config["data_path"],
        metadata_df.iloc[train_indices],
        transform=train_transform
    )
    val_dataset = PADUFESDataset(
        config["data_path"],
        metadata_df.iloc[val_indices],
        transform=val_test_transform
    )
    test_dataset = PADUFESDataset(
        config["data_path"],
        metadata_df.iloc[test_indices],
        transform=val_test_transform
    )

    # Set worker init function for reproducible DataLoader behavior
    def worker_init_fn(worker_id):
        np.random.seed(config["split_seed"] + worker_id)

    train_loader = DataLoader(
        train_dataset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn,
        generator=torch.Generator().manual_seed(config["split_seed"])
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )

    return train_loader, val_loader, test_loader

# ==========================================
# Model Architecture and Training
# ==========================================

class NonResidualBlock(nn.Module):
    """
    This block is a standard convolutional block WITHOUT the residual connection.
    """
    def __init__(self, in_channels, out_channels, stride=1):
        super(NonResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        return out

class ResNetTransformer(nn.Module):
    """
    A hybrid architecture combining a non-residual CNN backbone with a Transformer encoder.
    """
    def __init__(self, block, layers, num_classes, t_config):
        super(ResNetTransformer, self).__init__()
        self.in_channels = 64
        self.embedding_dim = t_config["embedding_dim"]

        # 1. CNN Backbone (Feature Extractor)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        # 2. Projection heads to create tokens from feature maps
        self.projections = nn.ModuleList([
            self._create_projection(64, self.embedding_dim),   # From initial maxpool
            self._create_projection(64, self.embedding_dim),   # From layer1
            self._create_projection(128, self.embedding_dim),  # From layer2
            self._create_projection(256, self.embedding_dim),  # From layer3
            self._create_projection(512, self.embedding_dim)   # From layer4
        ])

        # 3. Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embedding_dim,
            nhead=t_config["nhead"],
            dim_feedforward=t_config["dim_feedforward"],
            dropout=t_config["dropout"],
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=t_config["num_encoder_layers"]
        )

        # 4. Final Classifier
        self.classifier = nn.Linear(self.embedding_dim, num_classes)

    def _create_projection(self, in_features, out_features):
        return nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(in_features, out_features)
        )

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        current_in_channels = self.in_channels
        for s in strides:
            layers.append(block(current_in_channels, out_channels, s))
            current_in_channels = out_channels
        self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        # 1. Pass through CNN backbone and capture features
        features = []
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        features.append(x)

        x = self.layer1(x); features.append(x)
        x = self.layer2(x); features.append(x)
        x = self.layer3(x); features.append(x)
        x = self.layer4(x); features.append(x)

        # 2. Project features to tokens
        tokens = []
        for i, feature_map in enumerate(features):
            tokens.append(self.projections[i](feature_map))

        # 3. Stack tokens and pass through Transformer
        token_sequence = torch.stack(tokens, dim=1)
        transformer_out = self.transformer_encoder(token_sequence)

        # 4. Aggregate and classify
        aggregated_vector = transformer_out.mean(dim=1)
        logits = self.classifier(aggregated_vector)

        return logits

def ResNetTransformer18(num_classes=200, t_config=TRANSFORMER_CONFIG):
    return ResNetTransformer(NonResidualBlock, [2, 2, 2, 2], num_classes, t_config)

def evaluate_model(model, data_loader, criterion, device):
    """Evaluates the model on a given dataset."""
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_loss = total_loss / total
    accuracy = 100 * correct / total
    return avg_loss, accuracy

def train_and_validate(model, train_loader, val_loader, config):
    """Main training loop with callbacks."""
    device = config["device"]
    model.to(device)

    # Set seeds for reproducible training
    torch.manual_seed(config["split_seed"])
    np.random.seed(config["split_seed"])
    if torch.cuda.is_available():
        torch.cuda.manual_seed(config["split_seed"])
        torch.cuda.manual_seed_all(config["split_seed"])

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])

    # Initialize callbacks
    callbacks = [
        ReduceLROnPlateau(
            optimizer=optimizer,
            monitor='val_accuracy',
            factor=0.2,
            patience=3,
            min_lr=1e-7,
            verbose=1
        ),
        EarlyStopping(
            monitor='val_accuracy',
            patience=7,
            restore_best_weights=True,
            verbose=1
        ),
        ModelCheckpoint(
            filepath=config["checkpoint_path"],
            save_best_only=True,
            monitor='val_accuracy',
            verbose=1
        )
    ]

    print("\n--- Starting Training ---")
    early_stop = False

    for epoch in range(config["num_epochs"]):
        if early_stop:
            break

        model.train()
        start_time = time.time()
        running_loss, train_correct, train_total = 0.0, 0, 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Train]", leave=False)
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            progress_bar.set_postfix(loss=loss.item())

        train_loss = running_loss / train_total
        train_acc = 100 * train_correct / train_total
        val_loss, val_acc = evaluate_model(model, val_loader, criterion, device)

        print(f"Epoch {epoch+1}/{config['num_epochs']} | Time: {time.time() - start_time:.2f}s | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Prepare logs for callbacks
        logs = {
            'train_loss': train_loss,
            'train_accuracy': train_acc,
            'val_loss': val_loss,
            'val_accuracy': val_acc
        }

        # Execute callbacks
        for callback in callbacks:
            if isinstance(callback, EarlyStopping):
                if callback.on_epoch_end(epoch, logs, model):
                    early_stop = True
                    break
            elif isinstance(callback, ModelCheckpoint):
                callback.on_epoch_end(epoch, logs, model)
            else:
                callback.on_epoch_end(epoch, logs)

    # Execute callback cleanup
    for callback in callbacks:
        if isinstance(callback, EarlyStopping):
            callback.on_training_end(model)
        else:
            callback.on_training_end()

    print("--- Training Finished ---\n")
    return model

# ==========================================
# Evaluation
# ==========================================

def calculate_top_k_accuracy(outputs, labels, k_values=[1, 3, 5]):
    """Calculate top-k accuracy for given k values."""
    batch_size = labels.size(0)
    _, pred = outputs.topk(max(k_values), 1, True, True)
    pred = pred.t()
    correct = pred.eq(labels.view(1, -1).expand_as(pred))

    top_k_accuracies = {}
    for k in k_values:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        top_k_accuracies[k] = correct_k.item() / batch_size

    return top_k_accuracies

def calculate_entropy(probs):
    """Calculate entropy of probability distributions."""
    # Add small epsilon to avoid log(0)
    epsilon = 1e-8
    probs = probs + epsilon
    entropy = -torch.sum(probs * torch.log(probs), dim=1)
    return entropy

def calculate_ece(confidences, accuracies, n_bins=10):
    """Calculate Expected Calibration Error."""
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = 0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (confidences > bin_lower.item()) & (confidences <= bin_upper.item())
        prop_in_bin = in_bin.float().mean()

        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

    return ece.item()

def comprehensive_evaluation(model, test_loader, criterion, device):
    """Comprehensive evaluation with all requested metrics."""
    model.eval()

    all_predictions = []
    all_labels = []
    all_confidences = []
    all_entropies = []
    all_top_k_results = {1: [], 3: [], 5: []}
    total_loss = 0.0
    total_samples = 0

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Convert outputs to probabilities
            probs = torch.softmax(outputs, dim=1)

            # Get predictions and confidences
            confidences, predictions = torch.max(probs, dim=1)

            # Calculate entropy
            entropies = calculate_entropy(probs)

            # Calculate top-k accuracies
            top_k_accs = calculate_top_k_accuracy(outputs, labels, [1, 3, 5])

            # Store results
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_confidences.extend(confidences.cpu().numpy())
            all_entropies.extend(entropies.cpu().numpy())

            for k in [1, 3, 5]:
                all_top_k_results[k].extend([top_k_accs[k]] * labels.size(0))

            total_loss += loss.item() * images.size(0)
            total_samples += images.size(0)

    # Convert to numpy arrays
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    all_confidences = np.array(all_confidences)
    all_entropies = np.array(all_entropies)

    # Calculate basic metrics
    avg_loss = total_loss / total_samples

    # Calculate top-k accuracies
    top1 = np.mean([pred == label for pred, label in zip(all_predictions, all_labels)])
    top3 = np.mean(all_top_k_results[3])
    top5 = np.mean(all_top_k_results[5])

    # Calculate precision, recall, f1-score
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='macro', zero_division=0
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='weighted', zero_division=0
    )

    # Calculate confidence and calibration metrics
    avg_confidence = np.mean(all_confidences)
    std_confidence = np.std(all_confidences)
    avg_entropy = np.mean(all_entropies)

    # Calculate ECE
    accuracies = (all_predictions == all_labels).astype(float)
    ece = calculate_ece(torch.tensor(all_confidences), torch.tensor(accuracies))

    return {
        'test_loss': avg_loss,
        'top1': top1,
        'top3': top3,
        'top5': top5,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'f1_macro': f1_macro,
        'precision_weighted': precision_weighted,
        'recall_weighted': recall_weighted,
        'f1_weighted': f1_weighted,
        'avg_confidence': avg_confidence,
        'std_confidence': std_confidence,
        'avg_entropy': avg_entropy,
        'ece': ece
    }

# ==========================================
# Main Execution
# ==========================================

if __name__ == '__main__':
    # Set global seeds for full reproducibility
    torch.manual_seed(CONFIG["split_seed"])
    np.random.seed(CONFIG["split_seed"])
    if torch.cuda.is_available():
        torch.cuda.manual_seed(CONFIG["split_seed"])
        torch.cuda.manual_seed_all(CONFIG["split_seed"])
        # Ensure deterministic behavior on CUDA
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    print(f"Using device: {CONFIG['device']}")
    print(f"Reproducibility seed: {CONFIG['split_seed']}")

    # Get data loaders
    dataloaders = get_dataloaders(CONFIG)
    if dataloaders is None:
        print("Could not prepare data. Halting execution.", file=sys.stderr)
        sys.exit(1)  # Exit if data preparation failed

    train_loader, val_loader, test_loader = dataloaders

    # Initialize and train model
    model = ResNetTransformer18(num_classes=CONFIG["num_classes"], t_config=TRANSFORMER_CONFIG)
    trained_model = train_and_validate(model, train_loader, val_loader, CONFIG)

    # Comprehensive evaluation on the test set
    print("--- Starting Comprehensive Evaluation on Test Set ---")
    results = comprehensive_evaluation(trained_model, test_loader, nn.CrossEntropyLoss(), CONFIG["device"])

    # Print all requested metrics
    print("\n--- Top-K Accuracy Results ---")
    print(f"Top-1 Accuracy: {results['top1'] * 100:.2f}%")
    print(f"Top-3 Accuracy: {results['top3'] * 100:.2f}%")
    print(f"Top-5 Accuracy: {results['top5'] * 100:.2f}%")

    print("\n--- Additional Performance Metrics ---")
    print(f"Macro Average Precision: {results['precision_macro']:.4f}")
    print(f"Macro Average Recall: {results['recall_macro']:.4f}")
    print(f"Macro Average F1-Score: {results['f1_macro']:.4f}")
    print(f"Weighted Average Precision: {results['precision_weighted']:.4f}")
    print(f"Weighted Average Recall: {results['recall_weighted']:.4f}")
    print(f"Weighted Average F1-Score: {results['f1_weighted']:.4f}")

    print("\n--- Model Confidence & Calibration Metrics ---")
    print(f"Average Prediction Confidence: {results['avg_confidence']:.4f}")
    print(f"Confidence Standard Deviation: {results['std_confidence']:.4f}")
    print(f"Average Prediction Entropy: {results['avg_entropy']:.4f}")
    print(f"Expected Calibration Error: {results['ece']:.4f}")

    print(f"\nFinal Test Loss: {results['test_loss']:.4f}")
    print("===============================================")

Using device: cuda
Reproducibility seed: 42
Loading metadata from: ./pad-ufes-20/metadata.csv
Total samples: 2298
Diagnostic classes: diagnostic
BCC    845
ACK    730
NEV    244
SEK    235
SCC    192
MEL     52
Name: count, dtype: int64
✅ Dataset verification complete
Found 6 classes: ['ACK', 'BCC', 'MEL', 'NEV', 'SCC', 'SEK']
Dataset splits - Train: 1608, Val: 344, Test: 346, Total: 2298
Found 6 classes: ['ACK', 'BCC', 'MEL', 'NEV', 'SCC', 'SEK']
Found 6 classes: ['ACK', 'BCC', 'MEL', 'NEV', 'SCC', 'SEK']
Found 6 classes: ['ACK', 'BCC', 'MEL', 'NEV', 'SCC', 'SEK']

--- Starting Training ---




Epoch 1/100 | Time: 75.83s | Train Loss: 1.5568, Train Acc: 37.69% | Val Loss: 1.4853, Val Acc: 39.83%

Model saved to ./best_model_pad_ufes.pth




Epoch 2/100 | Time: 76.14s | Train Loss: 1.4967, Train Acc: 37.31% | Val Loss: 1.4784, Val Acc: 36.34%




Epoch 3/100 | Time: 76.21s | Train Loss: 1.5075, Train Acc: 33.46% | Val Loss: 1.4731, Val Acc: 36.34%




Epoch 4/100 | Time: 75.53s | Train Loss: 1.4852, Train Acc: 37.44% | Val Loss: 1.4105, Val Acc: 47.97%

Model saved to ./best_model_pad_ufes.pth




Epoch 5/100 | Time: 76.02s | Train Loss: 1.4840, Train Acc: 38.18% | Val Loss: 1.4612, Val Acc: 35.47%




Epoch 6/100 | Time: 75.26s | Train Loss: 1.4579, Train Acc: 37.75% | Val Loss: 1.4297, Val Acc: 40.41%




Epoch 7/100 | Time: 75.02s | Train Loss: 1.4435, Train Acc: 38.50% | Val Loss: 1.4247, Val Acc: 42.44%

Reducing learning rate from 1.00e-03 to 2.00e-04




Epoch 8/100 | Time: 76.09s | Train Loss: 1.4559, Train Acc: 38.81% | Val Loss: 1.4201, Val Acc: 42.73%




Epoch 9/100 | Time: 75.82s | Train Loss: 1.4322, Train Acc: 39.18% | Val Loss: 1.4250, Val Acc: 40.70%




Epoch 10/100 | Time: 76.25s | Train Loss: 1.4129, Train Acc: 38.74% | Val Loss: 1.4322, Val Acc: 43.90%

Reducing learning rate from 2.00e-04 to 4.00e-05




Epoch 11/100 | Time: 76.60s | Train Loss: 1.4064, Train Acc: 41.54% | Val Loss: 1.3808, Val Acc: 44.19%

Early stopping at epoch 11
Restored model weights from the end of the best epoch: 4
--- Training Finished ---

--- Starting Comprehensive Evaluation on Test Set ---


Evaluating: 100%|██████████| 6/6 [00:12<00:00,  2.06s/it]


--- Top-K Accuracy Results ---
Top-1 Accuracy: 44.51%
Top-3 Accuracy: 84.97%
Top-5 Accuracy: 98.84%

--- Additional Performance Metrics ---
Macro Average Precision: 0.1954
Macro Average Recall: 0.2172
Macro Average F1-Score: 0.1941
Weighted Average Precision: 0.3527
Weighted Average Recall: 0.4451
Weighted Average F1-Score: 0.3787

--- Model Confidence & Calibration Metrics ---
Average Prediction Confidence: 0.3865
Confidence Standard Deviation: 0.0523
Average Prediction Entropy: 1.4136
Expected Calibration Error: 0.0586

Final Test Loss: 1.3770





## Hybrid Transformer on ResNet-18 (Learnable PE)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import pandas as pd
from PIL import Image
import os
import requests
import zipfile
import shutil
from tqdm import tqdm
import time
import sys
import numpy as np
from sklearn.metrics import precision_recall_fscore_support
import torch.nn.functional as F

# --- Configuration Parameters ---
CONFIG = {
    "data_path": "./pad-ufes-20/",
    "dataset_url": "https://www.kaggle.com/api/v1/datasets/download/maxjen/pad-ufes-20",  # Kaggle API URL
    "batch_size": 64,  # Reduced due to smaller dataset
    "num_epochs": 100,
    "learning_rate": 0.001,
    "weight_decay": 1e-4,
    "split_seed": 42,
    "train_split": 0.70,
    "val_split": 0.15,
    "test_split": 0.15,
    "num_classes": 6,  # PAD-UFES-20 has 6 diagnostic classes
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "checkpoint_path": "./best_model_pad_ufes.pth",
}

# --- Transformer Configuration ---
TRANSFORMER_CONFIG = {
    "embedding_dim": 256,       # Dimension of the tokens fed to the transformer
    "nhead": 8,                 # Number of attention heads
    "num_encoder_layers": 3,    # Number of transformer encoder layers
    "dim_feedforward": 512,     # Hidden dimension in the feed-forward network
    "dropout": 0.1,
}

# ==========================================
# Callback Classes
# ==========================================

class Callback:
    """Base callback class"""
    def on_epoch_end(self, epoch, logs=None):
        pass

    def on_training_end(self):
        pass

class ReduceLROnPlateau(Callback):
    """Reduce learning rate when a metric has stopped improving"""
    def __init__(self, optimizer, monitor='val_accuracy', factor=0.2, patience=3, min_lr=1e-7, verbose=1):
        self.optimizer = optimizer
        self.monitor = monitor
        self.factor = factor
        self.patience = patience
        self.min_lr = min_lr
        self.verbose = verbose
        self.wait = 0
        self.best = None
        self.mode = 'min' if 'loss' in monitor else 'max'

    def on_epoch_end(self, epoch, logs=None):
        if logs is None:
            return

        current = logs.get(self.monitor)
        if current is None:
            return

        if self.best is None:
            self.best = current
            return

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                old_lr = self.optimizer.param_groups[0]['lr']
                new_lr = max(old_lr * self.factor, self.min_lr)
                if new_lr != old_lr:
                    self.optimizer.param_groups[0]['lr'] = new_lr
                    if self.verbose:
                        print(f"\nReducing learning rate from {old_lr:.2e} to {new_lr:.2e}")
                    self.wait = 0

class EarlyStopping(Callback):
    """Stop training when a monitored metric has stopped improving"""
    def __init__(self, monitor='val_accuracy', patience=7, restore_best_weights=True, verbose=1):
        self.monitor = monitor
        self.patience = patience
        self.restore_best_weights = restore_best_weights
        self.verbose = verbose
        self.wait = 0
        self.best = None
        self.best_weights = None
        self.mode = 'min' if 'loss' in monitor else 'max'
        self.stopped_epoch = 0

    def on_epoch_end(self, epoch, logs=None, model=None):
        if logs is None:
            return False

        current = logs.get(self.monitor)
        if current is None:
            return False

        if self.best is None:
            self.best = current
            if model is not None and self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
            return False

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            self.wait = 0
            if model is not None and self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                if self.verbose:
                    print(f"\nEarly stopping at epoch {epoch + 1}")
                return True
        return False

    def on_training_end(self, model=None):
        if self.stopped_epoch > 0 and self.verbose:
            print(f"Restored model weights from the end of the best epoch: {self.stopped_epoch + 1 - self.patience}")
        if model is not None and self.restore_best_weights and self.best_weights is not None:
            model.load_state_dict(self.best_weights)

class ModelCheckpoint(Callback):
    """Save the model after every epoch"""
    def __init__(self, filepath, save_best_only=True, monitor='val_accuracy', verbose=1):
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.monitor = monitor
        self.verbose = verbose
        self.best = None
        self.mode = 'min' if 'loss' in monitor else 'max'

    def on_epoch_end(self, epoch, logs=None, model=None):
        if logs is None or model is None:
            return

        current = logs.get(self.monitor)
        if current is None:
            return

        if not self.save_best_only:
            filepath = self.filepath.replace('.pth', f'_epoch_{epoch + 1}.pth')
            torch.save(model.state_dict(), filepath)
            if self.verbose:
                print(f"\nModel saved to {filepath}")
            return

        if self.best is None:
            self.best = current
            torch.save(model.state_dict(), self.filepath)
            if self.verbose:
                print(f"\nModel saved to {self.filepath}")
            return

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            torch.save(model.state_dict(), self.filepath)
            if self.verbose:
                print(f"\nModel saved to {self.filepath}")

# ==========================================
# Data Preprocessing
# ==========================================

def download_and_extract_pad_ufes(url, dest_path):
    """Downloads and extracts PAD-UFES-20 dataset from Kaggle with progress and robust error handling."""
    if os.path.exists(dest_path):
        print("Dataset directory already exists.")
        return True

    zip_path = dest_path.rstrip('/') + ".zip"
    extract_to_dir = os.path.abspath(os.path.join(dest_path, os.pardir))

    print(f"Downloading PAD-UFES-20 dataset from Kaggle...")
    try:
        # Add headers to mimic browser request
        headers = {
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
        }

        with requests.get(url, stream=True, timeout=60, headers=headers) as r:
            r.raise_for_status()
            total_size = int(r.headers.get('content-length', 0))
            with open(zip_path, 'wb') as f, tqdm(
                total=total_size, unit='iB', unit_scale=True, desc="PAD-UFES-20"
            ) as progress_bar:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
                    progress_bar.update(len(chunk))

        print("Extracting...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to_dir)

        # Check if extraction created nested folders and fix if needed
        extracted_folders = [f for f in os.listdir(extract_to_dir) if os.path.isdir(os.path.join(extract_to_dir, f))]
        if 'pad-ufes-20' in extracted_folders:
            # Dataset extracted correctly
            pass
        else:
            # Look for the actual dataset folder
            for folder in extracted_folders:
                folder_path = os.path.join(extract_to_dir, folder)
                if os.path.exists(os.path.join(folder_path, 'metadata.csv')):
                    # This is our dataset folder, rename it
                    os.rename(folder_path, dest_path)
                    break

        return True

    except requests.exceptions.RequestException as e:
        print(f"\nError downloading file: {e}", file=sys.stderr)
        print("Please download manually from: https://www.kaggle.com/datasets/maxjen/pad-ufes-20", file=sys.stderr)
        return False

    finally:
        if os.path.exists(zip_path):
            os.remove(zip_path)

class PADUFESDataset(Dataset):
    """Custom dataset class for PAD-UFES-20."""
    def __init__(self, data_dir, metadata_df, transform=None):
        self.data_dir = data_dir
        self.metadata_df = metadata_df.reset_index(drop=True)
        self.transform = transform

        # Create label mapping
        self.classes = sorted(metadata_df['diagnostic'].unique())
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

        print(f"Found {len(self.classes)} classes: {self.classes}")

    def __len__(self):
        return len(self.metadata_df)

    def __getitem__(self, idx):
        row = self.metadata_df.iloc[idx]

        # Handle different possible image path formats
        img_name = row['img_id']
        if not img_name.endswith('.png'):
            img_name += '.png'

        # Try different possible folder structures
        possible_paths = [
            os.path.join(self.data_dir, 'images', img_name),
            os.path.join(self.data_dir, 'imgs', img_name),
            os.path.join(self.data_dir, img_name),
        ]

        img_path = None
        for path in possible_paths:
            if os.path.exists(path):
                img_path = path
                break

        if img_path is None:
            # Last resort: search for the file
            for root, dirs, files in os.walk(self.data_dir):
                if img_name in files:
                    img_path = os.path.join(root, img_name)
                    break

        if img_path is None:
            print(f"Warning: Image {img_name} not found, using black placeholder")
            image = Image.new('RGB', (224, 224), color='black')
        else:
            try:
                image = Image.open(img_path).convert('RGB')
            except Exception as e:
                print(f"Error loading image {img_path}: {e}")
                image = Image.new('RGB', (224, 224), color='black')

        label = self.class_to_idx[row['diagnostic']]

        if self.transform:
            image = self.transform(image)

        return image, label

def prepare_pad_ufes_data(data_path):
    """Prepare PAD-UFES-20 data by reading metadata and organizing structure."""

    # Try different possible metadata file locations
    possible_metadata_paths = [
        os.path.join(data_path, 'metadata.csv'),
        os.path.join(data_path, 'PAD-UFES-20_metadata.csv'),
    ]

    metadata_path = None
    for path in possible_metadata_paths:
        if os.path.exists(path):
            metadata_path = path
            break

    if metadata_path is None:
        # Search for any CSV file that might be the metadata
        for root, dirs, files in os.walk(data_path):
            for file in files:
                if file.endswith('.csv') and 'metadata' in file.lower():
                    metadata_path = os.path.join(root, file)
                    break
            if metadata_path:
                break

    if metadata_path is None:
        print(f"Metadata file not found in {data_path}")
        return None

    print(f"Loading metadata from: {metadata_path}")
    try:
        metadata_df = pd.read_csv(metadata_path)
    except Exception as e:
        print(f"Error reading metadata file: {e}")
        return None

    # Print dataset info
    print(f"Total samples: {len(metadata_df)}")
    print(f"Diagnostic classes: {metadata_df['diagnostic'].value_counts()}")

    # Verify dataset integrity
    expected_samples = 2298
    expected_classes = 6

    actual_classes = len(metadata_df['diagnostic'].unique())
    if len(metadata_df) != expected_samples:
        print(f"⚠️  Warning: Expected {expected_samples} samples, got {len(metadata_df)}")
    if actual_classes != expected_classes:
        print(f"⚠️  Warning: Expected {expected_classes} classes, got {actual_classes}")

    print("✅ Dataset verification complete")
    return metadata_df

def get_dataloaders(config):
    """Downloads, prepares, and splits the PAD-UFES-20 data, returning DataLoaders."""

    # Try to download if URL is provided and dataset doesn't exist
    if config.get("dataset_url") and not os.path.exists(config["data_path"]):
        if not download_and_extract_pad_ufes(config["dataset_url"], config["data_path"]):
            print("\nAutomatic download failed. Please download manually:")
            print("1. Visit: https://www.kaggle.com/datasets/maxjen/pad-ufes-20")
            print("2. Download the dataset")
            print("3. Extract to:", config["data_path"])
            return None

    # Check if dataset exists
    if not os.path.exists(config["data_path"]):
        print(f"\nDataset not found at {config['data_path']}")
        print("Please download the PAD-UFES-20 dataset:")
        print("1. Visit: https://www.kaggle.com/datasets/maxjen/pad-ufes-20")
        print("2. Download the dataset")
        print("3. Extract to:", config["data_path"])
        print("4. Ensure the structure includes metadata.csv and image files")
        return None

    # Prepare metadata
    metadata_df = prepare_pad_ufes_data(config["data_path"])
    if metadata_df is None:
        return None

    # Define transforms
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    val_test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Create full dataset
    full_dataset = PADUFESDataset(config["data_path"], metadata_df, transform=None)

    # Ensure reproducible splits
    torch.manual_seed(config["split_seed"])
    np.random.seed(config["split_seed"])

    total_size = len(full_dataset)
    train_size = int(config["train_split"] * total_size)
    val_size = int(config["val_split"] * total_size)
    test_size = int(config["test_split"] * total_size)

    # Adjust sizes to ensure they sum to total_size
    actual_total = train_size + val_size + test_size
    if actual_total != total_size:
        test_size += (total_size - actual_total)

    print(f"Dataset splits - Train: {train_size}, Val: {val_size}, Test: {test_size}, Total: {total_size}")

    # Create indices for splitting
    indices = list(range(total_size))
    np.random.shuffle(indices)

    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size + val_size]
    test_indices = indices[train_size + val_size:]

    # Create subset datasets with appropriate transforms
    train_dataset = PADUFESDataset(
        config["data_path"],
        metadata_df.iloc[train_indices],
        transform=train_transform
    )
    val_dataset = PADUFESDataset(
        config["data_path"],
        metadata_df.iloc[val_indices],
        transform=val_test_transform
    )
    test_dataset = PADUFESDataset(
        config["data_path"],
        metadata_df.iloc[test_indices],
        transform=val_test_transform
    )

    # Set worker init function for reproducible DataLoader behavior
    def worker_init_fn(worker_id):
        np.random.seed(config["split_seed"] + worker_id)

    train_loader = DataLoader(
        train_dataset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn,
        generator=torch.Generator().manual_seed(config["split_seed"])
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )

    return train_loader, val_loader, test_loader

# ==========================================
# Model Architecture and Training
# ==========================================

class NonResidualBlock(nn.Module):
    """A standard convolutional block WITHOUT the residual connection."""
    def __init__(self, in_channels, out_channels, stride=1):
        super(NonResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        return out

class ResNetTransformer(nn.Module):
    """
    A hybrid architecture combining a non-residual CNN backbone with a Transformer encoder,
    including positional embeddings.
    """
    def __init__(self, block, layers, num_classes, t_config):
        super(ResNetTransformer, self).__init__()
        self.in_channels = 64
        self.embedding_dim = t_config["embedding_dim"]
        self.num_tokens = len(layers) + 1  # 4 layers + 1 initial capture

        # 1. CNN Backbone (Feature Extractor)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        # 2. Projection heads to create tokens from feature maps
        self.projections = nn.ModuleList([
            self._create_projection(64, self.embedding_dim),
            self._create_projection(64, self.embedding_dim),
            self._create_projection(128, self.embedding_dim),
            self._create_projection(256, self.embedding_dim),
            self._create_projection(512, self.embedding_dim)
        ])

        # 3. Learnable Positional Embedding
        self.positional_embedding = nn.Parameter(torch.randn(1, self.num_tokens, self.embedding_dim))

        # 4. Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embedding_dim,
            nhead=t_config["nhead"],
            dim_feedforward=t_config["dim_feedforward"],
            dropout=t_config["dropout"],
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=t_config["num_encoder_layers"]
        )

        # 5. Final Classifier
        self.classifier = nn.Linear(self.embedding_dim, num_classes)

    def _create_projection(self, in_features, out_features):
        return nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(in_features, out_features)
        )

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        current_in_channels = self.in_channels
        for s in strides:
            layers.append(block(current_in_channels, out_channels, s))
            current_in_channels = out_channels
        self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        # 1. Pass through CNN backbone and capture features
        features = []
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        features.append(x)  # Capture 1: After initial maxpool

        x = self.layer1(x); features.append(x)  # Capture 2: After layer1
        x = self.layer2(x); features.append(x)  # Capture 3: After layer2
        x = self.layer3(x); features.append(x)  # Capture 4: After layer3
        x = self.layer4(x); features.append(x)  # Capture 5: After layer4

        # 2. Project features to tokens
        tokens = [self.projections[i](feature_map) for i, feature_map in enumerate(features)]

        # 3. Stack tokens into a sequence
        token_sequence = torch.stack(tokens, dim=1)

        # 4. Add positional embedding
        token_sequence += self.positional_embedding

        # 5. Pass through Transformer
        transformer_out = self.transformer_encoder(token_sequence)

        # 6. Aggregate and classify
        aggregated_vector = transformer_out.mean(dim=1)
        logits = self.classifier(aggregated_vector)

        return logits

def evaluate_model(model, data_loader, criterion, device):
    """Evaluates the model on a given dataset."""
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_loss = total_loss / total
    accuracy = 100 * correct / total
    return avg_loss, accuracy

def train_and_validate(model, train_loader, val_loader, config):
    """Main training loop with callbacks."""
    device = config["device"]
    model.to(device)

    # Set seeds for reproducible training
    torch.manual_seed(config["split_seed"])
    np.random.seed(config["split_seed"])
    if torch.cuda.is_available():
        torch.cuda.manual_seed(config["split_seed"])
        torch.cuda.manual_seed_all(config["split_seed"])

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])

    # Initialize callbacks
    callbacks = [
        ReduceLROnPlateau(
            optimizer=optimizer,
            monitor='val_accuracy',
            factor=0.2,
            patience=3,
            min_lr=1e-7,
            verbose=1
        ),
        EarlyStopping(
            monitor='val_accuracy',
            patience=7,
            restore_best_weights=True,
            verbose=1
        ),
        ModelCheckpoint(
            filepath=config["checkpoint_path"],
            save_best_only=True,
            monitor='val_accuracy',
            verbose=1
        )
    ]

    print("\n--- Starting Training ---")
    early_stop = False

    for epoch in range(config["num_epochs"]):
        if early_stop:
            break

        model.train()
        start_time = time.time()
        running_loss, train_correct, train_total = 0.0, 0, 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Train]", leave=False)
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            progress_bar.set_postfix(loss=loss.item())

        train_loss = running_loss / train_total
        train_acc = 100 * train_correct / train_total
        val_loss, val_acc = evaluate_model(model, val_loader, criterion, device)

        print(f"Epoch {epoch+1}/{config['num_epochs']} | Time: {time.time() - start_time:.2f}s | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Prepare logs for callbacks
        logs = {
            'train_loss': train_loss,
            'train_accuracy': train_acc,
            'val_loss': val_loss,
            'val_accuracy': val_acc
        }

        # Execute callbacks
        for callback in callbacks:
            if isinstance(callback, EarlyStopping):
                if callback.on_epoch_end(epoch, logs, model):
                    early_stop = True
                    break
            elif isinstance(callback, ModelCheckpoint):
                callback.on_epoch_end(epoch, logs, model)
            else:
                callback.on_epoch_end(epoch, logs)

    # Execute callback cleanup
    for callback in callbacks:
        if isinstance(callback, EarlyStopping):
            callback.on_training_end(model)
        else:
            callback.on_training_end()

    print("--- Training Finished ---\n")
    return model

# ==========================================
# Evaluation
# ==========================================

def calculate_top_k_accuracy(outputs, labels, k_values=[1, 3, 5]):
    """Calculate top-k accuracy for given k values."""
    batch_size = labels.size(0)
    _, pred = outputs.topk(max(k_values), 1, True, True)
    pred = pred.t()
    correct = pred.eq(labels.view(1, -1).expand_as(pred))

    top_k_accuracies = {}
    for k in k_values:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        top_k_accuracies[k] = correct_k.item() / batch_size

    return top_k_accuracies

def calculate_entropy(probs):
    """Calculate entropy of probability distributions."""
    # Add small epsilon to avoid log(0)
    epsilon = 1e-8
    probs = probs + epsilon
    entropy = -torch.sum(probs * torch.log(probs), dim=1)
    return entropy

def calculate_ece(confidences, accuracies, n_bins=10):
    """Calculate Expected Calibration Error."""
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = 0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (confidences > bin_lower.item()) & (confidences <= bin_upper.item())
        prop_in_bin = in_bin.float().mean()

        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

    return ece.item()

def comprehensive_evaluation(model, test_loader, criterion, device):
    """Comprehensive evaluation with all requested metrics."""
    model.eval()

    all_predictions = []
    all_labels = []
    all_confidences = []
    all_entropies = []
    all_top_k_results = {1: [], 3: [], 5: []}
    total_loss = 0.0
    total_samples = 0

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Convert outputs to probabilities
            probs = torch.softmax(outputs, dim=1)

            # Get predictions and confidences
            confidences, predictions = torch.max(probs, dim=1)

            # Calculate entropy
            entropies = calculate_entropy(probs)

            # Calculate top-k accuracies
            top_k_accs = calculate_top_k_accuracy(outputs, labels, [1, 3, 5])

            # Store results
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_confidences.extend(confidences.cpu().numpy())
            all_entropies.extend(entropies.cpu().numpy())

            for k in [1, 3, 5]:
                all_top_k_results[k].extend([top_k_accs[k]] * labels.size(0))

            total_loss += loss.item() * images.size(0)
            total_samples += images.size(0)

    # Convert to numpy arrays
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    all_confidences = np.array(all_confidences)
    all_entropies = np.array(all_entropies)

    # Calculate basic metrics
    avg_loss = total_loss / total_samples

    # Calculate top-k accuracies
    top1 = np.mean([pred == label for pred, label in zip(all_predictions, all_labels)])
    top3 = np.mean(all_top_k_results[3])
    top5 = np.mean(all_top_k_results[5])

    # Calculate precision, recall, f1-score
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='macro', zero_division=0
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='weighted', zero_division=0
    )

    # Calculate confidence and calibration metrics
    avg_confidence = np.mean(all_confidences)
    std_confidence = np.std(all_confidences)
    avg_entropy = np.mean(all_entropies)

    # Calculate ECE
    accuracies = (all_predictions == all_labels).astype(float)
    ece = calculate_ece(torch.tensor(all_confidences), torch.tensor(accuracies))

    return {
        'test_loss': avg_loss,
        'top1': top1,
        'top3': top3,
        'top5': top5,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'f1_macro': f1_macro,
        'precision_weighted': precision_weighted,
        'recall_weighted': recall_weighted,
        'f1_weighted': f1_weighted,
        'avg_confidence': avg_confidence,
        'std_confidence': std_confidence,
        'avg_entropy': avg_entropy,
        'ece': ece
    }

# ==========================================
# Main Execution
# ==========================================

if __name__ == '__main__':
    # Set global seeds for full reproducibility
    torch.manual_seed(CONFIG["split_seed"])
    np.random.seed(CONFIG["split_seed"])
    if torch.cuda.is_available():
        torch.cuda.manual_seed(CONFIG["split_seed"])
        torch.cuda.manual_seed_all(CONFIG["split_seed"])
        # Ensure deterministic behavior on CUDA
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    print(f"Using device: {CONFIG['device']}")
    print(f"Reproducibility seed: {CONFIG['split_seed']}")

    # Get data loaders
    dataloaders = get_dataloaders(CONFIG)
    if dataloaders is None:
        print("Could not prepare data. Halting execution.", file=sys.stderr)
        sys.exit(1)  # Exit if data preparation failed

    train_loader, val_loader, test_loader = dataloaders

    # Initialize and train model with ResNetTransformer
    model = ResNetTransformer(
        block=NonResidualBlock,
        layers=[2, 2, 2, 2],
        num_classes=CONFIG["num_classes"],
        t_config=TRANSFORMER_CONFIG
    )

    print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")
    print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

    trained_model = train_and_validate(model, train_loader, val_loader, CONFIG)

    # Comprehensive evaluation on the test set
    print("--- Starting Comprehensive Evaluation on Test Set ---")
    results = comprehensive_evaluation(trained_model, test_loader, nn.CrossEntropyLoss(), CONFIG["device"])

    # Print all requested metrics
    print("\n--- Top-K Accuracy Results ---")
    print(f"Top-1 Accuracy: {results['top1'] * 100:.2f}%")
    print(f"Top-3 Accuracy: {results['top3'] * 100:.2f}%")
    print(f"Top-5 Accuracy: {results['top5'] * 100:.2f}%")

    print("\n--- Additional Performance Metrics ---")
    print(f"Macro Average Precision: {results['precision_macro']:.4f}")
    print(f"Macro Average Recall: {results['recall_macro']:.4f}")
    print(f"Macro Average F1-Score: {results['f1_macro']:.4f}")
    print(f"Weighted Average Precision: {results['precision_weighted']:.4f}")
    print(f"Weighted Average Recall: {results['recall_weighted']:.4f}")
    print(f"Weighted Average F1-Score: {results['f1_weighted']:.4f}")

    print("\n--- Model Confidence & Calibration Metrics ---")
    print(f"Average Prediction Confidence: {results['avg_confidence']:.4f}")
    print(f"Confidence Standard Deviation: {results['std_confidence']:.4f}")
    print(f"Average Prediction Entropy: {results['avg_entropy']:.4f}")
    print(f"Expected Calibration Error: {results['ece']:.4f}")

    print(f"\nFinal Test Loss: {results['test_loss']:.4f}")
    print("===============================================")

Using device: cuda
Reproducibility seed: 42
Loading metadata from: ./pad-ufes-20/metadata.csv
Total samples: 2298
Diagnostic classes: diagnostic
BCC    845
ACK    730
NEV    244
SEK    235
SCC    192
MEL     52
Name: count, dtype: int64
✅ Dataset verification complete
Found 6 classes: ['ACK', 'BCC', 'MEL', 'NEV', 'SCC', 'SEK']
Dataset splits - Train: 1608, Val: 344, Test: 346, Total: 2298
Found 6 classes: ['ACK', 'BCC', 'MEL', 'NEV', 'SCC', 'SEK']
Found 6 classes: ['ACK', 'BCC', 'MEL', 'NEV', 'SCC', 'SEK']
Found 6 classes: ['ACK', 'BCC', 'MEL', 'NEV', 'SCC', 'SEK']
Model has 12850246 parameters
Trainable parameters: 12850246

--- Starting Training ---




Epoch 1/100 | Time: 77.47s | Train Loss: 1.6639, Train Acc: 33.27% | Val Loss: 1.5029, Val Acc: 35.47%

Model saved to ./best_model_pad_ufes.pth




Epoch 2/100 | Time: 77.45s | Train Loss: 1.5394, Train Acc: 35.63% | Val Loss: 1.4876, Val Acc: 33.72%




Epoch 3/100 | Time: 77.08s | Train Loss: 1.5424, Train Acc: 34.45% | Val Loss: 1.5223, Val Acc: 35.17%




Epoch 4/100 | Time: 76.42s | Train Loss: 1.5133, Train Acc: 36.26% | Val Loss: 1.4681, Val Acc: 34.30%

Reducing learning rate from 1.00e-03 to 2.00e-04




Epoch 5/100 | Time: 76.30s | Train Loss: 1.4997, Train Acc: 36.07% | Val Loss: 1.4497, Val Acc: 35.47%




Epoch 6/100 | Time: 75.07s | Train Loss: 1.4695, Train Acc: 36.13% | Val Loss: 1.4708, Val Acc: 35.17%




Epoch 7/100 | Time: 75.34s | Train Loss: 1.4640, Train Acc: 36.38% | Val Loss: 1.4472, Val Acc: 37.21%

Model saved to ./best_model_pad_ufes.pth




Epoch 8/100 | Time: 82.09s | Train Loss: 1.4509, Train Acc: 37.56% | Val Loss: 1.3754, Val Acc: 40.70%

Model saved to ./best_model_pad_ufes.pth




Epoch 9/100 | Time: 80.45s | Train Loss: 1.4388, Train Acc: 39.43% | Val Loss: 1.4078, Val Acc: 40.99%

Model saved to ./best_model_pad_ufes.pth




Epoch 10/100 | Time: 79.48s | Train Loss: 1.4272, Train Acc: 38.68% | Val Loss: 1.3685, Val Acc: 43.31%

Model saved to ./best_model_pad_ufes.pth




Epoch 11/100 | Time: 77.37s | Train Loss: 1.4319, Train Acc: 40.30% | Val Loss: 1.3670, Val Acc: 44.48%

Model saved to ./best_model_pad_ufes.pth




Epoch 12/100 | Time: 76.66s | Train Loss: 1.4207, Train Acc: 39.37% | Val Loss: 1.5521, Val Acc: 32.85%




Epoch 13/100 | Time: 76.76s | Train Loss: 1.4090, Train Acc: 40.42% | Val Loss: 1.3502, Val Acc: 45.64%

Model saved to ./best_model_pad_ufes.pth




Epoch 14/100 | Time: 76.74s | Train Loss: 1.3926, Train Acc: 42.85% | Val Loss: 1.3164, Val Acc: 45.93%

Model saved to ./best_model_pad_ufes.pth




Epoch 15/100 | Time: 77.18s | Train Loss: 1.4157, Train Acc: 40.86% | Val Loss: 1.3898, Val Acc: 41.86%




Epoch 16/100 | Time: 76.65s | Train Loss: 1.3981, Train Acc: 43.16% | Val Loss: 1.3656, Val Acc: 49.71%

Model saved to ./best_model_pad_ufes.pth




Epoch 17/100 | Time: 76.13s | Train Loss: 1.3736, Train Acc: 44.53% | Val Loss: 1.2802, Val Acc: 48.84%




Epoch 18/100 | Time: 75.74s | Train Loss: 1.3732, Train Acc: 43.53% | Val Loss: 1.3200, Val Acc: 44.19%




Epoch 19/100 | Time: 76.29s | Train Loss: 1.3735, Train Acc: 42.54% | Val Loss: 1.3070, Val Acc: 49.13%

Reducing learning rate from 2.00e-04 to 4.00e-05




Epoch 20/100 | Time: 77.47s | Train Loss: 1.3436, Train Acc: 44.59% | Val Loss: 1.2741, Val Acc: 52.03%

Model saved to ./best_model_pad_ufes.pth




Epoch 21/100 | Time: 76.09s | Train Loss: 1.2882, Train Acc: 48.07% | Val Loss: 1.2549, Val Acc: 50.87%




Epoch 22/100 | Time: 75.68s | Train Loss: 1.2946, Train Acc: 46.39% | Val Loss: 1.2661, Val Acc: 52.62%

Model saved to ./best_model_pad_ufes.pth




Epoch 23/100 | Time: 76.08s | Train Loss: 1.2710, Train Acc: 48.76% | Val Loss: 1.2534, Val Acc: 51.45%




Epoch 24/100 | Time: 77.15s | Train Loss: 1.2740, Train Acc: 49.38% | Val Loss: 1.2565, Val Acc: 50.87%




Epoch 25/100 | Time: 77.39s | Train Loss: 1.2646, Train Acc: 50.50% | Val Loss: 1.2367, Val Acc: 51.16%

Reducing learning rate from 4.00e-05 to 8.00e-06




Epoch 26/100 | Time: 76.63s | Train Loss: 1.2577, Train Acc: 48.38% | Val Loss: 1.2429, Val Acc: 51.45%




Epoch 27/100 | Time: 77.18s | Train Loss: 1.2381, Train Acc: 48.88% | Val Loss: 1.2415, Val Acc: 52.91%

Model saved to ./best_model_pad_ufes.pth




Epoch 28/100 | Time: 76.95s | Train Loss: 1.2594, Train Acc: 48.26% | Val Loss: 1.2357, Val Acc: 52.03%




Epoch 29/100 | Time: 76.70s | Train Loss: 1.2377, Train Acc: 50.25% | Val Loss: 1.2221, Val Acc: 53.20%

Model saved to ./best_model_pad_ufes.pth




Epoch 30/100 | Time: 76.82s | Train Loss: 1.2497, Train Acc: 50.25% | Val Loss: 1.2293, Val Acc: 52.03%




Epoch 31/100 | Time: 76.87s | Train Loss: 1.2426, Train Acc: 49.13% | Val Loss: 1.2174, Val Acc: 53.20%




Epoch 32/100 | Time: 77.01s | Train Loss: 1.2424, Train Acc: 49.63% | Val Loss: 1.2154, Val Acc: 52.91%

Reducing learning rate from 8.00e-06 to 1.60e-06




Epoch 33/100 | Time: 76.47s | Train Loss: 1.2404, Train Acc: 50.37% | Val Loss: 1.2179, Val Acc: 52.33%




Epoch 34/100 | Time: 76.06s | Train Loss: 1.2498, Train Acc: 49.63% | Val Loss: 1.2157, Val Acc: 52.33%




Epoch 35/100 | Time: 76.38s | Train Loss: 1.2399, Train Acc: 50.12% | Val Loss: 1.2205, Val Acc: 51.74%

Reducing learning rate from 1.60e-06 to 3.20e-07




Epoch 36/100 | Time: 74.73s | Train Loss: 1.2189, Train Acc: 50.56% | Val Loss: 1.2184, Val Acc: 53.20%

Early stopping at epoch 36
Restored model weights from the end of the best epoch: 29
--- Training Finished ---

--- Starting Comprehensive Evaluation on Test Set ---


Evaluating: 100%|██████████| 6/6 [00:12<00:00,  2.02s/it]


--- Top-K Accuracy Results ---
Top-1 Accuracy: 50.58%
Top-3 Accuracy: 88.44%
Top-5 Accuracy: 98.84%

--- Additional Performance Metrics ---
Macro Average Precision: 0.3497
Macro Average Recall: 0.3002
Macro Average F1-Score: 0.3041
Weighted Average Precision: 0.4719
Weighted Average Recall: 0.5058
Weighted Average F1-Score: 0.4731

--- Model Confidence & Calibration Metrics ---
Average Prediction Confidence: 0.5087
Confidence Standard Deviation: 0.0957
Average Prediction Entropy: 1.2184
Expected Calibration Error: 0.0320

Final Test Loss: 1.2696





## Hybrid Transformer on ResNet-18 (RoPE)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import pandas as pd
from PIL import Image
import os
import requests
import zipfile
import shutil
from tqdm import tqdm
import time
import sys
import math
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

# --- Configuration Parameters ---
CONFIG = {
    "data_path": "./pad-ufes-20/",
    "dataset_url": "https://www.kaggle.com/api/v1/datasets/download/maxjen/pad-ufes-20",  # Kaggle API URL
    "batch_size": 64,  # Reduced due to smaller dataset
    "num_epochs": 100,
    "learning_rate": 0.001,
    "weight_decay": 1e-4,
    "split_seed": 42,
    "train_split": 0.70,
    "val_split": 0.15,
    "test_split": 0.15,
    "num_classes": 6,  # PAD-UFES-20 has 6 diagnostic classes
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "checkpoint_path": "./best_model_pad_ufes.pth",
}

# --- Transformer Configuration ---
TRANSFORMER_CONFIG = {
    "embedding_dim": 256,
    "nhead": 8,
    "num_encoder_layers": 3,
    "dim_feedforward": 512,
    "dropout": 0.1,
}

# ==========================================
# Callback Classes
# ==========================================

class Callback:
    """Base callback class"""
    def on_epoch_end(self, epoch, logs=None):
        pass

    def on_training_end(self):
        pass

class ReduceLROnPlateau(Callback):
    """Reduce learning rate when a metric has stopped improving"""
    def __init__(self, optimizer, monitor='val_accuracy', factor=0.2, patience=3, min_lr=1e-7, verbose=1):
        self.optimizer = optimizer
        self.monitor = monitor
        self.factor = factor
        self.patience = patience
        self.min_lr = min_lr
        self.verbose = verbose
        self.wait = 0
        self.best = None
        self.mode = 'min' if 'loss' in monitor else 'max'

    def on_epoch_end(self, epoch, logs=None):
        if logs is None:
            return

        current = logs.get(self.monitor)
        if current is None:
            return

        if self.best is None:
            self.best = current
            return

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                old_lr = self.optimizer.param_groups[0]['lr']
                new_lr = max(old_lr * self.factor, self.min_lr)
                if new_lr != old_lr:
                    self.optimizer.param_groups[0]['lr'] = new_lr
                    if self.verbose:
                        print(f"\nReducing learning rate from {old_lr:.2e} to {new_lr:.2e}")
                    self.wait = 0

class EarlyStopping(Callback):
    """Stop training when a monitored metric has stopped improving"""
    def __init__(self, monitor='val_accuracy', patience=7, restore_best_weights=True, verbose=1):
        self.monitor = monitor
        self.patience = patience
        self.restore_best_weights = restore_best_weights
        self.verbose = verbose
        self.wait = 0
        self.best = None
        self.best_weights = None
        self.mode = 'min' if 'loss' in monitor else 'max'
        self.stopped_epoch = 0

    def on_epoch_end(self, epoch, logs=None, model=None):
        if logs is None:
            return False

        current = logs.get(self.monitor)
        if current is None:
            return False

        if self.best is None:
            self.best = current
            if model is not None and self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
            return False

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            self.wait = 0
            if model is not None and self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                if self.verbose:
                    print(f"\nEarly stopping at epoch {epoch + 1}")
                return True
        return False

    def on_training_end(self, model=None):
        if self.stopped_epoch > 0 and self.verbose:
            print(f"Restored model weights from the end of the best epoch: {self.stopped_epoch + 1 - self.patience}")
        if model is not None and self.restore_best_weights and self.best_weights is not None:
            model.load_state_dict(self.best_weights)

class ModelCheckpoint(Callback):
    """Save the model after every epoch"""
    def __init__(self, filepath, save_best_only=True, monitor='val_accuracy', verbose=1):
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.monitor = monitor
        self.verbose = verbose
        self.best = None
        self.mode = 'min' if 'loss' in monitor else 'max'

    def on_epoch_end(self, epoch, logs=None, model=None):
        if logs is None or model is None:
            return

        current = logs.get(self.monitor)
        if current is None:
            return

        if not self.save_best_only:
            filepath = self.filepath.replace('.pth', f'_epoch_{epoch + 1}.pth')
            torch.save(model.state_dict(), filepath)
            if self.verbose:
                print(f"\nModel saved to {filepath}")
            return

        if self.best is None:
            self.best = current
            torch.save(model.state_dict(), self.filepath)
            if self.verbose:
                print(f"\nModel saved to {self.filepath}")
            return

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            torch.save(model.state_dict(), self.filepath)
            if self.verbose:
                print(f"\nModel saved to {self.filepath}")

# ==========================================
# Data Preprocessing
# ==========================================

def download_and_extract_pad_ufes(url, dest_path):
    """Downloads and extracts PAD-UFES-20 dataset from Kaggle with progress and robust error handling."""
    if os.path.exists(dest_path):
        print("Dataset directory already exists.")
        return True

    zip_path = dest_path.rstrip('/') + ".zip"
    extract_to_dir = os.path.abspath(os.path.join(dest_path, os.pardir))

    print(f"Downloading PAD-UFES-20 dataset from Kaggle...")
    try:
        # Add headers to mimic browser request
        headers = {
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
        }

        with requests.get(url, stream=True, timeout=60, headers=headers) as r:
            r.raise_for_status()
            total_size = int(r.headers.get('content-length', 0))
            with open(zip_path, 'wb') as f, tqdm(
                total=total_size, unit='iB', unit_scale=True, desc="PAD-UFES-20"
            ) as progress_bar:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
                    progress_bar.update(len(chunk))

        print("Extracting...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to_dir)

        # Check if extraction created nested folders and fix if needed
        extracted_folders = [f for f in os.listdir(extract_to_dir) if os.path.isdir(os.path.join(extract_to_dir, f))]
        if 'pad-ufes-20' in extracted_folders:
            # Dataset extracted correctly
            pass
        else:
            # Look for the actual dataset folder
            for folder in extracted_folders:
                folder_path = os.path.join(extract_to_dir, folder)
                if os.path.exists(os.path.join(folder_path, 'metadata.csv')):
                    # This is our dataset folder, rename it
                    os.rename(folder_path, dest_path)
                    break

        return True

    except requests.exceptions.RequestException as e:
        print(f"\nError downloading file: {e}", file=sys.stderr)
        print("Please download manually from: https://www.kaggle.com/datasets/maxjen/pad-ufes-20", file=sys.stderr)
        return False

    finally:
        if os.path.exists(zip_path):
            os.remove(zip_path)

class PADUFESDataset(Dataset):
    """Custom dataset class for PAD-UFES-20."""
    def __init__(self, data_dir, metadata_df, transform=None):
        self.data_dir = data_dir
        self.metadata_df = metadata_df.reset_index(drop=True)
        self.transform = transform

        # Create label mapping
        self.classes = sorted(metadata_df['diagnostic'].unique())
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

        print(f"Found {len(self.classes)} classes: {self.classes}")

    def __len__(self):
        return len(self.metadata_df)

    def __getitem__(self, idx):
        row = self.metadata_df.iloc[idx]

        # Handle different possible image path formats
        img_name = row['img_id']
        if not img_name.endswith('.png'):
            img_name += '.png'

        # Try different possible folder structures
        possible_paths = [
            os.path.join(self.data_dir, 'images', img_name),
            os.path.join(self.data_dir, 'imgs', img_name),
            os.path.join(self.data_dir, img_name),
        ]

        img_path = None
        for path in possible_paths:
            if os.path.exists(path):
                img_path = path
                break

        if img_path is None:
            # Last resort: search for the file
            for root, dirs, files in os.walk(self.data_dir):
                if img_name in files:
                    img_path = os.path.join(root, img_name)
                    break

        if img_path is None:
            print(f"Warning: Image {img_name} not found, using black placeholder")
            image = Image.new('RGB', (224, 224), color='black')
        else:
            try:
                image = Image.open(img_path).convert('RGB')
            except Exception as e:
                print(f"Error loading image {img_path}: {e}")
                image = Image.new('RGB', (224, 224), color='black')

        label = self.class_to_idx[row['diagnostic']]

        if self.transform:
            image = self.transform(image)

        return image, label

def prepare_pad_ufes_data(data_path):
    """Prepare PAD-UFES-20 data by reading metadata and organizing structure."""

    # Try different possible metadata file locations
    possible_metadata_paths = [
        os.path.join(data_path, 'metadata.csv'),
        os.path.join(data_path, 'PAD-UFES-20_metadata.csv'),
    ]

    metadata_path = None
    for path in possible_metadata_paths:
        if os.path.exists(path):
            metadata_path = path
            break

    if metadata_path is None:
        # Search for any CSV file that might be the metadata
        for root, dirs, files in os.walk(data_path):
            for file in files:
                if file.endswith('.csv') and 'metadata' in file.lower():
                    metadata_path = os.path.join(root, file)
                    break
            if metadata_path:
                break

    if metadata_path is None:
        print(f"Metadata file not found in {data_path}")
        return None

    print(f"Loading metadata from: {metadata_path}")
    try:
        metadata_df = pd.read_csv(metadata_path)
    except Exception as e:
        print(f"Error reading metadata file: {e}")
        return None

    # Print dataset info
    print(f"Total samples: {len(metadata_df)}")
    print(f"Diagnostic classes: {metadata_df['diagnostic'].value_counts()}")

    # Verify dataset integrity
    expected_samples = 2298
    expected_classes = 6

    actual_classes = len(metadata_df['diagnostic'].unique())
    if len(metadata_df) != expected_samples:
        print(f"⚠️  Warning: Expected {expected_samples} samples, got {len(metadata_df)}")
    if actual_classes != expected_classes:
        print(f"⚠️  Warning: Expected {expected_classes} classes, got {actual_classes}")

    print("✅ Dataset verification complete")
    return metadata_df

def get_dataloaders(config):
    """Downloads, prepares, and splits the PAD-UFES-20 data, returning DataLoaders."""

    # Try to download if URL is provided and dataset doesn't exist
    if config.get("dataset_url") and not os.path.exists(config["data_path"]):
        if not download_and_extract_pad_ufes(config["dataset_url"], config["data_path"]):
            print("\nAutomatic download failed. Please download manually:")
            print("1. Visit: https://www.kaggle.com/datasets/maxjen/pad-ufes-20")
            print("2. Download the dataset")
            print("3. Extract to:", config["data_path"])
            return None

    # Check if dataset exists
    if not os.path.exists(config["data_path"]):
        print(f"\nDataset not found at {config['data_path']}")
        print("Please download the PAD-UFES-20 dataset:")
        print("1. Visit: https://www.kaggle.com/datasets/maxjen/pad-ufes-20")
        print("2. Download the dataset")
        print("3. Extract to:", config["data_path"])
        print("4. Ensure the structure includes metadata.csv and image files")
        return None

    # Prepare metadata
    metadata_df = prepare_pad_ufes_data(config["data_path"])
    if metadata_df is None:
        return None

    # Define transforms
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    val_test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Create full dataset
    full_dataset = PADUFESDataset(config["data_path"], metadata_df, transform=None)

    # Ensure reproducible splits
    torch.manual_seed(config["split_seed"])
    np.random.seed(config["split_seed"])

    total_size = len(full_dataset)
    train_size = int(config["train_split"] * total_size)
    val_size = int(config["val_split"] * total_size)
    test_size = int(config["test_split"] * total_size)

    # Adjust sizes to ensure they sum to total_size
    actual_total = train_size + val_size + test_size
    if actual_total != total_size:
        test_size += (total_size - actual_total)

    print(f"Dataset splits - Train: {train_size}, Val: {val_size}, Test: {test_size}, Total: {total_size}")

    # Create indices for splitting
    indices = list(range(total_size))
    np.random.shuffle(indices)

    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size + val_size]
    test_indices = indices[train_size + val_size:]

    # Create subset datasets with appropriate transforms
    train_dataset = PADUFESDataset(
        config["data_path"],
        metadata_df.iloc[train_indices],
        transform=train_transform
    )
    val_dataset = PADUFESDataset(
        config["data_path"],
        metadata_df.iloc[val_indices],
        transform=val_test_transform
    )
    test_dataset = PADUFESDataset(
        config["data_path"],
        metadata_df.iloc[test_indices],
        transform=val_test_transform
    )

    # Set worker init function for reproducible DataLoader behavior
    def worker_init_fn(worker_id):
        np.random.seed(config["split_seed"] + worker_id)

    train_loader = DataLoader(
        train_dataset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn,
        generator=torch.Generator().manual_seed(config["split_seed"])
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )

    return train_loader, val_loader, test_loader

# ==========================================
# Model Architecture and Training
# ==========================================

class RotaryEmbedding(nn.Module):
    """
    The Rotary Positional Embedding (RoPE) module.
    This implementation is based on the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding".
    """
    def __init__(self, dim, base=10000):
        super().__init__()
        self.dim = dim
        self.base = base
        # Create inverse frequencies and register as a buffer
        inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x):
        seq_len = x.shape[1]
        # Check if we need to recompute the cache
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()[:, None, :]
            self.sin_cached = emb.sin()[:, None, :]

        # Apply the rotation
        return self.cos_cached, self.sin_cached

def rotate_half(x):
    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=x1.ndim - 1)

def apply_rotary_pos_emb(q, k, cos, sin):
    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

class MultiHeadAttentionWithRoPE(nn.Module):
    """Custom Multi-Head Attention with RoPE support"""
    def __init__(self, d_model, nhead, dropout=0.1, batch_first=True):
        super().__init__()
        assert d_model % nhead == 0

        self.d_model = d_model
        self.nhead = nhead
        self.d_k = d_model // nhead
        self.batch_first = batch_first  # Add this attribute

        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        self.rotary_emb = RotaryEmbedding(self.d_k)

    def forward(self, query, key, value, key_padding_mask=None, attn_mask=None, need_weights=False, is_causal=False):
        batch_size, seq_len, _ = query.size()

        # Linear transformations
        Q = self.w_q(query)  # [batch_size, seq_len, d_model]
        K = self.w_k(key)    # [batch_size, seq_len, d_model]
        V = self.w_v(value)  # [batch_size, seq_len, d_model]

        # Reshape for multi-head attention
        Q = Q.view(batch_size, seq_len, self.nhead, self.d_k)  # [batch, seq, heads, d_k]
        K = K.view(batch_size, seq_len, self.nhead, self.d_k)
        V = V.view(batch_size, seq_len, self.nhead, self.d_k)

        # Apply RoPE to Q and K
        cos, sin = self.rotary_emb(query)
        Q, K = apply_rotary_pos_emb(Q, K, cos, sin)

        # Transpose for attention computation: [batch, heads, seq, d_k]
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Apply masks if provided
        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask == 0, -1e9)

        if key_padding_mask is not None:
            # key_padding_mask: [batch_size, seq_len], True for padding positions
            key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2)  # [batch, 1, 1, seq]
            scores = scores.masked_fill(key_padding_mask, -1e9)

        attention_weights = torch.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Apply attention to values
        attention_output = torch.matmul(attention_weights, V)  # [batch, heads, seq, d_k]

        # Concatenate heads
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model
        )

        # Final linear transformation
        output = self.w_o(attention_output)

        if need_weights:
            return output, attention_weights.mean(dim=1)  # Average over heads for compatibility
        return output

class TransformerEncoderLayerWithRoPE(nn.Module):
    """
    A custom Transformer Encoder Layer that incorporates RoPE.
    """
    def __init__(self, d_model, nhead, dim_feedforward, dropout, batch_first=True):
        super().__init__()
        self.self_attn = MultiHeadAttentionWithRoPE(d_model, nhead, dropout, batch_first)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = nn.ReLU(inplace=True)
        self.batch_first = batch_first  # Add this attribute

    def forward(self, src, src_mask=None, src_key_padding_mask=None, is_causal=False):
        # Self-attention with RoPE
        src2 = self.self_attn(src, src, src, key_padding_mask=src_key_padding_mask, attn_mask=src_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # Feed Forward
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

class NonResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(NonResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        return out

class ResNetTransformer(nn.Module):
    def __init__(self, block, layers, num_classes, t_config):
        super(ResNetTransformer, self).__init__()
        self.in_channels = 64
        self.embedding_dim = t_config["embedding_dim"]

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.projections = nn.ModuleList([
            self._create_projection(64, self.embedding_dim),
            self._create_projection(64, self.embedding_dim),
            self._create_projection(128, self.embedding_dim),
            self._create_projection(256, self.embedding_dim),
            self._create_projection(512, self.embedding_dim)
        ])

        # Create transformer encoder layers manually
        self.transformer_layers = nn.ModuleList([
            TransformerEncoderLayerWithRoPE(
                d_model=self.embedding_dim,
                nhead=t_config["nhead"],
                dim_feedforward=t_config["dim_feedforward"],
                dropout=t_config["dropout"],
                batch_first=True
            ) for _ in range(t_config["num_encoder_layers"])
        ])

        self.classifier = nn.Linear(self.embedding_dim, num_classes)

    def _create_projection(self, in_features, out_features):
        return nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(in_features, out_features)
        )

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        current_in_channels = self.in_channels
        for s in strides:
            layers.append(block(current_in_channels, out_channels, s))
            current_in_channels = out_channels
        self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        features = []
        x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
        features.append(x)
        x = self.layer1(x)
        features.append(x)
        x = self.layer2(x)
        features.append(x)
        x = self.layer3(x)
        features.append(x)
        x = self.layer4(x)
        features.append(x)

        tokens = [self.projections[i](feat) for i, feat in enumerate(features)]
        token_sequence = torch.stack(tokens, dim=1)

        # Apply transformer layers manually
        for layer in self.transformer_layers:
            token_sequence = layer(token_sequence)

        aggregated_vector = token_sequence.mean(dim=1)
        logits = self.classifier(aggregated_vector)
        return logits

def evaluate_model(model, data_loader, criterion, device):
    """Evaluates the model on a given dataset."""
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_loss = total_loss / total
    accuracy = 100 * correct / total
    return avg_loss, accuracy

def train_and_validate(model, train_loader, val_loader, config):
    """Main training loop with callbacks."""
    device = config["device"]
    model.to(device)

    # Set seeds for reproducible training
    torch.manual_seed(config["split_seed"])
    np.random.seed(config["split_seed"])
    if torch.cuda.is_available():
        torch.cuda.manual_seed(config["split_seed"])
        torch.cuda.manual_seed_all(config["split_seed"])

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])

    # Initialize callbacks
    callbacks = [
        ReduceLROnPlateau(
            optimizer=optimizer,
            monitor='val_accuracy',
            factor=0.2,
            patience=3,
            min_lr=1e-7,
            verbose=1
        ),
        EarlyStopping(
            monitor='val_accuracy',
            patience=7,
            restore_best_weights=True,
            verbose=1
        ),
        ModelCheckpoint(
            filepath=config["checkpoint_path"],
            save_best_only=True,
            monitor='val_accuracy',
            verbose=1
        )
    ]

    print("\n--- Starting Training ---")
    early_stop = False

    for epoch in range(config["num_epochs"]):
        if early_stop:
            break

        model.train()
        start_time = time.time()
        running_loss, train_correct, train_total = 0.0, 0, 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Train]", leave=False)
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            progress_bar.set_postfix(loss=loss.item())

        train_loss = running_loss / train_total
        train_acc = 100 * train_correct / train_total
        val_loss, val_acc = evaluate_model(model, val_loader, criterion, device)

        print(f"Epoch {epoch+1}/{config['num_epochs']} | Time: {time.time() - start_time:.2f}s | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Prepare logs for callbacks
        logs = {
            'train_loss': train_loss,
            'train_accuracy': train_acc,
            'val_loss': val_loss,
            'val_accuracy': val_acc
        }

        # Execute callbacks
        for callback in callbacks:
            if isinstance(callback, EarlyStopping):
                if callback.on_epoch_end(epoch, logs, model):
                    early_stop = True
                    break
            elif isinstance(callback, ModelCheckpoint):
                callback.on_epoch_end(epoch, logs, model)
            else:
                callback.on_epoch_end(epoch, logs)

    # Execute callback cleanup
    for callback in callbacks:
        if isinstance(callback, EarlyStopping):
            callback.on_training_end(model)
        else:
            callback.on_training_end()

    print("--- Training Finished ---\n")
    return model

# ==========================================
# Evaluation
# ==========================================

def calculate_top_k_accuracy(outputs, labels, k_values=[1, 3, 5]):
    """Calculate top-k accuracy for given k values."""
    batch_size = labels.size(0)
    _, pred = outputs.topk(max(k_values), 1, True, True)
    pred = pred.t()
    correct = pred.eq(labels.view(1, -1).expand_as(pred))

    top_k_accuracies = {}
    for k in k_values:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        top_k_accuracies[k] = correct_k.item() / batch_size

    return top_k_accuracies

def calculate_entropy(probs):
    """Calculate entropy of probability distributions."""
    # Add small epsilon to avoid log(0)
    epsilon = 1e-8
    probs = probs + epsilon
    entropy = -torch.sum(probs * torch.log(probs), dim=1)
    return entropy

def calculate_ece(confidences, accuracies, n_bins=10):
    """Calculate Expected Calibration Error."""
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = 0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (confidences > bin_lower.item()) & (confidences <= bin_upper.item())
        prop_in_bin = in_bin.float().mean()

        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

    return ece.item()

def comprehensive_evaluation(model, test_loader, criterion, device):
    """Comprehensive evaluation with all requested metrics."""
    model.eval()

    all_predictions = []
    all_labels = []
    all_confidences = []
    all_entropies = []
    all_top_k_results = {1: [], 3: [], 5: []}
    total_loss = 0.0
    total_samples = 0

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Convert outputs to probabilities
            probs = torch.softmax(outputs, dim=1)

            # Get predictions and confidences
            confidences, predictions = torch.max(probs, dim=1)

            # Calculate entropy
            entropies = calculate_entropy(probs)

            # Calculate top-k accuracies
            top_k_accs = calculate_top_k_accuracy(outputs, labels, [1, 3, 5])

            # Store results
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_confidences.extend(confidences.cpu().numpy())
            all_entropies.extend(entropies.cpu().numpy())

            for k in [1, 3, 5]:
                all_top_k_results[k].extend([top_k_accs[k]] * labels.size(0))

            total_loss += loss.item() * images.size(0)
            total_samples += images.size(0)

    # Convert to numpy arrays
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    all_confidences = np.array(all_confidences)
    all_entropies = np.array(all_entropies)

    # Calculate basic metrics
    avg_loss = total_loss / total_samples

    # Calculate top-k accuracies
    top1 = np.mean([pred == label for pred, label in zip(all_predictions, all_labels)])
    top3 = np.mean(all_top_k_results[3])
    top5 = np.mean(all_top_k_results[5])

    # Calculate precision, recall, f1-score
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='macro', zero_division=0
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='weighted', zero_division=0
    )

    # Calculate confidence and calibration metrics
    avg_confidence = np.mean(all_confidences)
    std_confidence = np.std(all_confidences)
    avg_entropy = np.mean(all_entropies)

    # Calculate ECE
    accuracies = (all_predictions == all_labels).astype(float)
    ece = calculate_ece(torch.tensor(all_confidences), torch.tensor(accuracies))

    return {
        'test_loss': avg_loss,
        'top1': top1,
        'top3': top3,
        'top5': top5,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'f1_macro': f1_macro,
        'precision_weighted': precision_weighted,
        'recall_weighted': recall_weighted,
        'f1_weighted': f1_weighted,
        'avg_confidence': avg_confidence,
        'std_confidence': std_confidence,
        'avg_entropy': avg_entropy,
        'ece': ece
    }

# ==========================================
# Main Execution
# ==========================================

if __name__ == '__main__':
    # Set global seeds for full reproducibility
    torch.manual_seed(CONFIG["split_seed"])
    np.random.seed(CONFIG["split_seed"])
    if torch.cuda.is_available():
        torch.cuda.manual_seed(CONFIG["split_seed"])
        torch.cuda.manual_seed_all(CONFIG["split_seed"])
        # Ensure deterministic behavior on CUDA
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    print(f"Using device: {CONFIG['device']}")
    print(f"Reproducibility seed: {CONFIG['split_seed']}")

    # Get data loaders
    dataloaders = get_dataloaders(CONFIG)
    if dataloaders is None:
        print("Could not prepare data. Halting execution.", file=sys.stderr)
        sys.exit(1)  # Exit if data preparation failed

    train_loader, val_loader, test_loader = dataloaders

    # Initialize and train model
    model = ResNetTransformer(
        block=NonResidualBlock,
        layers=[2, 2, 2, 2],
        num_classes=CONFIG["num_classes"],
        t_config=TRANSFORMER_CONFIG
    )
    trained_model = train_and_validate(model, train_loader, val_loader, CONFIG)

    # Comprehensive evaluation on the test set
    print("--- Starting Comprehensive Evaluation on Test Set ---")
    results = comprehensive_evaluation(trained_model, test_loader, nn.CrossEntropyLoss(), CONFIG["device"])

    # Print all requested metrics
    print("\n--- Top-K Accuracy Results ---")
    print(f"Top-1 Accuracy: {results['top1'] * 100:.2f}%")
    print(f"Top-3 Accuracy: {results['top3'] * 100:.2f}%")
    print(f"Top-5 Accuracy: {results['top5'] * 100:.2f}%")

    print("\n--- Additional Performance Metrics ---")
    print(f"Macro Average Precision: {results['precision_macro']:.4f}")
    print(f"Macro Average Recall: {results['recall_macro']:.4f}")
    print(f"Macro Average F1-Score: {results['f1_macro']:.4f}")
    print(f"Weighted Average Precision: {results['precision_weighted']:.4f}")
    print(f"Weighted Average Recall: {results['recall_weighted']:.4f}")
    print(f"Weighted Average F1-Score: {results['f1_weighted']:.4f}")

    print("\n--- Model Confidence & Calibration Metrics ---")
    print(f"Average Prediction Confidence: {results['avg_confidence']:.4f}")
    print(f"Confidence Standard Deviation: {results['std_confidence']:.4f}")
    print(f"Average Prediction Entropy: {results['avg_entropy']:.4f}")
    print(f"Expected Calibration Error: {results['ece']:.4f}")

    print(f"\nFinal Test Loss: {results['test_loss']:.4f}")
    print("===============================================")

Using device: cuda
Reproducibility seed: 42
Downloading PAD-UFES-20 dataset from Kaggle...


PAD-UFES-20: 100%|██████████| 3.60G/3.60G [02:46<00:00, 21.6MiB/s]


Extracting...
Loading metadata from: ./pad-ufes-20/metadata.csv
Total samples: 2298
Diagnostic classes: diagnostic
BCC    845
ACK    730
NEV    244
SEK    235
SCC    192
MEL     52
Name: count, dtype: int64
✅ Dataset verification complete
Found 6 classes: ['ACK', 'BCC', 'MEL', 'NEV', 'SCC', 'SEK']
Dataset splits - Train: 1608, Val: 344, Test: 346, Total: 2298
Found 6 classes: ['ACK', 'BCC', 'MEL', 'NEV', 'SCC', 'SEK']
Found 6 classes: ['ACK', 'BCC', 'MEL', 'NEV', 'SCC', 'SEK']
Found 6 classes: ['ACK', 'BCC', 'MEL', 'NEV', 'SCC', 'SEK']

--- Starting Training ---




Epoch 1/100 | Time: 80.81s | Train Loss: 1.6814, Train Acc: 34.02% | Val Loss: 1.5097, Val Acc: 35.47%

Model saved to ./best_model_pad_ufes.pth




Epoch 2/100 | Time: 80.23s | Train Loss: 1.5323, Train Acc: 36.07% | Val Loss: 1.4779, Val Acc: 33.72%




Epoch 3/100 | Time: 80.09s | Train Loss: 1.5507, Train Acc: 34.08% | Val Loss: 1.4856, Val Acc: 34.30%




Epoch 4/100 | Time: 78.54s | Train Loss: 1.4979, Train Acc: 38.62% | Val Loss: 1.4487, Val Acc: 34.30%

Reducing learning rate from 1.00e-03 to 2.00e-04




Epoch 5/100 | Time: 78.14s | Train Loss: 1.4818, Train Acc: 36.32% | Val Loss: 1.4341, Val Acc: 36.63%

Model saved to ./best_model_pad_ufes.pth




Epoch 6/100 | Time: 78.20s | Train Loss: 1.4515, Train Acc: 39.12% | Val Loss: 1.4280, Val Acc: 40.99%

Model saved to ./best_model_pad_ufes.pth




Epoch 7/100 | Time: 78.39s | Train Loss: 1.4384, Train Acc: 37.56% | Val Loss: 1.4225, Val Acc: 39.83%




Epoch 8/100 | Time: 78.41s | Train Loss: 1.4338, Train Acc: 39.93% | Val Loss: 1.3758, Val Acc: 40.99%




Epoch 9/100 | Time: 77.00s | Train Loss: 1.4260, Train Acc: 40.24% | Val Loss: 1.3882, Val Acc: 43.02%

Model saved to ./best_model_pad_ufes.pth




Epoch 10/100 | Time: 77.29s | Train Loss: 1.4018, Train Acc: 38.93% | Val Loss: 1.4060, Val Acc: 40.70%




Epoch 11/100 | Time: 77.19s | Train Loss: 1.3765, Train Acc: 41.67% | Val Loss: 1.3756, Val Acc: 39.83%




Epoch 12/100 | Time: 77.94s | Train Loss: 1.3962, Train Acc: 40.49% | Val Loss: 1.3220, Val Acc: 41.86%

Reducing learning rate from 2.00e-04 to 4.00e-05




Epoch 13/100 | Time: 78.84s | Train Loss: 1.3507, Train Acc: 45.02% | Val Loss: 1.3064, Val Acc: 44.77%

Model saved to ./best_model_pad_ufes.pth




Epoch 14/100 | Time: 80.43s | Train Loss: 1.3374, Train Acc: 44.78% | Val Loss: 1.2968, Val Acc: 45.93%

Model saved to ./best_model_pad_ufes.pth




Epoch 15/100 | Time: 79.68s | Train Loss: 1.3259, Train Acc: 46.02% | Val Loss: 1.3125, Val Acc: 43.60%




Epoch 16/100 | Time: 79.21s | Train Loss: 1.3239, Train Acc: 46.02% | Val Loss: 1.3170, Val Acc: 45.64%




Epoch 17/100 | Time: 78.74s | Train Loss: 1.3124, Train Acc: 45.83% | Val Loss: 1.2766, Val Acc: 47.09%

Model saved to ./best_model_pad_ufes.pth




Epoch 18/100 | Time: 78.53s | Train Loss: 1.3027, Train Acc: 46.95% | Val Loss: 1.2874, Val Acc: 47.97%

Model saved to ./best_model_pad_ufes.pth




Epoch 19/100 | Time: 78.44s | Train Loss: 1.2991, Train Acc: 45.77% | Val Loss: 1.2783, Val Acc: 47.38%




Epoch 20/100 | Time: 79.51s | Train Loss: 1.3193, Train Acc: 45.77% | Val Loss: 1.2740, Val Acc: 49.71%

Model saved to ./best_model_pad_ufes.pth




Epoch 21/100 | Time: 77.97s | Train Loss: 1.2928, Train Acc: 46.83% | Val Loss: 1.2615, Val Acc: 49.71%




Epoch 22/100 | Time: 80.13s | Train Loss: 1.2897, Train Acc: 45.40% | Val Loss: 1.2727, Val Acc: 48.26%




Epoch 23/100 | Time: 78.33s | Train Loss: 1.3003, Train Acc: 45.71% | Val Loss: 1.2943, Val Acc: 45.35%

Reducing learning rate from 4.00e-05 to 8.00e-06




Epoch 24/100 | Time: 78.25s | Train Loss: 1.2738, Train Acc: 49.19% | Val Loss: 1.2720, Val Acc: 49.71%




Epoch 25/100 | Time: 78.02s | Train Loss: 1.2863, Train Acc: 47.82% | Val Loss: 1.2654, Val Acc: 52.03%

Model saved to ./best_model_pad_ufes.pth




Epoch 26/100 | Time: 77.78s | Train Loss: 1.2813, Train Acc: 48.38% | Val Loss: 1.2646, Val Acc: 52.03%




Epoch 27/100 | Time: 77.75s | Train Loss: 1.2676, Train Acc: 47.70% | Val Loss: 1.2543, Val Acc: 49.42%




Epoch 28/100 | Time: 77.89s | Train Loss: 1.2707, Train Acc: 47.01% | Val Loss: 1.2594, Val Acc: 51.74%

Reducing learning rate from 8.00e-06 to 1.60e-06




Epoch 29/100 | Time: 77.59s | Train Loss: 1.2683, Train Acc: 47.57% | Val Loss: 1.2549, Val Acc: 50.58%




Epoch 30/100 | Time: 77.28s | Train Loss: 1.2792, Train Acc: 48.45% | Val Loss: 1.2587, Val Acc: 51.45%




Epoch 31/100 | Time: 78.06s | Train Loss: 1.2711, Train Acc: 47.70% | Val Loss: 1.2522, Val Acc: 50.87%

Reducing learning rate from 1.60e-06 to 3.20e-07




Epoch 32/100 | Time: 78.72s | Train Loss: 1.2750, Train Acc: 47.76% | Val Loss: 1.2552, Val Acc: 50.58%

Early stopping at epoch 32
Restored model weights from the end of the best epoch: 25
--- Training Finished ---

--- Starting Comprehensive Evaluation on Test Set ---


Evaluating: 100%|██████████| 6/6 [00:12<00:00,  2.01s/it]


--- Top-K Accuracy Results ---
Top-1 Accuracy: 47.69%
Top-3 Accuracy: 86.42%
Top-5 Accuracy: 98.55%

--- Additional Performance Metrics ---
Macro Average Precision: 0.4672
Macro Average Recall: 0.3132
Macro Average F1-Score: 0.3102
Weighted Average Precision: 0.4833
Weighted Average Recall: 0.4769
Weighted Average F1-Score: 0.4309

--- Model Confidence & Calibration Metrics ---
Average Prediction Confidence: 0.4818
Confidence Standard Deviation: 0.0839
Average Prediction Entropy: 1.2624
Expected Calibration Error: 0.0373

Final Test Loss: 1.3184



