In [1]:
import torch
import torch.nn as nn
from torchinfo import summary
import numpy as np
from torchvision.models.vision_transformer import ViT_B_16_Weights
from torchvision.models.vision_transformer import vit_b_16
from pathlib import Path
import torchvision.transforms
import math
import torch.backends.mps
import plotly.graph_objects as go 
from plotly.subplots import make_subplots

import data
import eval
import vit
import train
import utils

In [2]:

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: mps


In [3]:
SEED = 100
BATCH = 32

Prepare data

In [None]:
imagenette2_path = Path("imagenette2")
imagenet_classes_path = Path('imagenet_class_index.json')
data_path, classes = data.prepare_imagenette(imagenette2_path, imagenet_classes_path)

Load data

In [None]:
train_data_path= data_path/'train'
test_data_path= data_path/'test'
train_ds, test_ds, train_dl, test_dl = data.load_imagenette(train_data_path, test_data_path, classes, BATCH, ViT_B_16_Weights.DEFAULT.transforms())

In [None]:
i = 31
torch.random.manual_seed(SEED)
mean= np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
imgs, ls = next(iter(test_dl))
print(test_ds.classes[ls[i]])
torchvision.transforms.ToPILImage()(imgs[i])
img_orig = (imgs[i] * std.reshape(3,1,1))+mean.reshape(3,1,1)
torchvision.transforms.ToPILImage()(img_orig)

Eval the ref model

In [None]:
weights = ViT_B_16_Weights.DEFAULT
ref_model = vit_b_16(weights=weights)

In [None]:
eval.eval(ref_model, test_dl, len(test_ds.classes), device)

In [None]:
# eval.eval_show(ref_model, test_ds, n=8, page=100)

Replication

In [4]:
L = 12
D = 768
HEADS = 12

PATCH = 16
IMAGE_W = 224
assert IMAGE_W % PATCH == 0, "Image size must be divisible by the patch size"
N = int((IMAGE_W/PATCH)**2)
assert D % HEADS == 0, "The latent vector size D must be divisible by the number of heads"
DH = int(D/HEADS) # To keep num of params constant we set DH = D/HEADS
DMSA = HEADS*DH*3
DMLP = 3072 # 4 times the D

NORM_EPS = 1e-6
DROPOUT = 0.1

In [None]:
m = vit.ViT(D, IMAGE_W, PATCH, HEADS, DMLP, L, len(classes), DROPOUT, NORM_EPS)
summary(m, depth=4, input_size=(1, 3, IMAGE_W, IMAGE_W),col_names=["kernel_size", "input_size", "output_size", "num_params"], row_settings=["var_names"],)

In [None]:
state = m.state_dict()
keys = np.array(list(state.keys()))

max_len = -1
max_i = -1
for i, k in enumerate(keys):
    if(len(k) >= max_len):
        max_len = len(k)
        max_i = i

rows = 10
pad_end = math.ceil(len(keys)/rows)*rows - len(keys)
keys = np.pad(keys, (0,pad_end), constant_values='')

margin  = 4
lines = np.stack(np.array_split(keys, int(len(keys)/rows))).T
for l in lines:
    print((' ' * margin).join([str(key).ljust(max_len) for key in l]))

In [None]:
ref_state = ViT_B_16_Weights.DEFAULT.get_state_dict(progress=True)

state["class_token"] = ref_state["class_token"].squeeze()
state["conv_proj.weight"] = ref_state["conv_proj.weight"]
state["conv_proj.bias"] = ref_state["conv_proj.bias"]
state["encoder.pos_embeddings"] = ref_state["encoder.pos_embedding"]
for l in range(L):
    state[f"encoder.layers.{l}.ln_1.weight"] = ref_state[f"encoder.layers.encoder_layer_{l}.ln_1.weight"]
    state[f"encoder.layers.{l}.ln_1.bias"] = ref_state[f"encoder.layers.encoder_layer_{l}.ln_1.bias"]
    state[f"encoder.layers.{l}.msa.qkv"] = ref_state[f"encoder.layers.encoder_layer_{l}.self_attention.in_proj_weight"]
    state[f"encoder.layers.{l}.msa.qkv_bias"] = ref_state[f"encoder.layers.encoder_layer_{l}.self_attention.in_proj_bias"]
    state[f"encoder.layers.{l}.msa.w0"] = ref_state[f"encoder.layers.encoder_layer_{l}.self_attention.out_proj.weight"]
    state[f"encoder.layers.{l}.msa.w0_bias"] = ref_state[f"encoder.layers.encoder_layer_{l}.self_attention.out_proj.bias"]
    state[f"encoder.layers.{l}.ln_2.weight"] = ref_state[f"encoder.layers.encoder_layer_{l}.ln_2.weight"]
    state[f"encoder.layers.{l}.ln_2.bias"] = ref_state[f"encoder.layers.encoder_layer_{l}.ln_2.bias"]
    state[f"encoder.layers.{l}.mlp.lin1.weight"] = ref_state[f"encoder.layers.encoder_layer_{l}.mlp.linear_1.weight"]
    state[f"encoder.layers.{l}.mlp.lin1.bias"] = ref_state[f"encoder.layers.encoder_layer_{l}.mlp.linear_1.bias"]
    state[f"encoder.layers.{l}.mlp.lin2.weight"] = ref_state[f"encoder.layers.encoder_layer_{l}.mlp.linear_2.weight"]
    state[f"encoder.layers.{l}.mlp.lin2.bias"] = ref_state[f"encoder.layers.encoder_layer_{l}.mlp.linear_2.bias"]
state["encoder.ln.weight"] = ref_state["encoder.ln.weight"]
state["encoder.ln.bias"] = ref_state["encoder.ln.bias"]
state["head.weight"] = ref_state["heads.head.weight"]
state["head.bias"] = ref_state["heads.head.bias"]

m.load_state_dict(state)

In [None]:
eval.eval(m, test_dl, len(test_ds.classes), device)

In [None]:
# eval.eval_show(m, test_ds, 8, 100)

<h1>Fine-tuning</h1>

In [29]:
import importlib
importlib.reload(utils)
importlib.reload(data)
importlib.reload(train)
importlib.reload(eval)

<module 'eval' from '/Users/yehormanevych/Projects/ViT/eval.py'>

In [21]:
batch_size=32
train_batch_n = 300
test_batch_n = math.floor(train_batch_n*0.2/0.8)
test_batch_n

75

In [22]:
cifar_train_ds, cifar_train_dl = data.load_cifar("cifar/train", train=True, batch_size=batch_size, batch_n=train_batch_n, transforms=ViT_B_16_Weights.DEFAULT.transforms())
cifar_test_ds, cifar_test_dl = data.load_cifar("cifar/test", train=False, batch_size=batch_size, batch_n=test_batch_n, transforms=ViT_B_16_Weights.DEFAULT.transforms())

Files already downloaded and verified
Created 261 train batches of size 32
Files already downloaded and verified
Created 63 test batches of size 32


In [23]:
cifar_classes = cifar_test_ds.classes

In [None]:
# i = 0
# img, l = cifar_train_ds[i]
# print(cifar_classes[l])
# utils.whitened_to_PIL(img)

In [25]:
weights = ViT_B_16_Weights.DEFAULT
ref_model = vit_b_16(weights=weights).to(device)

In [26]:
#freeze the ref model
for p in ref_model.parameters():
    p.requires_grad = False

#swap the classification layer
ref_model.heads = nn.Sequential(nn.Linear(in_features=D, out_features=len(cifar_classes))).to(device)
# summary(ref_model, depth=4, input_size=(1, 3, IMAGE_W, IMAGE_W),col_names=["kernel_size", "input_size", "output_size", "num_params","trainable"], row_settings=["var_names"],)


In [27]:
optim = torch.optim.SGD(ref_model.parameters(), lr=0.003, momentum=0.9)

In [30]:
train_metrics, test_metrics = train.train(ref_model, 1, cifar_train_dl, cifar_test_dl, device, len(cifar_classes), optim, nn.CrossEntropyLoss())

Epoch 0:
	Batch 0

	Train: MulticlassAccuracy = 0.062, MulticlassPrecision = 0.062, MulticlassRecall = 0.062, CrossEntropyLoss = 0.073

	Test: MulticlassAccuracy = 0.122, MulticlassPrecision = 0.119, MulticlassRecall = 0.122, CrossEntropyLoss = 0.073
	Batch 1
	Batch 2
	Batch 3
	Batch 4
	Batch 5
	Batch 6
	Batch 7
	Batch 8
	Batch 9
	Batch 10
	Batch 11
	Batch 12
	Batch 13
	Batch 14
	Batch 15
	Batch 16
	Batch 17
	Batch 18
	Batch 19
	Batch 20
	Batch 21
	Batch 22
	Batch 23
	Batch 24
	Batch 25
	Batch 26
	Batch 27
	Batch 28
	Batch 29
	Batch 30
	Batch 31
	Batch 32
	Batch 33
	Batch 34
	Batch 35
	Batch 36
	Batch 37
	Batch 38
	Batch 39
	Batch 40
	Batch 41
	Batch 42
	Batch 43
	Batch 44
	Batch 45
	Batch 46
	Batch 47
	Batch 48
	Batch 49
	Batch 50
	Batch 51
	Batch 52
	Batch 53

	Train: MulticlassAccuracy = 0.782, MulticlassPrecision = 0.782, MulticlassRecall = 0.782, CrossEntropyLoss = 0.029

	Test: MulticlassAccuracy = 0.907, MulticlassPrecision = 0.907, MulticlassRecall = 0.907, CrossEntropyLoss = 0

In [None]:
# train_acc = [m["MulticlassAccuracy"].cpu() for m in train_metrics]
# train_loss = [m["CrossEntropyLoss"].cpu() for m in train_metrics]
# test_acc = [m["MulticlassAccuracy"].cpu() for m in test_metrics]
# test_loss = [m["CrossEntropyLoss"].cpu() for m in test_metrics]

# fig = make_subplots(cols=2, rows=1, subplot_titles=["Loss", "Accuracy"])
# fig.add_scatter(x = np.arange(len(train_loss)), y=train_loss, col=1, row=1, name="Train loss")
# fig.add_scatter(x = np.arange(len(train_acc)), y=train_acc, col=2, row=1, name="Train acc")
# fig.add_scatter(x = np.arange(len(test_loss)), y=test_loss, col=1, row=1, name="Test loss")
# fig.add_scatter(x = np.arange(len(test_acc)), y=test_acc, col=2, row=1, name="Test acc")

In [37]:
models_path = Path("models")
models_path.mkdir(exist_ok=True)
model_path = models_path / "ref_model_b32_93acc.pth"
torch.save(ref_model.state_dict(), model_path)

In [42]:
# eval.eval_show(ref_model.cpu(), cifar_test_ds, n=8, page=2)