# KD only for ViT Tiny Student and ViT Small Teacher

In [None]:
import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torchvision.transforms import AutoAugment, AutoAugmentPolicy, RandomErasing
from transformers import (
    DeiTConfig,
    DeiTForImageClassification,
    ViTConfig,
    ViTPreTrainedModel,
    ViTModel
)
from tqdm import tqdm

# ---------------------------------------------
# Setup
# ---------------------------------------------
use_dp = True  # Use DataParallel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
np.random.seed(42)

mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)

alpha, temp = 0.5, 4.0
resolution = 32
epochs = 60
batch_size = 128

# ---------------------------------------------
# ViT with Distillation
# ---------------------------------------------
class ViTWithDistillation(ViTPreTrainedModel):
    def __init__(self, config: ViTConfig):
        super().__init__(config)
        self.vit = ViTModel(config)
        self.distill_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.distiller = nn.Linear(config.hidden_size, config.num_labels)
        self.init_weights()

    def forward(self, pixel_values, labels=None, teacher_logits=None,
                alpha=0.5, temperature=1.0):
        B = pixel_values.size(0)
        embeds = self.vit.embeddings(pixel_values, interpolate_pos_encoding=True)
        cls_emb, patch_emb = embeds[:, :1, :], embeds[:, 1:, :]
        dist_tok = self.distill_token.expand(B, -1, -1)
        x = torch.cat([cls_emb, dist_tok, patch_emb], dim=1)
        x = self.vit.encoder(x)[0]
        cls_out, dist_out = x[:, 0], x[:, 1]
        logits = self.classifier(cls_out)
        dist_logits = self.distiller(dist_out)
        output = {"logits": logits, "distill_logits": dist_logits}

        if labels is not None and teacher_logits is not None:
            loss_ce = F.cross_entropy(logits, labels)
            kd = F.kl_div(
                F.log_softmax(dist_logits / temperature, dim=1),
                F.softmax(teacher_logits / temperature, dim=1),
                reduction='batchmean'
            ) * (temperature ** 2)
            output["loss"] = (1 - alpha) * loss_ce + alpha * kd

        return output

# ---------------------------------------------
# Prepare Teacher
# ---------------------------------------------
teacher_config = DeiTConfig(
    image_size=resolution,
    patch_size=2,
    num_labels=10,
    hidden_size=384,
    num_hidden_layers=12,
    num_attention_heads=6,
    intermediate_size=1536,
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    stochastic_depth_prob=0.1
)

teacher = DeiTForImageClassification(teacher_config).to(device)
if use_dp:
    teacher = nn.DataParallel(teacher)

ckpt = torch.load("/kaggle/input/best-teacher/pytorch/default/1/best_teacher.pth", map_location=device)
# Remove "module." if present in keys
state = {k.replace("module.", ""): v for k, v in ckpt.items()}
teacher.module.load_state_dict(state, strict=True) if use_dp else teacher.load_state_dict(state, strict=True)
teacher.eval()
for p in teacher.parameters():
    p.requires_grad_(False)

# ---------------------------------------------
# Data Preparation
# ---------------------------------------------
train_tf = transforms.Compose([
    transforms.RandomCrop(resolution, padding=4),
    transforms.RandomHorizontalFlip(),
    AutoAugment(policy=AutoAugmentPolicy.CIFAR10),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
    RandomErasing(p=0.2, scale=(0.02, 0.2), ratio=(0.3, 3.3))
])

val_tf = transforms.Compose([
    transforms.CenterCrop(resolution),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

dataset = datasets.CIFAR10('./data', train=True, download=True, transform=train_tf)
num_val = int(0.1 * len(dataset))
train_idx, val_idx = torch.utils.data.random_split(
    list(range(len(dataset))), [len(dataset) - num_val, num_val],
    generator=torch.Generator().manual_seed(42)
)
train_ds = Subset(dataset, train_idx)
val_ds = Subset(datasets.CIFAR10('./data', train=True, download=False, transform=val_tf), val_idx)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4)

# ---------------------------------------------
# Initialize Student
# ---------------------------------------------
stu_cfg = ViTConfig(
    image_size=resolution,
    patch_size=2,
    num_labels=10,
    hidden_size=192,
    num_hidden_layers=12,
    num_attention_heads=3,
    intermediate_size=768,
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    stochastic_depth_prob=0.1
)

student = ViTWithDistillation(config=stu_cfg).to(device)
if use_dp:
    student = nn.DataParallel(student)

opt = optim.AdamW(student.parameters(), lr=3e-4, weight_decay=1e-4)

# ---------------------------------------------
# Training Loop
# ---------------------------------------------
for ep in range(1, epochs + 1):
    student.train()
    loop = tqdm(train_loader, desc=f"Train Epoch {ep}/{epochs}")
    for x, y in loop:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            tlog = teacher(pixel_values=x, interpolate_pos_encoding=True).logits
        out = student(pixel_values=x, labels=y, teacher_logits=tlog,
                      alpha=alpha, temperature=temp)
        loss = out['loss'].mean()
        opt.zero_grad()
        loss.backward()
        opt.step()
        loop.set_postfix(loss=loss.item())

    # Validation
    student.eval()
    correct, total = 0, 0
    val_loop = tqdm(val_loader, desc=f"Val Epoch {ep}/{epochs}")
    for xb, yb in val_loop:
        xb, yb = xb.to(device), yb.to(device)
        with torch.no_grad():
            preds = student(pixel_values=xb)['logits'].argmax(1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)
    val_acc = 100 * correct / total
    print(f"Epoch {ep} Validation Accuracy: {val_acc:.2f}%")

# ---------------------------------------------
# Final Test
# ---------------------------------------------
student.eval()
test_ds = datasets.CIFAR10('./data', train=False, transform=val_tf)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=4)
correct = 0
for xb, yb in tqdm(test_loader, desc="Final Test"):
    xb, yb = xb.to(device), yb.to(device)
    with torch.no_grad():
        preds = student(pixel_values=xb)['logits'].argmax(1)
    correct += (preds == yb).sum().item()

student_acc = 100 * correct / len(test_ds)
print(f"Final Test Accuracy: {student_acc:.2f}%")

2025-05-06 19:02:56.190837: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746558176.393674      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746558176.455182      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
  ckpt = torch.load("/kaggle/input/best-teacher/pytorch/default/1/best_teacher.pth", map_location=device)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:02<00:00, 77.2MB/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data


Train Epoch 1/60: 100%|██████████| 352/352 [03:52<00:00,  1.52it/s, loss=1.33]
Val Epoch 1/60: 100%|██████████| 40/40 [00:05<00:00,  7.44it/s]


Epoch 1 Validation Accuracy: 32.60%


Train Epoch 2/60: 100%|██████████| 352/352 [03:54<00:00,  1.50it/s, loss=1.32]
Val Epoch 2/60: 100%|██████████| 40/40 [00:04<00:00,  8.28it/s]


Epoch 2 Validation Accuracy: 45.48%


Train Epoch 3/60: 100%|██████████| 352/352 [03:52<00:00,  1.51it/s, loss=1.34]
Val Epoch 3/60: 100%|██████████| 40/40 [00:04<00:00,  8.42it/s]


Epoch 3 Validation Accuracy: 49.84%


Train Epoch 4/60: 100%|██████████| 352/352 [03:52<00:00,  1.51it/s, loss=1]    
Val Epoch 4/60: 100%|██████████| 40/40 [00:04<00:00,  8.32it/s]


Epoch 4 Validation Accuracy: 52.86%


Train Epoch 5/60: 100%|██████████| 352/352 [03:55<00:00,  1.50it/s, loss=0.944]
Val Epoch 5/60: 100%|██████████| 40/40 [00:04<00:00,  8.32it/s]


Epoch 5 Validation Accuracy: 55.44%


Train Epoch 6/60: 100%|██████████| 352/352 [03:53<00:00,  1.51it/s, loss=1.01] 
Val Epoch 6/60: 100%|██████████| 40/40 [00:04<00:00,  8.31it/s]


Epoch 6 Validation Accuracy: 57.60%


Train Epoch 7/60: 100%|██████████| 352/352 [03:55<00:00,  1.50it/s, loss=0.809]
Val Epoch 7/60: 100%|██████████| 40/40 [00:04<00:00,  8.25it/s]


Epoch 7 Validation Accuracy: 60.92%


Train Epoch 8/60: 100%|██████████| 352/352 [03:55<00:00,  1.49it/s, loss=1.06] 
Val Epoch 8/60: 100%|██████████| 40/40 [00:04<00:00,  8.29it/s]


Epoch 8 Validation Accuracy: 61.40%


Train Epoch 9/60: 100%|██████████| 352/352 [03:55<00:00,  1.49it/s, loss=0.839]
Val Epoch 9/60: 100%|██████████| 40/40 [00:04<00:00,  8.26it/s]


Epoch 9 Validation Accuracy: 61.68%


Train Epoch 10/60: 100%|██████████| 352/352 [03:55<00:00,  1.49it/s, loss=0.85] 
Val Epoch 10/60: 100%|██████████| 40/40 [00:04<00:00,  8.25it/s]


Epoch 10 Validation Accuracy: 62.74%


Train Epoch 11/60: 100%|██████████| 352/352 [03:55<00:00,  1.49it/s, loss=0.932]
Val Epoch 11/60: 100%|██████████| 40/40 [00:04<00:00,  8.25it/s]


Epoch 11 Validation Accuracy: 65.02%


Train Epoch 12/60: 100%|██████████| 352/352 [03:55<00:00,  1.49it/s, loss=0.916]
Val Epoch 12/60: 100%|██████████| 40/40 [00:04<00:00,  8.22it/s]


Epoch 12 Validation Accuracy: 66.54%


Train Epoch 13/60: 100%|██████████| 352/352 [03:55<00:00,  1.49it/s, loss=0.737]
Val Epoch 13/60: 100%|██████████| 40/40 [00:04<00:00,  8.29it/s]


Epoch 13 Validation Accuracy: 64.04%


Train Epoch 14/60: 100%|██████████| 352/352 [03:54<00:00,  1.50it/s, loss=0.734]
Val Epoch 14/60: 100%|██████████| 40/40 [00:04<00:00,  8.38it/s]


Epoch 14 Validation Accuracy: 65.70%


Train Epoch 15/60: 100%|██████████| 352/352 [03:50<00:00,  1.53it/s, loss=0.829]
Val Epoch 15/60: 100%|██████████| 40/40 [00:04<00:00,  8.38it/s]


Epoch 15 Validation Accuracy: 69.10%


Train Epoch 16/60: 100%|██████████| 352/352 [03:50<00:00,  1.52it/s, loss=0.71] 
Val Epoch 16/60: 100%|██████████| 40/40 [00:04<00:00,  8.45it/s]


Epoch 16 Validation Accuracy: 68.68%


Train Epoch 17/60: 100%|██████████| 352/352 [03:51<00:00,  1.52it/s, loss=0.744]
Val Epoch 17/60: 100%|██████████| 40/40 [00:04<00:00,  8.38it/s]


Epoch 17 Validation Accuracy: 70.58%


Train Epoch 18/60: 100%|██████████| 352/352 [03:51<00:00,  1.52it/s, loss=0.737]
Val Epoch 18/60: 100%|██████████| 40/40 [00:04<00:00,  8.18it/s]


Epoch 18 Validation Accuracy: 69.96%


Train Epoch 19/60: 100%|██████████| 352/352 [03:51<00:00,  1.52it/s, loss=0.671]
Val Epoch 19/60: 100%|██████████| 40/40 [00:04<00:00,  8.42it/s]


Epoch 19 Validation Accuracy: 71.28%


Train Epoch 20/60: 100%|██████████| 352/352 [03:51<00:00,  1.52it/s, loss=0.658]
Val Epoch 20/60: 100%|██████████| 40/40 [00:04<00:00,  8.27it/s]


Epoch 20 Validation Accuracy: 71.36%


Train Epoch 21/60: 100%|██████████| 352/352 [03:50<00:00,  1.52it/s, loss=0.774]
Val Epoch 21/60: 100%|██████████| 40/40 [00:04<00:00,  8.45it/s]


Epoch 21 Validation Accuracy: 73.20%


Train Epoch 22/60: 100%|██████████| 352/352 [03:51<00:00,  1.52it/s, loss=0.701]
Val Epoch 22/60: 100%|██████████| 40/40 [00:04<00:00,  8.34it/s]


Epoch 22 Validation Accuracy: 73.26%


Train Epoch 23/60: 100%|██████████| 352/352 [03:51<00:00,  1.52it/s, loss=0.63] 
Val Epoch 23/60: 100%|██████████| 40/40 [00:04<00:00,  8.42it/s]


Epoch 23 Validation Accuracy: 74.30%


Train Epoch 24/60: 100%|██████████| 352/352 [03:51<00:00,  1.52it/s, loss=0.764]
Val Epoch 24/60: 100%|██████████| 40/40 [00:04<00:00,  8.28it/s]


Epoch 24 Validation Accuracy: 74.02%


Train Epoch 25/60: 100%|██████████| 352/352 [03:51<00:00,  1.52it/s, loss=0.653]
Val Epoch 25/60: 100%|██████████| 40/40 [00:04<00:00,  8.39it/s]


Epoch 25 Validation Accuracy: 75.24%


Train Epoch 26/60: 100%|██████████| 352/352 [03:51<00:00,  1.52it/s, loss=0.757]
Val Epoch 26/60: 100%|██████████| 40/40 [00:04<00:00,  8.34it/s]


Epoch 26 Validation Accuracy: 74.30%


Train Epoch 27/60:  38%|███▊      | 133/352 [01:28<02:25,  1.51it/s, loss=0.669]

Training for the remaining epochs, Kaggle crashed.

This runs for the remaining 33 epochs

In [2]:
# ---------------------------------------------
# Training Loop
# ---------------------------------------------
for ep in range(1, 33 + 1):
    student.train()
    loop = tqdm(train_loader, desc=f"Train Epoch {ep}/{epochs}")
    for x, y in loop:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            tlog = teacher(pixel_values=x, interpolate_pos_encoding=True).logits
        out = student(pixel_values=x, labels=y, teacher_logits=tlog,
                      alpha=alpha, temperature=temp)
        loss = out['loss'].mean()
        opt.zero_grad()
        loss.backward()
        opt.step()
        loop.set_postfix(loss=loss.item())

    # Validation
    student.eval()
    correct, total = 0, 0
    val_loop = tqdm(val_loader, desc=f"Val Epoch {ep}/{epochs}")
    for xb, yb in val_loop:
        xb, yb = xb.to(device), yb.to(device)
        with torch.no_grad():
            preds = student(pixel_values=xb)['logits'].argmax(1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)
    val_acc = 100 * correct / total
    print(f"Epoch {ep} Validation Accuracy: {val_acc:.2f}%")

# ---------------------------------------------
# Final Test
# ---------------------------------------------
student.eval()
test_ds = datasets.CIFAR10('./data', train=False, transform=val_tf)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=4)
correct = 0
for xb, yb in tqdm(test_loader, desc="Final Test"):
    xb, yb = xb.to(device), yb.to(device)
    with torch.no_grad():
        preds = student(pixel_values=xb)['logits'].argmax(1)
    correct += (preds == yb).sum().item()

student_acc = 100 * correct / len(test_ds)
print(f"Final Test Accuracy: {student_acc:.2f}%")

Train Epoch 1/60: 100%|██████████| 352/352 [03:53<00:00,  1.51it/s, loss=0.695]
Val Epoch 1/60: 100%|██████████| 40/40 [00:04<00:00,  8.40it/s]


Epoch 1 Validation Accuracy: 76.66%


Train Epoch 2/60: 100%|██████████| 352/352 [03:52<00:00,  1.52it/s, loss=0.651]
Val Epoch 2/60: 100%|██████████| 40/40 [00:04<00:00,  8.41it/s]


Epoch 2 Validation Accuracy: 74.78%


Train Epoch 3/60: 100%|██████████| 352/352 [03:54<00:00,  1.50it/s, loss=0.723]
Val Epoch 3/60: 100%|██████████| 40/40 [00:04<00:00,  8.31it/s]


Epoch 3 Validation Accuracy: 76.42%


Train Epoch 4/60: 100%|██████████| 352/352 [03:55<00:00,  1.50it/s, loss=0.609]
Val Epoch 4/60: 100%|██████████| 40/40 [00:04<00:00,  8.28it/s]


Epoch 4 Validation Accuracy: 76.00%


Train Epoch 5/60: 100%|██████████| 352/352 [03:55<00:00,  1.50it/s, loss=0.535]
Val Epoch 5/60: 100%|██████████| 40/40 [00:04<00:00,  8.32it/s]


Epoch 5 Validation Accuracy: 76.70%


Train Epoch 6/60: 100%|██████████| 352/352 [03:55<00:00,  1.49it/s, loss=0.497]
Val Epoch 6/60: 100%|██████████| 40/40 [00:04<00:00,  8.25it/s]


Epoch 6 Validation Accuracy: 75.52%


Train Epoch 7/60: 100%|██████████| 352/352 [03:55<00:00,  1.50it/s, loss=0.626]
Val Epoch 7/60: 100%|██████████| 40/40 [00:04<00:00,  8.40it/s]


Epoch 7 Validation Accuracy: 77.76%


Train Epoch 8/60: 100%|██████████| 352/352 [03:54<00:00,  1.50it/s, loss=0.491]
Val Epoch 8/60: 100%|██████████| 40/40 [00:05<00:00,  7.48it/s]


Epoch 8 Validation Accuracy: 77.86%


Train Epoch 9/60: 100%|██████████| 352/352 [03:54<00:00,  1.50it/s, loss=0.533]
Val Epoch 9/60: 100%|██████████| 40/40 [00:04<00:00,  8.37it/s]


Epoch 9 Validation Accuracy: 77.10%


Train Epoch 10/60: 100%|██████████| 352/352 [03:55<00:00,  1.50it/s, loss=0.632]
Val Epoch 10/60: 100%|██████████| 40/40 [00:04<00:00,  8.19it/s]


Epoch 10 Validation Accuracy: 78.34%


Train Epoch 11/60: 100%|██████████| 352/352 [03:54<00:00,  1.50it/s, loss=0.607]
Val Epoch 11/60: 100%|██████████| 40/40 [00:04<00:00,  8.30it/s]


Epoch 11 Validation Accuracy: 77.96%


Train Epoch 12/60: 100%|██████████| 352/352 [03:55<00:00,  1.50it/s, loss=0.552]
Val Epoch 12/60: 100%|██████████| 40/40 [00:04<00:00,  8.30it/s]


Epoch 12 Validation Accuracy: 79.14%


Train Epoch 13/60: 100%|██████████| 352/352 [03:55<00:00,  1.50it/s, loss=0.563]
Val Epoch 13/60: 100%|██████████| 40/40 [00:04<00:00,  8.31it/s]


Epoch 13 Validation Accuracy: 78.02%


Train Epoch 14/60: 100%|██████████| 352/352 [03:55<00:00,  1.50it/s, loss=0.545]
Val Epoch 14/60: 100%|██████████| 40/40 [00:04<00:00,  8.32it/s]


Epoch 14 Validation Accuracy: 77.90%


Train Epoch 15/60: 100%|██████████| 352/352 [03:55<00:00,  1.50it/s, loss=0.639]
Val Epoch 15/60: 100%|██████████| 40/40 [00:04<00:00,  8.28it/s]


Epoch 15 Validation Accuracy: 78.38%


Train Epoch 16/60: 100%|██████████| 352/352 [03:54<00:00,  1.50it/s, loss=0.541]
Val Epoch 16/60: 100%|██████████| 40/40 [00:04<00:00,  8.28it/s]


Epoch 16 Validation Accuracy: 79.70%


Train Epoch 17/60: 100%|██████████| 352/352 [03:55<00:00,  1.49it/s, loss=0.56] 
Val Epoch 17/60: 100%|██████████| 40/40 [00:04<00:00,  8.29it/s]


Epoch 17 Validation Accuracy: 79.94%


Train Epoch 18/60: 100%|██████████| 352/352 [03:55<00:00,  1.50it/s, loss=0.572]
Val Epoch 18/60: 100%|██████████| 40/40 [00:04<00:00,  8.27it/s]


Epoch 18 Validation Accuracy: 78.82%


Train Epoch 19/60: 100%|██████████| 352/352 [03:55<00:00,  1.50it/s, loss=0.564]
Val Epoch 19/60: 100%|██████████| 40/40 [00:04<00:00,  8.27it/s]


Epoch 19 Validation Accuracy: 80.18%


Train Epoch 20/60: 100%|██████████| 352/352 [03:55<00:00,  1.50it/s, loss=0.659]
Val Epoch 20/60: 100%|██████████| 40/40 [00:04<00:00,  8.23it/s]


Epoch 20 Validation Accuracy: 79.80%


Train Epoch 21/60: 100%|██████████| 352/352 [03:54<00:00,  1.50it/s, loss=0.483]
Val Epoch 21/60: 100%|██████████| 40/40 [00:04<00:00,  8.32it/s]


Epoch 21 Validation Accuracy: 80.24%


Train Epoch 22/60: 100%|██████████| 352/352 [03:54<00:00,  1.50it/s, loss=0.577]
Val Epoch 22/60: 100%|██████████| 40/40 [00:04<00:00,  8.36it/s]


Epoch 22 Validation Accuracy: 80.00%


Train Epoch 23/60: 100%|██████████| 352/352 [03:54<00:00,  1.50it/s, loss=0.494]
Val Epoch 23/60: 100%|██████████| 40/40 [00:04<00:00,  8.25it/s]


Epoch 23 Validation Accuracy: 81.28%


Train Epoch 24/60: 100%|██████████| 352/352 [03:55<00:00,  1.50it/s, loss=0.481]
Val Epoch 24/60: 100%|██████████| 40/40 [00:04<00:00,  8.26it/s]


Epoch 24 Validation Accuracy: 80.70%


Train Epoch 25/60: 100%|██████████| 352/352 [03:52<00:00,  1.51it/s, loss=0.432]
Val Epoch 25/60: 100%|██████████| 40/40 [00:04<00:00,  8.23it/s]


Epoch 25 Validation Accuracy: 80.58%


Train Epoch 26/60: 100%|██████████| 352/352 [03:51<00:00,  1.52it/s, loss=0.462]
Val Epoch 26/60: 100%|██████████| 40/40 [00:04<00:00,  8.34it/s]


Epoch 26 Validation Accuracy: 79.60%


Train Epoch 27/60: 100%|██████████| 352/352 [03:52<00:00,  1.52it/s, loss=0.392]
Val Epoch 27/60: 100%|██████████| 40/40 [00:04<00:00,  8.31it/s]


Epoch 27 Validation Accuracy: 81.14%


Train Epoch 28/60: 100%|██████████| 352/352 [03:52<00:00,  1.52it/s, loss=0.386]
Val Epoch 28/60: 100%|██████████| 40/40 [00:04<00:00,  8.32it/s]


Epoch 28 Validation Accuracy: 81.54%


Train Epoch 29/60: 100%|██████████| 352/352 [03:51<00:00,  1.52it/s, loss=0.341]
Val Epoch 29/60: 100%|██████████| 40/40 [00:04<00:00,  8.33it/s]


Epoch 29 Validation Accuracy: 81.18%


Train Epoch 30/60: 100%|██████████| 352/352 [03:51<00:00,  1.52it/s, loss=0.524]
Val Epoch 30/60: 100%|██████████| 40/40 [00:04<00:00,  8.34it/s]


Epoch 30 Validation Accuracy: 81.06%


Train Epoch 31/60: 100%|██████████| 352/352 [03:51<00:00,  1.52it/s, loss=0.389]
Val Epoch 31/60: 100%|██████████| 40/40 [00:04<00:00,  8.30it/s]


Epoch 31 Validation Accuracy: 81.38%


Train Epoch 32/60: 100%|██████████| 352/352 [03:51<00:00,  1.52it/s, loss=0.528]
Val Epoch 32/60: 100%|██████████| 40/40 [00:04<00:00,  8.32it/s]


Epoch 32 Validation Accuracy: 81.46%


Train Epoch 33/60: 100%|██████████| 352/352 [03:51<00:00,  1.52it/s, loss=0.481]
Val Epoch 33/60: 100%|██████████| 40/40 [00:04<00:00,  8.26it/s]


Epoch 33 Validation Accuracy: 82.12%


Final Test: 100%|██████████| 79/79 [00:09<00:00,  8.41it/s]

Final Test Accuracy: 82.28%



