In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

In [3]:
# Helper Module: Stochastic Depth (DropPath)
# Randomly drops a residual connection during training.
# This is a form of regularization that improves performance for deep networks.
def drop_path(x, drop_prob: float = 0., training: bool = False):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output

In [4]:
class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

# Core ViT Components
class PatchEmbedding(nn.Module):
    """
    Splits the image into patches and embeds them.
    Parameters:
        img_size (int): Size of the input image (e.g., 32 for CIFAR-10).
        patch_size (int): Size of each patch (e.g., 4 for CIFAR-10).
        in_channels (int): Number of input channels (e.g., 3 for RGB).
        embed_dim (int): The embedding dimension.
    """
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=512):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        # A convolution layer to convert image to patches and embed them
        self.proj = nn.Conv2d(
            in_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
        )

    def forward(self, x):
        # Input x: (B, C, H, W)
        x = self.proj(x)  # (B, E, H/P, W/P)
        x = x.flatten(2)  # (B, E, N) where N is number of patches
        x = x.transpose(1, 2)  # (B, N, E)
        return x

class MultiHeadAttention(nn.Module):
    """Multi-Head Self-Attention block."""
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class MLP(nn.Module):
    """A simple Feed-Forward Network."""
    def __init__(self, in_features, hidden_features, out_features, drop=0.):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class TransformerEncoderBlock(nn.Module):
    """A single block of the Transformer Encoder."""
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class VisionTransformer(nn.Module):
    """The main Vision Transformer model."""
    def __init__(self, img_size=32, patch_size=4, in_chans=3, num_classes=10, embed_dim=512, depth=6,
                 num_heads=8, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.,
                 drop_path_rate=0.1):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim

        self.patch_embed = PatchEmbedding(
            img_size=img_size, patch_size=patch_size, in_channels=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.n_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.Sequential(*[
            TransformerEncoderBlock(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i])
            for i in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)

        # Classifier head
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialize weights
        nn.init.trunc_normal_(self.pos_embed, std=.02)
        nn.init.trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        x = self.blocks(x)
        x = self.norm(x)

        # Use the CLS token for classification
        return self.head(x[:, 0])

In [5]:
# --- Data Loading and Training Setup ---

def get_loaders(batch_size=256):
    """Prepare CIFAR-10 DataLoaders."""
    # Data augmentation and normalization for training
    # Just normalization for validation
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    trainloader = DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    testloader = DataLoader(
        testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    return trainloader, testloader

def train_one_epoch(model, dataloader, criterion, optimizer, scheduler, device, epoch):
    """Train the model for one epoch."""
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
    for i, (inputs, targets) in enumerate(progress_bar):
        inputs, targets = inputs.to(device), targets.to(device)

        # Linear warm-up
        if epoch < 5: # Warm-up for first 5 epochs
            lr_scale = (i + 1 + len(dataloader) * epoch) / (len(dataloader) * 5)
            for pg in optimizer.param_groups:
                pg['lr'] = CFG.lr * lr_scale

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar.set_postfix({
            'loss': f'{total_loss/(i+1):.4f}',
            'acc': f'{100.*correct/total:.2f}%',
            'lr': f'{optimizer.param_groups[0]["lr"]:.5f}'
        })

    if epoch >= 5: # Start scheduler after warm-up
        scheduler.step()

def evaluate(model, dataloader, criterion, device):
    """Evaluate the model on the test set."""
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    accuracy = 100. * correct / total
    avg_loss = total_loss / len(dataloader)
    print(f'Test Results: Accuracy: {accuracy:.2f}% | Avg Loss: {avg_loss:.4f}')
    return accuracy

In [6]:
# --- Main Execution Block ---

if __name__ == '__main__':
    # Configuration
    class CFG:
        img_size = 32
        patch_size = 4
        in_chans = 3
        num_classes = 10
        embed_dim = 512
        depth = 6
        num_heads = 8
        mlp_ratio = 2.0 # Reduced mlp_ratio for smaller model
        qkv_bias = True
        drop_rate = 0.1
        attn_drop_rate = 0.1
        drop_path_rate = 0.1

        batch_size = 256
        epochs = 100
        lr = 1e-3
        weight_decay = 0.05

    # Setup
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

    # Data
    trainloader, testloader = get_loaders(CFG.batch_size)

    # Model
    model = VisionTransformer(
        img_size=CFG.img_size,
        patch_size=CFG.patch_size,
        in_chans=CFG.in_chans,
        num_classes=CFG.num_classes,
        embed_dim=CFG.embed_dim,
        depth=CFG.depth,
        num_heads=CFG.num_heads,
        mlp_ratio=CFG.mlp_ratio,
        qkv_bias=CFG.qkv_bias,
        drop_rate=CFG.drop_rate,
        attn_drop_rate=CFG.attn_drop_rate,
        drop_path_rate=CFG.drop_path_rate
    ).to(device)

    print(f"Model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters.")

    # Loss, Optimizer, Scheduler
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    # Cosine Annealing scheduler, starts after the warm-up period
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.epochs - 5)

    # Training Loop
    best_acc = 0.0
    for epoch in range(CFG.epochs):
        train_one_epoch(model, trainloader, criterion, optimizer, scheduler, device, epoch)
        current_acc = evaluate(model, testloader, criterion, device)
        if current_acc > best_acc:
            best_acc = current_acc
            print(f"🎉 New best accuracy: {best_acc:.2f}%")
            # You can save your best model here if you want
            # torch.save(model.state_dict(), 'best_vit_cifar10.pth')

    print(f"\n--- Training Finished ---")
    print(f"🏆 Best Test Accuracy: {best_acc:.2f}%")

Using device: cuda


100%|██████████| 170M/170M [00:03<00:00, 43.1MB/s]


Model has 12,681,738 trainable parameters.


Epoch 1: 100%|██████████| 196/196 [01:31<00:00,  2.15it/s, loss=2.0955, acc=21.95%, lr=0.00020]


Test Results: Accuracy: 28.39% | Avg Loss: 1.9285
🎉 New best accuracy: 28.39%


Epoch 2: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.8314, acc=32.83%, lr=0.00040]


Test Results: Accuracy: 44.29% | Avg Loss: 1.5437
🎉 New best accuracy: 44.29%


Epoch 3: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.6920, acc=38.29%, lr=0.00060]


Test Results: Accuracy: 47.06% | Avg Loss: 1.4749
🎉 New best accuracy: 47.06%


Epoch 4: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.6624, acc=39.62%, lr=0.00080]


Test Results: Accuracy: 46.35% | Avg Loss: 1.4542


Epoch 5: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.6891, acc=38.98%, lr=0.00100]


Test Results: Accuracy: 43.66% | Avg Loss: 1.5545


Epoch 6: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.7093, acc=38.09%, lr=0.00100]


Test Results: Accuracy: 45.24% | Avg Loss: 1.5089


Epoch 7: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.7324, acc=36.79%, lr=0.00100]


Test Results: Accuracy: 43.91% | Avg Loss: 1.5356


Epoch 8: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.7231, acc=37.22%, lr=0.00100]


Test Results: Accuracy: 46.29% | Avg Loss: 1.4953


Epoch 9: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.6953, acc=38.45%, lr=0.00100]


Test Results: Accuracy: 45.53% | Avg Loss: 1.4876


Epoch 10: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.6880, acc=38.80%, lr=0.00100]


Test Results: Accuracy: 48.63% | Avg Loss: 1.4220
🎉 New best accuracy: 48.63%


Epoch 11: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.6580, acc=39.78%, lr=0.00099]


Test Results: Accuracy: 48.19% | Avg Loss: 1.4249


Epoch 12: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.6264, acc=41.24%, lr=0.00099]


Test Results: Accuracy: 51.86% | Avg Loss: 1.3321
🎉 New best accuracy: 51.86%


Epoch 13: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.5881, acc=42.18%, lr=0.00099]


Test Results: Accuracy: 53.80% | Avg Loss: 1.2568
🎉 New best accuracy: 53.80%


Epoch 14: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.5553, acc=43.82%, lr=0.00098]


Test Results: Accuracy: 53.52% | Avg Loss: 1.2781


Epoch 15: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.5246, acc=44.63%, lr=0.00098]


Test Results: Accuracy: 54.47% | Avg Loss: 1.2414
🎉 New best accuracy: 54.47%


Epoch 16: 100%|██████████| 196/196 [01:35<00:00,  2.06it/s, loss=1.5093, acc=45.15%, lr=0.00097]


Test Results: Accuracy: 54.87% | Avg Loss: 1.2563
🎉 New best accuracy: 54.87%


Epoch 17: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.4705, acc=46.58%, lr=0.00097]


Test Results: Accuracy: 56.05% | Avg Loss: 1.1956
🎉 New best accuracy: 56.05%


Epoch 18: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.4540, acc=47.60%, lr=0.00096]


Test Results: Accuracy: 58.44% | Avg Loss: 1.1436
🎉 New best accuracy: 58.44%


Epoch 19: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.4312, acc=48.34%, lr=0.00095]


Test Results: Accuracy: 58.78% | Avg Loss: 1.1459
🎉 New best accuracy: 58.78%


Epoch 20: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.4122, acc=49.16%, lr=0.00095]


Test Results: Accuracy: 58.76% | Avg Loss: 1.1316


Epoch 21: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.3906, acc=49.92%, lr=0.00094]


Test Results: Accuracy: 59.26% | Avg Loss: 1.1063
🎉 New best accuracy: 59.26%


Epoch 22: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.3688, acc=50.80%, lr=0.00093]


Test Results: Accuracy: 60.15% | Avg Loss: 1.0783
🎉 New best accuracy: 60.15%


Epoch 23: 100%|██████████| 196/196 [01:35<00:00,  2.06it/s, loss=1.3554, acc=51.31%, lr=0.00092]


Test Results: Accuracy: 61.48% | Avg Loss: 1.0729
🎉 New best accuracy: 61.48%


Epoch 24: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.3222, acc=52.67%, lr=0.00091]


Test Results: Accuracy: 62.19% | Avg Loss: 1.0290
🎉 New best accuracy: 62.19%


Epoch 25: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.3080, acc=52.88%, lr=0.00090]


Test Results: Accuracy: 62.69% | Avg Loss: 1.0149
🎉 New best accuracy: 62.69%


Epoch 26: 100%|██████████| 196/196 [01:35<00:00,  2.06it/s, loss=1.2880, acc=53.67%, lr=0.00089]


Test Results: Accuracy: 63.91% | Avg Loss: 0.9908
🎉 New best accuracy: 63.91%


Epoch 27: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.2694, acc=54.28%, lr=0.00088]


Test Results: Accuracy: 63.49% | Avg Loss: 1.0211


Epoch 28: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.2451, acc=55.58%, lr=0.00087]


Test Results: Accuracy: 63.73% | Avg Loss: 0.9953


Epoch 29: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.2263, acc=56.08%, lr=0.00086]


Test Results: Accuracy: 65.16% | Avg Loss: 0.9533
🎉 New best accuracy: 65.16%


Epoch 30: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.2092, acc=56.48%, lr=0.00085]


Test Results: Accuracy: 66.12% | Avg Loss: 0.9281
🎉 New best accuracy: 66.12%


Epoch 31: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.1882, acc=57.49%, lr=0.00084]


Test Results: Accuracy: 66.68% | Avg Loss: 0.9100
🎉 New best accuracy: 66.68%


Epoch 32: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.1575, acc=58.68%, lr=0.00083]


Test Results: Accuracy: 68.91% | Avg Loss: 0.8614
🎉 New best accuracy: 68.91%


Epoch 33: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.1435, acc=59.11%, lr=0.00081]


Test Results: Accuracy: 70.32% | Avg Loss: 0.8385
🎉 New best accuracy: 70.32%


Epoch 34: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.1256, acc=59.56%, lr=0.00080]


Test Results: Accuracy: 69.96% | Avg Loss: 0.8363


Epoch 35: 100%|██████████| 196/196 [01:35<00:00,  2.06it/s, loss=1.1158, acc=60.52%, lr=0.00079]


Test Results: Accuracy: 71.95% | Avg Loss: 0.7819
🎉 New best accuracy: 71.95%


Epoch 36: 100%|██████████| 196/196 [01:35<00:00,  2.06it/s, loss=1.0913, acc=61.18%, lr=0.00077]


Test Results: Accuracy: 71.35% | Avg Loss: 0.7918


Epoch 37: 100%|██████████| 196/196 [01:35<00:00,  2.06it/s, loss=1.0684, acc=61.95%, lr=0.00076]


Test Results: Accuracy: 71.28% | Avg Loss: 0.8201


Epoch 38: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.0536, acc=62.46%, lr=0.00075]


Test Results: Accuracy: 72.25% | Avg Loss: 0.7736
🎉 New best accuracy: 72.25%


Epoch 39: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.0410, acc=63.07%, lr=0.00073]


Test Results: Accuracy: 72.77% | Avg Loss: 0.7735
🎉 New best accuracy: 72.77%


Epoch 40: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.0243, acc=63.65%, lr=0.00072]


Test Results: Accuracy: 73.67% | Avg Loss: 0.7334
🎉 New best accuracy: 73.67%


Epoch 41: 100%|██████████| 196/196 [01:35<00:00,  2.06it/s, loss=0.9988, acc=64.80%, lr=0.00070]


Test Results: Accuracy: 74.48% | Avg Loss: 0.7115
🎉 New best accuracy: 74.48%


Epoch 42: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=1.0012, acc=64.63%, lr=0.00069]


Test Results: Accuracy: 73.66% | Avg Loss: 0.7351


Epoch 43: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.9776, acc=65.36%, lr=0.00067]


Test Results: Accuracy: 73.97% | Avg Loss: 0.7351


Epoch 44: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.9639, acc=65.67%, lr=0.00065]


Test Results: Accuracy: 75.20% | Avg Loss: 0.6964
🎉 New best accuracy: 75.20%


Epoch 45: 100%|██████████| 196/196 [01:35<00:00,  2.06it/s, loss=0.9497, acc=66.19%, lr=0.00064]


Test Results: Accuracy: 75.87% | Avg Loss: 0.6719
🎉 New best accuracy: 75.87%


Epoch 46: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.9320, acc=67.16%, lr=0.00062]


Test Results: Accuracy: 75.48% | Avg Loss: 0.6832


Epoch 47: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.9165, acc=67.51%, lr=0.00061]


Test Results: Accuracy: 76.85% | Avg Loss: 0.6430
🎉 New best accuracy: 76.85%


Epoch 48: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.9087, acc=67.66%, lr=0.00059]


Test Results: Accuracy: 77.96% | Avg Loss: 0.6348
🎉 New best accuracy: 77.96%


Epoch 49: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.8929, acc=68.34%, lr=0.00057]


Test Results: Accuracy: 77.01% | Avg Loss: 0.6432


Epoch 50: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.8897, acc=68.36%, lr=0.00056]


Test Results: Accuracy: 77.17% | Avg Loss: 0.6437


Epoch 51: 100%|██████████| 196/196 [01:35<00:00,  2.06it/s, loss=0.8698, acc=69.13%, lr=0.00054]


Test Results: Accuracy: 78.58% | Avg Loss: 0.6140
🎉 New best accuracy: 78.58%


Epoch 52: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.8596, acc=69.49%, lr=0.00052]


Test Results: Accuracy: 78.86% | Avg Loss: 0.6079
🎉 New best accuracy: 78.86%


Epoch 53: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.8455, acc=70.01%, lr=0.00051]


Test Results: Accuracy: 78.12% | Avg Loss: 0.6040


Epoch 54: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.8323, acc=70.70%, lr=0.00049]


Test Results: Accuracy: 79.36% | Avg Loss: 0.5817
🎉 New best accuracy: 79.36%


Epoch 55: 100%|██████████| 196/196 [01:36<00:00,  2.04it/s, loss=0.8165, acc=70.80%, lr=0.00048]


Test Results: Accuracy: 78.73% | Avg Loss: 0.6027


Epoch 56: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.8133, acc=71.22%, lr=0.00046]


Test Results: Accuracy: 79.35% | Avg Loss: 0.5870


Epoch 57: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.8041, acc=71.47%, lr=0.00044]


Test Results: Accuracy: 78.52% | Avg Loss: 0.6092


Epoch 58: 100%|██████████| 196/196 [01:36<00:00,  2.04it/s, loss=0.7833, acc=72.17%, lr=0.00043]


Test Results: Accuracy: 80.52% | Avg Loss: 0.5537
🎉 New best accuracy: 80.52%


Epoch 59: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.7810, acc=72.56%, lr=0.00041]


Test Results: Accuracy: 80.10% | Avg Loss: 0.5633


Epoch 60: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.7740, acc=72.50%, lr=0.00039]


Test Results: Accuracy: 80.38% | Avg Loss: 0.5497


Epoch 61: 100%|██████████| 196/196 [01:35<00:00,  2.06it/s, loss=0.7609, acc=72.93%, lr=0.00038]


Test Results: Accuracy: 81.04% | Avg Loss: 0.5363
🎉 New best accuracy: 81.04%


Epoch 62: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.7511, acc=73.24%, lr=0.00036]


Test Results: Accuracy: 81.04% | Avg Loss: 0.5354


Epoch 63: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.7344, acc=73.89%, lr=0.00035]


Test Results: Accuracy: 80.15% | Avg Loss: 0.5593


Epoch 64: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.7289, acc=74.19%, lr=0.00033]


Test Results: Accuracy: 81.27% | Avg Loss: 0.5323
🎉 New best accuracy: 81.27%


Epoch 65: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.7206, acc=74.45%, lr=0.00031]


Test Results: Accuracy: 81.81% | Avg Loss: 0.5206
🎉 New best accuracy: 81.81%


Epoch 66: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.7115, acc=75.00%, lr=0.00030]


Test Results: Accuracy: 81.85% | Avg Loss: 0.5192
🎉 New best accuracy: 81.85%


Epoch 67: 100%|██████████| 196/196 [01:35<00:00,  2.06it/s, loss=0.6993, acc=75.01%, lr=0.00028]


Test Results: Accuracy: 82.69% | Avg Loss: 0.5070
🎉 New best accuracy: 82.69%


Epoch 68: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.6880, acc=75.64%, lr=0.00027]


Test Results: Accuracy: 82.18% | Avg Loss: 0.5062


Epoch 69: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.6803, acc=75.73%, lr=0.00025]


Test Results: Accuracy: 82.46% | Avg Loss: 0.5062


Epoch 70: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.6691, acc=76.24%, lr=0.00024]


Test Results: Accuracy: 82.35% | Avg Loss: 0.5069


Epoch 71: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.6600, acc=76.53%, lr=0.00023]


Test Results: Accuracy: 82.81% | Avg Loss: 0.4915
🎉 New best accuracy: 82.81%


Epoch 72: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.6593, acc=76.74%, lr=0.00021]


Test Results: Accuracy: 83.00% | Avg Loss: 0.4903
🎉 New best accuracy: 83.00%


Epoch 73: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.6481, acc=76.98%, lr=0.00020]


Test Results: Accuracy: 83.13% | Avg Loss: 0.4895
🎉 New best accuracy: 83.13%


Epoch 74: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.6404, acc=77.38%, lr=0.00019]


Test Results: Accuracy: 82.57% | Avg Loss: 0.4990


Epoch 75: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.6349, acc=77.49%, lr=0.00017]


Test Results: Accuracy: 83.42% | Avg Loss: 0.4787
🎉 New best accuracy: 83.42%


Epoch 76: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.6238, acc=78.03%, lr=0.00016]


Test Results: Accuracy: 83.49% | Avg Loss: 0.4748
🎉 New best accuracy: 83.49%


Epoch 77: 100%|██████████| 196/196 [01:35<00:00,  2.06it/s, loss=0.6134, acc=78.22%, lr=0.00015]


Test Results: Accuracy: 83.50% | Avg Loss: 0.4742
🎉 New best accuracy: 83.50%


Epoch 78: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.6068, acc=78.37%, lr=0.00014]


Test Results: Accuracy: 83.61% | Avg Loss: 0.4801
🎉 New best accuracy: 83.61%


Epoch 79: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.5996, acc=78.75%, lr=0.00013]


Test Results: Accuracy: 83.86% | Avg Loss: 0.4621
🎉 New best accuracy: 83.86%


Epoch 80: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.5942, acc=78.88%, lr=0.00012]


Test Results: Accuracy: 83.79% | Avg Loss: 0.4681


Epoch 81: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.5869, acc=79.17%, lr=0.00011]


Test Results: Accuracy: 84.37% | Avg Loss: 0.4522
🎉 New best accuracy: 84.37%


Epoch 82: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.5826, acc=79.35%, lr=0.00010]


Test Results: Accuracy: 84.21% | Avg Loss: 0.4548


Epoch 83: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.5811, acc=79.48%, lr=0.00009]


Test Results: Accuracy: 84.37% | Avg Loss: 0.4544


Epoch 84: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.5758, acc=79.59%, lr=0.00008]


Test Results: Accuracy: 84.37% | Avg Loss: 0.4490


Epoch 85: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.5694, acc=79.70%, lr=0.00007]


Test Results: Accuracy: 84.59% | Avg Loss: 0.4509
🎉 New best accuracy: 84.59%


Epoch 86: 100%|██████████| 196/196 [01:36<00:00,  2.04it/s, loss=0.5671, acc=80.02%, lr=0.00006]


Test Results: Accuracy: 84.68% | Avg Loss: 0.4547
🎉 New best accuracy: 84.68%


Epoch 87: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.5622, acc=80.15%, lr=0.00005]


Test Results: Accuracy: 84.51% | Avg Loss: 0.4509


Epoch 88: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.5562, acc=80.17%, lr=0.00005]


Test Results: Accuracy: 84.68% | Avg Loss: 0.4484


Epoch 89: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.5508, acc=80.40%, lr=0.00004]


Test Results: Accuracy: 84.71% | Avg Loss: 0.4470
🎉 New best accuracy: 84.71%


Epoch 90: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.5477, acc=80.53%, lr=0.00003]


Test Results: Accuracy: 84.66% | Avg Loss: 0.4503


Epoch 91: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.5462, acc=80.84%, lr=0.00003]


Test Results: Accuracy: 84.90% | Avg Loss: 0.4453
🎉 New best accuracy: 84.90%


Epoch 92: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.5455, acc=80.74%, lr=0.00002]


Test Results: Accuracy: 84.74% | Avg Loss: 0.4436


Epoch 93: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.5406, acc=80.79%, lr=0.00002]


Test Results: Accuracy: 84.83% | Avg Loss: 0.4461


Epoch 94: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.5422, acc=80.80%, lr=0.00001]


Test Results: Accuracy: 84.87% | Avg Loss: 0.4449


Epoch 95: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.5369, acc=81.01%, lr=0.00001]


Test Results: Accuracy: 84.87% | Avg Loss: 0.4469


Epoch 96: 100%|██████████| 196/196 [01:35<00:00,  2.06it/s, loss=0.5437, acc=80.70%, lr=0.00001]


Test Results: Accuracy: 84.95% | Avg Loss: 0.4441
🎉 New best accuracy: 84.95%


Epoch 97: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.5319, acc=81.09%, lr=0.00000]


Test Results: Accuracy: 84.87% | Avg Loss: 0.4437


Epoch 98: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.5351, acc=81.00%, lr=0.00000]


Test Results: Accuracy: 84.95% | Avg Loss: 0.4420


Epoch 99: 100%|██████████| 196/196 [01:35<00:00,  2.05it/s, loss=0.5350, acc=81.05%, lr=0.00000]


Test Results: Accuracy: 84.88% | Avg Loss: 0.4429


Epoch 100: 100%|██████████| 196/196 [01:35<00:00,  2.06it/s, loss=0.5354, acc=80.98%, lr=0.00000]


Test Results: Accuracy: 84.90% | Avg Loss: 0.4429

--- Training Finished ---
🏆 Best Test Accuracy: 84.95%
