In [1]:
import yaml
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm
import timm


In [2]:
with open("../config.yaml", "r") as f:
    cfg = yaml.safe_load(f)

In [3]:
device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)

torch.manual_seed(42)
print(f"Using device: {device}")

Using device: mps


In [4]:
%run data_ingestion.ipynb

train_loader, test_loader = get_cifar10_dataloaders(cfg)



In [5]:
model = timm.create_model(
    "vit_tiny_patch16_224",
    pretrained=True,
    num_classes=10
)

model.to(device)


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)


In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=cfg.get("training", {}).get("epochs", 10)
)

In [None]:
def train_one_epoch(model, loader):
    model.train()
    running_loss = 0.0

    for x, y in tqdm(loader):
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(loader)


In [8]:
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            preds = out.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    return correct / total


In [9]:
epochs = 2
best_acc = 0.0

for epoch in range(epochs):
    train_loss = train_one_epoch(model, train_loader)
    val_acc = evaluate(model, test_loader)

    scheduler.step()

    print(f"epoch {epoch+1} | loss {train_loss:.4f} | val acc {val_acc:.4f}")

    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_vit.pt")


100%|██████████| 782/782 [11:15<00:00,  1.16it/s]


epoch 1 | loss 0.2407 | val acc 0.9507


100%|██████████| 782/782 [10:47<00:00,  1.21it/s]


epoch 2 | loss 0.0876 | val acc 0.9548


In [10]:
print("best validation accuracy:", best_acc)


best validation accuracy: 0.9548
