# Setup and GPU or CPU fallback

In [None]:
"""
Deep Learning Homework 6 - CIFAR-100 with ViT, ResNet-18, and Swin Transformer
Alex Ayerbe

This code implements and compares several models on the CIFAR-100 dataset:
1. ResNet-18 (pretrained)
2. Vision Transformer (ViT) with various configurations
3. Swin Transformer (pretrained and from scratch)

Requirements:
- PyTorch
- torchvision
- torchinfo (for FLOPs calculation)
- transformers (Hugging Face, for Swin models)
- tqdm (for progress bars)
"""

import os
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
from torchinfo import summary
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import SwinForImageClassification

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# CPU Fallback option for device configuration.
if torch.cuda.is_available():
    # Device configuration
    device = torch.device('cuda')
    print(f"CUDA is available. Using device: {device}")
else:
    # Device configuration
    device = torch.device("cpu")
    print("CUDA not available. Using CPU.")

CUDA is available. Using device: cuda


# Loading CIFAR100 Dataset

In [None]:
def get_cifar100_loaders(batch_size=64, num_workers=2):
    """
    Load CIFAR-100 dataset and create data loaders

    Args:
        batch_size: Batch size for training and testing
        num_workers: Number of workers for data loading

    Returns:
        train_loader, test_loader: DataLoader objects for training and testing
    """
    # Data transforms
    # For ResNet and custom ViT, we'll use standard transforms
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    # Load CIFAR-100 datasets
    train_dataset = torchvision.datasets.CIFAR100(
        root='./data',
        train=True,
        download=True,
        transform=transform_train
    )

    test_dataset = torchvision.datasets.CIFAR100(
        root='./data',
        train=False,
        download=True,
        transform=transform_test
    )

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers
    )

    return train_loader, test_loader

def get_cifar100_swin_loaders(batch_size=32, num_workers=2, img_size=224):
    """
    Load CIFAR-100 dataset with transforms suitable for Swin Transformer

    Args:
        batch_size: Batch size for training and testing
        num_workers: Number of workers for data loading
        img_size: Input image size expected by Swin Transformer

    Returns:
        train_loader, test_loader: DataLoader objects for training and testing
    """
    # Since Swin models are pretrained on ImageNet with 224x224 images,
    # we need to resize CIFAR-100 images to 224x224
    transform_train = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    transform_test = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    # Load CIFAR-100 datasets
    train_dataset = torchvision.datasets.CIFAR100(
        root='./data',
        train=True,
        download=True,
        transform=transform_train
    )

    test_dataset = torchvision.datasets.CIFAR100(
        root='./data',
        train=False,
        download=True,
        transform=transform_test
    )

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers
    )

    return train_loader, test_loader

# PROBLEM 1: ResNet-18 and ViT Models

## ResNet-18 Baseline

In [None]:
class ResNet18Model(nn.Module):
    """
    ResNet-18 model adapted for CIFAR-100
    """
    def __init__(self, num_classes=100):
        super(ResNet18Model, self).__init__()
        self.model = resnet18(pretrained=True)
        # Change the first conv layer to accept 32x32 images
        self.model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.model.maxpool = nn.Identity()  # Remove maxpool as it's not needed for small images
        # Adjust final fc layer for CIFAR-100 classes
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x):
        return self.model(x)

def train_resnet18():
    """
    Train and evaluate ResNet-18 on CIFAR-100

    Returns:
        dict: Results including accuracy, training time, etc.
    """
    print("Training ResNet-18 baseline on CIFAR-100")

    # Hyperparameters
    batch_size = 64
    num_epochs = 20
    learning_rate = 0.001

    # Load data
    train_loader, test_loader = get_cifar100_loaders(batch_size)

    # Initialize model
    model = ResNet18Model().to(device)

    # Calculate parameters and FLOPs
    model_stats = summary(model, input_size=(batch_size, 3, 32, 32), verbose=0)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")
    print(f"FLOPs per forward pass: {model_stats.total_mult_adds:,}")

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Training loop
    total_start_time = time.time()
    test_accuracies = []

    for epoch in range(num_epochs):
        model.train()
        epoch_start_time = time.time()
        running_loss = 0.0
        correct = 0
        total = 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for batch_idx, (inputs, targets) in enumerate(progress_bar):
            inputs, targets = inputs.to(device), targets.to(device)

            # Forward pass
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Backward and optimize
            loss.backward()
            optimizer.step()

            # Track statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar.set_postfix({
                'loss': running_loss/(batch_idx+1),
                'acc': 100.*correct/total
            })

        # Measure epoch time
        epoch_time = time.time() - epoch_start_time

        # Test accuracy
        test_acc = evaluate_model(model, test_loader)
        test_accuracies.append(test_acc)

        print(f"Epoch {epoch+1}/{num_epochs} - Time: {epoch_time:.2f}s - Test Acc: {test_acc:.2f}%")

    # Calculate total training time
    total_training_time = time.time() - total_start_time

    # Save results
    results = {
        'model': 'ResNet-18',
        'total_params': total_params,
        'flops': model_stats.total_mult_adds,
        'total_training_time': total_training_time,
        'avg_epoch_time': total_training_time / num_epochs,
        'test_accuracies': test_accuracies,
        'final_test_accuracy': test_accuracies[-1]
    }

    return results

## Custom Vision Transformer (ViT)

In [None]:
class PatchEmbedding(nn.Module):
    """
    Split the image into patches and linearly embed them
    """
    def __init__(self, img_size=32, patch_size=8, in_channels=3, embed_dim=256):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        # Create projection layer
        self.proj = nn.Conv2d(
            in_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        # (batch_size, channels, height, width) -> (batch_size, embed_dim, height/patch_size, width/patch_size)
        x = self.proj(x)
        # Flatten patches and transpose
        # (batch_size, embed_dim, height/patch_size, width/patch_size) -> (batch_size, embed_dim, n_patches)
        x = x.flatten(2)
        # (batch_size, embed_dim, n_patches) -> (batch_size, n_patches, embed_dim)
        x = x.transpose(1, 2)
        return x

class MultiHeadSelfAttention(nn.Module):
    """
    Multi-head Self-Attention module
    """
    def __init__(self, embed_dim=256, num_heads=4, dropout=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        # Define Q, K, V projections
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Input: (batch_size, n_patches+1, embed_dim)
        batch_size, n_tokens, embed_dim = x.shape

        # QKV projection
        qkv = self.qkv(x)
        qkv = qkv.reshape(batch_size, n_tokens, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch_size, num_heads, n_tokens, head_dim)

        # Separate Q, K, V
        q, k, v = qkv[0], qkv[1], qkv[2]  # Each: (batch_size, num_heads, n_tokens, head_dim)

        # Scaled dot-product attention
        # matmul: (batch_size, num_heads, n_tokens, head_dim) @ (batch_size, num_heads, head_dim, n_tokens) -> (batch_size, num_heads, n_tokens, n_tokens)
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)

        # Apply attention to V
        # matmul: (batch_size, num_heads, n_tokens, n_tokens) @ (batch_size, num_heads, n_tokens, head_dim) -> (batch_size, num_heads, n_tokens, head_dim)
        x = attn @ v

        # Reshape and project
        # (batch_size, num_heads, n_tokens, head_dim) -> (batch_size, n_tokens, num_heads, head_dim)
        x = x.transpose(1, 2).contiguous()
        # (batch_size, n_tokens, num_heads, head_dim) -> (batch_size, n_tokens, embed_dim)
        x = x.reshape(batch_size, n_tokens, embed_dim)

        x = self.proj(x)
        x = self.dropout(x)

        return x

class MLP(nn.Module):
    """
    Multi-Layer Perceptron for Vision Transformer
    """
    def __init__(self, in_features, hidden_features, out_features, dropout=0.0):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class TransformerEncoderBlock(nn.Module):
    """
    Transformer Encoder Block
    """
    def __init__(self, embed_dim=256, num_heads=4, mlp_ratio=4, dropout=0.0):
        super().__init__()
        # Layer Normalization before attention
        self.norm1 = nn.LayerNorm(embed_dim)
        # Multi-Head Self-Attention
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        # Layer Normalization before MLP
        self.norm2 = nn.LayerNorm(embed_dim)
        # MLP
        self.mlp = MLP(
            in_features=embed_dim,
            hidden_features=int(embed_dim * mlp_ratio),
            out_features=embed_dim,
            dropout=dropout
        )

    def forward(self, x):
        # Apply pre-norm for attention
        attn_output = self.attn(self.norm1(x))
        # Residual connection
        x = x + attn_output
        # Apply pre-norm for MLP
        mlp_output = self.mlp(self.norm2(x))
        # Residual connection
        x = x + mlp_output
        return x

class VisionTransformer(nn.Module):
    """
    Vision Transformer (ViT) model
    """
    def __init__(
        self,
        img_size=32,
        patch_size=8,
        in_channels=3,
        num_classes=100,
        embed_dim=256,
        depth=4,
        num_heads=4,
        mlp_ratio=4,
        dropout=0.1
    ):
        super().__init__()
        self.patch_embed = PatchEmbedding(
            img_size=img_size,
            patch_size=patch_size,
            in_channels=in_channels,
            embed_dim=embed_dim
        )

        # Number of patches
        self.n_patches = self.patch_embed.n_patches

        # Add class token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # Positional embedding
        self.pos_embed = nn.Parameter(
            torch.zeros(1, self.n_patches + 1, embed_dim)
        )

        # Dropout after embedding
        self.pos_drop = nn.Dropout(dropout)

        # Transformer Encoder
        self.transformer_blocks = nn.ModuleList([
            TransformerEncoderBlock(
                embed_dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                dropout=dropout
            )
            for _ in range(depth)
        ])

        # Layer Normalization
        self.norm = nn.LayerNorm(embed_dim)

        # Classification head
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialize positional embeddings and class token
        self._init_weights()

    def _init_weights(self):
        # Initialize positional embeddings
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        # Initialize class token
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        # Apply general weight initialization to all linear layers
        self.apply(self._init_weights_general)

    def _init_weights_general(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.zeros_(m.bias)
            nn.init.ones_(m.weight)

    def forward(self, x):
        # Get batch size
        batch_size = x.shape[0]

        # Create patch embeddings
        x = self.patch_embed(x)  # (batch_size, n_patches, embed_dim)

        # Prepend class token
        cls_token = self.cls_token.expand(batch_size, -1, -1)  # (batch_size, 1, embed_dim)
        x = torch.cat((cls_token, x), dim=1)  # (batch_size, n_patches+1, embed_dim)

        # Add positional embedding
        x = x + self.pos_embed  # (batch_size, n_patches+1, embed_dim)
        x = self.pos_drop(x)

        # Apply transformer blocks
        for block in self.transformer_blocks:
            x = block(x)

        # Layer normalization
        x = self.norm(x)

        # Take the class token representation
        x = x[:, 0]  # (batch_size, embed_dim)

        # Classification
        x = self.head(x)  # (batch_size, num_classes)

        return x

def train_vit(config):
    """
    Train and evaluate Vision Transformer on CIFAR-100

    Args:
        config: Dict containing ViT configuration parameters

    Returns:
        dict: Results including accuracy, training time, etc.
    """
    # Extract config parameters
    patch_size = config['patch_size']
    embed_dim = config['embed_dim']
    depth = config['depth']
    num_heads = config['num_heads']
    mlp_ratio = config['mlp_ratio']

    print(f"Training ViT (patch_size={patch_size}, embed_dim={embed_dim}, depth/transformer layers={depth}, heads={num_heads}, MLP Ratio={mlp_ratio})")

    # Hyperparameters
    batch_size = 64
    num_epochs = 20
    learning_rate = 0.001

    # Load data
    train_loader, test_loader = get_cifar100_loaders(batch_size)

    # Initialize model
    model = VisionTransformer(
        img_size=32,
        patch_size=patch_size,
        in_channels=3,
        num_classes=100,
        embed_dim=embed_dim,
        depth=depth,
        num_heads=num_heads,
        mlp_ratio=mlp_ratio,
        dropout=0.1
    ).to(device)

    # Calculate parameters and FLOPs
    model_stats = summary(model, input_size=(batch_size, 3, 32, 32), verbose=0)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")
    print(f"FLOPs per forward pass: {model_stats.total_mult_adds:,}")

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Training loop
    total_start_time = time.time()
    test_accuracies = []

    for epoch in range(num_epochs):
        model.train()
        epoch_start_time = time.time()
        running_loss = 0.0
        correct = 0
        total = 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for batch_idx, (inputs, targets) in enumerate(progress_bar):
            inputs, targets = inputs.to(device), targets.to(device)

            # Forward pass
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Backward and optimize
            loss.backward()
            optimizer.step()

            # Track statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar.set_postfix({
                'loss': running_loss/(batch_idx+1),
                'acc': 100.*correct/total
            })

        # Measure epoch time
        epoch_time = time.time() - epoch_start_time

        # Test accuracy
        test_acc = evaluate_model(model, test_loader)
        test_accuracies.append(test_acc)

        print(f"Epoch {epoch+1}/{num_epochs} - Time: {epoch_time:.2f}s - Test Acc: {test_acc:.2f}%")

    # Calculate total training time
    total_training_time = time.time() - total_start_time

    # Save results
    config_str = f"ViT-p{patch_size}-e{embed_dim}-d{depth}-h{num_heads}"
    results = {
        'model': config_str,
        'patch_size': patch_size,
        'embed_dim': embed_dim,
        'depth': depth,
        'num_heads': num_heads,
        'mlp_ratio': mlp_ratio,
        'total_params': total_params,
        'flops': model_stats.total_mult_adds,
        'total_training_time': total_training_time,
        'avg_epoch_time': total_training_time / num_epochs,
        'test_accuracies': test_accuracies,
        'final_test_accuracy': test_accuracies[-1]
    }

    return results

# PROBLEM 2: Swin Transformer Models

## Fine-Tuning Pretrained Swin Transformers

In [None]:
def fine_tune_swin(model_name, num_epochs=5):
    """
    Fine-tune a pretrained Swin Transformer model on CIFAR-100

    Args:
        model_name: HuggingFace model name (e.g., "microsoft/swin-tiny-patch4-window7-224")
        num_epochs: Number of training epochs

    Returns:
        dict: Results including accuracy, training time, etc.
    """
    print(f"Fine-tuning {model_name} on CIFAR-100")

    # Hyperparameters
    batch_size = 32
    learning_rate = 2e-5

    # Load data
    train_loader, test_loader = get_cifar100_swin_loaders(batch_size)

    # Initialize model
    model = SwinForImageClassification.from_pretrained(
        model_name,
        num_labels=100,
        ignore_mismatched_sizes=True
    ).to(device)

    # Freeze backbone for fine-tuning
    for param in model.swin.parameters():
        param.requires_grad = False

    # Calculate parameters and FLOPs
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    # Get a sample input for FLOPs calculation
    sample_input = torch.randn(1, 3, 224, 224).to(device)
    model_stats = summary(model, input_data=sample_input, verbose=0)

    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"FLOPs per forward pass: {model_stats.total_mult_adds:,}")

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)

    # Training loop
    total_start_time = time.time()
    test_accuracies = []

    for epoch in range(num_epochs):
        model.train()
        epoch_start_time = time.time()
        running_loss = 0.0
        correct = 0
        total = 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for batch_idx, (inputs, targets) in enumerate(progress_bar):
            inputs, targets = inputs.to(device), targets.to(device)

            # Forward pass
            optimizer.zero_grad()
            outputs = model(inputs).logits
            loss = criterion(outputs, targets)

            # Backward and optimize
            loss.backward()
            optimizer.step()

            # Track statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar.set_postfix({
                'loss': running_loss/(batch_idx+1),
                'acc': 100.*correct/total
            })

        # Measure epoch time
        epoch_time = time.time() - epoch_start_time

        # Test accuracy
        test_acc = evaluate_swin_model(model, test_loader)
        test_accuracies.append(test_acc)

        print(f"Epoch {epoch+1}/{num_epochs} - Time: {epoch_time:.2f}s - Test Acc: {test_acc:.2f}%")

    # Calculate total training time
    total_training_time = time.time() - total_start_time

    # Get model name (tiny or small)
    model_size = "tiny" if "tiny" in model_name else "small"

    # Save results
    results = {
        'model': f"Swin-{model_size}-pretrained",
        'total_params': total_params,
        'trainable_params': trainable_params,
        'flops': model_stats.total_mult_adds,  # Add FLOPs to results
        'total_training_time': total_training_time,
        'avg_epoch_time': total_training_time / num_epochs,
        'test_accuracies': test_accuracies,
        'final_test_accuracy': test_accuracies[-1]
    }

    return results

## Fine-Tuning Pretrained Swin Transformers from Scratch

In [None]:
class PatchMerging(nn.Module):
    """
    Patch Merging Layer for Swin Transformer - reduces resolution
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = nn.LayerNorm(4 * dim)

    def forward(self, x):
        """
        x: (B, H*W, C)
        """
        B, L, C = x.shape
        H = W = int(L ** 0.5)

        # Reshape to (B, H, W, C)
        x = x.view(B, H, W, C)

        # Group 2x2 patches
        x0 = x[:, 0::2, 0::2, :]  # (B, H/2, W/2, C)
        x1 = x[:, 1::2, 0::2, :]  # (B, H/2, W/2, C)
        x2 = x[:, 0::2, 1::2, :]  # (B, H/2, W/2, C)
        x3 = x[:, 1::2, 1::2, :]  # (B, H/2, W/2, C)

        # Concatenate along feature dimension
        x = torch.cat([x0, x1, x2, x3], -1)  # (B, H/2, W/2, 4*C)

        # Flatten H and W
        x = x.view(B, -1, 4 * C)  # (B, H/2*W/2, 4*C)

        # Apply normalization and reduction
        x = self.norm(x)
        x = self.reduction(x)  # (B, H/2*W/2, 2*C)

        return x

class WindowAttention(nn.Module):
    """
    Window-based Multi-Head Self-Attention module with dynamic window size support
    """
    def __init__(self, dim, window_size=4, num_heads=4, qkv_bias=True, dropout=0.0):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # (window_height, window_width)
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # Define projections
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(dropout)

        # Define relative position bias table with a maximum size
        # We'll use a larger size to accommodate different window sizes
        max_window_size = 8  # Maximum window size we'll support
        self.max_window_size = max_window_size
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * max_window_size - 1) * (2 * max_window_size - 1), num_heads)
        )

        # Initialize bias table
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

        # Keep track of the current relative position index window size
        self.current_window_size = window_size

        # Initialize relative position indices for the initial window size
        self._init_rel_pos_index()

    def _init_rel_pos_index(self):
        """
        Calculate relative position indices for the current window size
        """
        # Get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size)
        coords_w = torch.arange(self.window_size)
        coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing='ij'))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww

        # Calculate relative coordinates between each pair of tokens
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2

        # Shift to start from 0
        relative_coords[:, :, 0] += self.max_window_size - 1  # Shift using max_window_size
        relative_coords[:, :, 1] += self.max_window_size - 1
        relative_coords[:, :, 0] *= 2 * self.max_window_size - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww

        self.register_buffer("relative_position_index", relative_position_index)
        # Update the tracked window size
        self.current_window_size = self.window_size

    def forward(self, x, mask=None):
        """
        x: (B*num_windows, window_size*window_size, C)
        mask: (num_windows, window_size*window_size, window_size*window_size) or None
        """
        B_, N, C = x.shape

        # Check if window size has changed, if so recalculate relative position index
        if self.window_size != self.current_window_size:
            self._init_rel_pos_index()

        # QKV projection: (B*num_windows, window_size*window_size, 3*dim)
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B*num_windows, num_heads, window_size*window_size, head_dim)

        # Separate Q, K, V
        q, k, v = qkv[0], qkv[1], qkv[2]  # each: (B*num_windows, num_heads, window_size*window_size, head_dim)

        # Scaled dot-product attention
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))  # (B*num_windows, num_heads, window_size*window_size, window_size*window_size)

        # Get appropriate relative position bias
        # We need to ensure the bias matches the current window size
        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index.view(-1)
        ].view(
            self.window_size * self.window_size,
            self.window_size * self.window_size,
            -1
        )  # window_size*window_size, window_size*window_size, num_heads

        # Permute to match attn dimensions
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # num_heads, window_size*window_size, window_size*window_size

        # Add bias to attention scores
        attn = attn + relative_position_bias.unsqueeze(0)  # This should now have compatible dimensions

        # Apply mask if provided
        if mask is not None:
            # Convert mask to float
            nW = mask.shape[0]  # num_windows
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)

        # Apply softmax
        attn = attn.softmax(dim=-1)

        # Apply attention to V
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # (B*num_windows, window_size*window_size, dim)

        # Output projection
        x = self.proj(x)
        x = self.proj_drop(x)

        return x


class SwinTransformerBlock(nn.Module):
    """
    Swin Transformer Block with improved robustness for small image sizes
    """
    def __init__(self, dim, num_heads, window_size=4, shift_size=0, mlp_ratio=4., dropout=0.0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        # Ensure shift size is less than window size
        self.shift_size = min(shift_size, window_size // 2) if shift_size > 0 else 0
        self.mlp_ratio = mlp_ratio

        # Layer Normalization
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        # Window Attention
        self.attn = WindowAttention(
            dim=dim,
            window_size=window_size,
            num_heads=num_heads,
            dropout=dropout
        )

        # MLP
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            out_features=dim,
            dropout=dropout
        )

        # Attention mask for SW-MSA (for shifted windows)
        self.register_buffer("attn_mask", None)

    def forward(self, x):
        """
        x: (B, H*W, C)
        """
        B, L, C = x.shape
        # Compute the side length of the feature map, assuming it's square
        H = W = int(L ** 0.5)

        # Ensure H and W are integers (L must be a perfect square)
        assert H * W == L, f"Input length {L} is not a perfect square, cannot reshape to square feature map"

        # Re-validate if current window size is appropriate for this feature map size
        if H < self.window_size:
            # Adjust window size dynamically if feature map is smaller than window_size
            # This is a safety check in case SwinTransformerStage didn't adjust it
            self.window_size = H
            self.shift_size = min(self.shift_size, self.window_size // 2)
            # Need to also adjust WindowAttention's window_size
            self.attn.window_size = self.window_size

        # Store shortcut connection
        shortcut = x

        # Apply first normalization
        x = self.norm1(x)

        # Reshape to (B, H, W, C) for spatial operations
        x = x.reshape(B, H, W, C)

        # Cyclic shift (for SW-MSA)
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))

            # Calculate attention mask (only needed for SW-MSA)
            if self.attn_mask is None or self.attn_mask.size(0) != ((H // self.window_size) * (W // self.window_size)):
                # Calculate mask for SW-MSA
                img_mask = torch.zeros((1, H, W, 1), device=x.device)
                h_slices = (slice(0, -self.window_size),
                           slice(-self.window_size, -self.shift_size),
                           slice(-self.shift_size, None))
                w_slices = (slice(0, -self.window_size),
                           slice(-self.window_size, -self.shift_size),
                           slice(-self.shift_size, None))
                cnt = 0
                for h in h_slices:
                    for w in w_slices:
                        img_mask[:, h, w, :] = cnt
                        cnt += 1

                # Windows partition on mask
                mask_windows = window_partition(img_mask, self.window_size)  # (num_windows*B, window_size, window_size, 1)
                mask_windows = mask_windows.reshape(-1, self.window_size * self.window_size)  # (num_windows*B, window_size*window_size)
                attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # (num_windows*B, window_size*window_size, window_size*window_size)
                attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
                self.attn_mask = attn_mask
        else:
            shifted_x = x
            self.attn_mask = None

        # # Check if the feature map size is divisible by window size
        # pad_h = pad_w = 0
        # if H % self.window_size != 0 or W % self.window_size != 0:
        #     # Pad to make divisible
        #     pad_h = (self.window_size - H % self.window_size) % self.window_size
        #     pad_w = (self.window_size - W % self.window_size) % self.window_size
        #     if pad_h > 0 or pad_w > 0:
        #         shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h))
        #         H_pad, W_pad = H + pad_h, W + pad_w
        #     else:
        #         H_pad, W_pad = H, W
        # else:
        #     H_pad, W_pad = H, W

        # Partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # (num_windows*B, window_size, window_size, C)
        x_windows = x_windows.reshape(-1, self.window_size * self.window_size, C)  # (num_windows*B, window_size*window_size, C)

        # W-MSA/SW-MSA - note that self.attn now returns only x, not (x, q, k, v)
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # (num_windows*B, window_size*window_size, C)

        # Merge windows
        attn_windows = attn_windows.reshape(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # (B, H_pad, W_pad, C)

        # # Remove padding if needed
        # if pad_h > 0 or pad_w > 0:
        #     shifted_x = shifted_x[:, :H, :W, :]

        # Reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        # Reshape back to (B, H*W, C)
        x = x.reshape(B, H * W, C)

        # First residual connection
        x = shortcut + x

        # Second normalization and MLP
        x = x + self.mlp(self.norm2(x))

        return x

class SwinTransformerStage(nn.Module):
    """
    Stage of Swin Transformer with dynamic window size adjustment
    """
    def __init__(self, dim, depth, num_heads, window_size=7, mlp_ratio=4., dropout=0.0):
        super().__init__()
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else window_size // 2,
                mlp_ratio=mlp_ratio,
                dropout=dropout
            )
            for i in range(depth)
        ])
        self.window_size = window_size

    def forward(self, x):
        B, L, C = x.shape
        H = W = int(L ** 0.5)

        # if spatial dims < 2×2, just pass through
        if H <= 1:
            return x

        # figure out what window size we *should* be using
        effective_ws = min(self.window_size, H)

        # **Update every block’s window_size *and* its attention module** if it’s changed**
        if effective_ws != self.window_size:
            for block in self.blocks:
                block.window_size = effective_ws
                if block.shift_size > 0:
                    block.shift_size = effective_ws // 2

                # propagate into the attention module
                block.attn.window_size = effective_ws
                # force it to rebuild its relative_position_index
                block.attn.current_window_size = effective_ws
                block.attn._init_rel_pos_index()

        # now run through the blocks
        for block in self.blocks:
            x = block(x)

        return x


class BasicSwinTransformer(nn.Module):
    """
    A simplified Swin Transformer for training from scratch on CIFAR-100.
    Designed specifically to handle the small 32x32 input size of CIFAR-100.
    """
    def __init__(
        self,
        img_size=32,
        patch_size=2,  # Smaller patch size for 32x32 images
        in_channels=3,
        num_classes=100,
        embed_dim=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        window_size=4,  # Smaller window size for 32x32 images
        mlp_ratio=4.,
        dropout=0.1
    ):
        super().__init__()
        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.window_size = window_size
        self.img_size = img_size

        # Check that window size can evenly divide the image patches
        patches_per_side = img_size // patch_size
        assert window_size <= patches_per_side, f"Window size ({window_size}) must be <= patches per side ({patches_per_side})"
        assert patches_per_side % window_size == 0, f"Patches per side ({patches_per_side}) must be divisible by window size ({window_size})"

        # Calculate the number of patches
        self.patches_resolution = patches_per_side
        self.num_patches = self.patches_resolution ** 2

        # Patch embedding
        self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

        # Check how many downsamplings we can do before resolution becomes too small
        # This ensures we don't create layers that would operate on feature maps smaller than window_size
        max_layers = 0
        curr_resolution = patches_per_side
        for i in range(len(depths)):
            if i > 0:  # After the first layer, we downsample
                curr_resolution = curr_resolution // 2
                if curr_resolution < 2:  # Stop if resolution gets too small
                    break
            max_layers += 1

        self.max_layers = max_layers

        # Reduce depths array if necessary to match max_layers
        if max_layers < len(depths):
            print(f"Warning: Reducing model depth from {len(depths)} to {max_layers} layers due to resolution constraints")
            depths = depths[:max_layers]
            num_heads = num_heads[:max_layers]

        # Layers
        self.stages = nn.ModuleList()
        self.patch_merging_layers = nn.ModuleList()

        # Current feature resolution
        curr_resolution = self.patches_resolution

        # Feature dimension at each stage
        curr_dim = embed_dim

        # Build stages - only up to max_layers
        for i_layer in range(len(depths)):
            # Create stage
            stage = SwinTransformerStage(
                dim=curr_dim,
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=window_size,
                mlp_ratio=mlp_ratio,
                dropout=dropout
            )
            self.stages.append(stage)

            # Add patch merging layer except for the last stage
            if i_layer < len(depths) - 1:
                # We need enough resolution to keep merging
                if curr_resolution >= 2:
                    print(f"Warning: Resolution too small to merge at stage {i_layer}")
                    # Skip merging, just adjust dimension
                    #merge_layer = nn.Linear(curr_dim, curr_dim * 2)
                    merge_layer = PatchMerging(dim=curr_dim)
                    curr_resolution //= 2  # Halve the resolution
                    curr_dim *= 2  # Double the feature dimension
                else:
                    merge_layer = nn.Linear(curr_dim, curr_dim * 2)
                    curr_dim *= 2  # Double the feature dimension

                self.patch_merging_layers.append(merge_layer)

        # Final normalization and classification head
        self.norm = nn.LayerNorm(curr_dim)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(curr_dim, num_classes)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        # x shape: [B, C, H, W]

        # Patch embedding: [B, C, H, W] -> [B, embed_dim, H//patch_size, W//patch_size]
        x = self.patch_embed(x)

        # Reshape for transformer: [B, embed_dim, H', W'] -> [B, H'*W', embed_dim]
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # [B, H*W, C]

        # Go through stages and patch merging layers
        for i, stage in enumerate(self.stages):
            # Apply transformer stage
            x = stage(x)  # [B, H*W, C]

            # Apply patch merging if not the last stage
            if i < len(self.patch_merging_layers):
                # Check if we need special handling for small feature maps
                curr_size = int((x.shape[1]) ** 0.5)

                if curr_size <= 2 and isinstance(self.patch_merging_layers[i], PatchMerging):
                    # Skip patch merging for very small feature maps, just use linear projection
                    lin = nn.Linear(x.shape[2], x.shape[2] * 2, device=x.device)
                    x = lin(x)
                else:
                    # Apply regular patch merging or linear layer
                    x = self.patch_merging_layers[i](x)

        # Normalization
        x = self.norm(x)  # [B, H*W, C]

        # Global pooling
        x = x.transpose(1, 2)  # [B, C, H*W]
        x = self.avgpool(x).flatten(1)  # [B, C]

        # Classification head
        x = self.head(x)  # [B, num_classes]

        return x

def window_partition(x, window_size):
    """
    Partition into non-overlapping windows.
    Args:
        x: (B, H, W, C)
        window_size: window size
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    # Ensure H and W are divisible by window_size
    assert H % window_size == 0, f"Height {H} not divisible by window size {window_size}"
    assert W % window_size == 0, f"Width {W} not divisible by window size {window_size}"

    # Reshape to group window pixels
    x = x.reshape(B, H // window_size, window_size, W // window_size, window_size, C)
    # Permute and reshape to get windows
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    """
    Reverse window partition.
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size: Window size
        H: Height of image
        W: Width of image
    Returns:
        x: (B, H, W, C)
    """
    # Check that H and W are divisible by window_size
    assert H % window_size == 0, f"Height {H} not divisible by window size {window_size}"
    assert W % window_size == 0, f"Width {W} not divisible by window size {window_size}"

    # Calculate batch size
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    # Reshape back to original image format
    x = windows.reshape(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(B, H, W, -1)
    return x

def train_swin_from_scratch():
    """
    Train and evaluate Swin Transformer from scratch on CIFAR-100

    Returns:
        dict: Results including accuracy, training time, etc.
    """
    print("Training Swin Transformer from scratch on CIFAR-100")

    # Hyperparameters
    batch_size = 32
    num_epochs = 5
    learning_rate = 2e-5

    # Load data
    train_loader, test_loader = get_cifar100_loaders(batch_size)

    # Initialize model with parameters optimized for CIFAR-100's small image size
    model = BasicSwinTransformer(
        img_size=32,
        patch_size=2, # Smaller patch size for 32x32 images
        in_channels=3,
        num_classes=100,
        embed_dim=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        window_size=4,  # Smaller window size for 32x32 images
        mlp_ratio=4.
    ).to(device)

    # Since we're training from scratch, we should actually train all parameters
    # DO NOT freeze parameters when training from scratch
    # for name, param in model.named_parameters():
    #     if 'head' not in name:
    #         param.requires_grad = False

    # Print model summary


    print(f"Swin Transformer architecture:")
    print(f"- Patch size: 2")
    print(f"- Window size: 4")
    print(f"- Patches resolution: {model.patches_resolution}")
    print(f"- Number of patches: {model.num_patches}")

    # Calculate parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

    # Get a sample input for FLOPs calculation
    sample_input = torch.randn(1, 3, 32, 32).to(device)
    #sample_input = torch.randn(batch_size, 3, 32, 32).to(device)
    model_stats = summary(model, input_data=sample_input, verbose=0)
    print(f"Total parameters: {total_params:,}")
    print(f"FLOPs per forward pass: {model_stats.total_mult_adds:,}")

    # Loss and optimizer (train all parameters)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    #optimizer = optim.Adam(
    #     filter(lambda p: p.requires_grad, model.parameters()),
    #     lr=learning_rate
    # )

    # Training loop
    total_start_time = time.time()
    test_accuracies = []

    for epoch in range(1, num_epochs+1):
        model.train()
        epoch_start_time = time.time()
        running_loss = 0.0
        correct = 0
        total = 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}")

        for batch_idx, (inputs, targets) in enumerate(progress_bar):
            inputs, targets = inputs.to(device), targets.to(device)

            # Forward pass
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Backward and optimize
            loss.backward()
            optimizer.step()

            # Track statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar.set_postfix({
                'loss': running_loss/(batch_idx+1),
                'acc': 100.*correct/total
            })

        # Measure epoch time
        epoch_time = time.time() - epoch_start_time
        #train_acc = 100. * correct / total
        #print(f"Epoch {epoch}: Loss = {running_loss/total:.4f}, Train Acc = {train_acc:.2f}%, Time = {epoch_time:.1f}s")

        # Test accuracy
        test_acc = evaluate_model(model, test_loader)
        test_accuracies.append(test_acc)

        #print(f"Epoch {epoch+1}/{num_epochs} - Time: {epoch_time:.2f}s - Test Acc: {test_acc:.2f}%")
        print(f"Epoch {epoch}/{num_epochs} - Time: {epoch_time:.2f}s - Test Acc: {test_acc:.2f}%")
        #print(f"  -> Test Accuracy after epoch {epoch}: {test_acc:.2f}%")

    # Calculate total training time
    total_training_time = time.time() - total_start_time

    # Save results
    results = {
        'model': 'Swin-from-scratch',
        'total_params': total_params,
        'trainable_params': trainable_params,
        'flops': model_stats.total_mult_adds,
        'total_training_time': total_training_time,
        'avg_epoch_time': total_training_time / num_epochs,
        'test_accuracies': test_accuracies,
        'final_test_accuracy': test_accuracies[-1]
    }

    return results

# Evaluating Results

In [None]:
def evaluate_model(model, test_loader):
    """
    Evaluate a model on the test set

    Args:
        model: PyTorch model
        test_loader: DataLoader for test data

    Returns:
        float: Accuracy percentage
    """
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    accuracy = 100. * correct / total
    return accuracy

def evaluate_swin_model(model, test_loader):
    """
    Evaluate a Swin Transformer model from HuggingFace on the test set

    Args:
        model: PyTorch model
        test_loader: DataLoader for test data

    Returns:
        float: Accuracy percentage
    """
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs).logits
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    accuracy = 100. * correct / total
    return accuracy

# Model Training Functions

In [None]:
def run_problem1_training_best():
    """
    Run experiments for Problem 1: ResNet-18 vs ViT

    Returns:
        list: Results for all models
    """
    print("Running experiments for Problem 1: ResNet-18 vs ViT")

    results = []

    # ResNet-18 baseline
    resnet_results = train_resnet18()
    results.append(resnet_results)

    # ViT configurations
    vit_configs = [
        # Patch size, embed dim, depth, num_heads, mlp_ratio
        {'patch_size': 4, 'embed_dim': 256, 'depth': 4, 'num_heads': 4, 'mlp_ratio': 4},
        {'patch_size': 4, 'embed_dim': 512, 'depth': 4, 'num_heads': 8, 'mlp_ratio': 4},
        {'patch_size': 8, 'embed_dim': 256, 'depth': 4, 'num_heads': 4, 'mlp_ratio': 4},
        {'patch_size': 8, 'embed_dim': 512, 'depth': 8, 'num_heads': 8, 'mlp_ratio': 4},
    ]

    for config in vit_configs:
        vit_results = train_vit(config)
        results.append(vit_results)

    return results

def run_problem1_training_full():
    """
    Run experiments for Problem 1: ResNet-18 vs ViT

    Returns:
        list: Results for all models
    """
    print("Running experiments for Problem 1: ResNet-18 vs ViT")

    results = []

    # ResNet-18 baseline
    resnet_results = train_resnet18()
    results.append(resnet_results)

     # ViT configurations to sweep
    patch_sizes  = [4, 8]
    embed_dims   = [256, 512]
    depths       = [4, 8]
    num_heads    = [2, 4]
    mlp_ratios   = [2, 4]  # 2× or 4× embedding dimension

    vit_configs = [
        {
            'patch_size': p,
            'embed_dim': e,
            'depth': d,
            'num_heads': h,
            'mlp_ratio': r
        }
        for p in patch_sizes
        for e in embed_dims
        for d in depths
        for h in num_heads
        for r in mlp_ratios
    ]

    for config in vit_configs:
        vit_results = train_vit(config)
        results.append(vit_results)

    return results

def run_problem2_training_full():
    """
    Run experiments for Problem 2: Swin Transformer

    Returns:
        list: Results for all models
    """
    print("Running experiments for Problem 2: Swin Transformer")

    results = []

    # Fine-tune Swin-Tiny
    swin_tiny_results = fine_tune_swin("microsoft/swin-tiny-patch4-window7-224")
    results.append(swin_tiny_results)

    # Fine-tune Swin-Small
    swin_small_results = fine_tune_swin("microsoft/swin-small-patch4-window7-224")
    results.append(swin_small_results)

    # Train Swin from scratch
    swin_scratch_results = train_swin_from_scratch()
    results.append(swin_scratch_results)

    return results

# Exporting Results Functions

In [None]:
def print_results_table(results, title, csv_filename):
    """
    Store results in a pandas DataFrame and print the table.

    Args:
        results: List of result dictionaries
        title: Table title
    """
    print(f"\n{title}")
    print("=" * 130)

    # Collect rows
    table_rows = []
    for result in results:
        model_name = result['model']
        total_params = f"{result['total_params']:,}"
        flops = f"{result['flops']:,}" if 'flops' in result else "N/A"
        training_time = f"{result['total_training_time']:.2f}s"
        epoch_time = f"{result['avg_epoch_time']:.2f}s"
        final_acc = f"{result['final_test_accuracy']:.2f}%"

        test_accuracies = result.get('test_accuracies', [])
        if len(test_accuracies) >= 10:
            acc_10 = f"{test_accuracies[9]:.2f}%"
        elif test_accuracies:
            acc_10 = f"{max(test_accuracies):.2f}%"
        else:
            acc_10 = "N/A"

        row = {
            "Model": model_name,
            "Params": total_params,
            "FLOPs": flops,
            "Training Time": training_time,
            "Epoch Time": epoch_time,
            "Final Acc": final_acc,
            "10-Epoch Acc": acc_10
        }
        table_rows.append(row)

    # Convert to DataFrame and print
    df = pd.DataFrame(table_rows)
    print(df.to_string(index=False))
    print("=" * 130)

    if csv_filename:
        os.makedirs(os.path.dirname(csv_filename), exist_ok=True)
        df.to_csv(csv_filename, index=False)
        print(f"Results saved to: {csv_filename}")

def plot_test_accuracies(results, title, filename):
    """
    Plot test accuracies over epochs with the legend outside the plot.

    Args:
        results: List of result dictionaries
        title: Plot title
        filename: Output filename
    """
    plt.figure(figsize=(10, 6))

    # Plot each model’s accuracy curve
    for result in results:
        epochs = range(1, len(result['test_accuracies']) + 1)
        plt.plot(epochs, result['test_accuracies'], marker='o', label=result['model'])

    # Labels & grid
    plt.title(title)
    plt.xlabel('Epoch')
    plt.ylabel('Test Accuracy (%)')
    plt.grid(True, linestyle='--', alpha=0.5)

    # Layout + legend outside to the right
    plt.tight_layout()
    plt.legend(
        loc='center left',
        bbox_to_anchor=(1.02, 0.5),
        fontsize='small',
        title='Model',
        frameon=True
    )

    # Save with extra padding to include the legend
    plt.savefig(filename, bbox_inches='tight', dpi=150)
    plt.close()


# Main Execution

In [None]:
def main():
    # Create output directory
    os.makedirs('results', exist_ok=True)

    # Problem 1: ResNet-18 vs ViT Full (every combination == 32)
    p1_results_full = run_problem1_training_full()
    print_results_table(p1_results_full, "Problem 1 Results: ResNet-18 vs ViT Full", "results/run_problem1_training_full_results.csv")
    plot_test_accuracies(p1_results_full, "ResNet-18 vs ViT Test Accuracy Full", "results/resnet_vs_vit_full.png")

    # Problem 1: ResNet-18 vs ViT Best (theortically ideal combinations)
    p1_results_best = run_problem1_training_best()
    print_results_table(p1_results_best, "Problem 1 Results: ResNet-18 vs ViT Best", "results/run_problem1_training_best.csv")
    plot_test_accuracies(p1_results_best, "ResNet-18 vs ViT Test Accuracy Best", "results/resnet_vs_vit_best.png")

    #Problem 2: Swin Transformer
    p2_results = run_problem2_training_full()
    print_results_table(p2_results, "Problem 2 Results: Swin Transformer", "results/swin_results.csv")
    plot_test_accuracies(p2_results, "Swin Transformer Test Accuracy", "results/swin_results.png")

    print("\nHomework 6 Execution Complete")

if __name__ == "__main__":
    main()

Running experiments for Problem 1: ResNet-18 vs ViT
Training ResNet-18 baseline on CIFAR-100




Total parameters: 11,220,132
FLOPs per forward pass: 35,550,624,000


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:40<00:00, 19.45it/s, loss=2.67, acc=32.1]


Epoch 1/20 - Time: 40.20s - Test Acc: 44.21%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:40<00:00, 19.09it/s, loss=1.76, acc=50.9]


Epoch 2/20 - Time: 40.97s - Test Acc: 54.13%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.70it/s, loss=1.45, acc=58.7]


Epoch 3/20 - Time: 37.79s - Test Acc: 57.73%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.73it/s, loss=1.26, acc=63.2]


Epoch 4/20 - Time: 37.72s - Test Acc: 62.67%


Epoch 5/20: 100%|█████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.78it/s, loss=1.12, acc=67]


Epoch 5/20 - Time: 37.63s - Test Acc: 63.67%


Epoch 6/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.56it/s, loss=0.992, acc=70.5]


Epoch 6/20 - Time: 38.04s - Test Acc: 64.40%


Epoch 7/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 21.03it/s, loss=0.898, acc=72.7]


Epoch 7/20 - Time: 37.20s - Test Acc: 65.69%


Epoch 8/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.97it/s, loss=0.808, acc=75.3]


Epoch 8/20 - Time: 37.28s - Test Acc: 65.43%


Epoch 9/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.96it/s, loss=0.739, acc=77.3]


Epoch 9/20 - Time: 37.31s - Test Acc: 67.25%


Epoch 10/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.95it/s, loss=0.67, acc=79]


Epoch 10/20 - Time: 37.34s - Test Acc: 68.09%


Epoch 11/20: 100%|█████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.67it/s, loss=0.598, acc=81.2]


Epoch 11/20 - Time: 37.83s - Test Acc: 69.57%


Epoch 12/20: 100%|█████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.67it/s, loss=0.547, acc=82.7]


Epoch 12/20 - Time: 37.83s - Test Acc: 70.11%


Epoch 13/20: 100%|█████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.69it/s, loss=0.506, acc=83.8]


Epoch 13/20 - Time: 37.80s - Test Acc: 69.49%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.53it/s, loss=0.46, acc=85.3]


Epoch 14/20 - Time: 38.09s - Test Acc: 69.84%


Epoch 15/20: 100%|█████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.67it/s, loss=0.426, acc=86.3]


Epoch 15/20 - Time: 37.83s - Test Acc: 69.60%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.67it/s, loss=0.39, acc=87.4]


Epoch 16/20 - Time: 37.84s - Test Acc: 69.78%


Epoch 17/20: 100%|█████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.67it/s, loss=0.354, acc=88.4]


Epoch 17/20 - Time: 37.84s - Test Acc: 70.04%


Epoch 18/20: 100%|█████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.61it/s, loss=0.341, acc=88.8]


Epoch 18/20 - Time: 37.95s - Test Acc: 69.63%


Epoch 19/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.96it/s, loss=0.305, acc=90]


Epoch 19/20 - Time: 37.32s - Test Acc: 69.61%


Epoch 20/20: 100%|█████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.95it/s, loss=0.282, acc=90.7]


Epoch 20/20 - Time: 37.32s - Test Acc: 69.88%
Training ViT (patch_size=4, embed_dim=256, depth/transformer layers=4, heads=2, MLP Ratio=2)
Total parameters: 2,164,068
FLOPs per forward pass: 187,996,416


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.91it/s, loss=4.09, acc=7.02]


Epoch 1/20 - Time: 23.06s - Test Acc: 10.25%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 34.07it/s, loss=3.82, acc=10.4]


Epoch 2/20 - Time: 22.95s - Test Acc: 12.74%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 34.28it/s, loss=3.76, acc=11.5]


Epoch 3/20 - Time: 22.81s - Test Acc: 12.31%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.93it/s, loss=3.69, acc=12.7]


Epoch 4/20 - Time: 23.06s - Test Acc: 15.43%


Epoch 5/20: 100%|█████████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 34.18it/s, loss=3.61, acc=14]


Epoch 5/20 - Time: 22.88s - Test Acc: 15.26%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 34.11it/s, loss=3.56, acc=14.9]


Epoch 6/20 - Time: 22.92s - Test Acc: 17.26%


Epoch 7/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 34.08it/s, loss=3.5, acc=15.8]


Epoch 7/20 - Time: 22.95s - Test Acc: 17.93%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.82it/s, loss=3.46, acc=16.7]


Epoch 8/20 - Time: 23.12s - Test Acc: 18.47%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.91it/s, loss=3.44, acc=16.7]


Epoch 9/20 - Time: 23.06s - Test Acc: 19.17%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 34.05it/s, loss=3.43, acc=17.1]


Epoch 10/20 - Time: 22.97s - Test Acc: 18.50%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.83it/s, loss=3.41, acc=17.2]


Epoch 11/20 - Time: 23.12s - Test Acc: 19.11%


Epoch 12/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.98it/s, loss=3.37, acc=18]


Epoch 12/20 - Time: 23.01s - Test Acc: 19.60%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 34.00it/s, loss=3.38, acc=18.1]


Epoch 13/20 - Time: 23.00s - Test Acc: 21.20%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.99it/s, loss=3.36, acc=18.5]


Epoch 14/20 - Time: 23.00s - Test Acc: 21.07%


Epoch 15/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.93it/s, loss=3.37, acc=18]


Epoch 15/20 - Time: 23.05s - Test Acc: 20.86%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.98it/s, loss=3.32, acc=19.1]


Epoch 16/20 - Time: 23.03s - Test Acc: 22.11%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.99it/s, loss=3.32, acc=18.9]


Epoch 17/20 - Time: 23.01s - Test Acc: 22.23%


Epoch 18/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.59it/s, loss=3.3, acc=19.5]


Epoch 18/20 - Time: 23.28s - Test Acc: 20.57%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.52it/s, loss=3.31, acc=19.3]


Epoch 19/20 - Time: 23.33s - Test Acc: 21.25%


Epoch 20/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.45it/s, loss=3.32, acc=19]


Epoch 20/20 - Time: 23.37s - Test Acc: 21.47%
Training ViT (patch_size=4, embed_dim=256, depth/transformer layers=4, heads=2, MLP Ratio=4)
Total parameters: 3,214,692
FLOPs per forward pass: 255,236,352


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.28it/s, loss=4.03, acc=7.71]


Epoch 1/20 - Time: 26.71s - Test Acc: 10.70%


Epoch 2/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.47it/s, loss=3.7, acc=12.4]


Epoch 2/20 - Time: 26.55s - Test Acc: 15.42%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.44it/s, loss=3.53, acc=15.3]


Epoch 3/20 - Time: 26.56s - Test Acc: 16.93%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.43it/s, loss=3.42, acc=16.9]


Epoch 4/20 - Time: 26.59s - Test Acc: 19.01%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.54it/s, loss=3.34, acc=18.7]


Epoch 5/20 - Time: 26.47s - Test Acc: 20.07%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.57it/s, loss=3.25, acc=20.3]


Epoch 6/20 - Time: 26.45s - Test Acc: 23.40%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.28it/s, loss=3.17, acc=21.9]


Epoch 7/20 - Time: 26.71s - Test Acc: 22.83%


Epoch 8/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.41it/s, loss=3.1, acc=22.9]


Epoch 8/20 - Time: 26.60s - Test Acc: 24.91%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.39it/s, loss=3.04, acc=24.1]


Epoch 9/20 - Time: 26.61s - Test Acc: 25.68%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.59it/s, loss=2.99, acc=25.3]


Epoch 10/20 - Time: 26.43s - Test Acc: 27.39%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.42it/s, loss=2.93, acc=26.2]


Epoch 11/20 - Time: 26.58s - Test Acc: 27.61%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.67it/s, loss=2.89, acc=27.1]


Epoch 12/20 - Time: 26.38s - Test Acc: 29.16%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.62it/s, loss=2.84, acc=27.8]


Epoch 13/20 - Time: 26.40s - Test Acc: 29.81%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.62it/s, loss=2.79, acc=28.9]


Epoch 14/20 - Time: 26.40s - Test Acc: 30.26%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.43it/s, loss=2.74, acc=29.7]


Epoch 15/20 - Time: 26.59s - Test Acc: 31.69%


Epoch 16/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.49it/s, loss=2.7, acc=30.7]


Epoch 16/20 - Time: 26.51s - Test Acc: 31.77%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.37it/s, loss=2.66, acc=31.6]


Epoch 17/20 - Time: 26.63s - Test Acc: 32.62%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.45it/s, loss=2.63, acc=31.9]


Epoch 18/20 - Time: 26.55s - Test Acc: 32.53%


Epoch 19/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.46it/s, loss=2.58, acc=33]


Epoch 19/20 - Time: 26.54s - Test Acc: 32.72%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.59it/s, loss=2.55, acc=33.7]


Epoch 20/20 - Time: 26.43s - Test Acc: 32.70%
Training ViT (patch_size=4, embed_dim=256, depth/transformer layers=4, heads=4, MLP Ratio=2)
Total parameters: 2,164,068
FLOPs per forward pass: 187,996,416


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 32.80it/s, loss=3.98, acc=8.26]


Epoch 1/20 - Time: 23.84s - Test Acc: 10.63%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.11it/s, loss=3.63, acc=13.4]


Epoch 2/20 - Time: 23.62s - Test Acc: 16.47%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.14it/s, loss=3.45, acc=16.8]


Epoch 3/20 - Time: 23.62s - Test Acc: 19.04%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.21it/s, loss=3.34, acc=18.8]


Epoch 4/20 - Time: 23.55s - Test Acc: 21.51%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 32.93it/s, loss=3.25, acc=20.3]


Epoch 5/20 - Time: 23.75s - Test Acc: 23.53%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.07it/s, loss=3.17, acc=21.9]


Epoch 6/20 - Time: 23.65s - Test Acc: 24.65%


Epoch 7/20: 100%|█████████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.04it/s, loss=3.11, acc=23]


Epoch 7/20 - Time: 23.67s - Test Acc: 25.19%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.10it/s, loss=3.06, acc=23.7]


Epoch 8/20 - Time: 23.63s - Test Acc: 25.63%


Epoch 9/20: 100%|████████████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.20it/s, loss=3, acc=25]


Epoch 9/20 - Time: 23.59s - Test Acc: 26.69%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 32.95it/s, loss=2.95, acc=25.9]


Epoch 10/20 - Time: 23.73s - Test Acc: 28.84%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.06it/s, loss=2.88, acc=27.5]


Epoch 11/20 - Time: 23.65s - Test Acc: 29.27%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.17it/s, loss=2.84, acc=28.2]


Epoch 12/20 - Time: 23.58s - Test Acc: 30.58%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.20it/s, loss=2.78, acc=29.4]


Epoch 13/20 - Time: 23.57s - Test Acc: 31.39%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 32.86it/s, loss=2.74, acc=30.4]


Epoch 14/20 - Time: 23.79s - Test Acc: 31.29%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.20it/s, loss=2.69, acc=31.1]


Epoch 15/20 - Time: 23.55s - Test Acc: 33.17%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 32.91it/s, loss=2.65, acc=31.8]


Epoch 16/20 - Time: 23.78s - Test Acc: 33.87%


Epoch 17/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.11it/s, loss=2.6, acc=32.9]


Epoch 17/20 - Time: 23.62s - Test Acc: 34.82%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.07it/s, loss=2.56, acc=33.7]


Epoch 18/20 - Time: 23.66s - Test Acc: 34.21%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.15it/s, loss=2.51, acc=34.5]


Epoch 19/20 - Time: 23.59s - Test Acc: 36.15%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.15it/s, loss=2.47, acc=35.5]


Epoch 20/20 - Time: 23.59s - Test Acc: 36.00%
Training ViT (patch_size=4, embed_dim=256, depth/transformer layers=4, heads=4, MLP Ratio=4)
Total parameters: 3,214,692
FLOPs per forward pass: 255,236,352


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.86it/s, loss=4.04, acc=7.34]


Epoch 1/20 - Time: 27.11s - Test Acc: 9.94%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.77it/s, loss=3.79, acc=10.9]


Epoch 2/20 - Time: 27.18s - Test Acc: 12.07%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.77it/s, loss=3.67, acc=12.6]


Epoch 3/20 - Time: 27.18s - Test Acc: 13.68%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.93it/s, loss=3.61, acc=13.9]


Epoch 4/20 - Time: 27.03s - Test Acc: 16.07%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.65it/s, loss=3.52, acc=15.5]


Epoch 5/20 - Time: 27.29s - Test Acc: 16.27%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.73it/s, loss=3.48, acc=16.3]


Epoch 6/20 - Time: 27.21s - Test Acc: 18.56%


Epoch 7/20: 100%|█████████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.68it/s, loss=3.44, acc=17]


Epoch 7/20 - Time: 27.27s - Test Acc: 19.50%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.57it/s, loss=3.39, acc=17.7]


Epoch 8/20 - Time: 27.37s - Test Acc: 17.15%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.49it/s, loss=3.43, acc=17.3]


Epoch 9/20 - Time: 27.45s - Test Acc: 19.14%


Epoch 10/20: 100%|█████████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.81it/s, loss=3.4, acc=18]


Epoch 10/20 - Time: 27.14s - Test Acc: 20.55%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.61it/s, loss=3.35, acc=18.6]


Epoch 11/20 - Time: 27.33s - Test Acc: 20.77%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.70it/s, loss=3.33, acc=18.9]


Epoch 12/20 - Time: 27.25s - Test Acc: 21.51%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.80it/s, loss=3.31, acc=19.3]


Epoch 13/20 - Time: 27.15s - Test Acc: 20.11%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.82it/s, loss=3.31, acc=19.3]


Epoch 14/20 - Time: 27.14s - Test Acc: 21.33%


Epoch 15/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.69it/s, loss=3.3, acc=19.6]


Epoch 15/20 - Time: 27.27s - Test Acc: 21.43%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.84it/s, loss=3.26, acc=20.2]


Epoch 16/20 - Time: 27.12s - Test Acc: 22.40%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.71it/s, loss=3.25, acc=20.4]


Epoch 17/20 - Time: 27.24s - Test Acc: 21.17%


Epoch 18/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.82it/s, loss=3.22, acc=21]


Epoch 18/20 - Time: 27.13s - Test Acc: 23.07%


Epoch 19/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.71it/s, loss=3.2, acc=21.1]


Epoch 19/20 - Time: 27.23s - Test Acc: 23.49%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.77it/s, loss=3.25, acc=20.5]


Epoch 20/20 - Time: 27.20s - Test Acc: 23.05%
Training ViT (patch_size=4, embed_dim=256, depth/transformer layers=8, heads=2, MLP Ratio=2)
Total parameters: 4,272,484
FLOPs per forward pass: 322,935,040


Epoch 1/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.93it/s, loss=4.2, acc=5.44]


Epoch 1/20 - Time: 37.39s - Test Acc: 5.90%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.95it/s, loss=4.17, acc=5.72]


Epoch 2/20 - Time: 37.33s - Test Acc: 5.30%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.90it/s, loss=4.23, acc=5.28]


Epoch 3/20 - Time: 37.41s - Test Acc: 4.64%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 21.02it/s, loss=4.24, acc=5.08]


Epoch 4/20 - Time: 37.21s - Test Acc: 4.17%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.92it/s, loss=4.22, acc=5.25]


Epoch 5/20 - Time: 37.39s - Test Acc: 5.84%


Epoch 6/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.90it/s, loss=4.31, acc=4.2]


Epoch 6/20 - Time: 37.42s - Test Acc: 4.50%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.95it/s, loss=4.36, acc=3.93]


Epoch 7/20 - Time: 37.32s - Test Acc: 4.17%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.91it/s, loss=4.37, acc=3.62]


Epoch 8/20 - Time: 37.41s - Test Acc: 3.43%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.95it/s, loss=4.42, acc=3.22]


Epoch 9/20 - Time: 37.33s - Test Acc: 3.52%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.95it/s, loss=4.44, acc=3.09]


Epoch 10/20 - Time: 37.33s - Test Acc: 3.38%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.91it/s, loss=4.42, acc=3.16]


Epoch 11/20 - Time: 37.39s - Test Acc: 3.31%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.95it/s, loss=4.45, acc=2.85]


Epoch 12/20 - Time: 37.33s - Test Acc: 3.60%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 21.01it/s, loss=4.43, acc=3.15]


Epoch 13/20 - Time: 37.21s - Test Acc: 3.64%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.93it/s, loss=4.44, acc=2.99]


Epoch 14/20 - Time: 37.36s - Test Acc: 2.77%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.99it/s, loss=4.44, acc=3.12]


Epoch 15/20 - Time: 37.27s - Test Acc: 3.43%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.96it/s, loss=4.43, acc=3.29]


Epoch 16/20 - Time: 37.31s - Test Acc: 3.44%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.96it/s, loss=4.43, acc=3.02]


Epoch 17/20 - Time: 37.32s - Test Acc: 3.45%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 21.04it/s, loss=4.41, acc=3.17]


Epoch 18/20 - Time: 37.18s - Test Acc: 3.74%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 21.01it/s, loss=4.41, acc=3.29]


Epoch 19/20 - Time: 37.22s - Test Acc: 3.60%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 21.06it/s, loss=4.41, acc=3.24]


Epoch 20/20 - Time: 37.14s - Test Acc: 2.68%
Training ViT (patch_size=4, embed_dim=256, depth/transformer layers=8, heads=2, MLP Ratio=4)
Total parameters: 6,373,732
FLOPs per forward pass: 457,414,912


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.61it/s, loss=4.13, acc=6.12]


Epoch 1/20 - Time: 44.43s - Test Acc: 8.64%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.63it/s, loss=3.91, acc=8.99]


Epoch 2/20 - Time: 44.35s - Test Acc: 8.24%


Epoch 3/20: 100%|█████████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.63it/s, loss=3.9, acc=9.4]


Epoch 3/20 - Time: 44.37s - Test Acc: 8.48%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.63it/s, loss=4.03, acc=7.57]


Epoch 4/20 - Time: 44.35s - Test Acc: 8.35%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.64it/s, loss=4.25, acc=5.01]


Epoch 5/20 - Time: 44.33s - Test Acc: 5.15%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.62it/s, loss=4.32, acc=4.22]


Epoch 6/20 - Time: 44.41s - Test Acc: 4.99%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.64it/s, loss=4.27, acc=4.77]


Epoch 7/20 - Time: 44.34s - Test Acc: 5.56%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.68it/s, loss=4.27, acc=4.83]


Epoch 8/20 - Time: 44.23s - Test Acc: 4.15%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.68it/s, loss=4.28, acc=4.76]


Epoch 9/20 - Time: 44.23s - Test Acc: 4.59%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.65it/s, loss=4.34, acc=4.08]


Epoch 10/20 - Time: 44.31s - Test Acc: 5.76%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.60it/s, loss=4.29, acc=4.96]


Epoch 11/20 - Time: 44.43s - Test Acc: 5.77%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.64it/s, loss=4.31, acc=4.46]


Epoch 12/20 - Time: 44.34s - Test Acc: 4.91%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.63it/s, loss=4.32, acc=4.46]


Epoch 13/20 - Time: 44.37s - Test Acc: 5.28%


Epoch 14/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.61it/s, loss=4.3, acc=4.53]


Epoch 14/20 - Time: 44.40s - Test Acc: 4.43%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.61it/s, loss=4.32, acc=4.18]


Epoch 15/20 - Time: 44.41s - Test Acc: 4.37%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.63it/s, loss=4.34, acc=4.36]


Epoch 16/20 - Time: 44.35s - Test Acc: 3.58%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.59it/s, loss=4.32, acc=4.29]


Epoch 17/20 - Time: 44.45s - Test Acc: 5.19%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:45<00:00, 17.29it/s, loss=4.28, acc=4.79]


Epoch 18/20 - Time: 45.23s - Test Acc: 5.32%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.55it/s, loss=4.27, acc=4.84]


Epoch 19/20 - Time: 44.57s - Test Acc: 5.64%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.65it/s, loss=4.31, acc=4.35]


Epoch 20/20 - Time: 44.31s - Test Acc: 4.39%
Training ViT (patch_size=4, embed_dim=256, depth/transformer layers=8, heads=4, MLP Ratio=2)
Total parameters: 4,272,484
FLOPs per forward pass: 322,935,040


Epoch 1/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.19it/s, loss=4.1, acc=6.82]


Epoch 1/20 - Time: 38.74s - Test Acc: 8.62%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.26it/s, loss=3.89, acc=9.46]


Epoch 2/20 - Time: 38.60s - Test Acc: 9.08%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.15it/s, loss=3.91, acc=9.05]


Epoch 3/20 - Time: 38.82s - Test Acc: 9.78%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.08it/s, loss=3.92, acc=9.07]


Epoch 4/20 - Time: 38.95s - Test Acc: 9.18%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.14it/s, loss=4.03, acc=7.62]


Epoch 5/20 - Time: 38.83s - Test Acc: 8.27%


Epoch 6/20: 100%|██████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.10it/s, loss=4, acc=7.86]


Epoch 6/20 - Time: 38.91s - Test Acc: 8.24%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.14it/s, loss=4.04, acc=7.55]


Epoch 7/20 - Time: 38.85s - Test Acc: 8.50%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.15it/s, loss=4.12, acc=6.52]


Epoch 8/20 - Time: 38.80s - Test Acc: 7.25%


Epoch 9/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.17it/s, loss=4.18, acc=5.8]


Epoch 9/20 - Time: 38.77s - Test Acc: 6.81%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.11it/s, loss=4.19, acc=5.86]


Epoch 10/20 - Time: 38.88s - Test Acc: 6.17%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.07it/s, loss=4.17, acc=6.01]


Epoch 11/20 - Time: 38.97s - Test Acc: 6.76%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.16it/s, loss=4.16, acc=6.11]


Epoch 12/20 - Time: 38.79s - Test Acc: 6.69%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.15it/s, loss=4.21, acc=5.46]


Epoch 13/20 - Time: 38.81s - Test Acc: 5.15%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.16it/s, loss=4.24, acc=5.15]


Epoch 14/20 - Time: 38.80s - Test Acc: 6.10%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.14it/s, loss=4.24, acc=4.95]


Epoch 15/20 - Time: 38.83s - Test Acc: 5.76%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.09it/s, loss=4.21, acc=5.25]


Epoch 16/20 - Time: 38.93s - Test Acc: 5.37%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.09it/s, loss=4.26, acc=4.83]


Epoch 17/20 - Time: 38.92s - Test Acc: 5.34%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.16it/s, loss=4.24, acc=5.12]


Epoch 18/20 - Time: 38.79s - Test Acc: 6.20%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.06it/s, loss=4.23, acc=5.29]


Epoch 19/20 - Time: 39.01s - Test Acc: 5.72%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.21it/s, loss=4.26, acc=4.93]


Epoch 20/20 - Time: 38.69s - Test Acc: 5.11%
Training ViT (patch_size=4, embed_dim=256, depth/transformer layers=8, heads=4, MLP Ratio=4)
Total parameters: 6,373,732
FLOPs per forward pass: 457,414,912


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:46<00:00, 16.98it/s, loss=4.14, acc=6.21]


Epoch 1/20 - Time: 46.05s - Test Acc: 8.57%


Epoch 2/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:46<00:00, 16.95it/s, loss=3.9, acc=9.15]


Epoch 2/20 - Time: 46.12s - Test Acc: 10.47%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:46<00:00, 16.92it/s, loss=3.95, acc=8.83]


Epoch 3/20 - Time: 46.21s - Test Acc: 9.78%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:45<00:00, 17.00it/s, loss=3.83, acc=10.4]


Epoch 4/20 - Time: 46.00s - Test Acc: 9.69%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:46<00:00, 16.98it/s, loss=3.85, acc=10.2]


Epoch 5/20 - Time: 46.07s - Test Acc: 10.74%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:46<00:00, 16.97it/s, loss=3.84, acc=10.5]


Epoch 6/20 - Time: 46.09s - Test Acc: 11.08%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:46<00:00, 16.99it/s, loss=3.84, acc=10.4]


Epoch 7/20 - Time: 46.01s - Test Acc: 11.66%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:46<00:00, 16.98it/s, loss=3.78, acc=11.4]


Epoch 8/20 - Time: 46.05s - Test Acc: 11.12%


Epoch 9/20: 100%|█████████████████████████████████████████████████████████████| 782/782 [00:46<00:00, 16.96it/s, loss=3.81, acc=11]


Epoch 9/20 - Time: 46.11s - Test Acc: 11.46%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:46<00:00, 16.83it/s, loss=3.82, acc=10.8]


Epoch 10/20 - Time: 46.48s - Test Acc: 10.72%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:46<00:00, 16.79it/s, loss=3.84, acc=10.3]


Epoch 11/20 - Time: 46.59s - Test Acc: 11.57%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:46<00:00, 16.97it/s, loss=3.83, acc=10.6]


Epoch 12/20 - Time: 46.07s - Test Acc: 11.06%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:46<00:00, 16.97it/s, loss=3.84, acc=10.3]


Epoch 13/20 - Time: 46.08s - Test Acc: 11.44%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:46<00:00, 16.92it/s, loss=3.81, acc=11.2]


Epoch 14/20 - Time: 46.22s - Test Acc: 12.00%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:46<00:00, 16.97it/s, loss=3.81, acc=10.8]


Epoch 15/20 - Time: 46.09s - Test Acc: 11.37%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:46<00:00, 16.89it/s, loss=3.94, acc=9.33]


Epoch 16/20 - Time: 46.29s - Test Acc: 9.69%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:46<00:00, 16.93it/s, loss=3.93, acc=9.16]


Epoch 17/20 - Time: 46.19s - Test Acc: 8.30%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:46<00:00, 16.91it/s, loss=3.94, acc=8.97]


Epoch 18/20 - Time: 46.24s - Test Acc: 9.21%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:46<00:00, 16.90it/s, loss=3.96, acc=8.94]


Epoch 19/20 - Time: 46.27s - Test Acc: 9.58%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:46<00:00, 16.97it/s, loss=3.96, acc=8.94]


Epoch 20/20 - Time: 46.09s - Test Acc: 9.86%
Training ViT (patch_size=4, embed_dim=512, depth/transformer layers=4, heads=2, MLP Ratio=2)
Total parameters: 8,522,340
FLOPs per forward pass: 644,421,888


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.06it/s, loss=4.41, acc=3.71]


Epoch 1/20 - Time: 43.31s - Test Acc: 4.71%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.07it/s, loss=4.38, acc=3.81]


Epoch 2/20 - Time: 43.28s - Test Acc: 4.42%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.09it/s, loss=4.31, acc=4.39]


Epoch 3/20 - Time: 43.23s - Test Acc: 4.67%


Epoch 4/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.10it/s, loss=4.3, acc=4.68]


Epoch 4/20 - Time: 43.21s - Test Acc: 4.01%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.10it/s, loss=4.31, acc=4.49]


Epoch 5/20 - Time: 43.21s - Test Acc: 3.33%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.11it/s, loss=4.31, acc=4.52]


Epoch 6/20 - Time: 43.19s - Test Acc: 5.22%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.14it/s, loss=4.28, acc=4.92]


Epoch 7/20 - Time: 43.12s - Test Acc: 4.80%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.13it/s, loss=4.29, acc=4.84]


Epoch 8/20 - Time: 43.13s - Test Acc: 5.00%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.13it/s, loss=4.36, acc=3.98]


Epoch 9/20 - Time: 43.13s - Test Acc: 3.72%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.12it/s, loss=4.38, acc=3.84]


Epoch 10/20 - Time: 43.17s - Test Acc: 4.44%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.14it/s, loss=4.32, acc=4.31]


Epoch 11/20 - Time: 43.12s - Test Acc: 4.35%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.14it/s, loss=4.32, acc=4.41]


Epoch 12/20 - Time: 43.10s - Test Acc: 5.03%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.16it/s, loss=4.34, acc=4.29]


Epoch 13/20 - Time: 43.07s - Test Acc: 4.55%


Epoch 14/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.16it/s, loss=4.3, acc=4.54]


Epoch 14/20 - Time: 43.05s - Test Acc: 4.93%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.14it/s, loss=4.27, acc=4.89]


Epoch 15/20 - Time: 43.12s - Test Acc: 4.86%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.18it/s, loss=4.27, acc=5.19]


Epoch 16/20 - Time: 43.02s - Test Acc: 5.51%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.16it/s, loss=4.24, acc=5.37]


Epoch 17/20 - Time: 43.08s - Test Acc: 5.74%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.12it/s, loss=4.23, acc=5.37]


Epoch 18/20 - Time: 43.15s - Test Acc: 5.20%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.13it/s, loss=4.24, acc=5.27]


Epoch 19/20 - Time: 43.14s - Test Acc: 5.47%


Epoch 20/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.14it/s, loss=4.26, acc=5.2]


Epoch 20/20 - Time: 43.10s - Test Acc: 5.47%
Training ViT (patch_size=4, embed_dim=512, depth/transformer layers=4, heads=2, MLP Ratio=4)
Total parameters: 12,720,740
FLOPs per forward pass: 913,119,488


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.04it/s, loss=4.37, acc=3.94]


Epoch 1/20 - Time: 55.69s - Test Acc: 3.93%


Epoch 2/20: 100%|██████████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.05it/s, loss=4.35, acc=4]


Epoch 2/20 - Time: 55.66s - Test Acc: 4.56%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.04it/s, loss=4.31, acc=4.45]


Epoch 3/20 - Time: 55.70s - Test Acc: 4.67%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.08it/s, loss=4.34, acc=4.19]


Epoch 4/20 - Time: 55.53s - Test Acc: 4.26%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.08it/s, loss=4.35, acc=4.26]


Epoch 5/20 - Time: 55.55s - Test Acc: 4.94%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.10it/s, loss=4.34, acc=4.24]


Epoch 6/20 - Time: 55.48s - Test Acc: 4.52%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.14it/s, loss=4.35, acc=4.03]


Epoch 7/20 - Time: 55.32s - Test Acc: 4.83%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.11it/s, loss=4.29, acc=4.72]


Epoch 8/20 - Time: 55.41s - Test Acc: 4.98%


Epoch 9/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.13it/s, loss=4.3, acc=4.58]


Epoch 9/20 - Time: 55.35s - Test Acc: 4.73%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.12it/s, loss=4.31, acc=4.61]


Epoch 10/20 - Time: 55.38s - Test Acc: 4.78%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.13it/s, loss=4.31, acc=4.78]


Epoch 11/20 - Time: 55.36s - Test Acc: 5.01%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.12it/s, loss=4.29, acc=5.04]


Epoch 12/20 - Time: 55.37s - Test Acc: 5.86%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.13it/s, loss=4.25, acc=5.34]


Epoch 13/20 - Time: 55.35s - Test Acc: 6.20%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.12it/s, loss=4.24, acc=5.34]


Epoch 14/20 - Time: 55.38s - Test Acc: 5.83%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.13it/s, loss=4.27, acc=5.28]


Epoch 15/20 - Time: 55.35s - Test Acc: 5.34%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.12it/s, loss=4.31, acc=4.68]


Epoch 16/20 - Time: 55.37s - Test Acc: 4.72%


Epoch 17/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.09it/s, loss=4.3, acc=4.72]


Epoch 17/20 - Time: 55.49s - Test Acc: 5.45%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.12it/s, loss=4.27, acc=5.06]


Epoch 18/20 - Time: 55.38s - Test Acc: 5.00%


Epoch 19/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.11it/s, loss=4.3, acc=4.74]


Epoch 19/20 - Time: 55.40s - Test Acc: 5.06%


Epoch 20/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.11it/s, loss=4.26, acc=5.1]


Epoch 20/20 - Time: 55.43s - Test Acc: 5.44%
Training ViT (patch_size=4, embed_dim=512, depth/transformer layers=4, heads=4, MLP Ratio=2)
Total parameters: 8,522,340
FLOPs per forward pass: 644,421,888


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 17.84it/s, loss=4.29, acc=4.59]


Epoch 1/20 - Time: 43.82s - Test Acc: 5.05%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 17.94it/s, loss=4.28, acc=4.81]


Epoch 2/20 - Time: 43.60s - Test Acc: 5.96%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 17.93it/s, loss=4.23, acc=5.18]


Epoch 3/20 - Time: 43.63s - Test Acc: 6.26%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 17.99it/s, loss=4.17, acc=5.91]


Epoch 4/20 - Time: 43.47s - Test Acc: 6.61%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 17.96it/s, loss=4.22, acc=5.23]


Epoch 5/20 - Time: 43.55s - Test Acc: 5.80%


Epoch 6/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 17.96it/s, loss=4.22, acc=5.3]


Epoch 6/20 - Time: 43.54s - Test Acc: 5.27%


Epoch 7/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 17.98it/s, loss=4.28, acc=4.5]


Epoch 7/20 - Time: 43.49s - Test Acc: 5.38%


Epoch 8/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.02it/s, loss=4.3, acc=4.55]


Epoch 8/20 - Time: 43.39s - Test Acc: 5.21%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.01it/s, loss=4.27, acc=4.86]


Epoch 9/20 - Time: 43.43s - Test Acc: 5.06%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.00it/s, loss=4.27, acc=4.79]


Epoch 10/20 - Time: 43.44s - Test Acc: 5.34%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.00it/s, loss=4.25, acc=4.97]


Epoch 11/20 - Time: 43.45s - Test Acc: 5.47%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.01it/s, loss=4.26, acc=5.01]


Epoch 12/20 - Time: 43.42s - Test Acc: 5.67%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.00it/s, loss=4.22, acc=5.47]


Epoch 13/20 - Time: 43.44s - Test Acc: 5.86%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.01it/s, loss=4.21, acc=5.54]


Epoch 14/20 - Time: 43.43s - Test Acc: 6.00%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.02it/s, loss=4.24, acc=5.26]


Epoch 15/20 - Time: 43.42s - Test Acc: 5.25%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 17.99it/s, loss=4.24, acc=5.23]


Epoch 16/20 - Time: 43.47s - Test Acc: 5.57%


Epoch 17/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.01it/s, loss=4.23, acc=5.3]


Epoch 17/20 - Time: 43.43s - Test Acc: 5.84%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.04it/s, loss=4.24, acc=5.34]


Epoch 18/20 - Time: 43.35s - Test Acc: 5.31%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 17.98it/s, loss=4.25, acc=5.17]


Epoch 19/20 - Time: 43.49s - Test Acc: 5.84%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 17.99it/s, loss=4.24, acc=5.41]


Epoch 20/20 - Time: 43.48s - Test Acc: 4.98%
Training ViT (patch_size=4, embed_dim=512, depth/transformer layers=4, heads=4, MLP Ratio=4)
Total parameters: 12,720,740
FLOPs per forward pass: 913,119,488


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.03it/s, loss=4.27, acc=4.79]


Epoch 1/20 - Time: 55.73s - Test Acc: 4.94%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.00it/s, loss=4.29, acc=4.39]


Epoch 2/20 - Time: 55.87s - Test Acc: 4.69%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.02it/s, loss=4.32, acc=4.19]


Epoch 3/20 - Time: 55.79s - Test Acc: 3.71%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.01it/s, loss=4.29, acc=4.54]


Epoch 4/20 - Time: 55.81s - Test Acc: 5.03%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.07it/s, loss=4.25, acc=5.12]


Epoch 5/20 - Time: 55.56s - Test Acc: 4.99%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.08it/s, loss=4.25, acc=5.27]


Epoch 6/20 - Time: 55.53s - Test Acc: 5.32%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.09it/s, loss=4.26, acc=4.91]


Epoch 7/20 - Time: 55.50s - Test Acc: 5.66%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.07it/s, loss=4.26, acc=4.89]


Epoch 8/20 - Time: 55.59s - Test Acc: 4.69%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.07it/s, loss=4.32, acc=4.26]


Epoch 9/20 - Time: 55.58s - Test Acc: 4.17%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.09it/s, loss=4.29, acc=4.55]


Epoch 10/20 - Time: 55.49s - Test Acc: 5.53%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.11it/s, loss=4.27, acc=4.89]


Epoch 11/20 - Time: 55.44s - Test Acc: 4.80%


Epoch 12/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.09it/s, loss=4.3, acc=4.46]


Epoch 12/20 - Time: 55.50s - Test Acc: 5.36%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.11it/s, loss=4.26, acc=5.11]


Epoch 13/20 - Time: 55.44s - Test Acc: 5.44%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.12it/s, loss=4.25, acc=5.06]


Epoch 14/20 - Time: 55.39s - Test Acc: 5.56%


Epoch 15/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.08it/s, loss=4.29, acc=4.6]


Epoch 15/20 - Time: 55.54s - Test Acc: 4.69%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.11it/s, loss=4.28, acc=4.67]


Epoch 16/20 - Time: 55.44s - Test Acc: 5.19%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.13it/s, loss=4.28, acc=4.59]


Epoch 17/20 - Time: 55.34s - Test Acc: 5.44%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.12it/s, loss=4.32, acc=4.26]


Epoch 18/20 - Time: 55.37s - Test Acc: 3.67%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.11it/s, loss=4.32, acc=4.28]


Epoch 19/20 - Time: 55.42s - Test Acc: 4.59%


Epoch 20/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:55<00:00, 14.13it/s, loss=4.3, acc=4.46]


Epoch 20/20 - Time: 55.33s - Test Acc: 4.83%
Training ViT (patch_size=4, embed_dim=512, depth/transformer layers=8, heads=2, MLP Ratio=2)
Total parameters: 16,933,476
FLOPs per forward pass: 1,182,734,592


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:18<00:00,  9.94it/s, loss=4.48, acc=3.01]


Epoch 1/20 - Time: 78.64s - Test Acc: 2.70%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:18<00:00,  9.95it/s, loss=4.46, acc=2.81]


Epoch 2/20 - Time: 78.58s - Test Acc: 3.20%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:18<00:00,  9.96it/s, loss=4.47, acc=2.69]


Epoch 3/20 - Time: 78.50s - Test Acc: 2.81%


Epoch 4/20: 100%|████████████████████████████████████████████████████████████| 782/782 [01:18<00:00,  9.97it/s, loss=4.48, acc=2.6]


Epoch 4/20 - Time: 78.41s - Test Acc: 2.94%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:18<00:00,  9.98it/s, loss=4.48, acc=2.54]


Epoch 5/20 - Time: 78.33s - Test Acc: 1.99%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.01it/s, loss=4.51, acc=2.13]


Epoch 6/20 - Time: 78.13s - Test Acc: 2.26%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.01it/s, loss=4.49, acc=2.31]


Epoch 7/20 - Time: 78.10s - Test Acc: 2.36%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.01it/s, loss=4.47, acc=2.51]


Epoch 8/20 - Time: 78.11s - Test Acc: 2.70%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.01it/s, loss=4.47, acc=2.52]


Epoch 9/20 - Time: 78.13s - Test Acc: 2.74%


Epoch 10/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.02it/s, loss=4.46, acc=2.6]


Epoch 10/20 - Time: 78.10s - Test Acc: 2.44%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.00it/s, loss=4.46, acc=2.71]


Epoch 11/20 - Time: 78.23s - Test Acc: 2.47%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.02it/s, loss=4.44, acc=2.64]


Epoch 12/20 - Time: 78.07s - Test Acc: 2.27%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.00it/s, loss=4.46, acc=2.53]


Epoch 13/20 - Time: 78.19s - Test Acc: 2.58%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.01it/s, loss=4.53, acc=2.19]


Epoch 14/20 - Time: 78.08s - Test Acc: 1.73%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:17<00:00, 10.05it/s, loss=4.55, acc=1.99]


Epoch 15/20 - Time: 77.82s - Test Acc: 1.82%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.00it/s, loss=4.54, acc=2.13]


Epoch 16/20 - Time: 78.20s - Test Acc: 1.84%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:17<00:00, 10.03it/s, loss=4.55, acc=2.15]


Epoch 17/20 - Time: 77.96s - Test Acc: 1.57%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:17<00:00, 10.03it/s, loss=4.56, acc=2.02]


Epoch 18/20 - Time: 77.99s - Test Acc: 1.58%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.02it/s, loss=4.57, acc=1.99]


Epoch 19/20 - Time: 78.05s - Test Acc: 1.67%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.01it/s, loss=4.55, acc=2.02]


Epoch 20/20 - Time: 78.12s - Test Acc: 1.99%
Training ViT (patch_size=4, embed_dim=512, depth/transformer layers=8, heads=2, MLP Ratio=4)
Total parameters: 25,330,276
FLOPs per forward pass: 1,720,129,792


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.55it/s, loss=4.53, acc=2.25]


Epoch 1/20 - Time: 103.57s - Test Acc: 2.80%


Epoch 2/20: 100%|█████████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.56it/s, loss=4.5, acc=2.2]


Epoch 2/20 - Time: 103.44s - Test Acc: 2.40%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.57it/s, loss=4.51, acc=2.04]


Epoch 3/20 - Time: 103.37s - Test Acc: 2.05%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.57it/s, loss=4.47, acc=2.52]


Epoch 4/20 - Time: 103.25s - Test Acc: 2.54%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.58it/s, loss=4.46, acc=2.42]


Epoch 5/20 - Time: 103.14s - Test Acc: 2.26%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.59it/s, loss=4.46, acc=2.55]


Epoch 6/20 - Time: 103.03s - Test Acc: 2.21%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.59it/s, loss=4.51, acc=2.13]


Epoch 7/20 - Time: 103.10s - Test Acc: 2.14%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.59it/s, loss=4.52, acc=2.15]


Epoch 8/20 - Time: 103.07s - Test Acc: 2.18%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:42<00:00,  7.59it/s, loss=4.52, acc=2.06]


Epoch 9/20 - Time: 102.98s - Test Acc: 2.50%


Epoch 10/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:42<00:00,  7.60it/s, loss=4.52, acc=2.1]


Epoch 10/20 - Time: 102.96s - Test Acc: 2.83%


Epoch 11/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:42<00:00,  7.60it/s, loss=4.5, acc=2.38]


Epoch 11/20 - Time: 102.85s - Test Acc: 2.45%


Epoch 12/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:42<00:00,  7.60it/s, loss=4.48, acc=2.4]


Epoch 12/20 - Time: 102.88s - Test Acc: 2.81%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:42<00:00,  7.60it/s, loss=4.48, acc=2.46]


Epoch 13/20 - Time: 102.86s - Test Acc: 3.01%


Epoch 14/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:42<00:00,  7.60it/s, loss=4.5, acc=2.25]


Epoch 14/20 - Time: 102.92s - Test Acc: 2.62%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:42<00:00,  7.61it/s, loss=4.49, acc=2.42]


Epoch 15/20 - Time: 102.81s - Test Acc: 3.06%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:42<00:00,  7.61it/s, loss=4.48, acc=2.43]


Epoch 16/20 - Time: 102.73s - Test Acc: 3.04%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:42<00:00,  7.62it/s, loss=4.49, acc=2.26]


Epoch 17/20 - Time: 102.60s - Test Acc: 2.87%


Epoch 18/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:42<00:00,  7.61it/s, loss=4.48, acc=2.2]


Epoch 18/20 - Time: 102.80s - Test Acc: 2.72%


Epoch 19/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:42<00:00,  7.62it/s, loss=4.5, acc=2.18]


Epoch 19/20 - Time: 102.66s - Test Acc: 2.30%


Epoch 20/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:42<00:00,  7.61it/s, loss=4.5, acc=2.12]


Epoch 20/20 - Time: 102.82s - Test Acc: 2.81%
Training ViT (patch_size=4, embed_dim=512, depth/transformer layers=8, heads=4, MLP Ratio=2)
Total parameters: 16,933,476
FLOPs per forward pass: 1,182,734,592


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:19<00:00,  9.89it/s, loss=4.44, acc=3.07]


Epoch 1/20 - Time: 79.08s - Test Acc: 4.58%


Epoch 2/20: 100%|████████████████████████████████████████████████████████████| 782/782 [01:18<00:00,  9.93it/s, loss=4.4, acc=3.23]


Epoch 2/20 - Time: 78.73s - Test Acc: 3.38%


Epoch 3/20: 100%|████████████████████████████████████████████████████████████| 782/782 [01:18<00:00,  9.94it/s, loss=4.4, acc=3.11]


Epoch 3/20 - Time: 78.66s - Test Acc: 3.04%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:18<00:00,  9.94it/s, loss=4.42, acc=3.07]


Epoch 4/20 - Time: 78.67s - Test Acc: 2.72%


Epoch 5/20: 100%|████████████████████████████████████████████████████████████| 782/782 [01:18<00:00,  9.96it/s, loss=4.41, acc=3.1]


Epoch 5/20 - Time: 78.51s - Test Acc: 3.47%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:18<00:00,  9.98it/s, loss=4.42, acc=3.01]


Epoch 6/20 - Time: 78.37s - Test Acc: 2.69%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:18<00:00,  9.98it/s, loss=4.42, acc=2.93]


Epoch 7/20 - Time: 78.37s - Test Acc: 3.21%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:18<00:00,  9.99it/s, loss=4.44, acc=2.75]


Epoch 8/20 - Time: 78.26s - Test Acc: 2.58%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:18<00:00,  9.99it/s, loss=4.44, acc=2.69]


Epoch 9/20 - Time: 78.25s - Test Acc: 3.17%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.01it/s, loss=4.44, acc=2.81]


Epoch 10/20 - Time: 78.13s - Test Acc: 2.66%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.00it/s, loss=4.44, acc=2.84]


Epoch 11/20 - Time: 78.24s - Test Acc: 2.79%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.01it/s, loss=4.43, acc=3.01]


Epoch 12/20 - Time: 78.15s - Test Acc: 3.15%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.00it/s, loss=4.46, acc=2.63]


Epoch 13/20 - Time: 78.18s - Test Acc: 2.91%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.01it/s, loss=4.43, acc=3.01]


Epoch 14/20 - Time: 78.11s - Test Acc: 3.03%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.02it/s, loss=4.43, acc=3.02]


Epoch 15/20 - Time: 78.08s - Test Acc: 2.98%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.01it/s, loss=4.43, acc=3.05]


Epoch 16/20 - Time: 78.09s - Test Acc: 2.78%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.01it/s, loss=4.43, acc=3.16]


Epoch 17/20 - Time: 78.12s - Test Acc: 2.89%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.02it/s, loss=4.42, acc=3.04]


Epoch 18/20 - Time: 78.05s - Test Acc: 3.20%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.02it/s, loss=4.42, acc=3.39]


Epoch 19/20 - Time: 78.05s - Test Acc: 3.43%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:18<00:00, 10.02it/s, loss=4.41, acc=3.37]


Epoch 20/20 - Time: 78.05s - Test Acc: 3.37%
Training ViT (patch_size=4, embed_dim=512, depth/transformer layers=8, heads=4, MLP Ratio=4)
Total parameters: 25,330,276
FLOPs per forward pass: 1,720,129,792


Epoch 1/20: 100%|████████████████████████████████████████████████████████████| 782/782 [01:44<00:00,  7.52it/s, loss=4.4, acc=3.35]


Epoch 1/20 - Time: 104.05s - Test Acc: 4.10%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:44<00:00,  7.51it/s, loss=4.43, acc=2.89]


Epoch 2/20 - Time: 104.08s - Test Acc: 4.12%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.52it/s, loss=4.35, acc=3.46]


Epoch 3/20 - Time: 103.96s - Test Acc: 4.13%


Epoch 4/20: 100%|█████████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.52it/s, loss=4.4, acc=3.3]


Epoch 4/20 - Time: 103.96s - Test Acc: 3.21%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.54it/s, loss=4.41, acc=3.09]


Epoch 5/20 - Time: 103.77s - Test Acc: 2.67%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.55it/s, loss=4.41, acc=3.05]


Epoch 6/20 - Time: 103.58s - Test Acc: 3.49%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.56it/s, loss=4.42, acc=2.89]


Epoch 7/20 - Time: 103.48s - Test Acc: 3.31%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.55it/s, loss=4.43, acc=2.73]


Epoch 8/20 - Time: 103.64s - Test Acc: 2.55%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.55it/s, loss=4.45, acc=2.51]


Epoch 9/20 - Time: 103.59s - Test Acc: 2.84%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.55it/s, loss=4.42, acc=2.77]


Epoch 10/20 - Time: 103.52s - Test Acc: 2.73%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.56it/s, loss=4.42, acc=2.93]


Epoch 11/20 - Time: 103.44s - Test Acc: 3.07%


Epoch 12/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.56it/s, loss=4.4, acc=3.15]


Epoch 12/20 - Time: 103.44s - Test Acc: 3.08%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.57it/s, loss=4.41, acc=3.29]


Epoch 13/20 - Time: 103.27s - Test Acc: 2.94%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.57it/s, loss=4.41, acc=3.06]


Epoch 14/20 - Time: 103.25s - Test Acc: 2.44%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.57it/s, loss=4.41, acc=3.14]


Epoch 15/20 - Time: 103.30s - Test Acc: 2.67%


Epoch 16/20: 100%|███████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.58it/s, loss=4.4, acc=3.26]


Epoch 16/20 - Time: 103.12s - Test Acc: 2.80%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.57it/s, loss=4.39, acc=3.32]


Epoch 17/20 - Time: 103.29s - Test Acc: 3.17%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:43<00:00,  7.59it/s, loss=4.39, acc=3.41]


Epoch 18/20 - Time: 103.03s - Test Acc: 2.87%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:42<00:00,  7.59it/s, loss=4.39, acc=3.42]


Epoch 19/20 - Time: 102.99s - Test Acc: 2.97%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [01:42<00:00,  7.60it/s, loss=4.38, acc=3.51]


Epoch 20/20 - Time: 102.89s - Test Acc: 2.94%
Training ViT (patch_size=8, embed_dim=256, depth/transformer layers=4, heads=2, MLP Ratio=2)
Total parameters: 2,188,644
FLOPs per forward pass: 187,209,984


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.79it/s, loss=4.22, acc=5.44]


Epoch 1/20 - Time: 15.10s - Test Acc: 6.42%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.44it/s, loss=4.14, acc=6.49]


Epoch 2/20 - Time: 15.20s - Test Acc: 7.75%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.20it/s, loss=4.06, acc=7.19]


Epoch 3/20 - Time: 15.00s - Test Acc: 8.00%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.32it/s, loss=4.04, acc=7.61]


Epoch 4/20 - Time: 14.96s - Test Acc: 8.30%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.86it/s, loss=4.02, acc=8.02]


Epoch 5/20 - Time: 15.08s - Test Acc: 9.12%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.49it/s, loss=4.01, acc=8.04]


Epoch 6/20 - Time: 15.19s - Test Acc: 8.48%


Epoch 7/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.68it/s, loss=4.05, acc=7.7]


Epoch 7/20 - Time: 15.13s - Test Acc: 8.61%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.74it/s, loss=4.02, acc=8.08]


Epoch 8/20 - Time: 15.11s - Test Acc: 9.02%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.31it/s, loss=3.99, acc=8.51]


Epoch 9/20 - Time: 15.24s - Test Acc: 8.43%


Epoch 10/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.10it/s, loss=4.14, acc=6.8]


Epoch 10/20 - Time: 15.30s - Test Acc: 7.71%


Epoch 11/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.58it/s, loss=4.07, acc=7.6]


Epoch 11/20 - Time: 15.16s - Test Acc: 8.53%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.12it/s, loss=4.09, acc=7.07]


Epoch 12/20 - Time: 15.30s - Test Acc: 7.23%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.63it/s, loss=4.18, acc=5.97]


Epoch 13/20 - Time: 15.16s - Test Acc: 5.53%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.17it/s, loss=4.18, acc=5.97]


Epoch 14/20 - Time: 14.99s - Test Acc: 7.32%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.84it/s, loss=4.13, acc=6.87]


Epoch 15/20 - Time: 15.12s - Test Acc: 7.79%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.12it/s, loss=4.15, acc=6.44]


Epoch 16/20 - Time: 15.30s - Test Acc: 6.65%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.70it/s, loss=4.15, acc=6.36]


Epoch 17/20 - Time: 15.13s - Test Acc: 7.68%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.90it/s, loss=4.11, acc=7.06]


Epoch 18/20 - Time: 15.08s - Test Acc: 7.39%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.73it/s, loss=4.11, acc=7.12]


Epoch 19/20 - Time: 15.12s - Test Acc: 6.13%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 50.83it/s, loss=4.16, acc=6.28]


Epoch 20/20 - Time: 15.39s - Test Acc: 6.84%
Training ViT (patch_size=8, embed_dim=256, depth/transformer layers=4, heads=2, MLP Ratio=4)
Total parameters: 3,239,268
FLOPs per forward pass: 254,449,920


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.03it/s, loss=4.28, acc=4.72]


Epoch 1/20 - Time: 16.28s - Test Acc: 6.67%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.12it/s, loss=4.12, acc=6.46]


Epoch 2/20 - Time: 16.25s - Test Acc: 8.34%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.24it/s, loss=4.11, acc=6.87]


Epoch 3/20 - Time: 16.21s - Test Acc: 6.94%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.57it/s, loss=4.05, acc=7.52]


Epoch 4/20 - Time: 16.10s - Test Acc: 7.64%


Epoch 5/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.49it/s, loss=4.09, acc=6.9]


Epoch 5/20 - Time: 16.13s - Test Acc: 7.67%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.41it/s, loss=4.12, acc=6.72]


Epoch 6/20 - Time: 16.17s - Test Acc: 7.06%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.89it/s, loss=4.13, acc=6.64]


Epoch 7/20 - Time: 16.33s - Test Acc: 6.13%


Epoch 8/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.12it/s, loss=4.15, acc=6.3]


Epoch 8/20 - Time: 16.25s - Test Acc: 6.87%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.59it/s, loss=4.11, acc=6.87]


Epoch 9/20 - Time: 16.09s - Test Acc: 7.60%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.60it/s, loss=4.17, acc=6.15]


Epoch 10/20 - Time: 16.09s - Test Acc: 6.66%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.67it/s, loss=4.14, acc=6.35]


Epoch 11/20 - Time: 16.07s - Test Acc: 7.20%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.82it/s, loss=4.15, acc=6.37]


Epoch 12/20 - Time: 16.35s - Test Acc: 7.48%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.46it/s, loss=4.14, acc=6.44]


Epoch 13/20 - Time: 16.15s - Test Acc: 6.51%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.95it/s, loss=4.14, acc=6.39]


Epoch 14/20 - Time: 16.31s - Test Acc: 6.45%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.32it/s, loss=4.17, acc=6.37]


Epoch 15/20 - Time: 16.19s - Test Acc: 6.42%


Epoch 16/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.25it/s, loss=4.16, acc=6.2]


Epoch 16/20 - Time: 16.21s - Test Acc: 6.56%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.44it/s, loss=4.19, acc=5.89]


Epoch 17/20 - Time: 16.14s - Test Acc: 6.98%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.77it/s, loss=4.16, acc=6.37]


Epoch 18/20 - Time: 16.04s - Test Acc: 7.18%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.37it/s, loss=4.17, acc=6.21]


Epoch 19/20 - Time: 16.17s - Test Acc: 6.39%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.69it/s, loss=4.19, acc=5.95]


Epoch 20/20 - Time: 16.06s - Test Acc: 6.87%
Training ViT (patch_size=8, embed_dim=256, depth/transformer layers=4, heads=4, MLP Ratio=2)
Total parameters: 2,188,644
FLOPs per forward pass: 187,209,984


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.86it/s, loss=4.11, acc=6.73]


Epoch 1/20 - Time: 15.08s - Test Acc: 8.72%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.84it/s, loss=3.94, acc=9.02]


Epoch 2/20 - Time: 15.09s - Test Acc: 9.10%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.85it/s, loss=3.91, acc=9.32]


Epoch 3/20 - Time: 15.08s - Test Acc: 10.28%


Epoch 4/20: 100%|█████████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 50.90it/s, loss=3.87, acc=10]


Epoch 4/20 - Time: 15.36s - Test Acc: 10.16%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 50.60it/s, loss=3.87, acc=9.81]


Epoch 5/20 - Time: 15.45s - Test Acc: 9.66%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.24it/s, loss=3.85, acc=10.2]


Epoch 6/20 - Time: 14.97s - Test Acc: 10.91%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 52.09it/s, loss=3.81, acc=11.1]


Epoch 7/20 - Time: 15.01s - Test Acc: 10.73%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.68it/s, loss=3.78, acc=11.5]


Epoch 8/20 - Time: 14.84s - Test Acc: 11.02%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.83it/s, loss=3.75, acc=11.9]


Epoch 9/20 - Time: 15.10s - Test Acc: 12.58%


Epoch 10/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.57it/s, loss=3.7, acc=12.5]


Epoch 10/20 - Time: 14.88s - Test Acc: 12.22%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.29it/s, loss=3.72, acc=12.5]


Epoch 11/20 - Time: 14.96s - Test Acc: 13.00%


Epoch 12/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 52.04it/s, loss=3.7, acc=12.6]


Epoch 12/20 - Time: 15.03s - Test Acc: 13.16%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.41it/s, loss=3.68, acc=12.9]


Epoch 13/20 - Time: 14.92s - Test Acc: 11.23%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.84it/s, loss=3.64, acc=13.7]


Epoch 14/20 - Time: 14.82s - Test Acc: 14.21%


Epoch 15/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.34it/s, loss=3.61, acc=14]


Epoch 15/20 - Time: 14.94s - Test Acc: 13.88%


Epoch 16/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.50it/s, loss=3.55, acc=15]


Epoch 16/20 - Time: 14.90s - Test Acc: 15.28%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.48it/s, loss=3.51, acc=15.8]


Epoch 17/20 - Time: 14.90s - Test Acc: 15.66%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 51.90it/s, loss=3.48, acc=16.4]


Epoch 18/20 - Time: 15.07s - Test Acc: 16.45%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.45it/s, loss=3.45, acc=16.9]


Epoch 19/20 - Time: 14.91s - Test Acc: 15.40%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.64it/s, loss=3.41, acc=17.4]


Epoch 20/20 - Time: 14.87s - Test Acc: 17.44%
Training ViT (patch_size=8, embed_dim=256, depth/transformer layers=4, heads=4, MLP Ratio=4)
Total parameters: 3,239,268
FLOPs per forward pass: 254,449,920


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.29it/s, loss=4.12, acc=6.37]


Epoch 1/20 - Time: 16.54s - Test Acc: 8.67%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.30it/s, loss=3.99, acc=8.34]


Epoch 2/20 - Time: 16.19s - Test Acc: 9.75%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.84it/s, loss=3.94, acc=8.85]


Epoch 3/20 - Time: 16.35s - Test Acc: 10.02%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.15it/s, loss=3.94, acc=9.22]


Epoch 4/20 - Time: 16.24s - Test Acc: 9.08%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.34it/s, loss=3.94, acc=8.85]


Epoch 5/20 - Time: 16.18s - Test Acc: 8.38%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.19it/s, loss=3.94, acc=9.09]


Epoch 6/20 - Time: 16.23s - Test Acc: 9.49%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.78it/s, loss=3.94, acc=8.97]


Epoch 7/20 - Time: 16.38s - Test Acc: 9.71%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.01it/s, loss=3.92, acc=9.22]


Epoch 8/20 - Time: 16.29s - Test Acc: 9.10%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.51it/s, loss=3.91, acc=9.14]


Epoch 9/20 - Time: 16.12s - Test Acc: 8.57%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.29it/s, loss=3.92, acc=9.33]


Epoch 10/20 - Time: 16.19s - Test Acc: 10.28%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.16it/s, loss=3.91, acc=9.38]


Epoch 11/20 - Time: 16.25s - Test Acc: 10.08%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.10it/s, loss=3.91, acc=9.42]


Epoch 12/20 - Time: 16.27s - Test Acc: 10.33%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.10it/s, loss=3.95, acc=9.07]


Epoch 13/20 - Time: 16.26s - Test Acc: 9.41%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.70it/s, loss=3.94, acc=9.01]


Epoch 14/20 - Time: 16.39s - Test Acc: 9.15%


Epoch 15/20: 100%|█████████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.97it/s, loss=4, acc=8.26]


Epoch 15/20 - Time: 16.30s - Test Acc: 7.22%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.68it/s, loss=4.04, acc=7.84]


Epoch 16/20 - Time: 16.40s - Test Acc: 8.65%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.90it/s, loss=4.04, acc=7.89]


Epoch 17/20 - Time: 16.32s - Test Acc: 8.07%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.80it/s, loss=4.1, acc=7]


Epoch 18/20 - Time: 16.36s - Test Acc: 7.28%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.80it/s, loss=4.09, acc=7.04]


Epoch 19/20 - Time: 16.36s - Test Acc: 7.84%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.27it/s, loss=4.07, acc=7.39]


Epoch 20/20 - Time: 16.20s - Test Acc: 8.96%
Training ViT (patch_size=8, embed_dim=256, depth/transformer layers=8, heads=2, MLP Ratio=2)
Total parameters: 4,297,060
FLOPs per forward pass: 322,148,608


Epoch 1/20: 100%|██████████████████████████████████████████████████████████████| 782/782 [00:20<00:00, 37.52it/s, loss=4.24, acc=5]


Epoch 1/20 - Time: 20.84s - Test Acc: 4.04%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:20<00:00, 37.29it/s, loss=4.24, acc=5.05]


Epoch 2/20 - Time: 20.97s - Test Acc: 5.83%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:20<00:00, 37.30it/s, loss=4.31, acc=4.38]


Epoch 3/20 - Time: 20.97s - Test Acc: 3.51%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 36.81it/s, loss=4.37, acc=3.72]


Epoch 4/20 - Time: 21.24s - Test Acc: 3.34%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:20<00:00, 37.37it/s, loss=4.36, acc=3.87]


Epoch 5/20 - Time: 20.92s - Test Acc: 4.22%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:20<00:00, 37.40it/s, loss=4.39, acc=3.62]


Epoch 6/20 - Time: 20.91s - Test Acc: 3.78%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 37.15it/s, loss=4.38, acc=3.63]


Epoch 7/20 - Time: 21.05s - Test Acc: 3.72%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 36.48it/s, loss=4.35, acc=3.95]


Epoch 8/20 - Time: 21.44s - Test Acc: 3.66%


Epoch 9/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:20<00:00, 37.48it/s, loss=4.37, acc=3.7]


Epoch 9/20 - Time: 20.87s - Test Acc: 3.28%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 36.91it/s, loss=4.38, acc=3.41]


Epoch 10/20 - Time: 21.19s - Test Acc: 3.29%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:20<00:00, 37.37it/s, loss=4.36, acc=3.58]


Epoch 11/20 - Time: 20.93s - Test Acc: 3.51%


Epoch 12/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:20<00:00, 37.55it/s, loss=4.37, acc=3.5]


Epoch 12/20 - Time: 20.84s - Test Acc: 3.15%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:20<00:00, 37.42it/s, loss=4.37, acc=3.62]


Epoch 13/20 - Time: 20.90s - Test Acc: 3.75%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:20<00:00, 37.65it/s, loss=4.36, acc=3.67]


Epoch 14/20 - Time: 20.77s - Test Acc: 3.68%


Epoch 15/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:20<00:00, 37.51it/s, loss=4.4, acc=3.36]


Epoch 15/20 - Time: 20.85s - Test Acc: 3.37%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:20<00:00, 37.73it/s, loss=4.39, acc=3.41]


Epoch 16/20 - Time: 20.73s - Test Acc: 3.60%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:20<00:00, 37.39it/s, loss=4.41, acc=3.06]


Epoch 17/20 - Time: 20.92s - Test Acc: 2.43%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 37.18it/s, loss=4.42, acc=2.95]


Epoch 18/20 - Time: 21.03s - Test Acc: 3.57%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:20<00:00, 37.43it/s, loss=4.42, acc=3.04]


Epoch 19/20 - Time: 20.89s - Test Acc: 3.72%


Epoch 20/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:20<00:00, 37.32it/s, loss=4.4, acc=3.23]


Epoch 20/20 - Time: 20.95s - Test Acc: 3.43%
Training ViT (patch_size=8, embed_dim=256, depth/transformer layers=8, heads=2, MLP Ratio=4)
Total parameters: 6,398,308
FLOPs per forward pass: 456,628,480


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.92it/s, loss=4.28, acc=4.63]


Epoch 1/20 - Time: 23.06s - Test Acc: 5.70%


Epoch 2/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.81it/s, loss=4.22, acc=5.3]


Epoch 2/20 - Time: 23.13s - Test Acc: 4.82%


Epoch 3/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 34.13it/s, loss=4.4, acc=3.15]


Epoch 3/20 - Time: 22.91s - Test Acc: 3.49%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.77it/s, loss=4.37, acc=3.62]


Epoch 4/20 - Time: 23.16s - Test Acc: 3.72%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 34.02it/s, loss=4.38, acc=3.29]


Epoch 5/20 - Time: 22.99s - Test Acc: 3.86%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 34.19it/s, loss=4.41, acc=2.93]


Epoch 6/20 - Time: 22.87s - Test Acc: 3.13%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 34.11it/s, loss=4.39, acc=3.39]


Epoch 7/20 - Time: 22.92s - Test Acc: 3.78%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 34.10it/s, loss=4.37, acc=3.45]


Epoch 8/20 - Time: 22.93s - Test Acc: 4.00%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.87it/s, loss=4.36, acc=3.69]


Epoch 9/20 - Time: 23.09s - Test Acc: 4.07%


Epoch 10/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.95it/s, loss=4.38, acc=3.4]


Epoch 10/20 - Time: 23.03s - Test Acc: 3.29%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.90it/s, loss=4.39, acc=3.35]


Epoch 11/20 - Time: 23.07s - Test Acc: 3.65%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.68it/s, loss=4.37, acc=3.57]


Epoch 12/20 - Time: 23.22s - Test Acc: 3.78%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.80it/s, loss=4.37, acc=3.66]


Epoch 13/20 - Time: 23.14s - Test Acc: 3.68%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.95it/s, loss=4.35, acc=3.74]


Epoch 14/20 - Time: 23.03s - Test Acc: 3.75%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 34.03it/s, loss=4.35, acc=3.76]


Epoch 15/20 - Time: 22.99s - Test Acc: 3.75%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 34.07it/s, loss=4.36, acc=3.71]


Epoch 16/20 - Time: 22.95s - Test Acc: 3.40%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 34.04it/s, loss=4.39, acc=3.48]


Epoch 17/20 - Time: 22.97s - Test Acc: 3.71%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.97it/s, loss=4.35, acc=3.84]


Epoch 18/20 - Time: 23.02s - Test Acc: 3.88%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.76it/s, loss=4.33, acc=3.86]


Epoch 19/20 - Time: 23.16s - Test Acc: 3.92%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.66it/s, loss=4.35, acc=3.68]


Epoch 20/20 - Time: 23.24s - Test Acc: 3.64%
Training ViT (patch_size=8, embed_dim=256, depth/transformer layers=8, heads=4, MLP Ratio=2)
Total parameters: 4,297,060
FLOPs per forward pass: 322,148,608


Epoch 1/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 36.96it/s, loss=4.12, acc=6.5]


Epoch 1/20 - Time: 21.16s - Test Acc: 7.80%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 37.18it/s, loss=3.97, acc=8.49]


Epoch 2/20 - Time: 21.03s - Test Acc: 9.73%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 36.98it/s, loss=3.91, acc=9.28]


Epoch 3/20 - Time: 21.15s - Test Acc: 10.41%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:20<00:00, 37.53it/s, loss=3.86, acc=10.1]


Epoch 4/20 - Time: 20.84s - Test Acc: 9.74%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 37.21it/s, loss=3.85, acc=10.2]


Epoch 5/20 - Time: 21.02s - Test Acc: 10.00%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 36.77it/s, loss=3.89, acc=9.68]


Epoch 6/20 - Time: 21.26s - Test Acc: 9.53%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 36.95it/s, loss=3.93, acc=9.08]


Epoch 7/20 - Time: 21.16s - Test Acc: 9.35%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:20<00:00, 37.26it/s, loss=3.93, acc=9.17]


Epoch 8/20 - Time: 20.99s - Test Acc: 9.29%


Epoch 9/20: 100%|██████████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 36.74it/s, loss=3.95, acc=9]


Epoch 9/20 - Time: 21.28s - Test Acc: 9.50%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 36.64it/s, loss=3.97, acc=8.61]


Epoch 10/20 - Time: 21.34s - Test Acc: 8.58%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 37.11it/s, loss=3.99, acc=8.44]


Epoch 11/20 - Time: 21.07s - Test Acc: 9.28%


Epoch 12/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 37.19it/s, loss=3.95, acc=9.1]


Epoch 12/20 - Time: 21.03s - Test Acc: 9.51%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 37.01it/s, loss=3.95, acc=8.96]


Epoch 13/20 - Time: 21.13s - Test Acc: 8.87%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 37.13it/s, loss=3.95, acc=8.83]


Epoch 14/20 - Time: 21.06s - Test Acc: 9.99%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 37.06it/s, loss=3.99, acc=8.46]


Epoch 15/20 - Time: 21.10s - Test Acc: 7.52%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 36.88it/s, loss=4.05, acc=7.71]


Epoch 16/20 - Time: 21.21s - Test Acc: 8.87%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:20<00:00, 37.36it/s, loss=4.08, acc=7.55]


Epoch 17/20 - Time: 20.93s - Test Acc: 8.41%


Epoch 18/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 37.08it/s, loss=4.06, acc=7.6]


Epoch 18/20 - Time: 21.09s - Test Acc: 7.59%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 36.85it/s, loss=4.07, acc=7.61]


Epoch 19/20 - Time: 21.24s - Test Acc: 7.60%


Epoch 20/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:21<00:00, 36.83it/s, loss=4.05, acc=7.7]


Epoch 20/20 - Time: 21.23s - Test Acc: 8.55%
Training ViT (patch_size=8, embed_dim=256, depth/transformer layers=8, heads=4, MLP Ratio=4)
Total parameters: 6,398,308
FLOPs per forward pass: 456,628,480


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.57it/s, loss=4.18, acc=5.97]


Epoch 1/20 - Time: 23.30s - Test Acc: 6.83%


Epoch 2/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.81it/s, loss=4.09, acc=6.8]


Epoch 2/20 - Time: 23.13s - Test Acc: 6.64%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.73it/s, loss=4.17, acc=6.26]


Epoch 3/20 - Time: 23.18s - Test Acc: 6.41%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.70it/s, loss=4.13, acc=6.54]


Epoch 4/20 - Time: 23.21s - Test Acc: 6.86%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.74it/s, loss=4.06, acc=7.33]


Epoch 5/20 - Time: 23.18s - Test Acc: 8.18%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.55it/s, loss=4.14, acc=6.34]


Epoch 6/20 - Time: 23.31s - Test Acc: 6.42%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 32.88it/s, loss=4.13, acc=6.65]


Epoch 7/20 - Time: 23.78s - Test Acc: 7.57%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 32.73it/s, loss=4.11, acc=6.82]


Epoch 8/20 - Time: 23.89s - Test Acc: 6.37%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.86it/s, loss=4.14, acc=6.51]


Epoch 9/20 - Time: 23.10s - Test Acc: 6.64%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.72it/s, loss=4.14, acc=6.35]


Epoch 10/20 - Time: 23.19s - Test Acc: 6.47%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.75it/s, loss=4.14, acc=6.37]


Epoch 11/20 - Time: 23.17s - Test Acc: 5.97%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.66it/s, loss=4.14, acc=6.58]


Epoch 12/20 - Time: 23.23s - Test Acc: 7.66%


Epoch 13/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.51it/s, loss=4.16, acc=6.2]


Epoch 13/20 - Time: 23.33s - Test Acc: 6.52%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.56it/s, loss=4.17, acc=6.17]


Epoch 14/20 - Time: 23.30s - Test Acc: 6.93%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.76it/s, loss=4.19, acc=5.83]


Epoch 15/20 - Time: 23.18s - Test Acc: 5.72%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.47it/s, loss=4.19, acc=5.86]


Epoch 16/20 - Time: 23.36s - Test Acc: 6.24%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.59it/s, loss=4.17, acc=6.23]


Epoch 17/20 - Time: 23.28s - Test Acc: 6.33%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.75it/s, loss=4.19, acc=5.84]


Epoch 18/20 - Time: 23.17s - Test Acc: 5.55%


Epoch 19/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.85it/s, loss=4.2, acc=5.94]


Epoch 19/20 - Time: 23.10s - Test Acc: 5.89%


Epoch 20/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.84it/s, loss=4.23, acc=5.4]


Epoch 20/20 - Time: 23.11s - Test Acc: 5.87%
Training ViT (patch_size=8, embed_dim=512, depth/transformer layers=4, heads=2, MLP Ratio=2)
Total parameters: 8,571,492
FLOPs per forward pass: 642,849,024


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 34.84it/s, loss=4.48, acc=2.94]


Epoch 1/20 - Time: 22.45s - Test Acc: 3.22%


Epoch 2/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.04it/s, loss=4.52, acc=2.4]


Epoch 2/20 - Time: 22.32s - Test Acc: 1.79%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.43it/s, loss=4.47, acc=2.72]


Epoch 3/20 - Time: 22.07s - Test Acc: 2.93%


Epoch 4/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.20it/s, loss=4.47, acc=2.7]


Epoch 4/20 - Time: 22.22s - Test Acc: 3.06%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.25it/s, loss=4.46, acc=2.71]


Epoch 5/20 - Time: 22.19s - Test Acc: 3.43%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.28it/s, loss=4.41, acc=3.07]


Epoch 6/20 - Time: 22.17s - Test Acc: 4.14%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:23<00:00, 33.86it/s, loss=4.39, acc=3.46]


Epoch 7/20 - Time: 23.09s - Test Acc: 3.69%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.03it/s, loss=4.35, acc=3.78]


Epoch 8/20 - Time: 22.33s - Test Acc: 4.82%


Epoch 9/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.41it/s, loss=4.33, acc=4.1]


Epoch 9/20 - Time: 22.09s - Test Acc: 4.36%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.24it/s, loss=4.32, acc=4.26]


Epoch 10/20 - Time: 22.19s - Test Acc: 4.36%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.10it/s, loss=4.29, acc=4.54]


Epoch 11/20 - Time: 22.28s - Test Acc: 4.81%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.08it/s, loss=4.34, acc=4.09]


Epoch 12/20 - Time: 22.29s - Test Acc: 4.91%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.06it/s, loss=4.34, acc=4.13]


Epoch 13/20 - Time: 22.31s - Test Acc: 4.64%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.18it/s, loss=4.34, acc=3.93]


Epoch 14/20 - Time: 22.23s - Test Acc: 5.22%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.15it/s, loss=4.32, acc=4.28]


Epoch 15/20 - Time: 22.25s - Test Acc: 5.80%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.16it/s, loss=4.33, acc=4.21]


Epoch 16/20 - Time: 22.24s - Test Acc: 4.55%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.13it/s, loss=4.36, acc=3.94]


Epoch 17/20 - Time: 22.26s - Test Acc: 4.22%


Epoch 18/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.34it/s, loss=4.37, acc=3.7]


Epoch 18/20 - Time: 22.13s - Test Acc: 3.99%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 34.87it/s, loss=4.36, acc=3.67]


Epoch 19/20 - Time: 22.44s - Test Acc: 4.47%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.25it/s, loss=4.32, acc=4.32]


Epoch 20/20 - Time: 22.18s - Test Acc: 4.76%
Training ViT (patch_size=8, embed_dim=512, depth/transformer layers=4, heads=2, MLP Ratio=4)
Total parameters: 12,769,892
FLOPs per forward pass: 911,546,624


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.78it/s, loss=4.41, acc=3.43]


Epoch 1/20 - Time: 25.41s - Test Acc: 3.78%


Epoch 2/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.85it/s, loss=4.4, acc=3.49]


Epoch 2/20 - Time: 25.35s - Test Acc: 3.03%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.80it/s, loss=4.41, acc=3.45]


Epoch 3/20 - Time: 25.39s - Test Acc: 4.29%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.86it/s, loss=4.35, acc=4.22]


Epoch 4/20 - Time: 25.36s - Test Acc: 4.54%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.87it/s, loss=4.39, acc=3.47]


Epoch 5/20 - Time: 25.33s - Test Acc: 4.50%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.94it/s, loss=4.41, acc=3.53]


Epoch 6/20 - Time: 25.29s - Test Acc: 4.18%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.82it/s, loss=4.35, acc=4.47]


Epoch 7/20 - Time: 25.37s - Test Acc: 4.68%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.94it/s, loss=4.34, acc=4.24]


Epoch 8/20 - Time: 25.27s - Test Acc: 3.44%


Epoch 9/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.40it/s, loss=4.35, acc=4.1]


Epoch 9/20 - Time: 26.60s - Test Acc: 4.41%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.33it/s, loss=4.31, acc=4.66]


Epoch 10/20 - Time: 25.78s - Test Acc: 4.95%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.43it/s, loss=4.33, acc=4.33]


Epoch 11/20 - Time: 25.69s - Test Acc: 4.04%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.12it/s, loss=4.33, acc=4.19]


Epoch 12/20 - Time: 25.98s - Test Acc: 4.95%


Epoch 13/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:26<00:00, 29.80it/s, loss=4.28, acc=5.2]


Epoch 13/20 - Time: 26.24s - Test Acc: 5.21%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.72it/s, loss=4.25, acc=5.42]


Epoch 14/20 - Time: 25.46s - Test Acc: 5.45%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.74it/s, loss=4.23, acc=5.63]


Epoch 15/20 - Time: 25.44s - Test Acc: 5.50%


Epoch 16/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.75it/s, loss=4.27, acc=5.1]


Epoch 16/20 - Time: 25.44s - Test Acc: 5.77%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.80it/s, loss=4.31, acc=4.54]


Epoch 17/20 - Time: 25.41s - Test Acc: 5.52%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.76it/s, loss=4.28, acc=4.87]


Epoch 18/20 - Time: 25.42s - Test Acc: 5.59%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.81it/s, loss=4.27, acc=4.98]


Epoch 19/20 - Time: 25.38s - Test Acc: 5.36%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.91it/s, loss=4.28, acc=4.83]


Epoch 20/20 - Time: 25.30s - Test Acc: 4.78%
Training ViT (patch_size=8, embed_dim=512, depth/transformer layers=4, heads=4, MLP Ratio=2)
Total parameters: 8,571,492
FLOPs per forward pass: 642,849,024


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 34.90it/s, loss=4.36, acc=4.08]


Epoch 1/20 - Time: 22.40s - Test Acc: 4.72%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.25it/s, loss=4.33, acc=4.28]


Epoch 2/20 - Time: 22.20s - Test Acc: 4.46%


Epoch 3/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.22it/s, loss=4.3, acc=4.45]


Epoch 3/20 - Time: 22.20s - Test Acc: 4.86%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.19it/s, loss=4.28, acc=4.78]


Epoch 4/20 - Time: 22.22s - Test Acc: 4.38%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.18it/s, loss=4.29, acc=4.57]


Epoch 5/20 - Time: 22.23s - Test Acc: 4.45%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.13it/s, loss=4.28, acc=4.67]


Epoch 6/20 - Time: 22.26s - Test Acc: 4.81%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.31it/s, loss=4.22, acc=5.49]


Epoch 7/20 - Time: 22.16s - Test Acc: 6.08%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.31it/s, loss=4.23, acc=5.33]


Epoch 8/20 - Time: 22.14s - Test Acc: 5.64%


Epoch 9/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.29it/s, loss=4.3, acc=4.84]


Epoch 9/20 - Time: 22.16s - Test Acc: 5.89%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.06it/s, loss=4.22, acc=5.61]


Epoch 10/20 - Time: 22.30s - Test Acc: 6.18%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.10it/s, loss=4.26, acc=5.18]


Epoch 11/20 - Time: 22.28s - Test Acc: 5.39%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.08it/s, loss=4.26, acc=5.08]


Epoch 12/20 - Time: 22.29s - Test Acc: 5.41%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.27it/s, loss=4.23, acc=5.44]


Epoch 13/20 - Time: 22.19s - Test Acc: 6.59%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.35it/s, loss=4.23, acc=5.41]


Epoch 14/20 - Time: 22.14s - Test Acc: 6.36%


Epoch 15/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.41it/s, loss=4.22, acc=5.5]


Epoch 15/20 - Time: 22.08s - Test Acc: 5.63%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.19it/s, loss=4.22, acc=5.45]


Epoch 16/20 - Time: 22.24s - Test Acc: 6.19%


Epoch 17/20: 100%|█████████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.06it/s, loss=4.19, acc=6]


Epoch 17/20 - Time: 22.31s - Test Acc: 6.50%


Epoch 18/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.26it/s, loss=4.2, acc=5.72]


Epoch 18/20 - Time: 22.18s - Test Acc: 6.47%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.38it/s, loss=4.21, acc=5.76]


Epoch 19/20 - Time: 22.12s - Test Acc: 6.14%


Epoch 20/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:22<00:00, 35.47it/s, loss=4.22, acc=5.6]


Epoch 20/20 - Time: 22.05s - Test Acc: 6.01%
Training ViT (patch_size=8, embed_dim=512, depth/transformer layers=4, heads=4, MLP Ratio=4)
Total parameters: 12,769,892
FLOPs per forward pass: 911,546,624


Epoch 1/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 31.00it/s, loss=4.35, acc=4.1]


Epoch 1/20 - Time: 25.22s - Test Acc: 4.91%


Epoch 2/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 31.04it/s, loss=4.3, acc=4.62]


Epoch 2/20 - Time: 25.21s - Test Acc: 4.06%


Epoch 3/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.87it/s, loss=4.3, acc=4.37]


Epoch 3/20 - Time: 25.33s - Test Acc: 5.31%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.98it/s, loss=4.28, acc=4.62]


Epoch 4/20 - Time: 25.24s - Test Acc: 4.14%


Epoch 5/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.96it/s, loss=4.28, acc=4.8]


Epoch 5/20 - Time: 25.26s - Test Acc: 5.57%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.84it/s, loss=4.27, acc=5.04]


Epoch 6/20 - Time: 25.35s - Test Acc: 4.92%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.83it/s, loss=4.29, acc=4.76]


Epoch 7/20 - Time: 25.37s - Test Acc: 5.31%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.85it/s, loss=4.28, acc=4.89]


Epoch 8/20 - Time: 25.35s - Test Acc: 5.15%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.94it/s, loss=4.26, acc=5.03]


Epoch 9/20 - Time: 25.28s - Test Acc: 5.40%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.87it/s, loss=4.24, acc=5.33]


Epoch 10/20 - Time: 25.33s - Test Acc: 5.59%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.89it/s, loss=4.23, acc=5.43]


Epoch 11/20 - Time: 25.32s - Test Acc: 6.06%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.65it/s, loss=4.21, acc=5.61]


Epoch 12/20 - Time: 25.53s - Test Acc: 6.08%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.97it/s, loss=4.21, acc=5.77]


Epoch 13/20 - Time: 25.26s - Test Acc: 5.57%


Epoch 14/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.92it/s, loss=4.22, acc=5.4]


Epoch 14/20 - Time: 25.30s - Test Acc: 6.07%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.78it/s, loss=4.22, acc=5.53]


Epoch 15/20 - Time: 25.40s - Test Acc: 5.88%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.97it/s, loss=4.21, acc=5.53]


Epoch 16/20 - Time: 25.25s - Test Acc: 6.46%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.76it/s, loss=4.21, acc=5.61]


Epoch 17/20 - Time: 25.44s - Test Acc: 5.38%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.84it/s, loss=4.24, acc=5.12]


Epoch 18/20 - Time: 25.36s - Test Acc: 5.24%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.90it/s, loss=4.26, acc=4.96]


Epoch 19/20 - Time: 25.31s - Test Acc: 5.99%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:25<00:00, 30.90it/s, loss=4.23, acc=5.42]


Epoch 20/20 - Time: 25.31s - Test Acc: 6.00%
Training ViT (patch_size=8, embed_dim=512, depth/transformer layers=8, heads=2, MLP Ratio=2)
Total parameters: 16,982,628
FLOPs per forward pass: 1,181,161,728


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 21.92it/s, loss=4.46, acc=2.95]


Epoch 1/20 - Time: 35.69s - Test Acc: 3.18%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.03it/s, loss=4.45, acc=2.89]


Epoch 2/20 - Time: 35.51s - Test Acc: 2.98%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 21.97it/s, loss=4.45, acc=2.74]


Epoch 3/20 - Time: 35.60s - Test Acc: 2.83%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 21.94it/s, loss=4.44, acc=2.99]


Epoch 4/20 - Time: 35.65s - Test Acc: 3.02%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 21.93it/s, loss=4.45, acc=2.75]


Epoch 5/20 - Time: 35.67s - Test Acc: 2.34%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.01it/s, loss=4.48, acc=2.32]


Epoch 6/20 - Time: 35.53s - Test Acc: 2.37%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 21.98it/s, loss=4.53, acc=1.86]


Epoch 7/20 - Time: 35.58s - Test Acc: 2.06%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.05it/s, loss=4.53, acc=1.94]


Epoch 8/20 - Time: 35.47s - Test Acc: 1.91%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.04it/s, loss=4.54, acc=2.06]


Epoch 9/20 - Time: 35.49s - Test Acc: 2.39%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.06it/s, loss=4.53, acc=2.12]


Epoch 10/20 - Time: 35.46s - Test Acc: 2.27%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.07it/s, loss=4.51, acc=2.26]


Epoch 11/20 - Time: 35.42s - Test Acc: 2.15%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.04it/s, loss=4.51, acc=2.29]


Epoch 12/20 - Time: 35.47s - Test Acc: 2.30%


Epoch 13/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.07it/s, loss=4.5, acc=2.43]


Epoch 13/20 - Time: 35.44s - Test Acc: 2.42%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.04it/s, loss=4.51, acc=2.18]


Epoch 14/20 - Time: 35.47s - Test Acc: 2.38%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.08it/s, loss=4.47, acc=2.53]


Epoch 15/20 - Time: 35.42s - Test Acc: 2.43%


Epoch 16/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.01it/s, loss=4.47, acc=2.5]


Epoch 16/20 - Time: 35.53s - Test Acc: 2.66%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.10it/s, loss=4.47, acc=2.59]


Epoch 17/20 - Time: 35.40s - Test Acc: 2.50%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 21.98it/s, loss=4.47, acc=2.52]


Epoch 18/20 - Time: 35.58s - Test Acc: 2.35%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.07it/s, loss=4.47, acc=2.65]


Epoch 19/20 - Time: 35.43s - Test Acc: 2.71%


Epoch 20/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.06it/s, loss=4.47, acc=2.7]


Epoch 20/20 - Time: 35.44s - Test Acc: 2.89%
Training ViT (patch_size=8, embed_dim=512, depth/transformer layers=8, heads=2, MLP Ratio=4)
Total parameters: 25,379,428
FLOPs per forward pass: 1,718,556,928


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.49it/s, loss=4.54, acc=2.32]


Epoch 1/20 - Time: 42.30s - Test Acc: 2.96%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.47it/s, loss=4.46, acc=2.63]


Epoch 2/20 - Time: 42.36s - Test Acc: 2.96%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.52it/s, loss=4.55, acc=2.07]


Epoch 3/20 - Time: 42.22s - Test Acc: 1.84%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.53it/s, loss=4.56, acc=1.83]


Epoch 4/20 - Time: 42.20s - Test Acc: 2.05%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.53it/s, loss=4.58, acc=1.62]


Epoch 5/20 - Time: 42.19s - Test Acc: 1.59%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.45it/s, loss=4.57, acc=1.73]


Epoch 6/20 - Time: 42.38s - Test Acc: 2.35%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.52it/s, loss=4.53, acc=2.22]


Epoch 7/20 - Time: 42.22s - Test Acc: 2.68%


Epoch 8/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.52it/s, loss=4.52, acc=2.2]


Epoch 8/20 - Time: 42.23s - Test Acc: 2.10%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.51it/s, loss=4.53, acc=2.15]


Epoch 9/20 - Time: 42.27s - Test Acc: 2.68%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.55it/s, loss=4.51, acc=2.36]


Epoch 10/20 - Time: 42.17s - Test Acc: 2.35%


Epoch 11/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.55it/s, loss=4.5, acc=2.43]


Epoch 11/20 - Time: 42.15s - Test Acc: 3.14%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.55it/s, loss=4.48, acc=2.77]


Epoch 12/20 - Time: 42.16s - Test Acc: 2.88%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.52it/s, loss=4.48, acc=2.64]


Epoch 13/20 - Time: 42.23s - Test Acc: 2.70%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.55it/s, loss=4.48, acc=2.54]


Epoch 14/20 - Time: 42.15s - Test Acc: 3.03%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.53it/s, loss=4.49, acc=2.59]


Epoch 15/20 - Time: 42.21s - Test Acc: 3.01%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.53it/s, loss=4.46, acc=2.79]


Epoch 16/20 - Time: 42.20s - Test Acc: 3.06%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.48it/s, loss=4.46, acc=2.98]


Epoch 17/20 - Time: 42.33s - Test Acc: 3.08%


Epoch 18/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.49it/s, loss=4.45, acc=2.9]


Epoch 18/20 - Time: 42.29s - Test Acc: 3.38%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.51it/s, loss=4.44, acc=3.02]


Epoch 19/20 - Time: 42.24s - Test Acc: 2.95%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.52it/s, loss=4.44, acc=2.98]


Epoch 20/20 - Time: 42.22s - Test Acc: 3.03%
Training ViT (patch_size=8, embed_dim=512, depth/transformer layers=8, heads=4, MLP Ratio=2)
Total parameters: 16,982,628
FLOPs per forward pass: 1,181,161,728


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 21.97it/s, loss=4.38, acc=3.88]


Epoch 1/20 - Time: 35.59s - Test Acc: 4.83%


Epoch 2/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.09it/s, loss=4.3, acc=4.56]


Epoch 2/20 - Time: 35.40s - Test Acc: 4.77%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.08it/s, loss=4.26, acc=5.03]


Epoch 3/20 - Time: 35.42s - Test Acc: 5.19%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.04it/s, loss=4.27, acc=4.98]


Epoch 4/20 - Time: 35.47s - Test Acc: 4.32%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.00it/s, loss=4.33, acc=4.07]


Epoch 5/20 - Time: 35.54s - Test Acc: 4.35%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.05it/s, loss=4.32, acc=4.34]


Epoch 6/20 - Time: 35.46s - Test Acc: 4.65%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 21.98it/s, loss=4.32, acc=4.22]


Epoch 7/20 - Time: 35.57s - Test Acc: 4.52%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 21.98it/s, loss=4.35, acc=4.04]


Epoch 8/20 - Time: 35.58s - Test Acc: 4.16%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 21.90it/s, loss=4.37, acc=3.71]


Epoch 9/20 - Time: 35.70s - Test Acc: 4.37%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.02it/s, loss=4.33, acc=4.07]


Epoch 10/20 - Time: 35.52s - Test Acc: 4.17%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.01it/s, loss=4.35, acc=3.94]


Epoch 11/20 - Time: 35.52s - Test Acc: 4.43%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.00it/s, loss=4.36, acc=3.84]


Epoch 12/20 - Time: 35.55s - Test Acc: 4.44%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.05it/s, loss=4.35, acc=4.12]


Epoch 13/20 - Time: 35.47s - Test Acc: 4.63%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.01it/s, loss=4.36, acc=3.82]


Epoch 14/20 - Time: 35.52s - Test Acc: 4.43%


Epoch 15/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.06it/s, loss=4.4, acc=3.31]


Epoch 15/20 - Time: 35.45s - Test Acc: 3.42%


Epoch 16/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 21.95it/s, loss=4.4, acc=3.29]


Epoch 16/20 - Time: 35.62s - Test Acc: 3.60%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 22.03it/s, loss=4.38, acc=3.53]


Epoch 17/20 - Time: 35.50s - Test Acc: 4.43%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 21.93it/s, loss=4.36, acc=3.74]


Epoch 18/20 - Time: 35.65s - Test Acc: 4.47%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 21.98it/s, loss=4.36, acc=3.92]


Epoch 19/20 - Time: 35.57s - Test Acc: 4.17%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:35<00:00, 21.90it/s, loss=4.35, acc=3.86]


Epoch 20/20 - Time: 35.71s - Test Acc: 3.05%
Training ViT (patch_size=8, embed_dim=512, depth/transformer layers=8, heads=4, MLP Ratio=4)
Total parameters: 25,379,428
FLOPs per forward pass: 1,718,556,928


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.54it/s, loss=4.41, acc=3.23]


Epoch 1/20 - Time: 42.17s - Test Acc: 3.81%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.57it/s, loss=4.39, acc=3.29]


Epoch 2/20 - Time: 42.11s - Test Acc: 3.27%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.50it/s, loss=4.43, acc=3.02]


Epoch 3/20 - Time: 42.27s - Test Acc: 3.08%


Epoch 4/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.50it/s, loss=4.4, acc=3.28]


Epoch 4/20 - Time: 42.28s - Test Acc: 3.47%


Epoch 5/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.50it/s, loss=4.38, acc=3.4]


Epoch 5/20 - Time: 42.26s - Test Acc: 2.84%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.55it/s, loss=4.37, acc=3.52]


Epoch 6/20 - Time: 42.16s - Test Acc: 3.12%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.58it/s, loss=4.36, acc=3.61]


Epoch 7/20 - Time: 42.09s - Test Acc: 3.12%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.59it/s, loss=4.38, acc=3.29]


Epoch 8/20 - Time: 42.08s - Test Acc: 3.55%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.60it/s, loss=4.38, acc=3.32]


Epoch 9/20 - Time: 42.04s - Test Acc: 3.58%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.53it/s, loss=4.36, acc=3.39]


Epoch 10/20 - Time: 42.20s - Test Acc: 3.58%


Epoch 11/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.58it/s, loss=4.36, acc=3.5]


Epoch 11/20 - Time: 42.09s - Test Acc: 3.22%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.61it/s, loss=4.35, acc=3.56]


Epoch 12/20 - Time: 42.02s - Test Acc: 3.47%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.57it/s, loss=4.35, acc=3.68]


Epoch 13/20 - Time: 42.10s - Test Acc: 3.14%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.62it/s, loss=4.36, acc=3.43]


Epoch 14/20 - Time: 42.02s - Test Acc: 3.56%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:41<00:00, 18.63it/s, loss=4.34, acc=3.75]


Epoch 15/20 - Time: 41.97s - Test Acc: 3.46%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.58it/s, loss=4.33, acc=3.84]


Epoch 16/20 - Time: 42.08s - Test Acc: 3.58%


Epoch 17/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.55it/s, loss=4.34, acc=3.7]


Epoch 17/20 - Time: 42.15s - Test Acc: 3.64%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.59it/s, loss=4.33, acc=3.69]


Epoch 18/20 - Time: 42.06s - Test Acc: 3.75%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.53it/s, loss=4.34, acc=3.81]


Epoch 19/20 - Time: 42.20s - Test Acc: 3.83%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.61it/s, loss=4.34, acc=3.67]


Epoch 20/20 - Time: 42.02s - Test Acc: 3.62%

Problem 1 Results: ResNet-18 vs ViT Full
            Model     Params          FLOPs Training Time Epoch Time Final Acc 10-Epoch Acc
        ResNet-18 11,220,132 35,550,624,000       909.72s     45.49s    69.88%       68.09%
ViT-p4-e256-d4-h2  2,164,068    187,996,416       592.03s     29.60s    21.47%       18.50%
ViT-p4-e256-d4-h2  3,214,692    255,236,352       664.25s     33.21s    32.70%       27.39%
ViT-p4-e256-d4-h4  2,164,068    187,996,416       604.51s     30.23s    36.00%       28.84%
ViT-p4-e256-d4-h4  3,214,692    255,236,352       678.26s     33.91s    23.05%       20.55%
ViT-p4-e256-d8-h2  4,272,484    322,935,040       891.99s     44.60s     2.68%        3.38%
ViT-p4-e256-d8-h2  6,373,732    457,414,912      1040.01s     52.00s     4.39%        5.76%
ViT-p4-e256-d8-h4  4,272,484    322,935,040       924.34s     46.22s     5.11%        6.17%
ViT-p4-e256-d8-h4  6,373,732    457,414,912      1079.06s     53.95s     9.86%       



Total parameters: 11,220,132
FLOPs per forward pass: 35,550,624,000


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 21.02it/s, loss=2.78, acc=29.5]


Epoch 1/20 - Time: 37.20s - Test Acc: 42.35%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 21.05it/s, loss=1.81, acc=49.6]


Epoch 2/20 - Time: 37.15s - Test Acc: 53.51%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 21.03it/s, loss=1.49, acc=57.7]


Epoch 3/20 - Time: 37.19s - Test Acc: 56.87%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.96it/s, loss=1.29, acc=62.4]


Epoch 4/20 - Time: 37.31s - Test Acc: 60.38%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.96it/s, loss=1.14, acc=66.2]


Epoch 5/20 - Time: 37.32s - Test Acc: 62.17%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.96it/s, loss=1.03, acc=69.4]


Epoch 6/20 - Time: 37.31s - Test Acc: 64.93%


Epoch 7/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.98it/s, loss=0.922, acc=72.4]


Epoch 7/20 - Time: 37.28s - Test Acc: 64.72%


Epoch 8/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 21.02it/s, loss=0.835, acc=74.5]


Epoch 8/20 - Time: 37.20s - Test Acc: 67.05%


Epoch 9/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 21.00it/s, loss=0.747, acc=77]


Epoch 9/20 - Time: 37.24s - Test Acc: 66.13%


Epoch 10/20: 100%|█████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.96it/s, loss=0.678, acc=79.1]


Epoch 10/20 - Time: 37.30s - Test Acc: 66.99%


Epoch 11/20: 100%|█████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 21.01it/s, loss=0.626, acc=80.5]


Epoch 11/20 - Time: 37.23s - Test Acc: 67.51%


Epoch 12/20: 100%|█████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.99it/s, loss=0.569, acc=81.9]


Epoch 12/20 - Time: 37.26s - Test Acc: 68.69%


Epoch 13/20: 100%|█████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 21.02it/s, loss=0.516, acc=83.5]


Epoch 13/20 - Time: 37.20s - Test Acc: 67.89%


Epoch 14/20: 100%|█████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 21.02it/s, loss=0.471, acc=84.9]


Epoch 14/20 - Time: 37.19s - Test Acc: 69.38%


Epoch 15/20: 100%|█████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.96it/s, loss=0.432, acc=86.2]


Epoch 15/20 - Time: 37.31s - Test Acc: 69.73%


Epoch 16/20: 100%|█████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.99it/s, loss=0.392, acc=87.2]


Epoch 16/20 - Time: 37.25s - Test Acc: 69.68%


Epoch 17/20: 100%|█████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 21.02it/s, loss=0.372, acc=88.1]


Epoch 17/20 - Time: 37.19s - Test Acc: 69.80%


Epoch 18/20: 100%|█████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 21.00it/s, loss=0.347, acc=88.5]


Epoch 18/20 - Time: 37.24s - Test Acc: 70.02%


Epoch 19/20: 100%|█████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.99it/s, loss=0.322, acc=89.6]


Epoch 19/20 - Time: 37.26s - Test Acc: 69.41%


Epoch 20/20: 100%|█████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 21.00it/s, loss=0.288, acc=90.5]


Epoch 20/20 - Time: 37.24s - Test Acc: 69.41%
Training ViT (patch_size=4, embed_dim=256, depth/transformer layers=4, heads=4, MLP Ratio=4)
Total parameters: 3,214,692
FLOPs per forward pass: 255,236,352


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.45it/s, loss=4.03, acc=7.59]


Epoch 1/20 - Time: 27.49s - Test Acc: 11.28%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.67it/s, loss=3.67, acc=12.6]


Epoch 2/20 - Time: 27.27s - Test Acc: 13.13%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.86it/s, loss=3.49, acc=15.9]


Epoch 3/20 - Time: 27.10s - Test Acc: 18.02%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.79it/s, loss=3.35, acc=18.4]


Epoch 4/20 - Time: 27.16s - Test Acc: 20.81%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.68it/s, loss=3.26, acc=20.2]


Epoch 5/20 - Time: 27.26s - Test Acc: 22.84%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.80it/s, loss=3.18, acc=21.4]


Epoch 6/20 - Time: 27.15s - Test Acc: 24.90%


Epoch 7/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.71it/s, loss=3.1, acc=23.1]


Epoch 7/20 - Time: 27.24s - Test Acc: 26.21%


Epoch 8/20: 100%|█████████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.53it/s, loss=3.04, acc=24]


Epoch 8/20 - Time: 27.41s - Test Acc: 26.96%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.62it/s, loss=2.97, acc=25.7]


Epoch 9/20 - Time: 27.32s - Test Acc: 27.95%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.81it/s, loss=2.91, acc=26.8]


Epoch 10/20 - Time: 27.14s - Test Acc: 29.53%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.67it/s, loss=2.86, acc=27.7]


Epoch 11/20 - Time: 27.27s - Test Acc: 29.28%


Epoch 12/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.79it/s, loss=2.8, acc=28.8]


Epoch 12/20 - Time: 27.16s - Test Acc: 32.04%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.67it/s, loss=2.75, acc=29.7]


Epoch 13/20 - Time: 27.27s - Test Acc: 32.46%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.69it/s, loss=2.69, acc=30.9]


Epoch 14/20 - Time: 27.27s - Test Acc: 32.94%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.78it/s, loss=2.65, acc=31.7]


Epoch 15/20 - Time: 27.17s - Test Acc: 33.74%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.72it/s, loss=2.59, acc=32.8]


Epoch 16/20 - Time: 27.23s - Test Acc: 34.62%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.74it/s, loss=2.55, acc=33.9]


Epoch 17/20 - Time: 27.21s - Test Acc: 34.80%


Epoch 18/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.82it/s, loss=2.5, acc=34.8]


Epoch 18/20 - Time: 27.14s - Test Acc: 36.15%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.71it/s, loss=2.44, acc=35.8]


Epoch 19/20 - Time: 27.24s - Test Acc: 36.32%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:27<00:00, 28.68it/s, loss=2.38, acc=37.1]


Epoch 20/20 - Time: 27.29s - Test Acc: 36.60%
Training ViT (patch_size=4, embed_dim=512, depth/transformer layers=4, heads=8, MLP Ratio=4)
Total parameters: 12,720,740
FLOPs per forward pass: 913,119,488


Epoch 1/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:57<00:00, 13.69it/s, loss=4.22, acc=5.5]


Epoch 1/20 - Time: 57.12s - Test Acc: 6.58%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:57<00:00, 13.68it/s, loss=4.05, acc=7.34]


Epoch 2/20 - Time: 57.15s - Test Acc: 8.21%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:57<00:00, 13.69it/s, loss=4.02, acc=7.69]


Epoch 3/20 - Time: 57.14s - Test Acc: 8.26%


Epoch 4/20: 100%|██████████████████████████████████████████████████████████████| 782/782 [00:57<00:00, 13.69it/s, loss=4, acc=8.12]


Epoch 4/20 - Time: 57.10s - Test Acc: 9.51%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████████| 782/782 [00:57<00:00, 13.71it/s, loss=4, acc=8.2]


Epoch 5/20 - Time: 57.05s - Test Acc: 9.82%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:57<00:00, 13.71it/s, loss=4.06, acc=7.54]


Epoch 6/20 - Time: 57.04s - Test Acc: 8.39%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:57<00:00, 13.71it/s, loss=4.04, acc=7.44]


Epoch 7/20 - Time: 57.04s - Test Acc: 9.12%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:56<00:00, 13.73it/s, loss=4.01, acc=7.74]


Epoch 8/20 - Time: 56.97s - Test Acc: 8.68%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:57<00:00, 13.71it/s, loss=4.07, acc=7.28]


Epoch 9/20 - Time: 57.05s - Test Acc: 7.36%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:56<00:00, 13.75it/s, loss=4.05, acc=7.59]


Epoch 10/20 - Time: 56.87s - Test Acc: 7.95%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:56<00:00, 13.76it/s, loss=4.04, acc=7.74]


Epoch 11/20 - Time: 56.84s - Test Acc: 8.65%


Epoch 12/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:56<00:00, 13.74it/s, loss=4.02, acc=7.9]


Epoch 12/20 - Time: 56.91s - Test Acc: 9.04%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:56<00:00, 13.73it/s, loss=4.02, acc=7.85]


Epoch 13/20 - Time: 56.97s - Test Acc: 8.11%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:56<00:00, 13.77it/s, loss=4.08, acc=7.31]


Epoch 14/20 - Time: 56.78s - Test Acc: 8.27%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:56<00:00, 13.76it/s, loss=4.04, acc=7.97]


Epoch 15/20 - Time: 56.82s - Test Acc: 8.65%


Epoch 16/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:56<00:00, 13.76it/s, loss=4.1, acc=6.97]


Epoch 16/20 - Time: 56.81s - Test Acc: 7.61%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:56<00:00, 13.77it/s, loss=4.13, acc=6.76]


Epoch 17/20 - Time: 56.80s - Test Acc: 6.92%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:56<00:00, 13.76it/s, loss=4.17, acc=6.05]


Epoch 18/20 - Time: 56.83s - Test Acc: 7.14%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:56<00:00, 13.77it/s, loss=4.14, acc=6.48]


Epoch 19/20 - Time: 56.81s - Test Acc: 7.18%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:56<00:00, 13.76it/s, loss=4.11, acc=6.86]


Epoch 20/20 - Time: 56.84s - Test Acc: 7.46%
Training ViT (patch_size=8, embed_dim=256, depth/transformer layers=4, heads=4, MLP Ratio=4)
Total parameters: 3,239,268
FLOPs per forward pass: 254,449,920


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.97it/s, loss=4.14, acc=6.39]


Epoch 1/20 - Time: 16.30s - Test Acc: 8.74%


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.26it/s, loss=3.98, acc=8.34]


Epoch 2/20 - Time: 16.55s - Test Acc: 8.90%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.54it/s, loss=3.92, acc=9.21]


Epoch 3/20 - Time: 16.45s - Test Acc: 9.03%


Epoch 4/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.75it/s, loss=3.9, acc=9.47]


Epoch 4/20 - Time: 16.38s - Test Acc: 10.17%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.35it/s, loss=3.86, acc=10.2]


Epoch 5/20 - Time: 16.17s - Test Acc: 9.90%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.91it/s, loss=3.87, acc=9.82]


Epoch 6/20 - Time: 16.32s - Test Acc: 10.30%


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.57it/s, loss=3.86, acc=10.3]


Epoch 7/20 - Time: 16.12s - Test Acc: 10.25%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.66it/s, loss=3.86, acc=10.5]


Epoch 8/20 - Time: 16.07s - Test Acc: 11.16%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.12it/s, loss=3.84, acc=10.6]


Epoch 9/20 - Time: 16.25s - Test Acc: 10.96%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.36it/s, loss=3.88, acc=9.94]


Epoch 10/20 - Time: 16.17s - Test Acc: 9.84%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.17it/s, loss=3.88, acc=9.82]


Epoch 11/20 - Time: 16.58s - Test Acc: 10.63%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.72it/s, loss=3.86, acc=9.83]


Epoch 12/20 - Time: 16.39s - Test Acc: 11.13%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.51it/s, loss=3.85, acc=10.3]


Epoch 13/20 - Time: 16.12s - Test Acc: 11.73%


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.62it/s, loss=3.84, acc=10.5]


Epoch 14/20 - Time: 16.08s - Test Acc: 10.63%


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.08it/s, loss=3.91, acc=9.47]


Epoch 15/20 - Time: 16.26s - Test Acc: 9.63%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.59it/s, loss=3.91, acc=9.55]


Epoch 16/20 - Time: 16.43s - Test Acc: 9.85%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.17it/s, loss=3.93, acc=9.15]


Epoch 17/20 - Time: 16.59s - Test Acc: 10.30%


Epoch 18/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.95it/s, loss=3.9, acc=9.66]


Epoch 18/20 - Time: 16.31s - Test Acc: 10.41%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.96it/s, loss=3.92, acc=9.17]


Epoch 19/20 - Time: 16.32s - Test Acc: 9.24%


Epoch 20/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 47.81it/s, loss=3.9, acc=9.77]


Epoch 20/20 - Time: 16.36s - Test Acc: 9.59%
Training ViT (patch_size=8, embed_dim=512, depth/transformer layers=8, heads=8, MLP Ratio=4)
Total parameters: 25,379,428
FLOPs per forward pass: 1,718,556,928


Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.30it/s, loss=4.32, acc=4.22]


Epoch 1/20 - Time: 42.73s - Test Acc: 5.78%


Epoch 2/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.34it/s, loss=4.2, acc=5.67]


Epoch 2/20 - Time: 42.64s - Test Acc: 5.35%


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.33it/s, loss=4.17, acc=6.19]


Epoch 3/20 - Time: 42.69s - Test Acc: 6.01%


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.34it/s, loss=4.18, acc=5.89]


Epoch 4/20 - Time: 42.63s - Test Acc: 6.32%


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.28it/s, loss=4.19, acc=6.15]


Epoch 5/20 - Time: 42.77s - Test Acc: 6.68%


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.34it/s, loss=4.18, acc=5.93]


Epoch 6/20 - Time: 42.64s - Test Acc: 5.49%


Epoch 7/20: 100%|████████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.32it/s, loss=4.25, acc=5.1]


Epoch 7/20 - Time: 42.69s - Test Acc: 5.88%


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.34it/s, loss=4.25, acc=5.21]


Epoch 8/20 - Time: 42.63s - Test Acc: 5.07%


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.36it/s, loss=4.25, acc=5.01]


Epoch 9/20 - Time: 42.58s - Test Acc: 5.39%


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.39it/s, loss=4.24, acc=5.14]


Epoch 10/20 - Time: 42.53s - Test Acc: 5.88%


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.38it/s, loss=4.21, acc=5.47]


Epoch 11/20 - Time: 42.54s - Test Acc: 6.34%


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.36it/s, loss=4.21, acc=5.45]


Epoch 12/20 - Time: 42.60s - Test Acc: 6.07%


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.38it/s, loss=4.22, acc=5.52]


Epoch 13/20 - Time: 42.54s - Test Acc: 5.56%


Epoch 14/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.35it/s, loss=4.2, acc=5.87]


Epoch 14/20 - Time: 42.62s - Test Acc: 6.23%


Epoch 15/20: 100%|███████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.32it/s, loss=4.2, acc=5.93]


Epoch 15/20 - Time: 42.70s - Test Acc: 5.75%


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.38it/s, loss=4.24, acc=5.42]


Epoch 16/20 - Time: 42.54s - Test Acc: 5.72%


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.30it/s, loss=4.26, acc=5.08]


Epoch 17/20 - Time: 42.74s - Test Acc: 5.71%


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.31it/s, loss=4.24, acc=5.49]


Epoch 18/20 - Time: 42.71s - Test Acc: 5.40%


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.30it/s, loss=4.23, acc=5.37]


Epoch 19/20 - Time: 42.74s - Test Acc: 5.73%


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.33it/s, loss=4.24, acc=5.27]


Epoch 20/20 - Time: 42.68s - Test Acc: 5.56%

Problem 1 Results: ResNet-18 vs ViT Best
            Model     Params          FLOPs Training Time Epoch Time Final Acc 10-Epoch Acc
        ResNet-18 11,220,132 35,550,624,000       893.62s     44.68s    69.41%       66.99%
ViT-p4-e256-d4-h4  3,214,692    255,236,352       678.15s     33.91s    36.60%       29.53%
ViT-p4-e512-d4-h8 12,720,740    913,119,488      1311.76s     65.59s     7.46%        7.95%
ViT-p8-e256-d4-h4  3,239,268    254,449,920       447.46s     22.37s     9.59%        9.84%
ViT-p8-e512-d8-h8 25,379,428  1,718,556,928       998.32s     49.92s     5.56%        5.88%
Results saved to: results/run_problem1_training_best.csv
Running experiments for Problem 2: Swin Transformer
Fine-tuning microsoft/swin-tiny-patch4-window7-224 on CIFAR-100


Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-tiny-patch4-window7-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([100, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Total parameters: 27,596,254
Trainable parameters: 76,900
FLOPs per forward pass: 62,104,420


Epoch 1/5: 100%|██████████████████████████████████████████████████████████| 1563/1563 [02:52<00:00,  9.09it/s, loss=4.11, acc=22.8]


Epoch 1/5 - Time: 172.04s - Test Acc: 44.97%


Epoch 2/5: 100%|████████████████████████████████████████████████████████████| 1563/1563 [02:52<00:00,  9.09it/s, loss=3.19, acc=52]


Epoch 2/5 - Time: 172.03s - Test Acc: 57.17%


Epoch 3/5: 100%|██████████████████████████████████████████████████████████| 1563/1563 [02:52<00:00,  9.08it/s, loss=2.54, acc=59.3]


Epoch 3/5 - Time: 172.05s - Test Acc: 61.07%


Epoch 4/5: 100%|███████████████████████████████████████████████████████████| 1563/1563 [02:51<00:00,  9.09it/s, loss=2.1, acc=62.8]


Epoch 4/5 - Time: 172.00s - Test Acc: 63.54%


Epoch 5/5: 100%|██████████████████████████████████████████████████████████| 1563/1563 [02:52<00:00,  9.08it/s, loss=1.81, acc=64.8]


Epoch 5/5 - Time: 172.07s - Test Acc: 65.20%
Fine-tuning microsoft/swin-small-patch4-window7-224 on CIFAR-100


Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-small-patch4-window7-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([100, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Total parameters: 48,914,158
Trainable parameters: 76,900
FLOPs per forward pass: 104,686,948


Epoch 1/5: 100%|██████████████████████████████████████████████████████████| 1563/1563 [04:26<00:00,  5.87it/s, loss=4.04, acc=27.2]


Epoch 1/5 - Time: 266.39s - Test Acc: 52.25%


Epoch 2/5: 100%|██████████████████████████████████████████████████████████| 1563/1563 [04:26<00:00,  5.86it/s, loss=3.02, acc=58.1]


Epoch 2/5 - Time: 266.59s - Test Acc: 61.55%


Epoch 3/5: 100%|████████████████████████████████████████████████████████████| 1563/1563 [04:26<00:00,  5.86it/s, loss=2.31, acc=64]


Epoch 3/5 - Time: 266.53s - Test Acc: 64.97%


Epoch 4/5: 100%|██████████████████████████████████████████████████████████| 1563/1563 [04:26<00:00,  5.86it/s, loss=1.86, acc=66.5]


Epoch 4/5 - Time: 266.60s - Test Acc: 67.05%


Epoch 5/5: 100%|██████████████████████████████████████████████████████████| 1563/1563 [04:26<00:00,  5.86it/s, loss=1.59, acc=68.3]


Epoch 5/5 - Time: 266.63s - Test Acc: 68.65%
Training Swin Transformer from scratch on CIFAR-100
Swin Transformer architecture:
- Patch size: 2
- Window size: 4
- Patches resolution: 16
- Number of patches: 256
Total parameters: 27,600,334
Trainable parameters: 27,600,334
Total parameters: 27,600,334
FLOPs per forward pass: 29,894,308


Epoch 1/5: 100%|█████████████████████████████████████████████████████████████| 1563/1563 [01:07<00:00, 22.99it/s, loss=4, acc=8.61]


Epoch 1/5 - Time: 67.99s - Test Acc: 12.17%


Epoch 2/5: 100%|██████████████████████████████████████████████████████████| 1563/1563 [01:07<00:00, 23.18it/s, loss=3.57, acc=15.1]


Epoch 2/5 - Time: 67.43s - Test Acc: 18.07%


Epoch 3/5: 100%|██████████████████████████████████████████████████████████| 1563/1563 [01:07<00:00, 23.17it/s, loss=3.31, acc=19.8]


Epoch 3/5 - Time: 67.46s - Test Acc: 21.13%


Epoch 4/5: 100%|██████████████████████████████████████████████████████████| 1563/1563 [01:07<00:00, 23.21it/s, loss=3.11, acc=23.6]


Epoch 4/5 - Time: 67.34s - Test Acc: 25.45%


Epoch 5/5: 100%|██████████████████████████████████████████████████████████| 1563/1563 [01:07<00:00, 23.17it/s, loss=2.94, acc=26.6]


Epoch 5/5 - Time: 67.45s - Test Acc: 28.31%

Problem 2 Results: Swin Transformer
                Model     Params       FLOPs Training Time Epoch Time Final Acc 10-Epoch Acc
 Swin-tiny-pretrained 27,596,254  62,104,420      1045.35s    209.07s    65.20%       65.20%
Swin-small-pretrained 48,914,158 104,686,948      1610.08s    322.02s    68.65%       68.65%
    Swin-from-scratch 27,600,334  29,894,308       381.97s     76.39s    28.31%       28.31%
Results saved to: results/swin_results.csv

Homework 6 Execution Complete
