# This notebook fine-tunes the replicated model up to 95% accuracy

In [124]:
import torch
import torch.nn as nn
from torchvision.models.vision_transformer import ViT_B_16_Weights
from torchvision.models.vision_transformer import vit_b_16
import torch.backends.mps
import math

import data
import eval
import vit
import train
import utils

In [125]:
import importlib


def reload():
    importlib.reload(data)
    importlib.reload(eval)
    importlib.reload(vit)
    importlib.reload(train)
    importlib.reload(utils)


reload()

In [None]:
import utils

device = utils.get_device()

# ------------------Model architecture hyperparameters--------------------
L = 12
D = 768
HEADS = 12
PATCH = 16
IMAGE_W = 224
assert IMAGE_W % PATCH == 0, "Image size must be divisible by the patch size"
N = int((IMAGE_W / PATCH) ** 2)
assert D % HEADS == 0, "The latent vector size D must be divisible by the number of heads"
# To keep num of params constant we set DH = D/HEADS
DH = int(D / HEADS)
DMSA = HEADS * DH * 3
DMLP = 3072
NORM_EPS = 1e-6

# ---------------------Fine-tuning hyperparameters------------------------
LR = 0.003
MOMENTUM = 0.9
STEPS = 10000
BATCH = 32
DROPOUT = 0.0

# ----------------------------Other consts--------------------------------
SEED = 100
EVALS_PER_EPOCH = 1
IMAGENET_CLASSES_N = 1000

In [None]:
device = utils.get_device()

In [None]:
print("Loading CIFAR-10 dataset")
train_ds, train_dl = data.load_CIFAR(
    train=True, batch_size=BATCH, transforms=ViT_B_16_Weights.DEFAULT.transforms(), seed=SEED
)
test_ds, test_dl = data.load_CIFAR(
    train=False, batch_size=BATCH, transforms=ViT_B_16_Weights.DEFAULT.transforms(), seed=SEED
)

In [129]:
epochs = math.ceil(STEPS * BATCH / len(train_ds))
epochs

7

In [123]:
classes = test_ds.classes

In [None]:
utils.show(test_ds, 1)

# Going wild with fine-tuning

In [None]:
print("Creating the reference model")
weights = ViT_B_16_Weights.DEFAULT
torch.manual_seed(SEED)
ref_model = vit_b_16(weights=weights, dropout=DROPOUT).to(device)

In [None]:
print("Freezing the model and swapping it's classification head")
# freeze the ref model
for p in ref_model.parameters():
    p.requires_grad = False

# swap the classification layer
torch.manual_seed(SEED)
lin = nn.Linear(in_features=D, out_features=len(classes))
lin.weight = nn.Parameter(torch.zeros(len(classes), D))
lin.bias = nn.Parameter(torch.zeros(len(classes)))
ref_model.heads = nn.Sequential(lin).to(device)
# summary(ref_model, depth=4, input_size=(1, 3, IMAGE_W, IMAGE_W),col_names=["kernel_size", "input_size", "output_size", "num_params","trainable"], row_settings=["var_names"],)

In [106]:
optim = torch.optim.SGD(ref_model.parameters(), lr=LR, momentum=MOMENTUM)

In [None]:
print("Fine-tuning the reference model")
torch.manual_seed(SEED)
train_metrics, test_metrics = train.train(
    ref_model,
    epochs,
    train_dl,
    test_dl,
    device,
    len(classes),
    optim,
    nn.CrossEntropyLoss(),
    EVALS_PER_EPOCH,
    checkpoints=epochs,
)

In [None]:
eval.plot_metrics(train_metrics, test_metrics)

In [None]:
eval.eval_show(ref_model, test_ds, n=16, page=0)

In [None]:
accuracy = test_metrics[eval.Metrics.ACCURACY.value][-1].item()
utils.save_model(ref_model, accuracy, "FT_CIFAR")