In [None]:
# For this notebook to run with updated APIs, we need torch 1.12+ and torchvision 0.13+
try:
    import torch
    import torchvision
    assert int(torch.__version__.split(".")[1]) >= 12, "torch version should be 1.12+"
    assert int(torchvision.__version__.split(".")[1]) >= 13, "torchvision version should be 0.13+"
    print(f"torch version: {torch.__version__}")
    print(f"torchvision version: {torchvision.__version__}")
except:
    print(f"[INFO] torch/torchvision versions not as required, installing nightly versions.")
    !pip3 install -U --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu113
    import torch
    import torchvision
    print(f"torch version: {torch.__version__}")
    print(f"torchvision version: {torchvision.__version__}")

# Continue with regular imports
import os
import sys
import shutil
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import transforms

# Try to get torchinfo, install it if it doesn't work
try:
    from torchinfo import summary
except:
    print("[INFO] Couldn't find torchinfo... installing it.")
    !pip install -q torchinfo
    from torchinfo import summary

Force_Redownload = 1
try:
    if Force_Redownload: raise Exception("Force Redownload")
    from CSE203B_Final_Project import engine
    from CSE203B_Final_Project.helper_functions import set_seeds, plot_loss_curves
    from CSE203B_Final_Project.data_prep import prepare_loaders
    from CSE203B_Final_Project.models import SViT, prepare_model
except Exception as E:
    print("[INFO] Downloading Code from GitHub.")
    !rm -rf CSE203B_Final_Project
    !git clone https://github.com/ArmanOmmid/CSE203B_Final_Project
    !mv CSE203B_Final_Project .
finally:
    from CSE203B_Final_Project import engine
    from CSE203B_Final_Project.helper_functions import set_seeds, plot_loss_curves
    from CSE203B_Final_Project.data_prep import prepare_loaders
    from CSE203B_Final_Project.models import SViT, prepare_model

device = "cuda" if torch.cuda.is_available() else "cpu"
device
data_folder = 'datasets'

In [None]:
# Create image size (from Table 3 in the ViT paper) ViT Paper used 224 ; # Create transform pipeline manually
IMG_SIZE = 224 
example_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: (x.repeat(3, 1, 1) if x.size(0)==1 else x)), # Turns to RGB
])           
data_folder = 'datasets'

BATCH_SIZE = 3
NUM_WORKERS = 2
total_images = 25000
test_proportion = 0.2
dataset_name = "Caltech256"

train_dataloader, test_dataloader, class_names = prepare_loaders(data_folder, dataset_name, total_images, test_proportion, example_transforms, BATCH_SIZE, NUM_WORKERS)
print(len(train_dataloader.dataset.indices), len(test_dataloader.dataset.indices), len(class_names))

In [None]:
# Plot image with matplotlib
image_batch, label_batch = next(iter(train_dataloader)) # Get a batch of images
image, label = image_batch[0], label_batch[0] # Get a single image from the batch
image.shape, label # View the batch shapes
plt.imshow(image.permute(1, 2, 0)) # rearrange image dimensions to suit matplotlib [color_channels, height, width] -> [height, width, color_channels]
plt.title(class_names[label])
plt.axis(False)

In [None]:
# Pretrained ViT (Standard)

vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT # "DEFAULT" means best available
pretrained_vit = torchvision.models.vit_b_16(weights=vit_weights)

IMG_SIZE = 224 
vit_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: (x.repeat(3, 1, 1) if x.size(0)==1 else x)), # Turns to RGB
])

BATCH_SIZE = 32
NUM_WORKERS = 2
total_images = 25000
test_proportion = 0.1
dataset_name = "Caltech256"

train_dataloader, test_dataloader, class_names = prepare_loaders(data_folder, dataset_name, total_images, test_proportion, vit_transforms, BATCH_SIZE, NUM_WORKERS)
print(len(train_dataloader.dataset.indices), len(test_dataloader.dataset.indices), len(class_names))

pretrained_vit, model_summary = prepare_model(pretrained_vit, len(class_names))
# print(model_summary)

In [None]:
# Train a pretrained ViT feature extractor
optimizer = torch.optim.Adam(params=pretrained_vit.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()
set_seeds()

epochs = 20
pretrained_vit_results = engine.train(model=pretrained_vit, train_dataloader=train_dataloader, test_dataloader=test_dataloader,
                                      optimizer=optimizer, loss_fn=loss_fn, epochs=epochs, device=device)

print("")
plot_loss_curves(pretrained_vit_results)

In [None]:
svit_backbone_weights = torchvision.models.ViT_B_16_Weights.DEFAULT # "DEFAULT" means best available

IMG_SIZE = 224 
vit_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: (x.repeat(3, 1, 1) if x.size(0)==1 else x)), # Turns to RGB
])

BATCH_SIZE = 32
NUM_WORKERS = 2
total_images = 25000
test_proportion = 0.1
dataset_name = "Caltech256"

train_dataloader, test_dataloader, class_names = prepare_loaders(data_folder, dataset_name, total_images, test_proportion, vit_transforms, BATCH_SIZE, NUM_WORKERS)
print(len(train_dataloader.dataset.indices), len(test_dataloader.dataset.indices), len(class_names))

# Pretrained SVM-ViT (Standard)
svit_backbone = torchvision.models.vit_b_16(weights=svit_backbone_weights)
svit_backbone, model_summary = prepare_model(svit_backbone)

svit = SViT(svit_backbone)

In [None]:
fit_results = svit.fit(train_dataloader)
score_results = svit.score(test_dataloader)
score_results

In [None]:
# Pretrained ViT (SWAG)

vit_weights_swag = torchvision.models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1 # get SWAG weights
pretrained_vit_swag = torchvision.models.vit_b_16(weights=vit_weights_swag)

swag_transforms = transforms.Compose([
    transforms.Lambda(lambda x: x.convert('RGB')), # Turns to RGB
    vit_weights_swag.transforms()
])

BATCH_SIZE = 32
total_images = 25000
test_proportion = 0.125
dataset_name = "Caltech256"

train_dataloader, test_dataloader, class_names = prepare_loaders(data_folder, dataset_name, total_images, test_proportion, swag_transforms, BATCH_SIZE, NUM_WORKERS)
print(len(train_dataloader.dataset.indices), len(test_dataloader.dataset.indices), len(class_names))

pretrained_vit_swag, model_summary = prepare_model(pretrained_vit_swag, len(class_names))
# print(model_summary)

In [None]:
# Train a pretrained ViT feature extractor with SWAG weights
optimizer = torch.optim.Adam(params=pretrained_vit_swag.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss() 
set_seeds()

epochs = 20
pretrained_vit_swag_results = engine.train(model=pretrained_vit_swag, train_dataloader=train_dataloader, test_dataloader=test_dataloader,
                                      optimizer=optimizer, loss_fn=loss_fn, epochs=epochs, device=device)

print("")
plot_loss_curves(pretrained_vit_swag_results)

In [None]:
svit_swag_backbone_weights = torchvision.models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1 # get SWAG weights

swag_transforms = transforms.Compose([
    transforms.Lambda(lambda x: x.convert('RGB')), # Turns to RGB
    svit_swag_backbone_weights.transforms()
])

BATCH_SIZE = 32
NUM_WORKERS = 2
total_images = 25000
test_proportion = 0.125
dataset_name = "Caltech256"

train_dataloader, test_dataloader, class_names = prepare_loaders(data_folder, dataset_name, total_images, test_proportion, swag_transforms, BATCH_SIZE, NUM_WORKERS)
print(len(train_dataloader.dataset.indices), len(test_dataloader.dataset.indices), len(class_names))

# Pretrained SVM-ViT (SWAG)
svit_swag_backbone = torchvision.models.vit_b_16(weights=svit_swag_backbone_weights)
svit_swag_backbone, model_summary = prepare_model(svit_swag_backbone)

svit_swag = SViT(svit_swag_backbone)

In [None]:
fit_results = svit_swag.fit(train_dataloader)
score_results = svit_swag.score(test_dataloader)
score_results