<a href="https://colab.research.google.com/github/SoudeepGhoshal/TResNet/blob/main/TResNet_Tiny-ImageNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Tiny ImageNet

## ResNet-18

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import os
import requests
import zipfile
import shutil
from tqdm import tqdm
import time
import sys
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

# --- Configuration Parameters ---
CONFIG = {
    "data_path": "./tiny-imagenet-200/",
    "dataset_url": "http://cs231n.stanford.edu/tiny-imagenet-200.zip",
    "batch_size": 128,
    "num_epochs": 100,
    "learning_rate": 0.001,
    "weight_decay": 1e-4,
    "split_seed": 42, # Seed for creating reproducible dataset splits
    "train_split": 0.70,
    "val_split": 0.15,
    "test_split": 0.15,  # Explicitly define test split for clarity
    "num_classes": 200,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "checkpoint_path": "./best_model.pth",  # Added checkpoint path
}

# ==========================================
# Callback Classes
# ==========================================

class Callback:
    """Base callback class"""
    def on_epoch_end(self, epoch, logs=None):
        pass

    def on_training_end(self):
        pass

class ReduceLROnPlateau(Callback):
    """Reduce learning rate when a metric has stopped improving"""
    def __init__(self, optimizer, monitor='val_accuracy', factor=0.2, patience=3, min_lr=1e-7, verbose=1):
        self.optimizer = optimizer
        self.monitor = monitor
        self.factor = factor
        self.patience = patience
        self.min_lr = min_lr
        self.verbose = verbose
        self.wait = 0
        self.best = None
        self.mode = 'min' if 'loss' in monitor else 'max'

    def on_epoch_end(self, epoch, logs=None):
        if logs is None:
            return

        current = logs.get(self.monitor)
        if current is None:
            return

        if self.best is None:
            self.best = current
            return

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                old_lr = self.optimizer.param_groups[0]['lr']
                new_lr = max(old_lr * self.factor, self.min_lr)
                if new_lr != old_lr:
                    self.optimizer.param_groups[0]['lr'] = new_lr
                    if self.verbose:
                        print(f"\nReducing learning rate from {old_lr:.2e} to {new_lr:.2e}")
                    self.wait = 0

class EarlyStopping(Callback):
    """Stop training when a monitored metric has stopped improving"""
    def __init__(self, monitor='val_accuracy', patience=7, restore_best_weights=True, verbose=1):
        self.monitor = monitor
        self.patience = patience
        self.restore_best_weights = restore_best_weights
        self.verbose = verbose
        self.wait = 0
        self.best = None
        self.best_weights = None
        self.mode = 'min' if 'loss' in monitor else 'max'
        self.stopped_epoch = 0

    def on_epoch_end(self, epoch, logs=None, model=None):
        if logs is None:
            return False

        current = logs.get(self.monitor)
        if current is None:
            return False

        if self.best is None:
            self.best = current
            if model is not None and self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
            return False

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            self.wait = 0
            if model is not None and self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                if self.verbose:
                    print(f"\nEarly stopping at epoch {epoch + 1}")
                return True
        return False

    def on_training_end(self, model=None):
        if self.stopped_epoch > 0 and self.verbose:
            print(f"Restored model weights from the end of the best epoch: {self.stopped_epoch + 1 - self.patience}")
        if model is not None and self.restore_best_weights and self.best_weights is not None:
            model.load_state_dict(self.best_weights)

class ModelCheckpoint(Callback):
    """Save the model after every epoch"""
    def __init__(self, filepath, save_best_only=True, monitor='val_accuracy', verbose=1):
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.monitor = monitor
        self.verbose = verbose
        self.best = None
        self.mode = 'min' if 'loss' in monitor else 'max'

    def on_epoch_end(self, epoch, logs=None, model=None):
        if logs is None or model is None:
            return

        current = logs.get(self.monitor)
        if current is None:
            return

        if not self.save_best_only:
            filepath = self.filepath.replace('.pth', f'_epoch_{epoch + 1}.pth')
            torch.save(model.state_dict(), filepath)
            if self.verbose:
                print(f"\nModel saved to {filepath}")
            return

        if self.best is None:
            self.best = current
            torch.save(model.state_dict(), self.filepath)
            if self.verbose:
                print(f"\nModel saved to {self.filepath}")
            return

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            torch.save(model.state_dict(), self.filepath)
            if self.verbose:
                print(f"\nModel saved to {self.filepath}")

# ==========================================
# Data Preprocessing
# ==========================================

def download_and_unzip(url, dest_path):
    """Downloads and unzips with progress and robust error handling."""
    if os.path.exists(dest_path):
        print("Dataset directory already exists.")
        return True

    zip_path = dest_path.rstrip('/') + ".zip"
    # The parent directory where we will extract the zip file
    extract_to_dir = os.path.abspath(os.path.join(dest_path, os.pardir))

    print(f"Downloading dataset from {url}...")
    try:
        with requests.get(url, stream=True, timeout=30) as r:
            r.raise_for_status()
            total_size = int(r.headers.get('content-length', 0))
            with open(zip_path, 'wb') as f, tqdm(
                total=total_size, unit='iB', unit_scale=True, desc="Tiny ImageNet"
            ) as progress_bar:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
                    progress_bar.update(len(chunk))

        print("Unzipping...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            # **FIX:** Extract to the parent directory to avoid nested folders
            zip_ref.extractall(extract_to_dir)
        return True

    except requests.exceptions.RequestException as e:
        print(f"\nError downloading file: {e}", file=sys.stderr)
        print("Please check your internet connection or the dataset URL.", file=sys.stderr)
        return False

    finally:
        if os.path.exists(zip_path):
            os.remove(zip_path) # Clean up the zip file

def prepare_val_folder(val_dir):
    """Formats the validation folder to be compatible with ImageFolder."""
    val_annotations_path = os.path.join(val_dir, 'val_annotations.txt')
    if not os.path.exists(val_annotations_path):
        return # Already formatted

    print("Formatting validation folder...")
    val_img_dir = os.path.join(val_dir, 'images')
    val_img_dict = {}
    with open(val_annotations_path, 'r') as f:
        for line in f.readlines():
            parts = line.strip().split('\t')
            img_name, class_id = parts[0], parts[1]
            val_img_dict[img_name] = class_id

    for img_name, class_id in val_img_dict.items():
        class_dir = os.path.join(val_dir, class_id)
        os.makedirs(class_dir, exist_ok=True)
        src_path = os.path.join(val_img_dir, img_name)
        if os.path.exists(src_path):
            shutil.move(src_path, os.path.join(class_dir, img_name))

    if os.path.exists(val_annotations_path):
        os.remove(val_annotations_path)
    if os.path.exists(val_img_dir):
        shutil.rmtree(val_img_dir)

def get_dataloaders(config):
    """Downloads, prepares, and splits the data, returning DataLoaders."""
    if not download_and_unzip(config["dataset_url"], config["data_path"]):
        return None # Return None if download fails

    train_dir = os.path.join(config["data_path"], 'train')
    val_dir = os.path.join(config["data_path"], 'val')
    prepare_val_folder(val_dir)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    train_dataset = ImageFolder(train_dir, transform=transform)
    val_dataset_original = ImageFolder(val_dir, transform=transform)
    full_dataset = torch.utils.data.ConcatDataset([train_dataset, val_dataset_original])

    # Ensure reproducible splits
    torch.manual_seed(config["split_seed"])
    np.random.seed(config["split_seed"])

    total_size = len(full_dataset)
    train_size = int(config["train_split"] * total_size)
    val_size = int(config["val_split"] * total_size)
    test_size = int(config["test_split"] * total_size)

    # Adjust sizes to ensure they sum to total_size (handle rounding issues)
    actual_total = train_size + val_size + test_size
    if actual_total != total_size:
        # Add/subtract the difference to test_size
        test_size += (total_size - actual_total)

    print(f"Dataset splits - Train: {train_size}, Val: {val_size}, Test: {test_size}, Total: {total_size}")

    # Create reproducible generator for splitting
    generator = torch.Generator().manual_seed(config["split_seed"])
    train_subset, val_subset, test_subset = random_split(
        full_dataset, [train_size, val_size, test_size], generator=generator
    )

    # Set worker init function for reproducible DataLoader behavior
    def worker_init_fn(worker_id):
        np.random.seed(config["split_seed"] + worker_id)

    train_loader = DataLoader(
        train_subset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn,
        generator=torch.Generator().manual_seed(config["split_seed"])
    )
    val_loader = DataLoader(
        val_subset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )
    test_loader = DataLoader(
        test_subset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )

    return train_loader, val_loader, test_loader

# ==========================================
# Model Architecture and Training
# ==========================================

class ResidualBlock(nn.Module):
    """A residual block, the fundamental building block of ResNet."""
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        # Main path
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Shortcut path (for matching dimensions)
        self.downsample = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        return out

class ResNet(nn.Module):
    """A modular ResNet implementation."""
    def __init__(self, block, layers, num_classes=200):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for s in strides:
            layers.append(block(self.in_channels, out_channels, s))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

def ResNet18(num_classes=200):
    return ResNet(ResidualBlock, [2, 2, 2, 2], num_classes)

def evaluate_model(model, data_loader, criterion, device):
    """Evaluates the model on a given dataset."""
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_loss = total_loss / total
    accuracy = 100 * correct / total
    return avg_loss, accuracy

def train_and_validate(model, train_loader, val_loader, config):
    """Main training loop with callbacks."""
    device = config["device"]
    model.to(device)

    # Set seeds for reproducible training
    torch.manual_seed(config["split_seed"])
    np.random.seed(config["split_seed"])
    if torch.cuda.is_available():
        torch.cuda.manual_seed(config["split_seed"])
        torch.cuda.manual_seed_all(config["split_seed"])

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])

    # Initialize callbacks
    callbacks = [
        ReduceLROnPlateau(
            optimizer=optimizer,
            monitor='val_accuracy',
            factor=0.2,
            patience=3,
            min_lr=1e-7,
            verbose=1
        ),
        EarlyStopping(
            monitor='val_accuracy',
            patience=7,
            restore_best_weights=True,
            verbose=1
        ),
        ModelCheckpoint(
            filepath=config["checkpoint_path"],
            save_best_only=True,
            monitor='val_accuracy',
            verbose=1
        )
    ]

    print("\n--- Starting Training ---")
    early_stop = False

    for epoch in range(config["num_epochs"]):
        if early_stop:
            break

        model.train()
        start_time = time.time()
        running_loss, train_correct, train_total = 0.0, 0, 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Train]", leave=False)
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            progress_bar.set_postfix(loss=loss.item())

        train_loss = running_loss / train_total
        train_acc = 100 * train_correct / train_total
        val_loss, val_acc = evaluate_model(model, val_loader, criterion, device)

        print(f"Epoch {epoch+1}/{config['num_epochs']} | Time: {time.time() - start_time:.2f}s | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Prepare logs for callbacks
        logs = {
            'train_loss': train_loss,
            'train_accuracy': train_acc,
            'val_loss': val_loss,
            'val_accuracy': val_acc
        }

        # Execute callbacks
        for callback in callbacks:
            if isinstance(callback, EarlyStopping):
                if callback.on_epoch_end(epoch, logs, model):
                    early_stop = True
                    break
            elif isinstance(callback, ModelCheckpoint):
                callback.on_epoch_end(epoch, logs, model)
            else:
                callback.on_epoch_end(epoch, logs)

    # Execute callback cleanup
    for callback in callbacks:
        if isinstance(callback, EarlyStopping):
            callback.on_training_end(model)
        else:
            callback.on_training_end()

    print("--- Training Finished ---\n")
    return model

# ==========================================
# Evaluation
# ==========================================

def calculate_top_k_accuracy(outputs, labels, k_values=[1, 3, 5]):
    """Calculate top-k accuracy for given k values."""
    batch_size = labels.size(0)
    _, pred = outputs.topk(max(k_values), 1, True, True)
    pred = pred.t()
    correct = pred.eq(labels.view(1, -1).expand_as(pred))

    top_k_accuracies = {}
    for k in k_values:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        top_k_accuracies[k] = correct_k.item() / batch_size

    return top_k_accuracies

def calculate_entropy(probs):
    """Calculate entropy of probability distributions."""
    # Add small epsilon to avoid log(0)
    epsilon = 1e-8
    probs = probs + epsilon
    entropy = -torch.sum(probs * torch.log(probs), dim=1)
    return entropy

def calculate_ece(confidences, accuracies, n_bins=10):
    """Calculate Expected Calibration Error."""
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = 0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (confidences > bin_lower.item()) & (confidences <= bin_upper.item())
        prop_in_bin = in_bin.float().mean()

        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

    return ece.item()

def comprehensive_evaluation(model, test_loader, criterion, device):
    """Comprehensive evaluation with all requested metrics."""
    model.eval()

    all_predictions = []
    all_labels = []
    all_confidences = []
    all_entropies = []
    all_top_k_results = {1: [], 3: [], 5: []}
    total_loss = 0.0
    total_samples = 0

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Convert outputs to probabilities
            probs = torch.softmax(outputs, dim=1)

            # Get predictions and confidences
            confidences, predictions = torch.max(probs, dim=1)

            # Calculate entropy
            entropies = calculate_entropy(probs)

            # Calculate top-k accuracies
            top_k_accs = calculate_top_k_accuracy(outputs, labels, [1, 3, 5])

            # Store results
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_confidences.extend(confidences.cpu().numpy())
            all_entropies.extend(entropies.cpu().numpy())

            for k in [1, 3, 5]:
                all_top_k_results[k].extend([top_k_accs[k]] * labels.size(0))

            total_loss += loss.item() * images.size(0)
            total_samples += images.size(0)

    # Convert to numpy arrays
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    all_confidences = np.array(all_confidences)
    all_entropies = np.array(all_entropies)

    # Calculate basic metrics
    avg_loss = total_loss / total_samples

    # Calculate top-k accuracies
    top1 = np.mean([pred == label for pred, label in zip(all_predictions, all_labels)])
    top3 = np.mean(all_top_k_results[3])
    top5 = np.mean(all_top_k_results[5])

    # Calculate precision, recall, f1-score
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='macro', zero_division=0
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='weighted', zero_division=0
    )

    # Calculate confidence and calibration metrics
    avg_confidence = np.mean(all_confidences)
    std_confidence = np.std(all_confidences)
    avg_entropy = np.mean(all_entropies)

    # Calculate ECE
    accuracies = (all_predictions == all_labels).astype(float)
    ece = calculate_ece(torch.tensor(all_confidences), torch.tensor(accuracies))

    return {
        'test_loss': avg_loss,
        'top1': top1,
        'top3': top3,
        'top5': top5,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'f1_macro': f1_macro,
        'precision_weighted': precision_weighted,
        'recall_weighted': recall_weighted,
        'f1_weighted': f1_weighted,
        'avg_confidence': avg_confidence,
        'std_confidence': std_confidence,
        'avg_entropy': avg_entropy,
        'ece': ece
    }

# ==========================================
# Main Execution
# ==========================================

if __name__ == '__main__':
    # Set global seeds for full reproducibility
    torch.manual_seed(CONFIG["split_seed"])
    np.random.seed(CONFIG["split_seed"])
    if torch.cuda.is_available():
        torch.cuda.manual_seed(CONFIG["split_seed"])
        torch.cuda.manual_seed_all(CONFIG["split_seed"])
        # Ensure deterministic behavior on CUDA
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    print(f"Using device: {CONFIG['device']}")
    print(f"Reproducibility seed: {CONFIG['split_seed']}")

    # Get data loaders
    dataloaders = get_dataloaders(CONFIG)
    if dataloaders is None:
        print("Could not prepare data. Halting execution.", file=sys.stderr)
        sys.exit(1) # Exit if data preparation failed

    train_loader, val_loader, test_loader = dataloaders

    # Initialize and train model
    model = ResNet18(num_classes=CONFIG["num_classes"])
    trained_model = train_and_validate(model, train_loader, val_loader, CONFIG)

    # Comprehensive evaluation on the test set
    print("--- Starting Comprehensive Evaluation on Test Set ---")
    results = comprehensive_evaluation(trained_model, test_loader, nn.CrossEntropyLoss(), CONFIG["device"])

    # Print all requested metrics
    print("\n--- Top-K Accuracy Results ---")
    print(f"Top-1 Accuracy: {results['top1'] * 100:.2f}%")
    print(f"Top-3 Accuracy: {results['top3'] * 100:.2f}%")
    print(f"Top-5 Accuracy: {results['top5'] * 100:.2f}%")

    print("\n--- Additional Performance Metrics ---")
    print(f"Macro Average Precision: {results['precision_macro']:.4f}")
    print(f"Macro Average Recall: {results['recall_macro']:.4f}")
    print(f"Macro Average F1-Score: {results['f1_macro']:.4f}")
    print(f"Weighted Average Precision: {results['precision_weighted']:.4f}")
    print(f"Weighted Average Recall: {results['recall_weighted']:.4f}")
    print(f"Weighted Average F1-Score: {results['f1_weighted']:.4f}")

    print("\n--- Model Confidence & Calibration Metrics ---")
    print(f"Average Prediction Confidence: {results['avg_confidence']:.4f}")
    print(f"Confidence Standard Deviation: {results['std_confidence']:.4f}")
    print(f"Average Prediction Entropy: {results['avg_entropy']:.4f}")
    print(f"Expected Calibration Error: {results['ece']:.4f}")

    print(f"\nFinal Test Loss: {results['test_loss']:.4f}")
    print("===============================================")

Using device: cuda
Reproducibility seed: 42
Dataset directory already exists.
Dataset splits - Train: 77000, Val: 16500, Test: 16500, Total: 110000

--- Starting Training ---




Epoch 1/100 | Time: 65.55s | Train Loss: 4.4417, Train Acc: 8.14% | Val Loss: 4.0291, Val Acc: 12.84%

Model saved to ./best_model.pth




Epoch 2/100 | Time: 54.40s | Train Loss: 3.6114, Train Acc: 18.83% | Val Loss: 3.5136, Val Acc: 20.83%

Model saved to ./best_model.pth




Epoch 3/100 | Time: 52.70s | Train Loss: 3.2187, Train Acc: 25.73% | Val Loss: 3.2694, Val Acc: 25.09%

Model saved to ./best_model.pth




Epoch 4/100 | Time: 52.49s | Train Loss: 2.9188, Train Acc: 31.27% | Val Loss: 3.1764, Val Acc: 27.27%

Model saved to ./best_model.pth




Epoch 5/100 | Time: 54.17s | Train Loss: 2.6395, Train Acc: 36.50% | Val Loss: 3.0319, Val Acc: 29.72%

Model saved to ./best_model.pth




Epoch 6/100 | Time: 54.26s | Train Loss: 2.3533, Train Acc: 42.05% | Val Loss: 3.0845, Val Acc: 30.70%

Model saved to ./best_model.pth




Epoch 7/100 | Time: 54.16s | Train Loss: 2.0389, Train Acc: 48.57% | Val Loss: 3.0938, Val Acc: 30.76%

Model saved to ./best_model.pth




Epoch 8/100 | Time: 53.22s | Train Loss: 1.6773, Train Acc: 56.32% | Val Loss: 3.0481, Val Acc: 32.52%

Model saved to ./best_model.pth




Epoch 9/100 | Time: 53.22s | Train Loss: 1.2684, Train Acc: 66.01% | Val Loss: 3.2677, Val Acc: 31.64%




Epoch 10/100 | Time: 53.35s | Train Loss: 0.8732, Train Acc: 75.70% | Val Loss: 3.5970, Val Acc: 30.70%




Epoch 11/100 | Time: 53.27s | Train Loss: 0.5759, Train Acc: 83.69% | Val Loss: 3.9835, Val Acc: 29.50%

Reducing learning rate from 1.00e-03 to 2.00e-04




Epoch 12/100 | Time: 53.63s | Train Loss: 0.1756, Train Acc: 96.31% | Val Loss: 3.7164, Val Acc: 33.64%

Model saved to ./best_model.pth




Epoch 13/100 | Time: 52.71s | Train Loss: 0.0491, Train Acc: 99.75% | Val Loss: 3.8361, Val Acc: 33.84%

Model saved to ./best_model.pth




Epoch 14/100 | Time: 53.26s | Train Loss: 0.0264, Train Acc: 99.95% | Val Loss: 3.9447, Val Acc: 33.91%

Model saved to ./best_model.pth




Epoch 15/100 | Time: 53.66s | Train Loss: 0.0175, Train Acc: 99.97% | Val Loss: 4.0365, Val Acc: 33.58%




Epoch 16/100 | Time: 53.44s | Train Loss: 0.0138, Train Acc: 99.97% | Val Loss: 4.1036, Val Acc: 33.45%




Epoch 17/100 | Time: 53.65s | Train Loss: 0.0113, Train Acc: 99.96% | Val Loss: 4.1634, Val Acc: 33.17%

Reducing learning rate from 2.00e-04 to 4.00e-05




Epoch 18/100 | Time: 54.44s | Train Loss: 0.0081, Train Acc: 99.98% | Val Loss: 4.1814, Val Acc: 33.32%




Epoch 19/100 | Time: 54.38s | Train Loss: 0.0063, Train Acc: 99.98% | Val Loss: 4.1896, Val Acc: 33.41%




Epoch 20/100 | Time: 54.01s | Train Loss: 0.0057, Train Acc: 99.97% | Val Loss: 4.1959, Val Acc: 33.48%

Reducing learning rate from 4.00e-05 to 8.00e-06




Epoch 21/100 | Time: 54.01s | Train Loss: 0.0048, Train Acc: 99.98% | Val Loss: 4.1942, Val Acc: 33.46%

Early stopping at epoch 21
Restored model weights from the end of the best epoch: 14
--- Training Finished ---

--- Starting Comprehensive Evaluation on Test Set ---


Evaluating: 100%|██████████| 129/129 [00:08<00:00, 14.97it/s]


--- Top-K Accuracy Results ---
Top-1 Accuracy: 33.21%
Top-3 Accuracy: 50.75%
Top-5 Accuracy: 58.76%

--- Additional Performance Metrics ---
Macro Average Precision: 0.3285
Macro Average Recall: 0.3311
Macro Average F1-Score: 0.3279
Weighted Average Precision: 0.3323
Weighted Average Recall: 0.3321
Weighted Average F1-Score: 0.3304

--- Model Confidence & Calibration Metrics ---
Average Prediction Confidence: 0.6770
Confidence Standard Deviation: 0.2487
Average Prediction Entropy: 1.0597
Expected Calibration Error: 0.3449

Final Test Loss: 4.2113





## Hybrid Transformer on ResNet-18 (No PE)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import os
import requests
import zipfile
import shutil
from tqdm import tqdm
import time
import sys
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

# --- Configuration Parameters ---
CONFIG = {
    "data_path": "./tiny-imagenet-200/",
    "dataset_url": "http://cs231n.stanford.edu/tiny-imagenet-200.zip",
    "batch_size": 128,
    "num_epochs": 50,
    "learning_rate": 0.001,
    "weight_decay": 1e-4,
    "split_seed": 42,  # Seed for creating reproducible dataset splits
    "train_split": 0.70,
    "val_split": 0.15,
    "test_split": 0.15,  # Explicitly define test split for clarity
    "num_classes": 200,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "checkpoint_path": "./best_model.pth",  # Added checkpoint path
}

# --- Transformer Configuration ---
TRANSFORMER_CONFIG = {
    "embedding_dim": 32,
    "nhead": 16,
    "num_encoder_layers": 3,
    "dim_feedforward": 2,
    "dropout": 0.1,
}

# ==========================================
# Callback Classes
# ==========================================

class Callback:
    """Base callback class"""
    def on_epoch_end(self, epoch, logs=None):
        pass

    def on_training_end(self):
        pass

class ReduceLROnPlateau(Callback):
    """Reduce learning rate when a metric has stopped improving"""
    def __init__(self, optimizer, monitor='val_loss', factor=0.2, patience=3, min_lr=1e-7, verbose=1):
        self.optimizer = optimizer
        self.monitor = monitor
        self.factor = factor
        self.patience = patience
        self.min_lr = min_lr
        self.verbose = verbose
        self.wait = 0
        self.best = None
        self.mode = 'min' if 'loss' in monitor else 'max'

    def on_epoch_end(self, epoch, logs=None):
        if logs is None:
            return

        current = logs.get(self.monitor)
        if current is None:
            return

        if self.best is None:
            self.best = current
            return

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                old_lr = self.optimizer.param_groups[0]['lr']
                new_lr = max(old_lr * self.factor, self.min_lr)
                if new_lr != old_lr:
                    self.optimizer.param_groups[0]['lr'] = new_lr
                    if self.verbose:
                        print(f"\nReducing learning rate from {old_lr:.2e} to {new_lr:.2e}")
                    self.wait = 0

class EarlyStopping(Callback):
    """Stop training when a monitored metric has stopped improving"""
    def __init__(self, monitor='val_loss', patience=7, restore_best_weights=True, verbose=1):
        self.monitor = monitor
        self.patience = patience
        self.restore_best_weights = restore_best_weights
        self.verbose = verbose
        self.wait = 0
        self.best = None
        self.best_weights = None
        self.mode = 'min' if 'loss' in monitor else 'max'
        self.stopped_epoch = 0

    def on_epoch_end(self, epoch, logs=None, model=None):
        if logs is None:
            return False

        current = logs.get(self.monitor)
        if current is None:
            return False

        if self.best is None:
            self.best = current
            if model is not None and self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
            return False

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            self.wait = 0
            if model is not None and self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                if self.verbose:
                    print(f"\nEarly stopping at epoch {epoch + 1}")
                return True
        return False

    def on_training_end(self, model=None):
        if self.stopped_epoch > 0 and self.verbose:
            print(f"Restored model weights from the end of the best epoch: {self.stopped_epoch + 1 - self.patience}")
        if model is not None and self.restore_best_weights and self.best_weights is not None:
            model.load_state_dict(self.best_weights)

class ModelCheckpoint(Callback):
    """Save the model after every epoch"""
    def __init__(self, filepath, save_best_only=True, monitor='val_accuracy', verbose=1):
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.monitor = monitor
        self.verbose = verbose
        self.best = None
        self.mode = 'min' if 'loss' in monitor else 'max'

    def on_epoch_end(self, epoch, logs=None, model=None):
        if logs is None or model is None:
            return

        current = logs.get(self.monitor)
        if current is None:
            return

        if not self.save_best_only:
            filepath = self.filepath.replace('.pth', f'_epoch_{epoch + 1}.pth')
            torch.save(model.state_dict(), filepath)
            if self.verbose:
                print(f"\nModel saved to {filepath}")
            return

        if self.best is None:
            self.best = current
            torch.save(model.state_dict(), self.filepath)
            if self.verbose:
                print(f"\nModel saved to {self.filepath}")
            return

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            torch.save(model.state_dict(), self.filepath)
            if self.verbose:
                print(f"\nModel saved to {self.filepath}")

# ==========================================
# Data Preprocessing
# ==========================================

def download_and_unzip(url, dest_path):
    """Downloads and unzips with progress and robust error handling."""
    if os.path.exists(dest_path):
        print("Dataset directory already exists.")
        return True

    zip_path = dest_path.rstrip('/') + ".zip"
    # The parent directory where we will extract the zip file
    extract_to_dir = os.path.abspath(os.path.join(dest_path, os.pardir))

    print(f"Downloading dataset from {url}...")
    try:
        with requests.get(url, stream=True, timeout=30) as r:
            r.raise_for_status()
            total_size = int(r.headers.get('content-length', 0))
            with open(zip_path, 'wb') as f, tqdm(
                total=total_size, unit='iB', unit_scale=True, desc="Tiny ImageNet"
            ) as progress_bar:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
                    progress_bar.update(len(chunk))

        print("Unzipping...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            # **FIX:** Extract to the parent directory to avoid nested folders
            zip_ref.extractall(extract_to_dir)
        return True

    except requests.exceptions.RequestException as e:
        print(f"\nError downloading file: {e}", file=sys.stderr)
        print("Please check your internet connection or the dataset URL.", file=sys.stderr)
        return False

    finally:
        if os.path.exists(zip_path):
            os.remove(zip_path)  # Clean up the zip file

def prepare_val_folder(val_dir):
    """Formats the validation folder to be compatible with ImageFolder."""
    val_annotations_path = os.path.join(val_dir, 'val_annotations.txt')
    if not os.path.exists(val_annotations_path):
        return  # Already formatted

    print("Formatting validation folder...")
    val_img_dir = os.path.join(val_dir, 'images')
    val_img_dict = {}
    with open(val_annotations_path, 'r') as f:
        for line in f.readlines():
            parts = line.strip().split('\t')
            img_name, class_id = parts[0], parts[1]
            val_img_dict[img_name] = class_id

    for img_name, class_id in val_img_dict.items():
        class_dir = os.path.join(val_dir, class_id)
        os.makedirs(class_dir, exist_ok=True)
        src_path = os.path.join(val_img_dir, img_name)
        if os.path.exists(src_path):
            shutil.move(src_path, os.path.join(class_dir, img_name))

    if os.path.exists(val_annotations_path):
        os.remove(val_annotations_path)
    if os.path.exists(val_img_dir):
        shutil.rmtree(val_img_dir)

def get_dataloaders(config):
    """Downloads, prepares, and splits the data, returning DataLoaders."""
    if not download_and_unzip(config["dataset_url"], config["data_path"]):
        return None  # Return None if download fails

    train_dir = os.path.join(config["data_path"], 'train')
    val_dir = os.path.join(config["data_path"], 'val')
    prepare_val_folder(val_dir)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    train_dataset = ImageFolder(train_dir, transform=transform)
    val_dataset_original = ImageFolder(val_dir, transform=transform)
    full_dataset = torch.utils.data.ConcatDataset([train_dataset, val_dataset_original])

    # Ensure reproducible splits
    torch.manual_seed(config["split_seed"])
    np.random.seed(config["split_seed"])

    total_size = len(full_dataset)
    train_size = int(config["train_split"] * total_size)
    val_size = int(config["val_split"] * total_size)
    test_size = int(config["test_split"] * total_size)

    # Adjust sizes to ensure they sum to total_size (handle rounding issues)
    actual_total = train_size + val_size + test_size
    if actual_total != total_size:
        # Add/subtract the difference to test_size
        test_size += (total_size - actual_total)

    print(f"Dataset splits - Train: {train_size}, Val: {val_size}, Test: {test_size}, Total: {total_size}")

    # Create reproducible generator for splitting
    generator = torch.Generator().manual_seed(config["split_seed"])
    train_subset, val_subset, test_subset = random_split(
        full_dataset, [train_size, val_size, test_size], generator=generator
    )

    # Set worker init function for reproducible DataLoader behavior
    def worker_init_fn(worker_id):
        np.random.seed(config["split_seed"] + worker_id)

    train_loader = DataLoader(
        train_subset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn,
        generator=torch.Generator().manual_seed(config["split_seed"])
    )
    val_loader = DataLoader(
        val_subset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )
    test_loader = DataLoader(
        test_subset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )

    return train_loader, val_loader, test_loader

# ==========================================
# Model Architecture and Training
# ==========================================

class NonResidualBlock(nn.Module):
    """
    This block is a standard convolutional block WITHOUT the residual connection.
    """
    def __init__(self, in_channels, out_channels, stride=1):
        super(NonResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        return out

class ResNetTransformer(nn.Module):
    """
    A hybrid architecture combining a non-residual CNN backbone with a Transformer encoder.
    """
    def __init__(self, block, layers, num_classes, t_config):
        super(ResNetTransformer, self).__init__()
        self.in_channels = 64
        self.embedding_dim = t_config["embedding_dim"]

        # 1. CNN Backbone (Feature Extractor)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        # 2. Projection heads to create tokens from feature maps
        self.projections = nn.ModuleList([
            self._create_projection(64, self.embedding_dim),   # From initial maxpool
            self._create_projection(64, self.embedding_dim),   # From layer1
            self._create_projection(128, self.embedding_dim),  # From layer2
            self._create_projection(256, self.embedding_dim),  # From layer3
            self._create_projection(512, self.embedding_dim)   # From layer4
        ])

        # 3. Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embedding_dim,
            nhead=t_config["nhead"],
            dim_feedforward=t_config["dim_feedforward"],
            dropout=t_config["dropout"],
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=t_config["num_encoder_layers"]
        )

        # 4. Final Classifier
        self.classifier = nn.Linear(self.embedding_dim, num_classes)

    def _create_projection(self, in_features, out_features):
        return nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(in_features, out_features)
        )

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        current_in_channels = self.in_channels
        for s in strides:
            layers.append(block(current_in_channels, out_channels, s))
            current_in_channels = out_channels
        self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        # 1. Pass through CNN backbone and capture features
        features = []
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        features.append(x)

        x = self.layer1(x); features.append(x)
        x = self.layer2(x); features.append(x)
        x = self.layer3(x); features.append(x)
        x = self.layer4(x); features.append(x)

        # 2. Project features to tokens
        tokens = []
        for i, feature_map in enumerate(features):
            tokens.append(self.projections[i](feature_map))

        # 3. Stack tokens and pass through Transformer
        token_sequence = torch.stack(tokens, dim=1)
        transformer_out = self.transformer_encoder(token_sequence)

        # 4. Aggregate and classify
        aggregated_vector = transformer_out.mean(dim=1)
        logits = self.classifier(aggregated_vector)

        return logits

def ResNetTransformer18(num_classes=200, t_config=TRANSFORMER_CONFIG):
    return ResNetTransformer(NonResidualBlock, [2, 2, 2, 2], num_classes, t_config)

def evaluate_model(model, data_loader, criterion, device):
    """Evaluates the model on a given dataset."""
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_loss = total_loss / total
    accuracy = 100 * correct / total
    return avg_loss, accuracy

def train_and_validate(model, train_loader, val_loader, config):
    """Main training loop with callbacks."""
    device = config["device"]
    model.to(device)

    # Set seeds for reproducible training
    torch.manual_seed(config["split_seed"])
    np.random.seed(config["split_seed"])
    if torch.cuda.is_available():
        torch.cuda.manual_seed(config["split_seed"])
        torch.cuda.manual_seed_all(config["split_seed"])

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])

    # Initialize callbacks
    callbacks = [
        ReduceLROnPlateau(
            optimizer=optimizer,
            monitor='val_accuracy',
            factor=0.2,
            patience=3,
            min_lr=1e-7,
            verbose=1
        ),
        EarlyStopping(
            monitor='val_accuracy',
            patience=7,
            restore_best_weights=True,
            verbose=1
        ),
        ModelCheckpoint(
            filepath=config["checkpoint_path"],
            save_best_only=True,
            monitor='val_accuracy',
            verbose=1
        )
    ]

    print("\n--- Starting Training ---")
    early_stop = False

    for epoch in range(config["num_epochs"]):
        if early_stop:
            break

        model.train()
        start_time = time.time()
        running_loss, train_correct, train_total = 0.0, 0, 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Train]", leave=False)
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            progress_bar.set_postfix(loss=loss.item())

        train_loss = running_loss / train_total
        train_acc = 100 * train_correct / train_total
        val_loss, val_acc = evaluate_model(model, val_loader, criterion, device)

        print(f"Epoch {epoch+1}/{config['num_epochs']} | Time: {time.time() - start_time:.2f}s | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Prepare logs for callbacks
        logs = {
            'train_loss': train_loss,
            'train_accuracy': train_acc,
            'val_loss': val_loss,
            'val_accuracy': val_acc
        }

        # Execute callbacks
        for callback in callbacks:
            if isinstance(callback, EarlyStopping):
                if callback.on_epoch_end(epoch, logs, model):
                    early_stop = True
                    break
            elif isinstance(callback, ModelCheckpoint):
                callback.on_epoch_end(epoch, logs, model)
            else:
                callback.on_epoch_end(epoch, logs)

    # Execute callback cleanup
    for callback in callbacks:
        if isinstance(callback, EarlyStopping):
            callback.on_training_end(model)
        else:
            callback.on_training_end()

    print("--- Training Finished ---\n")
    return model

# ==========================================
# Evaluation
# ==========================================

def calculate_top_k_accuracy(outputs, labels, k_values=[1, 3, 5]):
    """Calculate top-k accuracy for given k values."""
    batch_size = labels.size(0)
    _, pred = outputs.topk(max(k_values), 1, True, True)
    pred = pred.t()
    correct = pred.eq(labels.view(1, -1).expand_as(pred))

    top_k_accuracies = {}
    for k in k_values:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        top_k_accuracies[k] = correct_k.item() / batch_size

    return top_k_accuracies

def calculate_entropy(probs):
    """Calculate entropy of probability distributions."""
    # Add small epsilon to avoid log(0)
    epsilon = 1e-8
    probs = probs + epsilon
    entropy = -torch.sum(probs * torch.log(probs), dim=1)
    return entropy

def calculate_ece(confidences, accuracies, n_bins=10):
    """Calculate Expected Calibration Error."""
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = 0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (confidences > bin_lower.item()) & (confidences <= bin_upper.item())
        prop_in_bin = in_bin.float().mean()

        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

    return ece.item()

def comprehensive_evaluation(model, test_loader, criterion, device):
    """Comprehensive evaluation with all requested metrics."""
    model.eval()

    all_predictions = []
    all_labels = []
    all_confidences = []
    all_entropies = []
    all_top_k_results = {1: [], 3: [], 5: []}
    total_loss = 0.0
    total_samples = 0

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Convert outputs to probabilities
            probs = torch.softmax(outputs, dim=1)

            # Get predictions and confidences
            confidences, predictions = torch.max(probs, dim=1)

            # Calculate entropy
            entropies = calculate_entropy(probs)

            # Calculate top-k accuracies
            top_k_accs = calculate_top_k_accuracy(outputs, labels, [1, 3, 5])

            # Store results
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_confidences.extend(confidences.cpu().numpy())
            all_entropies.extend(entropies.cpu().numpy())

            for k in [1, 3, 5]:
                all_top_k_results[k].extend([top_k_accs[k]] * labels.size(0))

            total_loss += loss.item() * images.size(0)
            total_samples += images.size(0)

    # Convert to numpy arrays
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    all_confidences = np.array(all_confidences)
    all_entropies = np.array(all_entropies)

    # Calculate basic metrics
    avg_loss = total_loss / total_samples

    # Calculate top-k accuracies
    top1 = np.mean([pred == label for pred, label in zip(all_predictions, all_labels)])
    top3 = np.mean(all_top_k_results[3])
    top5 = np.mean(all_top_k_results[5])

    # Calculate precision, recall, f1-score
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='macro', zero_division=0
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='weighted', zero_division=0
    )

    # Calculate confidence and calibration metrics
    avg_confidence = np.mean(all_confidences)
    std_confidence = np.std(all_confidences)
    avg_entropy = np.mean(all_entropies)

    # Calculate ECE
    accuracies = (all_predictions == all_labels).astype(float)
    ece = calculate_ece(torch.tensor(all_confidences), torch.tensor(accuracies))

    return {
        'test_loss': avg_loss,
        'top1': top1,
        'top3': top3,
        'top5': top5,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'f1_macro': f1_macro,
        'precision_weighted': precision_weighted,
        'recall_weighted': recall_weighted,
        'f1_weighted': f1_weighted,
        'avg_confidence': avg_confidence,
        'std_confidence': std_confidence,
        'avg_entropy': avg_entropy,
        'ece': ece
    }

# ==========================================
# Main Execution
# ==========================================

if __name__ == '__main__':
    # Set global seeds for full reproducibility
    torch.manual_seed(CONFIG["split_seed"])
    np.random.seed(CONFIG["split_seed"])
    if torch.cuda.is_available():
        torch.cuda.manual_seed(CONFIG["split_seed"])
        torch.cuda.manual_seed_all(CONFIG["split_seed"])
        # Ensure deterministic behavior on CUDA
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    print(f"Using device: {CONFIG['device']}")
    print(f"Reproducibility seed: {CONFIG['split_seed']}")

    # Get data loaders
    dataloaders = get_dataloaders(CONFIG)
    if dataloaders is None:
        print("Could not prepare data. Halting execution.", file=sys.stderr)
        sys.exit(1)  # Exit if data preparation failed

    train_loader, val_loader, test_loader = dataloaders

    # Initialize and train model
    model = ResNetTransformer18(num_classes=CONFIG["num_classes"], t_config=TRANSFORMER_CONFIG)
    trained_model = train_and_validate(model, train_loader, val_loader, CONFIG)

    # Comprehensive evaluation on the test set
    print("--- Starting Comprehensive Evaluation on Test Set ---")
    results = comprehensive_evaluation(trained_model, test_loader, nn.CrossEntropyLoss(), CONFIG["device"])

    # Print all requested metrics
    print("\n--- Top-K Accuracy Results ---")
    print(f"Top-1 Accuracy: {results['top1'] * 100:.2f}%")
    print(f"Top-3 Accuracy: {results['top3'] * 100:.2f}%")
    print(f"Top-5 Accuracy: {results['top5'] * 100:.2f}%")

    print("\n--- Additional Performance Metrics ---")
    print(f"Macro Average Precision: {results['precision_macro']:.4f}")
    print(f"Macro Average Recall: {results['recall_macro']:.4f}")
    print(f"Macro Average F1-Score: {results['f1_macro']:.4f}")
    print(f"Weighted Average Precision: {results['precision_weighted']:.4f}")
    print(f"Weighted Average Recall: {results['recall_weighted']:.4f}")
    print(f"Weighted Average F1-Score: {results['f1_weighted']:.4f}")

    print("\n--- Model Confidence & Calibration Metrics ---")
    print(f"Average Prediction Confidence: {results['avg_confidence']:.4f}")
    print(f"Confidence Standard Deviation: {results['std_confidence']:.4f}")
    print(f"Average Prediction Entropy: {results['avg_entropy']:.4f}")
    print(f"Expected Calibration Error: {results['ece']:.4f}")

    print(f"\nFinal Test Loss: {results['test_loss']:.4f}")
    print("===============================================")

Using device: cuda
Reproducibility seed: 42
Dataset directory already exists.
Dataset splits - Train: 77000, Val: 16500, Test: 16500, Total: 110000

--- Starting Training ---




Epoch 1/50 | Time: 56.04s | Train Loss: 4.9006, Train Acc: 3.08% | Val Loss: 4.6179, Val Acc: 5.46%

Model saved to ./best_model.pth




Epoch 2/50 | Time: 55.80s | Train Loss: 4.3391, Train Acc: 7.92% | Val Loss: 4.1335, Val Acc: 10.38%

Model saved to ./best_model.pth




Epoch 3/50 | Time: 56.66s | Train Loss: 4.0144, Train Acc: 12.00% | Val Loss: 3.9501, Val Acc: 13.52%

Model saved to ./best_model.pth




Epoch 4/50 | Time: 55.65s | Train Loss: 3.8190, Train Acc: 14.70% | Val Loss: 3.8424, Val Acc: 14.93%

Model saved to ./best_model.pth




Epoch 5/50 | Time: 56.26s | Train Loss: 3.6643, Train Acc: 17.06% | Val Loss: 3.6914, Val Acc: 16.62%

Model saved to ./best_model.pth




Epoch 6/50 | Time: 56.26s | Train Loss: 3.5368, Train Acc: 19.23% | Val Loss: 3.5540, Val Acc: 19.01%

Model saved to ./best_model.pth




Epoch 7/50 | Time: 55.55s | Train Loss: 3.4375, Train Acc: 21.01% | Val Loss: 3.5422, Val Acc: 19.40%

Model saved to ./best_model.pth




Epoch 8/50 | Time: 54.55s | Train Loss: 3.3411, Train Acc: 22.82% | Val Loss: 3.3462, Val Acc: 23.19%

Model saved to ./best_model.pth




Epoch 9/50 | Time: 55.40s | Train Loss: 3.2473, Train Acc: 24.33% | Val Loss: 3.3073, Val Acc: 23.66%

Model saved to ./best_model.pth




Epoch 10/50 | Time: 54.45s | Train Loss: 3.1657, Train Acc: 25.84% | Val Loss: 3.2600, Val Acc: 24.76%

Model saved to ./best_model.pth




Epoch 11/50 | Time: 54.94s | Train Loss: 3.0931, Train Acc: 27.24% | Val Loss: 3.2149, Val Acc: 26.15%

Model saved to ./best_model.pth




Epoch 12/50 | Time: 54.98s | Train Loss: 3.0221, Train Acc: 28.54% | Val Loss: 3.1643, Val Acc: 27.07%

Model saved to ./best_model.pth




Epoch 13/50 | Time: 55.12s | Train Loss: 2.9577, Train Acc: 29.66% | Val Loss: 3.0988, Val Acc: 28.00%

Model saved to ./best_model.pth




Epoch 14/50 | Time: 54.81s | Train Loss: 2.8902, Train Acc: 31.24% | Val Loss: 3.0929, Val Acc: 28.45%

Model saved to ./best_model.pth




Epoch 15/50 | Time: 54.37s | Train Loss: 2.8249, Train Acc: 32.45% | Val Loss: 3.0697, Val Acc: 28.70%

Model saved to ./best_model.pth




Epoch 16/50 | Time: 54.84s | Train Loss: 2.7676, Train Acc: 33.61% | Val Loss: 3.0869, Val Acc: 28.75%

Model saved to ./best_model.pth




Epoch 17/50 | Time: 55.28s | Train Loss: 2.7072, Train Acc: 34.77% | Val Loss: 2.9764, Val Acc: 30.65%

Model saved to ./best_model.pth




Epoch 18/50 | Time: 54.85s | Train Loss: 2.6429, Train Acc: 36.25% | Val Loss: 3.0122, Val Acc: 30.44%




Epoch 19/50 | Time: 54.48s | Train Loss: 2.5888, Train Acc: 37.04% | Val Loss: 3.0357, Val Acc: 30.26%




Epoch 20/50 | Time: 55.72s | Train Loss: 2.5263, Train Acc: 38.35% | Val Loss: 2.9399, Val Acc: 31.93%

Model saved to ./best_model.pth




Epoch 21/50 | Time: 54.98s | Train Loss: 2.4685, Train Acc: 39.34% | Val Loss: 2.9420, Val Acc: 31.78%




Epoch 22/50 | Time: 54.81s | Train Loss: 2.4074, Train Acc: 40.72% | Val Loss: 2.9768, Val Acc: 31.58%




Epoch 23/50 | Time: 54.67s | Train Loss: 2.3498, Train Acc: 41.71% | Val Loss: 3.0017, Val Acc: 31.55%

Reducing learning rate from 1.00e-03 to 2.00e-04




Epoch 24/50 | Time: 55.57s | Train Loss: 2.0038, Train Acc: 49.34% | Val Loss: 2.8487, Val Acc: 35.18%

Model saved to ./best_model.pth




Epoch 25/50 | Time: 54.89s | Train Loss: 1.9012, Train Acc: 51.68% | Val Loss: 2.8873, Val Acc: 34.82%




Epoch 26/50 | Time: 54.41s | Train Loss: 1.8380, Train Acc: 52.93% | Val Loss: 2.9316, Val Acc: 34.28%




Epoch 27/50 | Time: 54.53s | Train Loss: 1.7778, Train Acc: 54.18% | Val Loss: 2.9733, Val Acc: 34.32%

Reducing learning rate from 2.00e-04 to 4.00e-05




Epoch 28/50 | Time: 55.60s | Train Loss: 1.6570, Train Acc: 57.07% | Val Loss: 2.9845, Val Acc: 34.29%




Epoch 29/50 | Time: 54.76s | Train Loss: 1.6284, Train Acc: 57.86% | Val Loss: 3.0074, Val Acc: 34.13%




Epoch 30/50 | Time: 54.44s | Train Loss: 1.6072, Train Acc: 58.25% | Val Loss: 3.0217, Val Acc: 34.10%

Reducing learning rate from 4.00e-05 to 8.00e-06




Epoch 31/50 | Time: 54.79s | Train Loss: 1.5796, Train Acc: 58.96% | Val Loss: 3.0313, Val Acc: 33.84%

Early stopping at epoch 31
Restored model weights from the end of the best epoch: 24
--- Training Finished ---

--- Starting Comprehensive Evaluation on Test Set ---


Evaluating: 100%|██████████| 129/129 [00:08<00:00, 15.47it/s]


--- Top-K Accuracy Results ---
Top-1 Accuracy: 33.39%
Top-3 Accuracy: 52.24%
Top-5 Accuracy: 60.82%

--- Additional Performance Metrics ---
Macro Average Precision: 0.3255
Macro Average Recall: 0.3341
Macro Average F1-Score: 0.3272
Weighted Average Precision: 0.3283
Weighted Average Recall: 0.3339
Weighted Average F1-Score: 0.3286

--- Model Confidence & Calibration Metrics ---
Average Prediction Confidence: 0.4573
Confidence Standard Deviation: 0.2411
Average Prediction Entropy: 1.9928
Expected Calibration Error: 0.1234

Final Test Loss: 3.0453





## Hybrid Transformer on ResNet-18 (Learnable PE)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import os
import requests
import zipfile
import shutil
from tqdm import tqdm
import time
import sys
import numpy as np
from sklearn.metrics import precision_recall_fscore_support
import torch.nn.functional as F

# --- Configuration Parameters ---
CONFIG = {
    "data_path": "./tiny-imagenet-200/",
    "dataset_url": "http://cs231n.stanford.edu/tiny-imagenet-200.zip",
    "batch_size": 128,
    "num_epochs": 100,
    "learning_rate": 0.001,
    "weight_decay": 1e-4,
    "split_seed": 42,  # Seed for creating reproducible dataset splits
    "train_split": 0.70,
    "val_split": 0.15,
    "test_split": 0.15,  # Explicitly define test split for clarity
    "num_classes": 200,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "checkpoint_path": "./best_model.pth",  # Added checkpoint path
}

# --- Transformer Configuration ---
TRANSFORMER_CONFIG = {
    "embedding_dim": 256,       # Dimension of the tokens fed to the transformer
    "nhead": 8,                 # Number of attention heads
    "num_encoder_layers": 3,    # Number of transformer encoder layers
    "dim_feedforward": 512,     # Hidden dimension in the feed-forward network
    "dropout": 0.1,
}

# ==========================================
# Callback Classes
# ==========================================

class Callback:
    """Base callback class"""
    def on_epoch_end(self, epoch, logs=None):
        pass

    def on_training_end(self):
        pass

class ReduceLROnPlateau(Callback):
    """Reduce learning rate when a metric has stopped improving"""
    def __init__(self, optimizer, monitor='val_accuracy', factor=0.2, patience=3, min_lr=1e-7, verbose=1):
        self.optimizer = optimizer
        self.monitor = monitor
        self.factor = factor
        self.patience = patience
        self.min_lr = min_lr
        self.verbose = verbose
        self.wait = 0
        self.best = None
        self.mode = 'min' if 'loss' in monitor else 'max'

    def on_epoch_end(self, epoch, logs=None):
        if logs is None:
            return

        current = logs.get(self.monitor)
        if current is None:
            return

        if self.best is None:
            self.best = current
            return

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                old_lr = self.optimizer.param_groups[0]['lr']
                new_lr = max(old_lr * self.factor, self.min_lr)
                if new_lr != old_lr:
                    self.optimizer.param_groups[0]['lr'] = new_lr
                    if self.verbose:
                        print(f"\nReducing learning rate from {old_lr:.2e} to {new_lr:.2e}")
                    self.wait = 0

class EarlyStopping(Callback):
    """Stop training when a monitored metric has stopped improving"""
    def __init__(self, monitor='val_accuracy', patience=7, restore_best_weights=True, verbose=1):
        self.monitor = monitor
        self.patience = patience
        self.restore_best_weights = restore_best_weights
        self.verbose = verbose
        self.wait = 0
        self.best = None
        self.best_weights = None
        self.mode = 'min' if 'loss' in monitor else 'max'
        self.stopped_epoch = 0

    def on_epoch_end(self, epoch, logs=None, model=None):
        if logs is None:
            return False

        current = logs.get(self.monitor)
        if current is None:
            return False

        if self.best is None:
            self.best = current
            if model is not None and self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
            return False

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            self.wait = 0
            if model is not None and self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                if self.verbose:
                    print(f"\nEarly stopping at epoch {epoch + 1}")
                return True
        return False

    def on_training_end(self, model=None):
        if self.stopped_epoch > 0 and self.verbose:
            print(f"Restored model weights from the end of the best epoch: {self.stopped_epoch + 1 - self.patience}")
        if model is not None and self.restore_best_weights and self.best_weights is not None:
            model.load_state_dict(self.best_weights)

class ModelCheckpoint(Callback):
    """Save the model after every epoch"""
    def __init__(self, filepath, save_best_only=True, monitor='val_accuracy', verbose=1):
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.monitor = monitor
        self.verbose = verbose
        self.best = None
        self.mode = 'min' if 'loss' in monitor else 'max'

    def on_epoch_end(self, epoch, logs=None, model=None):
        if logs is None or model is None:
            return

        current = logs.get(self.monitor)
        if current is None:
            return

        if not self.save_best_only:
            filepath = self.filepath.replace('.pth', f'_epoch_{epoch + 1}.pth')
            torch.save(model.state_dict(), filepath)
            if self.verbose:
                print(f"\nModel saved to {filepath}")
            return

        if self.best is None:
            self.best = current
            torch.save(model.state_dict(), self.filepath)
            if self.verbose:
                print(f"\nModel saved to {self.filepath}")
            return

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            torch.save(model.state_dict(), self.filepath)
            if self.verbose:
                print(f"\nModel saved to {self.filepath}")

# ==========================================
# Data Preprocessing
# ==========================================

def download_and_unzip(url, dest_path):
    """Downloads and unzips with progress and robust error handling."""
    if os.path.exists(dest_path):
        print("Dataset directory already exists.")
        return True

    zip_path = dest_path.rstrip('/') + ".zip"
    # The parent directory where we will extract the zip file
    extract_to_dir = os.path.abspath(os.path.join(dest_path, os.pardir))

    print(f"Downloading dataset from {url}...")
    try:
        with requests.get(url, stream=True, timeout=30) as r:
            r.raise_for_status()
            total_size = int(r.headers.get('content-length', 0))
            with open(zip_path, 'wb') as f, tqdm(
                total=total_size, unit='iB', unit_scale=True, desc="Tiny ImageNet"
            ) as progress_bar:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
                    progress_bar.update(len(chunk))

        print("Unzipping...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            # **FIX:** Extract to the parent directory to avoid nested folders
            zip_ref.extractall(extract_to_dir)
        return True

    except requests.exceptions.RequestException as e:
        print(f"\nError downloading file: {e}", file=sys.stderr)
        print("Please check your internet connection or the dataset URL.", file=sys.stderr)
        return False

    finally:
        if os.path.exists(zip_path):
            os.remove(zip_path)  # Clean up the zip file

def prepare_val_folder(val_dir):
    """Formats the validation folder to be compatible with ImageFolder."""
    val_annotations_path = os.path.join(val_dir, 'val_annotations.txt')
    if not os.path.exists(val_annotations_path):
        return  # Already formatted

    print("Formatting validation folder...")
    val_img_dir = os.path.join(val_dir, 'images')
    val_img_dict = {}
    with open(val_annotations_path, 'r') as f:
        for line in f.readlines():
            parts = line.strip().split('\t')
            img_name, class_id = parts[0], parts[1]
            val_img_dict[img_name] = class_id

    for img_name, class_id in val_img_dict.items():
        class_dir = os.path.join(val_dir, class_id)
        os.makedirs(class_dir, exist_ok=True)
        src_path = os.path.join(val_img_dir, img_name)
        if os.path.exists(src_path):
            shutil.move(src_path, os.path.join(class_dir, img_name))

    if os.path.exists(val_annotations_path):
        os.remove(val_annotations_path)
    if os.path.exists(val_img_dir):
        shutil.rmtree(val_img_dir)

def get_dataloaders(config):
    """Downloads, prepares, and splits the data, returning DataLoaders."""
    if not download_and_unzip(config["dataset_url"], config["data_path"]):
        return None  # Return None if download fails

    train_dir = os.path.join(config["data_path"], 'train')
    val_dir = os.path.join(config["data_path"], 'val')
    prepare_val_folder(val_dir)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    train_dataset = ImageFolder(train_dir, transform=transform)
    val_dataset_original = ImageFolder(val_dir, transform=transform)
    full_dataset = torch.utils.data.ConcatDataset([train_dataset, val_dataset_original])

    # Ensure reproducible splits
    torch.manual_seed(config["split_seed"])
    np.random.seed(config["split_seed"])

    total_size = len(full_dataset)
    train_size = int(config["train_split"] * total_size)
    val_size = int(config["val_split"] * total_size)
    test_size = int(config["test_split"] * total_size)

    # Adjust sizes to ensure they sum to total_size (handle rounding issues)
    actual_total = train_size + val_size + test_size
    if actual_total != total_size:
        # Add/subtract the difference to test_size
        test_size += (total_size - actual_total)

    print(f"Dataset splits - Train: {train_size}, Val: {val_size}, Test: {test_size}, Total: {total_size}")

    # Create reproducible generator for splitting
    generator = torch.Generator().manual_seed(config["split_seed"])
    train_subset, val_subset, test_subset = random_split(
        full_dataset, [train_size, val_size, test_size], generator=generator
    )

    # Set worker init function for reproducible DataLoader behavior
    def worker_init_fn(worker_id):
        np.random.seed(config["split_seed"] + worker_id)

    train_loader = DataLoader(
        train_subset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn,
        generator=torch.Generator().manual_seed(config["split_seed"])
    )
    val_loader = DataLoader(
        val_subset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )
    test_loader = DataLoader(
        test_subset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )

    return train_loader, val_loader, test_loader

# ==========================================
# Model Architecture and Training
# ==========================================

class NonResidualBlock(nn.Module):
    """A standard convolutional block WITHOUT the residual connection."""
    def __init__(self, in_channels, out_channels, stride=1):
        super(NonResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        return out

class ResNetTransformer(nn.Module):
    """
    A hybrid architecture combining a non-residual CNN backbone with a Transformer encoder,
    including positional embeddings.
    """
    def __init__(self, block, layers, num_classes, t_config):
        super(ResNetTransformer, self).__init__()
        self.in_channels = 64
        self.embedding_dim = t_config["embedding_dim"]
        self.num_tokens = len(layers) + 1  # 4 layers + 1 initial capture

        # 1. CNN Backbone (Feature Extractor)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        # 2. Projection heads to create tokens from feature maps
        self.projections = nn.ModuleList([
            self._create_projection(64, self.embedding_dim),
            self._create_projection(64, self.embedding_dim),
            self._create_projection(128, self.embedding_dim),
            self._create_projection(256, self.embedding_dim),
            self._create_projection(512, self.embedding_dim)
        ])

        # 3. Learnable Positional Embedding
        self.positional_embedding = nn.Parameter(torch.randn(1, self.num_tokens, self.embedding_dim))

        # 4. Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embedding_dim,
            nhead=t_config["nhead"],
            dim_feedforward=t_config["dim_feedforward"],
            dropout=t_config["dropout"],
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=t_config["num_encoder_layers"]
        )

        # 5. Final Classifier
        self.classifier = nn.Linear(self.embedding_dim, num_classes)

    def _create_projection(self, in_features, out_features):
        return nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(in_features, out_features)
        )

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        current_in_channels = self.in_channels
        for s in strides:
            layers.append(block(current_in_channels, out_channels, s))
            current_in_channels = out_channels
        self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        # 1. Pass through CNN backbone and capture features
        features = []
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        features.append(x)  # Capture 1: After initial maxpool

        x = self.layer1(x); features.append(x)  # Capture 2: After layer1
        x = self.layer2(x); features.append(x)  # Capture 3: After layer2
        x = self.layer3(x); features.append(x)  # Capture 4: After layer3
        x = self.layer4(x); features.append(x)  # Capture 5: After layer4

        # 2. Project features to tokens
        tokens = [self.projections[i](feature_map) for i, feature_map in enumerate(features)]

        # 3. Stack tokens into a sequence
        token_sequence = torch.stack(tokens, dim=1)

        # 4. Add positional embedding
        token_sequence += self.positional_embedding

        # 5. Pass through Transformer
        transformer_out = self.transformer_encoder(token_sequence)

        # 6. Aggregate and classify
        aggregated_vector = transformer_out.mean(dim=1)
        logits = self.classifier(aggregated_vector)

        return logits

def evaluate_model(model, data_loader, criterion, device):
    """Evaluates the model on a given dataset."""
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_loss = total_loss / total
    accuracy = 100 * correct / total
    return avg_loss, accuracy

def train_and_validate(model, train_loader, val_loader, config):
    """Main training loop with callbacks."""
    device = config["device"]
    model.to(device)

    # Set seeds for reproducible training
    torch.manual_seed(config["split_seed"])
    np.random.seed(config["split_seed"])
    if torch.cuda.is_available():
        torch.cuda.manual_seed(config["split_seed"])
        torch.cuda.manual_seed_all(config["split_seed"])

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])

    # Initialize callbacks
    callbacks = [
        ReduceLROnPlateau(
            optimizer=optimizer,
            monitor='val_accuracy',
            factor=0.2,
            patience=3,
            min_lr=1e-7,
            verbose=1
        ),
        EarlyStopping(
            monitor='val_accuracy',
            patience=7,
            restore_best_weights=True,
            verbose=1
        ),
        ModelCheckpoint(
            filepath=config["checkpoint_path"],
            save_best_only=True,
            monitor='val_accuracy',
            verbose=1
        )
    ]

    print("\n--- Starting Training ---")
    early_stop = False

    for epoch in range(config["num_epochs"]):
        if early_stop:
            break

        model.train()
        start_time = time.time()
        running_loss, train_correct, train_total = 0.0, 0, 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Train]", leave=False)
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            progress_bar.set_postfix(loss=loss.item())

        train_loss = running_loss / train_total
        train_acc = 100 * train_correct / train_total
        val_loss, val_acc = evaluate_model(model, val_loader, criterion, device)

        print(f"Epoch {epoch+1}/{config['num_epochs']} | Time: {time.time() - start_time:.2f}s | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Prepare logs for callbacks
        logs = {
            'train_loss': train_loss,
            'train_accuracy': train_acc,
            'val_loss': val_loss,
            'val_accuracy': val_acc
        }

        # Execute callbacks
        for callback in callbacks:
            if isinstance(callback, EarlyStopping):
                if callback.on_epoch_end(epoch, logs, model):
                    early_stop = True
                    break
            elif isinstance(callback, ModelCheckpoint):
                callback.on_epoch_end(epoch, logs, model)
            else:
                callback.on_epoch_end(epoch, logs)

    # Execute callback cleanup
    for callback in callbacks:
        if isinstance(callback, EarlyStopping):
            callback.on_training_end(model)
        else:
            callback.on_training_end()

    print("--- Training Finished ---\n")
    return model

# ==========================================
# Evaluation
# ==========================================

def calculate_top_k_accuracy(outputs, labels, k_values=[1, 3, 5]):
    """Calculate top-k accuracy for given k values."""
    batch_size = labels.size(0)
    _, pred = outputs.topk(max(k_values), 1, True, True)
    pred = pred.t()
    correct = pred.eq(labels.view(1, -1).expand_as(pred))

    top_k_accuracies = {}
    for k in k_values:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        top_k_accuracies[k] = correct_k.item() / batch_size

    return top_k_accuracies

def calculate_entropy(probs):
    """Calculate entropy of probability distributions."""
    # Add small epsilon to avoid log(0)
    epsilon = 1e-8
    probs = probs + epsilon
    entropy = -torch.sum(probs * torch.log(probs), dim=1)
    return entropy

def calculate_ece(confidences, accuracies, n_bins=10):
    """Calculate Expected Calibration Error."""
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = 0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (confidences > bin_lower.item()) & (confidences <= bin_upper.item())
        prop_in_bin = in_bin.float().mean()

        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

    return ece.item()

def comprehensive_evaluation(model, test_loader, criterion, device):
    """Comprehensive evaluation with all requested metrics."""
    model.eval()

    all_predictions = []
    all_labels = []
    all_confidences = []
    all_entropies = []
    all_top_k_results = {1: [], 3: [], 5: []}
    total_loss = 0.0
    total_samples = 0

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Convert outputs to probabilities
            probs = torch.softmax(outputs, dim=1)

            # Get predictions and confidences
            confidences, predictions = torch.max(probs, dim=1)

            # Calculate entropy
            entropies = calculate_entropy(probs)

            # Calculate top-k accuracies
            top_k_accs = calculate_top_k_accuracy(outputs, labels, [1, 3, 5])

            # Store results
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_confidences.extend(confidences.cpu().numpy())
            all_entropies.extend(entropies.cpu().numpy())

            for k in [1, 3, 5]:
                all_top_k_results[k].extend([top_k_accs[k]] * labels.size(0))

            total_loss += loss.item() * images.size(0)
            total_samples += images.size(0)

    # Convert to numpy arrays
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    all_confidences = np.array(all_confidences)
    all_entropies = np.array(all_entropies)

    # Calculate basic metrics
    avg_loss = total_loss / total_samples

    # Calculate top-k accuracies
    top1 = np.mean([pred == label for pred, label in zip(all_predictions, all_labels)])
    top3 = np.mean(all_top_k_results[3])
    top5 = np.mean(all_top_k_results[5])

    # Calculate precision, recall, f1-score
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='macro', zero_division=0
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='weighted', zero_division=0
    )

    # Calculate confidence and calibration metrics
    avg_confidence = np.mean(all_confidences)
    std_confidence = np.std(all_confidences)
    avg_entropy = np.mean(all_entropies)

    # Calculate ECE
    accuracies = (all_predictions == all_labels).astype(float)
    ece = calculate_ece(torch.tensor(all_confidences), torch.tensor(accuracies))

    return {
        'test_loss': avg_loss,
        'top1': top1,
        'top3': top3,
        'top5': top5,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'f1_macro': f1_macro,
        'precision_weighted': precision_weighted,
        'recall_weighted': recall_weighted,
        'f1_weighted': f1_weighted,
        'avg_confidence': avg_confidence,
        'std_confidence': std_confidence,
        'avg_entropy': avg_entropy,
        'ece': ece
    }

# ==========================================
# Main Execution
# ==========================================

if __name__ == '__main__':
    # Set global seeds for full reproducibility
    torch.manual_seed(CONFIG["split_seed"])
    np.random.seed(CONFIG["split_seed"])
    if torch.cuda.is_available():
        torch.cuda.manual_seed(CONFIG["split_seed"])
        torch.cuda.manual_seed_all(CONFIG["split_seed"])
        # Ensure deterministic behavior on CUDA
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    print(f"Using device: {CONFIG['device']}")
    print(f"Reproducibility seed: {CONFIG['split_seed']}")

    # Get data loaders
    dataloaders = get_dataloaders(CONFIG)
    if dataloaders is None:
        print("Could not prepare data. Halting execution.", file=sys.stderr)
        sys.exit(1)  # Exit if data preparation failed

    train_loader, val_loader, test_loader = dataloaders

    # Initialize and train model with ResNetTransformer
    model = ResNetTransformer(
        block=NonResidualBlock,
        layers=[2, 2, 2, 2],
        num_classes=CONFIG["num_classes"],
        t_config=TRANSFORMER_CONFIG
    )

    print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")
    print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

    trained_model = train_and_validate(model, train_loader, val_loader, CONFIG)

    # Comprehensive evaluation on the test set
    print("--- Starting Comprehensive Evaluation on Test Set ---")
    results = comprehensive_evaluation(trained_model, test_loader, nn.CrossEntropyLoss(), CONFIG["device"])

    # Print all requested metrics
    print("\n--- Top-K Accuracy Results ---")
    print(f"Top-1 Accuracy: {results['top1'] * 100:.2f}%")
    print(f"Top-3 Accuracy: {results['top3'] * 100:.2f}%")
    print(f"Top-5 Accuracy: {results['top5'] * 100:.2f}%")

    print("\n--- Additional Performance Metrics ---")
    print(f"Macro Average Precision: {results['precision_macro']:.4f}")
    print(f"Macro Average Recall: {results['recall_macro']:.4f}")
    print(f"Macro Average F1-Score: {results['f1_macro']:.4f}")
    print(f"Weighted Average Precision: {results['precision_weighted']:.4f}")
    print(f"Weighted Average Recall: {results['recall_weighted']:.4f}")
    print(f"Weighted Average F1-Score: {results['f1_weighted']:.4f}")

    print("\n--- Model Confidence & Calibration Metrics ---")
    print(f"Average Prediction Confidence: {results['avg_confidence']:.4f}")
    print(f"Confidence Standard Deviation: {results['std_confidence']:.4f}")
    print(f"Average Prediction Entropy: {results['avg_entropy']:.4f}")
    print(f"Expected Calibration Error: {results['ece']:.4f}")

    print(f"\nFinal Test Loss: {results['test_loss']:.4f}")
    print("===============================================")

Using device: cuda
Reproducibility seed: 42
Dataset directory already exists.
Dataset splits - Train: 77000, Val: 16500, Test: 16500, Total: 110000
Model has 12900104 parameters
Trainable parameters: 12900104

--- Starting Training ---




Epoch 1/100 | Time: 57.38s | Train Loss: 4.8858, Train Acc: 2.98% | Val Loss: 4.5156, Val Acc: 5.65%

Model saved to ./best_model.pth




Epoch 2/100 | Time: 56.96s | Train Loss: 4.3319, Train Acc: 7.71% | Val Loss: 4.2834, Val Acc: 8.63%

Model saved to ./best_model.pth




Epoch 3/100 | Time: 55.97s | Train Loss: 4.0589, Train Acc: 11.17% | Val Loss: 4.0407, Val Acc: 12.13%

Model saved to ./best_model.pth




Epoch 4/100 | Time: 56.58s | Train Loss: 3.8711, Train Acc: 13.77% | Val Loss: 3.8111, Val Acc: 15.13%

Model saved to ./best_model.pth




Epoch 5/100 | Time: 56.66s | Train Loss: 3.7228, Train Acc: 15.74% | Val Loss: 3.7984, Val Acc: 15.27%

Model saved to ./best_model.pth




Epoch 6/100 | Time: 56.88s | Train Loss: 3.6071, Train Acc: 17.61% | Val Loss: 3.6304, Val Acc: 17.72%

Model saved to ./best_model.pth




Epoch 7/100 | Time: 58.70s | Train Loss: 3.5116, Train Acc: 19.01% | Val Loss: 3.5444, Val Acc: 18.79%

Model saved to ./best_model.pth




Epoch 8/100 | Time: 57.01s | Train Loss: 3.4194, Train Acc: 20.63% | Val Loss: 3.4375, Val Acc: 20.79%

Model saved to ./best_model.pth




Epoch 9/100 | Time: 57.93s | Train Loss: 3.3387, Train Acc: 22.04% | Val Loss: 3.3985, Val Acc: 21.70%

Model saved to ./best_model.pth




Epoch 10/100 | Time: 57.24s | Train Loss: 3.2655, Train Acc: 23.35% | Val Loss: 3.3036, Val Acc: 23.50%

Model saved to ./best_model.pth




Epoch 11/100 | Time: 57.30s | Train Loss: 3.1896, Train Acc: 24.60% | Val Loss: 3.4155, Val Acc: 21.87%




Epoch 12/100 | Time: 54.88s | Train Loss: 3.1207, Train Acc: 25.86% | Val Loss: 3.2776, Val Acc: 24.47%

Model saved to ./best_model.pth




Epoch 13/100 | Time: 54.74s | Train Loss: 3.0619, Train Acc: 26.97% | Val Loss: 3.3179, Val Acc: 24.15%




Epoch 14/100 | Time: 54.83s | Train Loss: 3.0029, Train Acc: 28.11% | Val Loss: 3.1880, Val Acc: 26.23%

Model saved to ./best_model.pth




Epoch 15/100 | Time: 55.68s | Train Loss: 2.9450, Train Acc: 29.07% | Val Loss: 3.1501, Val Acc: 26.56%

Model saved to ./best_model.pth




Epoch 16/100 | Time: 54.99s | Train Loss: 2.8991, Train Acc: 29.82% | Val Loss: 3.1747, Val Acc: 26.26%




Epoch 17/100 | Time: 55.25s | Train Loss: 2.8561, Train Acc: 30.95% | Val Loss: 3.0685, Val Acc: 27.82%

Model saved to ./best_model.pth




Epoch 18/100 | Time: 55.52s | Train Loss: 2.8069, Train Acc: 31.97% | Val Loss: 3.1760, Val Acc: 26.43%




Epoch 19/100 | Time: 55.14s | Train Loss: 2.7681, Train Acc: 32.63% | Val Loss: 3.0172, Val Acc: 29.19%

Model saved to ./best_model.pth




Epoch 20/100 | Time: 54.87s | Train Loss: 2.7274, Train Acc: 33.13% | Val Loss: 3.0236, Val Acc: 29.43%

Model saved to ./best_model.pth




Epoch 21/100 | Time: 57.18s | Train Loss: 2.6895, Train Acc: 33.89% | Val Loss: 3.0847, Val Acc: 28.85%




Epoch 22/100 | Time: 55.97s | Train Loss: 2.6575, Train Acc: 34.45% | Val Loss: 2.9994, Val Acc: 30.16%

Model saved to ./best_model.pth




Epoch 23/100 | Time: 57.67s | Train Loss: 2.6177, Train Acc: 35.26% | Val Loss: 3.1152, Val Acc: 28.68%




Epoch 24/100 | Time: 55.73s | Train Loss: 2.5795, Train Acc: 36.06% | Val Loss: 3.0298, Val Acc: 29.99%




Epoch 25/100 | Time: 55.43s | Train Loss: 2.5471, Train Acc: 36.73% | Val Loss: 2.9672, Val Acc: 31.32%

Model saved to ./best_model.pth




Epoch 26/100 | Time: 57.04s | Train Loss: 2.5091, Train Acc: 37.59% | Val Loss: 3.0935, Val Acc: 29.93%




Epoch 27/100 | Time: 55.29s | Train Loss: 2.4812, Train Acc: 37.98% | Val Loss: 3.0095, Val Acc: 30.94%




Epoch 28/100 | Time: 55.21s | Train Loss: 2.4541, Train Acc: 38.50% | Val Loss: 3.0848, Val Acc: 30.09%

Reducing learning rate from 1.00e-03 to 2.00e-04




Epoch 29/100 | Time: 56.34s | Train Loss: 2.1303, Train Acc: 45.61% | Val Loss: 2.8647, Val Acc: 34.22%

Model saved to ./best_model.pth




Epoch 30/100 | Time: 55.55s | Train Loss: 2.0564, Train Acc: 47.21% | Val Loss: 2.9102, Val Acc: 34.19%




Epoch 31/100 | Time: 56.13s | Train Loss: 2.0247, Train Acc: 47.82% | Val Loss: 2.9045, Val Acc: 34.30%

Model saved to ./best_model.pth




Epoch 32/100 | Time: 55.28s | Train Loss: 1.9944, Train Acc: 48.44% | Val Loss: 2.9327, Val Acc: 33.94%




Epoch 33/100 | Time: 55.56s | Train Loss: 1.9684, Train Acc: 48.86% | Val Loss: 2.9664, Val Acc: 33.72%




Epoch 34/100 | Time: 56.70s | Train Loss: 1.9476, Train Acc: 49.39% | Val Loss: 2.9740, Val Acc: 34.08%

Reducing learning rate from 2.00e-04 to 4.00e-05




Epoch 35/100 | Time: 55.69s | Train Loss: 1.8398, Train Acc: 51.87% | Val Loss: 2.9664, Val Acc: 34.37%

Model saved to ./best_model.pth




Epoch 36/100 | Time: 55.95s | Train Loss: 1.8250, Train Acc: 52.30% | Val Loss: 2.9828, Val Acc: 34.33%




Epoch 37/100 | Time: 56.38s | Train Loss: 1.8183, Train Acc: 52.55% | Val Loss: 2.9748, Val Acc: 34.42%

Model saved to ./best_model.pth




Epoch 38/100 | Time: 55.31s | Train Loss: 1.8119, Train Acc: 52.46% | Val Loss: 2.9912, Val Acc: 34.42%

Model saved to ./best_model.pth




Epoch 39/100 | Time: 55.38s | Train Loss: 1.8051, Train Acc: 52.81% | Val Loss: 3.0024, Val Acc: 34.37%




Epoch 40/100 | Time: 56.09s | Train Loss: 1.7951, Train Acc: 53.08% | Val Loss: 3.0141, Val Acc: 34.36%




Epoch 41/100 | Time: 55.78s | Train Loss: 1.7916, Train Acc: 53.03% | Val Loss: 3.0204, Val Acc: 34.41%

Reducing learning rate from 4.00e-05 to 8.00e-06




Epoch 42/100 | Time: 55.95s | Train Loss: 1.7643, Train Acc: 53.75% | Val Loss: 3.0177, Val Acc: 34.45%

Model saved to ./best_model.pth




Epoch 43/100 | Time: 55.24s | Train Loss: 1.7602, Train Acc: 53.73% | Val Loss: 3.0195, Val Acc: 34.33%




Epoch 44/100 | Time: 54.95s | Train Loss: 1.7569, Train Acc: 54.06% | Val Loss: 3.0262, Val Acc: 34.49%

Model saved to ./best_model.pth




Epoch 45/100 | Time: 55.98s | Train Loss: 1.7577, Train Acc: 53.96% | Val Loss: 3.0325, Val Acc: 34.19%




Epoch 46/100 | Time: 56.00s | Train Loss: 1.7594, Train Acc: 53.85% | Val Loss: 3.0287, Val Acc: 34.29%




Epoch 47/100 | Time: 55.02s | Train Loss: 1.7579, Train Acc: 53.93% | Val Loss: 3.0331, Val Acc: 34.28%

Reducing learning rate from 8.00e-06 to 1.60e-06




Epoch 48/100 | Time: 55.04s | Train Loss: 1.7514, Train Acc: 54.12% | Val Loss: 3.0331, Val Acc: 34.35%




Epoch 49/100 | Time: 55.79s | Train Loss: 1.7485, Train Acc: 54.13% | Val Loss: 3.0332, Val Acc: 34.22%




Epoch 50/100 | Time: 55.01s | Train Loss: 1.7458, Train Acc: 54.09% | Val Loss: 3.0292, Val Acc: 34.19%

Reducing learning rate from 1.60e-06 to 3.20e-07




Epoch 51/100 | Time: 56.40s | Train Loss: 1.7511, Train Acc: 54.02% | Val Loss: 3.0348, Val Acc: 34.27%

Early stopping at epoch 51
Restored model weights from the end of the best epoch: 44
--- Training Finished ---

--- Starting Comprehensive Evaluation on Test Set ---


Evaluating: 100%|██████████| 129/129 [00:08<00:00, 15.85it/s]



--- Top-K Accuracy Results ---
Top-1 Accuracy: 34.15%
Top-3 Accuracy: 52.64%
Top-5 Accuracy: 61.71%

--- Additional Performance Metrics ---
Macro Average Precision: 0.3345
Macro Average Recall: 0.3422
Macro Average F1-Score: 0.3352
Weighted Average Precision: 0.3371
Weighted Average Recall: 0.3415
Weighted Average F1-Score: 0.3362

--- Model Confidence & Calibration Metrics ---
Average Prediction Confidence: 0.5021
Confidence Standard Deviation: 0.2522
Average Prediction Entropy: 1.8057
Expected Calibration Error: 0.1606

Final Test Loss: 3.0428


## Hybrid Transformer on ResNet-18 (RoPE)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import os
import requests
import zipfile
import shutil
from tqdm import tqdm
import time
import sys
import math
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

# --- Configuration Parameters ---
CONFIG = {
    "data_path": "./tiny-imagenet-200/",
    "dataset_url": "http://cs231n.stanford.edu/tiny-imagenet-200.zip",
    "batch_size": 128,
    "num_epochs": 100,
    "learning_rate": 0.001,
    "weight_decay": 1e-4,
    "split_seed": 42,  # Seed for creating reproducible dataset splits
    "train_split": 0.70,
    "val_split": 0.15,
    "test_split": 0.15,  # Explicitly define test split for clarity
    "num_classes": 200,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "checkpoint_path": "./best_model.pth",  # Added checkpoint path
}

# --- Transformer Configuration ---
TRANSFORMER_CONFIG = {
    "embedding_dim": 256,
    "nhead": 8,
    "num_encoder_layers": 3,
    "dim_feedforward": 512,
    "dropout": 0.1,
}

# ==========================================
# Callback Classes
# ==========================================

class Callback:
    """Base callback class"""
    def on_epoch_end(self, epoch, logs=None):
        pass

    def on_training_end(self):
        pass

class ReduceLROnPlateau(Callback):
    """Reduce learning rate when a metric has stopped improving"""
    def __init__(self, optimizer, monitor='val_accuracy', factor=0.2, patience=3, min_lr=1e-7, verbose=1):
        self.optimizer = optimizer
        self.monitor = monitor
        self.factor = factor
        self.patience = patience
        self.min_lr = min_lr
        self.verbose = verbose
        self.wait = 0
        self.best = None
        self.mode = 'min' if 'loss' in monitor else 'max'

    def on_epoch_end(self, epoch, logs=None):
        if logs is None:
            return

        current = logs.get(self.monitor)
        if current is None:
            return

        if self.best is None:
            self.best = current
            return

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                old_lr = self.optimizer.param_groups[0]['lr']
                new_lr = max(old_lr * self.factor, self.min_lr)
                if new_lr != old_lr:
                    self.optimizer.param_groups[0]['lr'] = new_lr
                    if self.verbose:
                        print(f"\nReducing learning rate from {old_lr:.2e} to {new_lr:.2e}")
                    self.wait = 0

class EarlyStopping(Callback):
    """Stop training when a monitored metric has stopped improving"""
    def __init__(self, monitor='val_accuracy', patience=7, restore_best_weights=True, verbose=1):
        self.monitor = monitor
        self.patience = patience
        self.restore_best_weights = restore_best_weights
        self.verbose = verbose
        self.wait = 0
        self.best = None
        self.best_weights = None
        self.mode = 'min' if 'loss' in monitor else 'max'
        self.stopped_epoch = 0

    def on_epoch_end(self, epoch, logs=None, model=None):
        if logs is None:
            return False

        current = logs.get(self.monitor)
        if current is None:
            return False

        if self.best is None:
            self.best = current
            if model is not None and self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
            return False

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            self.wait = 0
            if model is not None and self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                if self.verbose:
                    print(f"\nEarly stopping at epoch {epoch + 1}")
                return True
        return False

    def on_training_end(self, model=None):
        if self.stopped_epoch > 0 and self.verbose:
            print(f"Restored model weights from the end of the best epoch: {self.stopped_epoch + 1 - self.patience}")
        if model is not None and self.restore_best_weights and self.best_weights is not None:
            model.load_state_dict(self.best_weights)

class ModelCheckpoint(Callback):
    """Save the model after every epoch"""
    def __init__(self, filepath, save_best_only=True, monitor='val_accuracy', verbose=1):
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.monitor = monitor
        self.verbose = verbose
        self.best = None
        self.mode = 'min' if 'loss' in monitor else 'max'

    def on_epoch_end(self, epoch, logs=None, model=None):
        if logs is None or model is None:
            return

        current = logs.get(self.monitor)
        if current is None:
            return

        if not self.save_best_only:
            filepath = self.filepath.replace('.pth', f'_epoch_{epoch + 1}.pth')
            torch.save(model.state_dict(), filepath)
            if self.verbose:
                print(f"\nModel saved to {filepath}")
            return

        if self.best is None:
            self.best = current
            torch.save(model.state_dict(), self.filepath)
            if self.verbose:
                print(f"\nModel saved to {self.filepath}")
            return

        if self.mode == 'min':
            improved = current < self.best
        else:
            improved = current > self.best

        if improved:
            self.best = current
            torch.save(model.state_dict(), self.filepath)
            if self.verbose:
                print(f"\nModel saved to {self.filepath}")

# ==========================================
# Data Preprocessing
# ==========================================

def download_and_unzip(url, dest_path):
    """Downloads and unzips with progress and robust error handling."""
    if os.path.exists(dest_path):
        print("Dataset directory already exists.")
        return True

    zip_path = dest_path.rstrip('/') + ".zip"
    # The parent directory where we will extract the zip file
    extract_to_dir = os.path.abspath(os.path.join(dest_path, os.pardir))

    print(f"Downloading dataset from {url}...")
    try:
        with requests.get(url, stream=True, timeout=30) as r:
            r.raise_for_status()
            total_size = int(r.headers.get('content-length', 0))
            with open(zip_path, 'wb') as f, tqdm(
                total=total_size, unit='iB', unit_scale=True, desc="Tiny ImageNet"
            ) as progress_bar:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
                    progress_bar.update(len(chunk))

        print("Unzipping...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to_dir)
        return True

    except requests.exceptions.RequestException as e:
        print(f"\nError downloading file: {e}", file=sys.stderr)
        print("Please check your internet connection or the dataset URL.", file=sys.stderr)
        return False

    finally:
        if os.path.exists(zip_path):
            os.remove(zip_path)  # Clean up the zip file

def prepare_val_folder(val_dir):
    """Formats the validation folder to be compatible with ImageFolder."""
    val_annotations_path = os.path.join(val_dir, 'val_annotations.txt')
    if not os.path.exists(val_annotations_path):
        return  # Already formatted

    print("Formatting validation folder...")
    val_img_dir = os.path.join(val_dir, 'images')
    val_img_dict = {}
    with open(val_annotations_path, 'r') as f:
        for line in f.readlines():
            parts = line.strip().split('\t')
            img_name, class_id = parts[0], parts[1]
            val_img_dict[img_name] = class_id

    for img_name, class_id in val_img_dict.items():
        class_dir = os.path.join(val_dir, class_id)
        os.makedirs(class_dir, exist_ok=True)
        src_path = os.path.join(val_img_dir, img_name)
        if os.path.exists(src_path):
            shutil.move(src_path, os.path.join(class_dir, img_name))

    if os.path.exists(val_annotations_path):
        os.remove(val_annotations_path)
    if os.path.exists(val_img_dir):
        shutil.rmtree(val_img_dir)

def get_dataloaders(config):
    """Downloads, prepares, and splits the data, returning DataLoaders."""
    if not download_and_unzip(config["dataset_url"], config["data_path"]):
        return None  # Return None if download fails

    train_dir = os.path.join(config["data_path"], 'train')
    val_dir = os.path.join(config["data_path"], 'val')
    prepare_val_folder(val_dir)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    train_dataset = ImageFolder(train_dir, transform=transform)
    val_dataset_original = ImageFolder(val_dir, transform=transform)
    full_dataset = torch.utils.data.ConcatDataset([train_dataset, val_dataset_original])

    # Ensure reproducible splits
    torch.manual_seed(config["split_seed"])
    np.random.seed(config["split_seed"])

    total_size = len(full_dataset)
    train_size = int(config["train_split"] * total_size)
    val_size = int(config["val_split"] * total_size)
    test_size = int(config["test_split"] * total_size)

    # Adjust sizes to ensure they sum to total_size (handle rounding issues)
    actual_total = train_size + val_size + test_size
    if actual_total != total_size:
        test_size += (total_size - actual_total)

    print(f"Dataset splits - Train: {train_size}, Val: {val_size}, Test: {test_size}, Total: {total_size}")

    # Create reproducible generator for splitting
    generator = torch.Generator().manual_seed(config["split_seed"])
    train_subset, val_subset, test_subset = random_split(
        full_dataset, [train_size, val_size, test_size], generator=generator
    )

    # Set worker init function for reproducible DataLoader behavior
    def worker_init_fn(worker_id):
        np.random.seed(config["split_seed"] + worker_id)

    train_loader = DataLoader(
        train_subset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn,
        generator=torch.Generator().manual_seed(config["split_seed"])
    )
    val_loader = DataLoader(
        val_subset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )
    test_loader = DataLoader(
        test_subset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )

    return train_loader, val_loader, test_loader

# ==========================================
# Model Architecture and Training
# ==========================================

class RotaryEmbedding(nn.Module):
    """
    The Rotary Positional Embedding (RoPE) module.
    This implementation is based on the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding".
    """
    def __init__(self, dim, base=10000):
        super().__init__()
        self.dim = dim
        self.base = base
        # Create inverse frequencies and register as a buffer
        inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x):
        seq_len = x.shape[1]
        # Check if we need to recompute the cache
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()[:, None, :]
            self.sin_cached = emb.sin()[:, None, :]

        # Apply the rotation
        return self.cos_cached, self.sin_cached

def rotate_half(x):
    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=x1.ndim - 1)

def apply_rotary_pos_emb(q, k, cos, sin):
    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

class MultiHeadAttentionWithRoPE(nn.Module):
    """Custom Multi-Head Attention with RoPE support"""
    def __init__(self, d_model, nhead, dropout=0.1, batch_first=True):
        super().__init__()
        assert d_model % nhead == 0

        self.d_model = d_model
        self.nhead = nhead
        self.d_k = d_model // nhead
        self.batch_first = batch_first  # Add this attribute

        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        self.rotary_emb = RotaryEmbedding(self.d_k)

    def forward(self, query, key, value, key_padding_mask=None, attn_mask=None, need_weights=False, is_causal=False):
        batch_size, seq_len, _ = query.size()

        # Linear transformations
        Q = self.w_q(query)  # [batch_size, seq_len, d_model]
        K = self.w_k(key)    # [batch_size, seq_len, d_model]
        V = self.w_v(value)  # [batch_size, seq_len, d_model]

        # Reshape for multi-head attention
        Q = Q.view(batch_size, seq_len, self.nhead, self.d_k)  # [batch, seq, heads, d_k]
        K = K.view(batch_size, seq_len, self.nhead, self.d_k)
        V = V.view(batch_size, seq_len, self.nhead, self.d_k)

        # Apply RoPE to Q and K
        cos, sin = self.rotary_emb(query)
        Q, K = apply_rotary_pos_emb(Q, K, cos, sin)

        # Transpose for attention computation: [batch, heads, seq, d_k]
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Apply masks if provided
        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask == 0, -1e9)

        if key_padding_mask is not None:
            # key_padding_mask: [batch_size, seq_len], True for padding positions
            key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2)  # [batch, 1, 1, seq]
            scores = scores.masked_fill(key_padding_mask, -1e9)

        attention_weights = torch.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Apply attention to values
        attention_output = torch.matmul(attention_weights, V)  # [batch, heads, seq, d_k]

        # Concatenate heads
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model
        )

        # Final linear transformation
        output = self.w_o(attention_output)

        if need_weights:
            return output, attention_weights.mean(dim=1)  # Average over heads for compatibility
        return output

class TransformerEncoderLayerWithRoPE(nn.Module):
    """
    A custom Transformer Encoder Layer that incorporates RoPE.
    """
    def __init__(self, d_model, nhead, dim_feedforward, dropout, batch_first=True):
        super().__init__()
        self.self_attn = MultiHeadAttentionWithRoPE(d_model, nhead, dropout, batch_first)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = nn.ReLU(inplace=True)
        self.batch_first = batch_first  # Add this attribute

    def forward(self, src, src_mask=None, src_key_padding_mask=None, is_causal=False):
        # Self-attention with RoPE
        src2 = self.self_attn(src, src, src, key_padding_mask=src_key_padding_mask, attn_mask=src_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # Feed Forward
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

class NonResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(NonResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        return out

class ResNetTransformer(nn.Module):
    def __init__(self, block, layers, num_classes, t_config):
        super(ResNetTransformer, self).__init__()
        self.in_channels = 64
        self.embedding_dim = t_config["embedding_dim"]

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.projections = nn.ModuleList([
            self._create_projection(64, self.embedding_dim),
            self._create_projection(64, self.embedding_dim),
            self._create_projection(128, self.embedding_dim),
            self._create_projection(256, self.embedding_dim),
            self._create_projection(512, self.embedding_dim)
        ])

        # Create transformer encoder layers manually
        self.transformer_layers = nn.ModuleList([
            TransformerEncoderLayerWithRoPE(
                d_model=self.embedding_dim,
                nhead=t_config["nhead"],
                dim_feedforward=t_config["dim_feedforward"],
                dropout=t_config["dropout"],
                batch_first=True
            ) for _ in range(t_config["num_encoder_layers"])
        ])

        self.classifier = nn.Linear(self.embedding_dim, num_classes)

    def _create_projection(self, in_features, out_features):
        return nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(in_features, out_features)
        )

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        current_in_channels = self.in_channels
        for s in strides:
            layers.append(block(current_in_channels, out_channels, s))
            current_in_channels = out_channels
        self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        features = []
        x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
        features.append(x)
        x = self.layer1(x)
        features.append(x)
        x = self.layer2(x)
        features.append(x)
        x = self.layer3(x)
        features.append(x)
        x = self.layer4(x)
        features.append(x)

        tokens = [self.projections[i](feat) for i, feat in enumerate(features)]
        token_sequence = torch.stack(tokens, dim=1)

        # Apply transformer layers manually
        for layer in self.transformer_layers:
            token_sequence = layer(token_sequence)

        aggregated_vector = token_sequence.mean(dim=1)
        logits = self.classifier(aggregated_vector)
        return logits

def evaluate_model(model, data_loader, criterion, device):
    """Evaluates the model on a given dataset."""
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_loss = total_loss / total
    accuracy = 100 * correct / total
    return avg_loss, accuracy

def train_and_validate(model, train_loader, val_loader, config):
    """Main training loop with callbacks."""
    device = config["device"]
    model.to(device)

    # Set seeds for reproducible training
    torch.manual_seed(config["split_seed"])
    np.random.seed(config["split_seed"])
    if torch.cuda.is_available():
        torch.cuda.manual_seed(config["split_seed"])
        torch.cuda.manual_seed_all(config["split_seed"])

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])

    # Initialize callbacks
    callbacks = [
        ReduceLROnPlateau(
            optimizer=optimizer,
            monitor='val_accuracy',
            factor=0.2,
            patience=3,
            min_lr=1e-7,
            verbose=1
        ),
        EarlyStopping(
            monitor='val_accuracy',
            patience=7,
            restore_best_weights=True,
            verbose=1
        ),
        ModelCheckpoint(
            filepath=config["checkpoint_path"],
            save_best_only=True,
            monitor='val_accuracy',
            verbose=1
        )
    ]

    print("\n--- Starting Training ---")
    early_stop = False

    for epoch in range(config["num_epochs"]):
        if early_stop:
            break

        model.train()
        start_time = time.time()
        running_loss, train_correct, train_total = 0.0, 0, 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Train]", leave=False)
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            progress_bar.set_postfix(loss=loss.item())

        train_loss = running_loss / train_total
        train_acc = 100 * train_correct / train_total
        val_loss, val_acc = evaluate_model(model, val_loader, criterion, device)

        print(f"Epoch {epoch+1}/{config['num_epochs']} | Time: {time.time() - start_time:.2f}s | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Prepare logs for callbacks
        logs = {
            'train_loss': train_loss,
            'train_accuracy': train_acc,
            'val_loss': val_loss,
            'val_accuracy': val_acc
        }

        # Execute callbacks
        for callback in callbacks:
            if isinstance(callback, EarlyStopping):
                if callback.on_epoch_end(epoch, logs, model):
                    early_stop = True
                    break
            elif isinstance(callback, ModelCheckpoint):
                callback.on_epoch_end(epoch, logs, model)
            else:
                callback.on_epoch_end(epoch, logs)

    # Execute callback cleanup
    for callback in callbacks:
        if isinstance(callback, EarlyStopping):
            callback.on_training_end(model)
        else:
            callback.on_training_end()

    print("--- Training Finished ---\n")
    return model

# ==========================================
# Evaluation
# ==========================================

def calculate_top_k_accuracy(outputs, labels, k_values=[1, 3, 5]):
    """Calculate top-k accuracy for given k values."""
    batch_size = labels.size(0)
    _, pred = outputs.topk(max(k_values), 1, True, True)
    pred = pred.t()
    correct = pred.eq(labels.view(1, -1).expand_as(pred))

    top_k_accuracies = {}
    for k in k_values:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        top_k_accuracies[k] = correct_k.item() / batch_size

    return top_k_accuracies

def calculate_entropy(probs):
    """Calculate entropy of probability distributions."""
    # Add small epsilon to avoid log(0)
    epsilon = 1e-8
    probs = probs + epsilon
    entropy = -torch.sum(probs * torch.log(probs), dim=1)
    return entropy

def calculate_ece(confidences, accuracies, n_bins=10):
    """Calculate Expected Calibration Error."""
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = 0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (confidences > bin_lower.item()) & (confidences <= bin_upper.item())
        prop_in_bin = in_bin.float().mean()

        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

    return ece.item()

def comprehensive_evaluation(model, test_loader, criterion, device):
    """Comprehensive evaluation with all requested metrics."""
    model.eval()

    all_predictions = []
    all_labels = []
    all_confidences = []
    all_entropies = []
    all_top_k_results = {1: [], 3: [], 5: []}
    total_loss = 0.0
    total_samples = 0

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Convert outputs to probabilities
            probs = torch.softmax(outputs, dim=1)

            # Get predictions and confidences
            confidences, predictions = torch.max(probs, dim=1)

            # Calculate entropy
            entropies = calculate_entropy(probs)

            # Calculate top-k accuracies
            top_k_accs = calculate_top_k_accuracy(outputs, labels, [1, 3, 5])

            # Store results
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_confidences.extend(confidences.cpu().numpy())
            all_entropies.extend(entropies.cpu().numpy())

            for k in [1, 3, 5]:
                all_top_k_results[k].extend([top_k_accs[k]] * labels.size(0))

            total_loss += loss.item() * images.size(0)
            total_samples += images.size(0)

    # Convert to numpy arrays
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    all_confidences = np.array(all_confidences)
    all_entropies = np.array(all_entropies)

    # Calculate basic metrics
    avg_loss = total_loss / total_samples

    # Calculate top-k accuracies
    top1 = np.mean([pred == label for pred, label in zip(all_predictions, all_labels)])
    top3 = np.mean(all_top_k_results[3])
    top5 = np.mean(all_top_k_results[5])

    # Calculate precision, recall, f1-score
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='macro', zero_division=0
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='weighted', zero_division=0
    )

    # Calculate confidence and calibration metrics
    avg_confidence = np.mean(all_confidences)
    std_confidence = np.std(all_confidences)
    avg_entropy = np.mean(all_entropies)

    # Calculate ECE
    accuracies = (all_predictions == all_labels).astype(float)
    ece = calculate_ece(torch.tensor(all_confidences), torch.tensor(accuracies))

    return {
        'test_loss': avg_loss,
        'top1': top1,
        'top3': top3,
        'top5': top5,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'f1_macro': f1_macro,
        'precision_weighted': precision_weighted,
        'recall_weighted': recall_weighted,
        'f1_weighted': f1_weighted,
        'avg_confidence': avg_confidence,
        'std_confidence': std_confidence,
        'avg_entropy': avg_entropy,
        'ece': ece
    }

# ==========================================
# Main Execution
# ==========================================

if __name__ == '__main__':
    # Set global seeds for full reproducibility
    torch.manual_seed(CONFIG["split_seed"])
    np.random.seed(CONFIG["split_seed"])
    if torch.cuda.is_available():
        torch.cuda.manual_seed(CONFIG["split_seed"])
        torch.cuda.manual_seed_all(CONFIG["split_seed"])
        # Ensure deterministic behavior on CUDA
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    print(f"Using device: {CONFIG['device']}")
    print(f"Reproducibility seed: {CONFIG['split_seed']}")

    # Get data loaders
    dataloaders = get_dataloaders(CONFIG)
    if dataloaders is None:
        print("Could not prepare data. Halting execution.", file=sys.stderr)
        sys.exit(1)  # Exit if data preparation failed

    train_loader, val_loader, test_loader = dataloaders

    # Initialize and train model
    model = ResNetTransformer(
        block=NonResidualBlock,
        layers=[2, 2, 2, 2],
        num_classes=CONFIG["num_classes"],
        t_config=TRANSFORMER_CONFIG
    )
    trained_model = train_and_validate(model, train_loader, val_loader, CONFIG)

    # Comprehensive evaluation on the test set
    print("--- Starting Comprehensive Evaluation on Test Set ---")
    results = comprehensive_evaluation(trained_model, test_loader, nn.CrossEntropyLoss(), CONFIG["device"])

    # Print all requested metrics
    print("\n--- Top-K Accuracy Results ---")
    print(f"Top-1 Accuracy: {results['top1'] * 100:.2f}%")
    print(f"Top-3 Accuracy: {results['top3'] * 100:.2f}%")
    print(f"Top-5 Accuracy: {results['top5'] * 100:.2f}%")

    print("\n--- Additional Performance Metrics ---")
    print(f"Macro Average Precision: {results['precision_macro']:.4f}")
    print(f"Macro Average Recall: {results['recall_macro']:.4f}")
    print(f"Macro Average F1-Score: {results['f1_macro']:.4f}")
    print(f"Weighted Average Precision: {results['precision_weighted']:.4f}")
    print(f"Weighted Average Recall: {results['recall_weighted']:.4f}")
    print(f"Weighted Average F1-Score: {results['f1_weighted']:.4f}")

    print("\n--- Model Confidence & Calibration Metrics ---")
    print(f"Average Prediction Confidence: {results['avg_confidence']:.4f}")
    print(f"Confidence Standard Deviation: {results['std_confidence']:.4f}")
    print(f"Average Prediction Entropy: {results['avg_entropy']:.4f}")
    print(f"Expected Calibration Error: {results['ece']:.4f}")

    print(f"\nFinal Test Loss: {results['test_loss']:.4f}")
    print("===============================================")

Using device: cuda
Reproducibility seed: 42
Downloading dataset from http://cs231n.stanford.edu/tiny-imagenet-200.zip...


Tiny ImageNet: 100%|██████████| 248M/248M [00:25<00:00, 9.70MiB/s]


Unzipping...
Formatting validation folder...
Dataset splits - Train: 77000, Val: 16500, Test: 16500, Total: 110000

--- Starting Training ---




Epoch 1/100 | Time: 60.55s | Train Loss: 4.6645, Train Acc: 4.95% | Val Loss: 4.2716, Val Acc: 8.33%

Model saved to ./best_model.pth




Epoch 2/100 | Time: 57.77s | Train Loss: 4.0776, Train Acc: 10.82% | Val Loss: 3.9235, Val Acc: 13.03%

Model saved to ./best_model.pth




Epoch 3/100 | Time: 57.42s | Train Loss: 3.7706, Train Acc: 15.49% | Val Loss: 3.7226, Val Acc: 16.55%

Model saved to ./best_model.pth




Epoch 4/100 | Time: 57.51s | Train Loss: 3.5552, Train Acc: 18.93% | Val Loss: 3.5902, Val Acc: 19.30%

Model saved to ./best_model.pth




Epoch 5/100 | Time: 56.57s | Train Loss: 3.3872, Train Acc: 21.74% | Val Loss: 3.4518, Val Acc: 21.18%

Model saved to ./best_model.pth




Epoch 6/100 | Time: 56.57s | Train Loss: 3.2484, Train Acc: 24.31% | Val Loss: 3.3680, Val Acc: 22.65%

Model saved to ./best_model.pth




Epoch 7/100 | Time: 56.71s | Train Loss: 3.1402, Train Acc: 26.36% | Val Loss: 3.2524, Val Acc: 24.76%

Model saved to ./best_model.pth




Epoch 8/100 | Time: 56.11s | Train Loss: 3.0445, Train Acc: 28.03% | Val Loss: 3.1008, Val Acc: 27.65%

Model saved to ./best_model.pth




Epoch 9/100 | Time: 57.22s | Train Loss: 2.9604, Train Acc: 29.43% | Val Loss: 3.0473, Val Acc: 28.46%

Model saved to ./best_model.pth




Epoch 10/100 | Time: 56.78s | Train Loss: 2.8796, Train Acc: 31.11% | Val Loss: 3.0805, Val Acc: 28.42%




Epoch 11/100 | Time: 58.34s | Train Loss: 2.8078, Train Acc: 32.62% | Val Loss: 3.0366, Val Acc: 29.18%

Model saved to ./best_model.pth




Epoch 12/100 | Time: 58.39s | Train Loss: 2.7410, Train Acc: 33.67% | Val Loss: 3.0729, Val Acc: 29.58%

Model saved to ./best_model.pth




Epoch 13/100 | Time: 58.23s | Train Loss: 2.6813, Train Acc: 34.74% | Val Loss: 3.0411, Val Acc: 30.50%

Model saved to ./best_model.pth




Epoch 14/100 | Time: 56.43s | Train Loss: 2.6253, Train Acc: 35.90% | Val Loss: 2.9692, Val Acc: 31.07%

Model saved to ./best_model.pth




Epoch 15/100 | Time: 57.44s | Train Loss: 2.5626, Train Acc: 37.05% | Val Loss: 2.9295, Val Acc: 32.22%

Model saved to ./best_model.pth




Epoch 16/100 | Time: 56.82s | Train Loss: 2.5182, Train Acc: 37.88% | Val Loss: 2.9728, Val Acc: 31.28%




Epoch 17/100 | Time: 58.37s | Train Loss: 2.4728, Train Acc: 38.88% | Val Loss: 2.8813, Val Acc: 32.52%

Model saved to ./best_model.pth




Epoch 18/100 | Time: 57.14s | Train Loss: 2.4268, Train Acc: 39.81% | Val Loss: 2.8642, Val Acc: 32.82%

Model saved to ./best_model.pth




Epoch 19/100 | Time: 57.64s | Train Loss: 2.3776, Train Acc: 40.83% | Val Loss: 2.8674, Val Acc: 33.66%

Model saved to ./best_model.pth




Epoch 20/100 | Time: 57.06s | Train Loss: 2.3384, Train Acc: 41.47% | Val Loss: 2.8855, Val Acc: 33.61%




Epoch 21/100 | Time: 58.11s | Train Loss: 2.2981, Train Acc: 42.30% | Val Loss: 2.8658, Val Acc: 34.30%

Model saved to ./best_model.pth




Epoch 22/100 | Time: 57.19s | Train Loss: 2.2504, Train Acc: 43.17% | Val Loss: 2.8709, Val Acc: 34.16%




Epoch 23/100 | Time: 58.30s | Train Loss: 2.2124, Train Acc: 44.09% | Val Loss: 2.8842, Val Acc: 33.97%




Epoch 24/100 | Time: 57.14s | Train Loss: 2.1763, Train Acc: 44.55% | Val Loss: 2.8612, Val Acc: 34.35%

Model saved to ./best_model.pth




Epoch 25/100 | Time: 58.26s | Train Loss: 2.1434, Train Acc: 45.14% | Val Loss: 2.9040, Val Acc: 34.17%




Epoch 26/100 | Time: 57.44s | Train Loss: 2.0981, Train Acc: 46.28% | Val Loss: 2.9893, Val Acc: 33.67%




Epoch 27/100 | Time: 57.28s | Train Loss: 2.0647, Train Acc: 46.81% | Val Loss: 2.9261, Val Acc: 34.35%

Reducing learning rate from 1.00e-03 to 2.00e-04




Epoch 28/100 | Time: 57.69s | Train Loss: 1.6809, Train Acc: 56.13% | Val Loss: 2.7975, Val Acc: 38.16%

Model saved to ./best_model.pth




Epoch 29/100 | Time: 56.98s | Train Loss: 1.5801, Train Acc: 58.33% | Val Loss: 2.8540, Val Acc: 37.86%




Epoch 30/100 | Time: 57.67s | Train Loss: 1.5274, Train Acc: 59.53% | Val Loss: 2.9179, Val Acc: 37.65%




Epoch 31/100 | Time: 56.86s | Train Loss: 1.4863, Train Acc: 60.53% | Val Loss: 2.9663, Val Acc: 37.27%

Reducing learning rate from 2.00e-04 to 4.00e-05




Epoch 32/100 | Time: 58.22s | Train Loss: 1.3576, Train Acc: 63.78% | Val Loss: 2.9652, Val Acc: 37.76%




Epoch 33/100 | Time: 57.43s | Train Loss: 1.3328, Train Acc: 64.48% | Val Loss: 2.9880, Val Acc: 37.47%




Epoch 34/100 | Time: 58.39s | Train Loss: 1.3251, Train Acc: 64.68% | Val Loss: 3.0134, Val Acc: 37.61%

Reducing learning rate from 4.00e-05 to 8.00e-06




Epoch 35/100 | Time: 57.51s | Train Loss: 1.2946, Train Acc: 65.56% | Val Loss: 3.0092, Val Acc: 37.56%

Early stopping at epoch 35
Restored model weights from the end of the best epoch: 28
--- Training Finished ---

--- Starting Comprehensive Evaluation on Test Set ---


Evaluating: 100%|██████████| 129/129 [00:08<00:00, 14.41it/s]


--- Top-K Accuracy Results ---
Top-1 Accuracy: 37.86%
Top-3 Accuracy: 56.63%
Top-5 Accuracy: 65.21%

--- Additional Performance Metrics ---
Macro Average Precision: 0.3731
Macro Average Recall: 0.3780
Macro Average F1-Score: 0.3733
Weighted Average Precision: 0.3760
Weighted Average Recall: 0.3786
Weighted Average F1-Score: 0.3751

--- Model Confidence & Calibration Metrics ---
Average Prediction Confidence: 0.5694
Confidence Standard Deviation: 0.2576
Average Prediction Entropy: 1.5253
Expected Calibration Error: 0.1908

Final Test Loss: 2.9821



