# MLP-Mixer on CIFAR-10: A Minimalist Vision Transformer Alternative

## Introduction

Welcome to this comprehensive implementation of **MLP-Mixer**, a novel architecture that challenges the conventional wisdom of vision transformers. This notebook demonstrates a fully-featured training pipeline for the MLP-Mixer model on the CIFAR-10 dataset, showcasing how simple Multi-Layer Perceptrons (MLPs) can achieve competitive performance in computer vision tasks without convolutional layers or attention mechanisms.

### What is MLP-Mixer?

MLP-Mixer is a groundbreaking architecture introduced by Tolstikhin et al. (2021) that replaces both convolutions and self-attention mechanisms with two types of MLP layers:
1. **Token-mixing MLPs**: Operate across spatial locations (per patch)
2. **Channel-mixing MLPs**: Operate across channels (per location)

This "mixer" approach achieves surprisingly strong performance while maintaining architectural simplicity and computational efficiency.

### Key Features of This Implementation

This notebook provides a production-ready training framework with:

- **Complete MLP-Mixer Implementation**: From patch embedding to mixer blocks with residual connections
- **Advanced Training Techniques**:
  - Mixup data augmentation for improved generalization
  - Exponential Moving Average (EMA) for stable training
  - Cosine annealing with warmup learning rate scheduling
  - AdamW optimizer with weight decay
- **Modern PyTorch Practices**:
  - `torch.compile()` integration for performance optimization
  - TF32 precision support on compatible GPUs
  - Efficient data loading with AutoAugment policies
- **Comprehensive Evaluation**:
  - Precision, recall, F1-score per class
  - Confusion matrix visualization
  - Model checkpointing and metrics logging

### Model Specifications

- **Architecture**: Tiny MLP-Mixer (< 1M parameters)
- **Input**: 32√ó32 RGB images (CIFAR-10)
- **Patch Size**: 4√ó4 (64 patches per image)
- **Embedding Dimension**: 160
- **Mixer Blocks**: 6 layers
- **Token/Channel Hidden Dimensions**: Dynamically scaled

### Why This Implementation Matters

This notebook serves as both an educational resource and a practical template for:
- Understanding the MLP-Mixer architecture in depth
- Learning modern PyTorch training best practices
- Experimenting with alternative vision architectures
- Building a foundation for custom computer vision projects

Whether you're a researcher exploring beyond transformers, a practitioner looking for efficient vision models, or a student learning about modern deep learning architectures, this implementation offers valuable insights into minimalist yet powerful approaches to computer vision.

Let's dive into the code and explore how simple MLPs can see!

## Imports

In [1]:
import argparse
import math
import os
import random
import copy
import warnings
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

# Optional: metrics & visualization
try:
    import matplotlib.pyplot as plt
    import seaborn as sns
    from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
    MATPLOTLIB_AVAILABLE = True
except ImportError:
    MATPLOTLIB_AVAILABLE = False
    print("[Warning] matplotlib/seaborn/sklearn unavailable ‚Äî text-only metrics.")

## Model Definitions

In [2]:
class PatchEmbedding(nn.Module):
    """Linear patch embedding: split image into non-overlapping patches and project to embedding space."""
    def __init__(self, img_size=32, patch_size=4, in_ch=3, embed_dim=160):
        super().__init__()
        assert img_size % patch_size == 0, "img_size must be divisible by patch_size"
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_dim = in_ch * patch_size * patch_size
        self.proj = nn.Linear(self.patch_dim, embed_dim)

    def forward(self, x):
        B, C, H, W = x.shape
        p = self.patch_size
        x = x.unfold(2, p, p).unfold(3, p, p)
        x = x.contiguous().view(B, C, H // p, W // p, p * p)
        x = x.permute(0, 2, 3, 1, 4).contiguous()
        x = x.view(B, -1, self.patch_dim)
        x = self.proj(x)
        return x


class MixerBlock(nn.Module):
    """One Mixer block: alternates token-mixing and channel-mixing MLPs with residual connections."""
    def __init__(self, num_patches, embed_dim, token_hidden, channel_hidden, drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.token_fc1 = nn.Linear(num_patches, token_hidden)
        self.token_fc2 = nn.Linear(token_hidden, num_patches)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.channel_fc1 = nn.Linear(embed_dim, channel_hidden)
        self.channel_fc2 = nn.Linear(channel_hidden, embed_dim)
        self.dropout = nn.Dropout(drop)

    def forward(self, x):
        # Token-mixing
        y = self.norm1(x)
        y = y.permute(0, 2, 1)
        y = F.gelu(self.token_fc1(y))
        y = self.token_fc2(y)
        y = y.permute(0, 2, 1)
        x = x + self.dropout(y)
        # Channel-mixing
        y = self.norm2(x)
        y = F.gelu(self.channel_fc1(y))
        y = self.channel_fc2(y)
        x = x + self.dropout(y)
        return x


class MLPMixerTiny(nn.Module):
    """Tiny MLP-Mixer for CIFAR-10 (<1M params)."""
    def __init__(self, img_size=32, patch_size=4, in_ch=3, embed_dim=160, num_blocks=6,
                 token_hidden_mul=1.0, channel_hidden_mul=3.0, num_classes=10, drop=0.0):
        super().__init__()
        self.patch = PatchEmbedding(img_size, patch_size, in_ch, embed_dim)
        num_patches = self.patch.num_patches
        token_hidden = max(1, int(num_patches * token_hidden_mul))
        channel_hidden = max(1, int(embed_dim * channel_hidden_mul))
        self.blocks = nn.ModuleList([
            MixerBlock(num_patches, embed_dim, token_hidden, channel_hidden, drop)
            for _ in range(num_blocks)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        x = x.mean(dim=1)
        x = self.head(x)
        return x

## Utilities

In [3]:
def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def mixup_data(x, y, alpha=1.0, device='cuda'):
    if alpha <= 0:
        return x, y, y, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size, device=device)
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


class ModelEMA:
    """EMA robust to torch.compile() via deepcopy and _orig_mod unwrapping."""
    def __init__(self, model, decay=0.9999, device='cuda'):
        self.decay = decay
        self.ema = self._clone_model(model).to(device)
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def _clone_model(self, model):
        if hasattr(model, "_orig_mod"):
            model = model._orig_mod
        m = copy.deepcopy(model)
        m.eval()
        return m

    def update(self, model):
        with torch.no_grad():
            model_state = model._orig_mod.state_dict() if hasattr(model, "_orig_mod") else model.state_dict()
            ema_state = self.ema.state_dict()
            for k, ema_param in ema_state.items():
                model_param = model_state[k]
                if ema_param.dtype.is_floating_point:
                    ema_param.mul_(self.decay).add_(model_param.detach(), alpha=1.0 - self.decay)
                else:
                    ema_param.copy_(model_param.detach())

## Training & Evaluation

In [4]:
def train_one_epoch(model, ema, loader, optimizer, device, epoch, args):
    model.train()
    total_loss = 0.0
    total = 0
    correct = 0
    pbar = tqdm(loader, desc=f"Epoch {epoch+1:03d} [Train]", leave=False)
    for x, y in pbar:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        if args.mixup:
            x, y_a, y_b, lam = mixup_data(x, y, alpha=args.mixup_alpha, device=device)
            outputs = model(x)
            loss = lam * F.cross_entropy(outputs, y_a) + (1 - lam) * F.cross_entropy(outputs, y_b)
        else:
            outputs = model(x)
            loss = F.cross_entropy(outputs, y)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        if ema is not None:
            ema.update(model)

        batch_size = x.size(0)
        total_loss += loss.item() * batch_size
        total += batch_size
        _, pred = outputs.max(1)
        if args.mixup:
            correct += (lam * pred.eq(y_a).sum().item() + (1 - lam) * pred.eq(y_b).sum().item())
        else:
            correct += pred.eq(y).sum().item()

        pbar.set_postfix({'loss': f"{loss.item():.4f}", 'acc': f"{100. * correct / total:.2f}%"})

    return total_loss / total, 100.0 * correct / total


@torch.no_grad()
def evaluate(model, loader, device, compute_metrics=False):
    model.eval()
    total = 0
    correct = 0
    total_loss = 0.0
    all_preds, all_labels = [], []

    pbar = tqdm(loader, desc="Evaluating", leave=False)
    for x, y in pbar:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        out = model(x)
        loss = F.cross_entropy(out, y)
        total_loss += loss.item() * x.size(0)
        total += x.size(0)
        _, pred = out.max(1)
        correct += pred.eq(y).sum().item()

        if compute_metrics:
            all_preds.append(pred.cpu())
            all_labels.append(y.cpu())
        pbar.set_postfix({'loss': f"{loss.item():.4f}", 'acc': f"{100. * correct / total:.2f}%"})

    avg_loss = total_loss / total
    acc = 100.0 * correct / total
    if compute_metrics:
        return avg_loss, acc, torch.cat(all_preds), torch.cat(all_labels)
    return avg_loss, acc

## Metrics & Visualization

In [5]:
def compute_final_metrics(preds, labels, num_classes=10):
    preds_np = preds.numpy()
    labels_np = labels.numpy()
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels_np, preds_np, labels=list(range(num_classes)), zero_division=0
    )
    cm = confusion_matrix(labels_np, preds_np, labels=list(range(num_classes)))
    return {
        'precision_per_class': precision,
        'recall_per_class': recall,
        'f1_per_class': f1,
        'precision_macro': precision.mean(),
        'recall_macro': recall.mean(),
        'f1_macro': f1.mean(),
        'confusion_matrix': cm
    }


def plot_confusion_matrix(cm, class_names, save_path=None):
    if not MATPLOTLIB_AVAILABLE:
        return
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

## Main

In [6]:
def main():
    parser = argparse.ArgumentParser(description="Train MLPMixerTiny on CIFAR-10")
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--batch-size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--weight-decay', type=float, default=1e-4)
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
    parser.add_argument('--mixup', action='store_true')
    parser.add_argument('--mixup-alpha', type=float, default=0.8)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--save-dir', type=str, default='./checkpoints')
    parser.add_argument('--resume', type=str, default='')
    args = parser.parse_args(args=[])

    # Reproducibility
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    device = torch.device(args.device)
    print("Using device:", device)

    # ----------------------------- TF32 CONFIG (NEW API, PyTorch >= 2.9) -----------------------------
    if torch.cuda.is_available():
        cap = torch.cuda.get_device_capability()
        if cap >= (8, 0):
            torch.backends.fp32_precision = "tf32"
            torch.backends.cuda.matmul.fp32_precision = "tf32"
            torch.backends.cudnn.fp32_precision = "tf32"
            torch.backends.cudnn.conv.fp32_precision = "tf32"
            print("TF32 enabled via new API (GPU capability:", cap, ")")
        else:
            torch.backends.fp32_precision = "ieee"
            torch.backends.cuda.matmul.fp32_precision = "ieee"
            torch.backends.cudnn.fp32_precision = "ieee"
            torch.backends.cudnn.conv.fp32_precision = "ieee"
            print("IEEE FP32 enforced (GPU capability:", cap, ")")
    else:
        print("CUDA not available ‚Äî running on CPU")

    # Suppress non-critical max_autotune_gemm warning
    warnings.filterwarnings("ignore", message="Not enough SMs to use max_autotune_gemm mode")

    # ----------------------------- DATA -----------------------------
    print("Loading CIFAR-10 dataset...")
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ])

    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    class_names = testset.classes

    # Use num_workers=2 to avoid DataLoader warning on smaller systems
    train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True,
                              num_workers=2, pin_memory=True)
    val_loader = DataLoader(testset, batch_size=256, shuffle=False,
                            num_workers=2, pin_memory=True)

    # ----------------------------- MODEL -----------------------------
    print("Building MLPMixerTiny...")
    model = MLPMixerTiny(
        img_size=32, patch_size=4, in_ch=3, embed_dim=160,
        num_blocks=6, token_hidden_mul=1.0, channel_hidden_mul=3.0,
        num_classes=10, drop=0.0
    ).to(device)

    total_params = count_parameters(model)
    print("Model built. Trainable params:", f"{total_params:,}", "(< 1M)")

    # torch.compile (with graceful fallback)
    if hasattr(torch, 'compile') and device.type == 'cuda':
        try:
            model = torch.compile(model, mode="reduce-overhead")
            print("Model compiled (mode='reduce-overhead')")
        except Exception as e:
            print("[Warning] torch.compile failed:", e, ". Using eager mode.")

    # ----------------------------- OPTIMIZER & SCHEDULER -----------------------------
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    def lr_lambda(epoch):
        warmup = 10
        if epoch < warmup:
            return (epoch + 1) / max(1, warmup)
        return 0.5 * (1.0 + math.cos(math.pi * (epoch - warmup) / (args.epochs - warmup)))
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # ----------------------------- EMA & CHECKPOINT -----------------------------
    ema = ModelEMA(model, decay=0.9999, device=device)
    os.makedirs(args.save_dir, exist_ok=True)
    dataset_name = "cifar10"
    best_model_path = os.path.join(args.save_dir, f"{dataset_name}_best.pt")

    # ----------------------------- RESUME -----------------------------
    start_epoch, best_acc = 0, 0.0
    if args.resume:
        if os.path.isfile(args.resume):
            print("Resuming from:", args.resume)
            ckpt = torch.load(args.resume, map_location=device)
            model.load_state_dict(ckpt['model_state'])
            optimizer.load_state_dict(ckpt['optimizer'])
            scheduler.load_state_dict(ckpt.get('scheduler', {}))
            start_epoch = ckpt.get('epoch', 0)
            best_acc = ckpt.get('val_acc', 0.0)
            print("Resuming from epoch", start_epoch, ", best val acc:", f"{best_acc:.2f}%")
            ema = ModelEMA(model, decay=0.9999, device=device)
        else:
            print("[Error] Checkpoint not found:", args.resume)
            return

    # ----------------------------- TRAINING LOOP -----------------------------
    print("\nStarting training from epoch", start_epoch+1, "to", args.epochs, "...")
    for epoch in range(start_epoch, args.epochs):
        train_loss, train_acc = train_one_epoch(model, ema, train_loader, optimizer, device, epoch, args)
        scheduler.step()

        eval_model = ema.ema if ema is not None else model
        val_loss, val_acc = evaluate(eval_model, val_loader, device)

        lr = scheduler.get_last_lr()[0]
        print(f"Epoch {epoch+1:03d} | "
              f"Train Loss {train_loss:.4f} Acc {train_acc:.2f}% | "
              f"Val Loss {val_loss:.4f} Acc {val_acc:.2f}% | "
              f"LR {lr:.2e}")

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save({
                'epoch': epoch + 1,
                'model_state': eval_model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'val_acc': best_acc,
                'args': vars(args)
            }, best_model_path)
            print("New best model saved:", best_model_path, "(Val Acc:", f"{best_acc:.2f}%)")

    print("\nTraining finished. Best validation accuracy:", f"{best_acc:.2f}%")

## Final Evaluation

In [7]:
def evaluate_checkpoint(
    checkpoint_path: str,
    device: torch.device,
    dataset_name: str = "cifar10",
    save_dir: str = "./checkpoints",
    batch_size: int = 256,
    num_workers: int = 2,
    class_names: list = None,
    plot_cm: bool = True
) -> dict:
    """
    Standalone function to evaluate a saved checkpoint on CIFAR-10 test set.

    Args:
        checkpoint_path (str): Path to the .pt checkpoint file.
        device (torch.device): Device to run evaluation on.
        dataset_name (str): Name for labeling outputs (e.g., 'cifar10').
        save_dir (str): Directory to save metrics and plots.
        batch_size (int): Batch size for DataLoader.
        num_workers (int): Workers for DataLoader.
        class_names (list, optional): List of class names. Defaults to CIFAR-10.
        plot_cm (bool): Whether to generate and save confusion matrix plot.

    Returns:
        dict: Dictionary containing test loss, accuracy, and full metrics.
    """
    # ----------------------------- DEFAULTS -----------------------------
    if class_names is None:
        class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                       'dog', 'frog', 'horse', 'ship', 'truck']

    # ----------------------------- DATA -----------------------------
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ])
    testset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False,
                             num_workers=num_workers, pin_memory=True)

    # ----------------------------- MODEL (RAW, UNCOMPILED) -----------------------------
    print("Building raw model for evaluation...")
    model = MLPMixerTiny(
        img_size=32, patch_size=4, in_ch=3, embed_dim=160,
        num_blocks=6, token_hidden_mul=1.0, channel_hidden_mul=3.0,
        num_classes=10, drop=0.0
    ).to(device)

    # Load checkpoint
    print(f"Loading checkpoint: {checkpoint_path}")
    ckpt = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(ckpt['model_state'])
    model.eval()

    # Optional: compile for faster inference (safe fallback)
    if hasattr(torch, 'compile') and device.type == 'cuda':
        try:
            model = torch.compile(model, mode="default")
            print("Model compiled for inference (mode='default').")
        except Exception as e:
            print("[Warning] torch.compile failed during eval:", e)

    # ----------------------------- EVALUATION -----------------------------
    print("Running evaluation on test set...")
    test_loss, test_acc, preds, labels = evaluate(model, test_loader, device, compute_metrics=True)
    print(f"\n‚úÖ Test Results ‚Üí Loss: {test_loss:.4f}, Accuracy: {test_acc:.2f}%")

    # ----------------------------- METRICS -----------------------------
    metrics = compute_final_metrics(preds, labels, num_classes=10)

    # ----------------------------- SAVE RESULTS -----------------------------
    os.makedirs(save_dir, exist_ok=True)

    # Text metrics
    metrics_path = os.path.join(save_dir, f"{dataset_name}_final_metrics.txt")
    with open(metrics_path, 'w') as f:
        f.write(f"Checkpoint: {checkpoint_path}\n")
        f.write(f"Dataset: {dataset_name}\n")
        f.write(f"Test Accuracy: {test_acc:.2f}%\n")
        f.write(f"Test Loss: {test_loss:.4f}\n\n")

        f.write(f"Macro Precision: {metrics['precision_macro']:.4f}\n")
        f.write(f"Macro Recall:    {metrics['recall_macro']:.4f}\n")
        f.write(f"Macro F1:        {metrics['f1_macro']:.4f}\n\n")

        f.write("Per-class metrics:\n")
        for i, cls in enumerate(class_names):
            p, r, f1 = (
                metrics['precision_per_class'][i],
                metrics['recall_per_class'][i],
                metrics['f1_per_class'][i]
            )
            f.write(f"{cls:10}: P={p:.4f}, R={r:.4f}, F1={f1:.4f}\n")
        f.write(f"\nConfusion Matrix:\n{metrics['confusion_matrix']}")
    print("üìù Metrics saved to:", metrics_path)

    # Confusion matrix plot
    if plot_cm and MATPLOTLIB_AVAILABLE:
        cm_path = os.path.join(save_dir, f"{dataset_name}_confusion_matrix.png")
        plot_confusion_matrix(metrics['confusion_matrix'], class_names, save_path=cm_path)
        print("üñºÔ∏è  Confusion matrix saved to:", cm_path)
    elif plot_cm:
        print("[Info] Skipping confusion matrix plot (matplotlib not available).")

    # Return structured results
    results = {
        'test_loss': test_loss,
        'test_acc': test_acc,
        'precision_macro': metrics['precision_macro'],
        'recall_macro': metrics['recall_macro'],
        'f1_macro': metrics['f1_macro'],
        'precision_per_class': metrics['precision_per_class'],
        'recall_per_class': metrics['recall_per_class'],
        'f1_per_class': metrics['f1_per_class'],
        'confusion_matrix': metrics['confusion_matrix'],
        'checkpoint_path': checkpoint_path,
        'metrics_file': metrics_path,
    }
    return results

## Training

In [None]:
if __name__ == '__main__':
    main()

## Evaluation

In [None]:
import torch
from torchvision import datasets, transforms
import os # Import os for path joining

# Re-initialize device (was defined in main() scope)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Re-create testset to get class_names (was defined in main() scope)
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# Define save_dir and checkpoint_path (derived from args in main())
save_dir = './checkpoints'
dataset_name = "cifar10"
checkpoint_file = os.path.join(save_dir, f"{dataset_name}_best.pt")

# ----------------------------- FINAL EVALUATION -----------------------------
print("\nFinal evaluation on test set...")
results = evaluate_checkpoint(
    checkpoint_path=checkpoint_file,
    device=device,
    dataset_name=dataset_name,
    save_dir=save_dir,
    batch_size=256,
    class_names=testset.classes  # Now testset.classes is available
)
print("Macro Avg ‚Üí P:", f"{results['precision_macro']:.4f},",
      "R:", f"{results['recall_macro']:.4f},",
      "F1:", f"{results['f1_macro']:.4f}")