# This notebook fine-tunes the replicated ViT model on CIFAR-10 dataset

In [None]:
import torch
import torch.nn as nn
from torchvision.models.vision_transformer import ViT_B_16_Weights
import torch.backends.mps

from consts import *
import data
import eval
import vit
import train
import utils

In [None]:
device = utils.get_device()

In [None]:
print("Loading CIFAR-10 dataset")
train_ds, train_dl = data.load_CIFAR(
    train=True,
    batch_size=BATCH,
    transforms=ViT_B_16_Weights.DEFAULT.transforms(),
    seed=SEED,
)
test_ds, test_dl = data.load_CIFAR(
    train=False,
    batch_size=BATCH,
    transforms=ViT_B_16_Weights.DEFAULT.transforms(),
    seed=SEED,
)

In [None]:
classes = test_ds.classes

In [None]:
utils.show(test_ds, 100)

# Fine-tuning the replicated model

In [None]:
print("Creating the replicated model")
torch.manual_seed(SEED)
m = vit.ViT(D, IMAGE_W, PATCH, HEADS, DMLP, L, IMAGENET_CLASSES_N, DROPOUT, NORM_EPS)

In [None]:
print("Loading last saved ViT model")
m = utils.load_last_model(m)

In [None]:
print("Freezing the model and swapping it's classification head")
# freeze the model
for p in m.parameters():
    p.requires_grad = False

# swap the classification layer
torch.manual_seed(SEED)
m.head = nn.Linear(in_features=D, out_features=len(classes))
m.to(device)

In [None]:
optim = torch.optim.SGD(m.parameters(), lr=LR, momentum=MOMENTUM)

In [None]:
print("Fine-tuning the replicated model")
torch.manual_seed(SEED)
train_metrics, test_metrics = train.train(
    m,
    EPOCHS,
    train_dl,
    test_dl,
    device,
    len(classes),
    optim,
    nn.CrossEntropyLoss(),
)

In [None]:
eval.plot_metrics(train_metrics, test_metrics)

In [None]:
accuracy = test_metrics[eval.Metrics.ACCURACY.value][-1].item()
utils.save_model(m, accuracy, "FT_CIFAR")