In [1]:
!git clone 'https://github.com/aakashvardhan/Transformers-Collab.git'

Cloning into 'Transformers-Collab'...
remote: Enumerating objects: 388, done.[K
remote: Counting objects: 100% (388/388), done.[K
remote: Compressing objects: 100% (333/333), done.[K
remote: Total 388 (delta 93), reused 342 (delta 47), pack-reused 0[K
Receiving objects: 100% (388/388), 13.88 MiB | 15.99 MiB/s, done.
Resolving deltas: 100% (93/93), done.


In [2]:
%cd Transformers-Collab

/content/Transformers-Collab


In [3]:
from utils import (create_dataloader,
                   get_img_batch,
                   show_img,
                   patchify_img,
                   show_conv2d_feature_maps)
from config import VITConfig
import torchvision
import torch
import torch.nn as nn

ModuleNotFoundError: ignored

In [None]:
config = VITConfig()

train_dataloader, test_dataloader, class_names = create_dataloader(config)

In [None]:
image, label = get_img_batch(train_dataloader)

In [None]:
show_img(image, label, class_names)

In [None]:
patchify_img(image, label, class_names, config)

In [None]:
show_conv2d_feature_maps(image, config)

In [None]:
show_flattened_feature_map(image, config)

In [None]:
from models.transformer import VIT

In [None]:
vit = VIT(num_classes=len(class_names))

In [None]:
get_vit_model_summary(vit)

In [None]:
optimizer = torch.optim.Adam(params=vit.parameters(),
                            lr=config.lr, # Base LR from Table 3 for ViT-* ImageNet-1k
                            betas=config.betas, # default values but also mentioned in ViT paper section 4.1 (Training & Fine-tuning)
                            weight_decay=config.weight_decay) # from the ViT paper section 4.1 (Training & Fine-tuning) and Table 3 for ViT-* ImageNet-1k

# Setup the loss function for multi-class classification
loss_fn = torch.nn.CrossEntropyLoss()

# Train the model and save the training results to a dictionary
results = train(model=vit,
                    train_dataloader=train_dataloader,
                    test_dataloader=test_dataloader,
                    optimizer=optimizer,
                    loss_fn=loss_fn,
                    epochs=config.epoch,
                    device=config.device)

In [None]:
from evaluation_utils import plot_loss_curves

# Plot our ViT model's loss curves
plot_loss_curves(results)

In [None]:
# 1. Get pretrained weights for ViT-Base
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT # requires torchvision >= 0.13, "DEFAULT" means best available

# 2. Setup a ViT model instance with pretrained weights
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device)

# 3. Freeze the base parameters
for parameter in pretrained_vit.parameters():
    parameter.requires_grad = False


pretrained_vit.heads = nn.Linear(in_features=768, out_features=len(class_names)).to(device)
# pretrained_vit # uncomment for model output

In [None]:
get_vit_model_summary(pretrained_vit)

In [None]:
# Get automatic transforms from pretrained ViT weights
pretrained_vit_transforms = pretrained_vit_weights.transforms()
print(pretrained_vit_transforms)

In [None]:
# Setup dataloaders
train_dataloader_pretrained, test_dataloader_pretrained, class_names = data_setup.create_dataloader(train_dir=train_dir,
                                                                                                     test_dir=test_dir,
                                                                                                     transform=pretrained_vit_transforms,
                                                                                                     batch_size=32)

In [None]:
# Create optimizer and loss function
optimizer = torch.optim.Adam(params=pretrained_vit.parameters(),
                             lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()

# Train the classifier head of the pretrained ViT feature extractor model
pretrained_vit_results = train(model=pretrained_vit,
                                      train_dataloader=train_dataloader_pretrained,
                                      test_dataloader=test_dataloader_pretrained,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      epochs=config.n_epochs,
                                      device=config.device)

In [None]:
plot_loss_curves(pretrained_vit_results)

In [None]:
# Save the model
utils.save_model(model=pretrained_vit,
                 target_dir="models",
                 model_name="08_pretrained_vit_feature_extractor_pizza_steak_sushi.pth")

In [None]:
!python main.py 'vit'