In [None]:
# If you are running this within a python venv and have the error "Module not found: numpy", 
# ensure that your kernel is set to $VENV_NAME and run $python -m ipykernel install --user --name=$VENV_NAME 
import test_infra as infra
import os
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn

# Training Settings
epochs = 15 # Need to monitor validation loss during training to avoid overfitting https://datascience.stackexchange.com/questions/46523/is-a-large-number-of-epochs-good-or-bad-idea-in-cnn
batch_size = 32 # This should usually be kept to a size that is a power of two
lr = 3e-5 # Need to implement learning rate decay 
gamma = 0.7

#hyperparams = infra.hyperparams(epochs, batch_size, lr, gamma)

In [None]:
model_names = ["VanillaViT", "SimpleViT", "T2TViT", "CrossViT", "PiT", "LeViT", "CvT", "MobileViT", "SmallDataViT"]
vit_experiment = infra.training_statistics(model_names, epochs, batch_size, lr, gamma)

In [None]:
from torchvision import transforms
# Define the transformations that will be applied to the images during the loading process
transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor(),
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Mean and Std. Dev values used here are commonly used with the ImageNet dataset
])

In [None]:
# Define path to dataset
dataset_path = "%s/../data/ga_imgs/" % (os.getcwd())
vit_experiment.load_data(dataset_path, transform)

In [None]:
# VanillaViT
# Need to get rid of magic numers here
from vit_pytorch import ViT

VanillaViT = ViT(
    image_size = 128,
    patch_size = 8,
    num_classes = 23,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
).to(device)
vit_experiment.train_model(VanillaViT, "VanillaViT")

In [None]:
# SimpleViT 
from vit_pytorch import SimpleViT

SimpleViT = SimpleViT(
    image_size = 128,
    patch_size = 8,
    num_classes = 23,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048
).to(device)
train_model(SimpleViT, "SimpleViT")

In [None]:
# T2TViT
from vit_pytorch.t2t import T2TViT

T2TViT = T2TViT(
    dim = 512,
    image_size = 128,
    depth = 5,
    heads = 8,
    mlp_dim = 512,
    num_classes = 23,
    t2t_layers = ((7, 4), (3, 2), (3, 2)) # tuples of the kernel size and stride of  each consecutive layers of the initial token to token module
).to(device)
train_model(T2TViT, "T2TViT")

In [None]:
# CrossViT
from vit_pytorch.cross_vit import CrossViT

CrossViT = CrossViT(
    image_size = 128,
    num_classes = 23,
    depth = 4,               # number of multi-scanvidia-smi --gpu-reset -i $nle encoding blocks
    sm_dim = 192,            # high res dimension
    sm_patch_size = 16,      # high res patch size (should be smaller than lg_patch_size)
    sm_enc_depth = 2,        # high res depth
    sm_enc_heads = 8,        # high res heads
    sm_enc_mlp_dim = 2048,   # high res feedforward dimension
    lg_dim = 384,            # low res dimension
    lg_patch_size = 64,      # low res patch size
    lg_enc_depth = 3,        # low res depth
    lg_enc_heads = 8,        # low res heads
    lg_enc_mlp_dim = 2048,   # low res feedforward dimensions
    cross_attn_depth = 2,    # cross attention rounds
    cross_attn_heads = 8,    # cross attention heads
    dropout = 0.1,
    emb_dropout = 0.1
).to(device)
train_model(CrossViT, "CrossViT")

In [None]:
# PiT
from vit_pytorch.pit import PiT

PiT = PiT(
    image_size = 128,
    patch_size = 16,
    dim = 256,
    num_classes = 23,
    depth = (3, 3, 3),     # list of depths, indicating the number of rounds of each stage before a downsample
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
).to(device)
train_model(PiT, "PiT")

In [None]:
# LeViT
from vit_pytorch.levit import LeViT

LeViT = LeViT(
    image_size = 128,
    num_classes = 23,
    stages = 3,             # number of stages
    dim = (256, 384, 512),  # dimensions at each stage
    depth = 4,              # transformer of depth 4 at each stage
    heads = (4, 6, 8),      # heads at each stage
    mlp_mult = 2,
    dropout = 0.1
).to(device)
train_model(LeViT, "LeViT")

In [None]:
# CvT
from vit_pytorch.cvt import CvT

CvT = CvT(
    num_classes = 23,
    s1_emb_dim = 64,        # stage 1 - dimension
    s1_emb_kernel = 7,      # stage 1 - conv kernel
    s1_emb_stride = 4,      # stage 1 - conv stride
    s1_proj_kernel = 3,     # stage 1 - attention ds-conv kernel size
    s1_kv_proj_stride = 2,  # stage 1 - attention key / value projection stride
    s1_heads = 1,           # stage 1 - heads
    s1_depth = 1,           # stage 1 - depth
    s1_mlp_mult = 4,        # stage 1 - feedforward expansion factor
    s2_emb_dim = 192,       # stage 2 - (same as above)
    s2_emb_kernel = 3,
    s2_emb_stride = 2,
    s2_proj_kernel = 3,
    s2_kv_proj_stride = 2,
    s2_heads = 3,
    s2_depth = 2,
    s2_mlp_mult = 4,
    s3_emb_dim = 384,       # stage 3 - (same as above)
    s3_emb_kernel = 3,
    s3_emb_stride = 2,
    s3_proj_kernel = 3,
    s3_kv_proj_stride = 2,
    s3_heads = 4,
    s3_depth = 10,
    s3_mlp_mult = 4,
    dropout = 0,
).to(device)
train_model(CvT, "CvT")

In [None]:
"""
# ScalableViT
from vit_pytorch.scalable_vit import ScalableViT

ScalableViT = ScalableViT(
    num_classes = 23,
    dim = 64,                               # starting model dimension. at every stage, dimension is doubled
    heads = (2, 4, 8, 16),                  # number of attention heads at each stage
    depth = (2, 2, 20, 2),                  # number of transformer blocks at each stage
    ssa_dim_key = (40, 40, 40, 32),         # the dimension of the attention keys (and queries) for SSA. in the paper, they represented this as a scale factor on the base dimension per key (ssa_dim_key / dim_key)
    reduction_factor = (8, 4, 2, 1),        # downsampling of the key / values in SSA. in the paper, this was represented as (reduction_factor ** -2)
    window_size = (32, 16, None, None),     # window size of the IWSA at each stage. None means no windowing needed
    dropout = 0.1,                          # attention and feedforward dropout
).to(device)
train_model(ScalableViT, "ScalableViT")


# Test Run "runs out of memory" here for some reason, despite not actually running out of memory 
"""

In [None]:
# MobileViT
from vit_pytorch.mobile_vit import MobileViT

MobileViT = MobileViT(
    image_size = (128, 128),
    dims = [96, 120, 144],
    channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
    num_classes = 23,
).to(device)
train_model(MobileViT, "MobileViT")

In [None]:
# SmallDataViT
from vit_pytorch.vit_for_small_dataset import ViT as SmallDataViT

SmallDataViT = SmallDataViT(
    image_size = 128,
    patch_size = 16,
    num_classes = 23,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
).to(device)
train_model(SmallDataViT, "SmallDataViT")

In [None]:
# Accuracy Convergence
plt.figure(figsize=(20,10))
for model in models:
    acc_list = accuracy_dict[model]
    plt.plot(acc_list, label=model)

plt.xticks(range(0,epochs), labels=range(1,epochs+1))
plt.grid(True)
plt.xlabel("Epoch", fontsize=14)
plt.ylabel("Accuracy (Percentage Points)", fontsize=14)
plt.title("Training Accuracy Convergence")
plt.legend()

plt.show()

In [None]:
# Perform Inference using model
import torchvision.transforms.functional as TF

# Classes defined in alphabetical order
classes = [
    "Ammobatoidini",
    "Andrenini",
    "Anthidiini",
    "Anthophorini",
    "Apini",
    "Augochlorini",
    "Bombini",
    "Calliopsini",
    "Caupolicanini",
    "Ceratinini",
    "Emphorini",
    "Epeolini",
    "Eucerini",
    "Halictini",
    "Megachilini",
    "Melectini",
    "Nomadini",
    "Osmiini",
    "Panurgini",
    "Protandrenini",
    "Sphecodini",
    "Xylocopini",
]

# Load our demo images (taken from google images)
demo_images = {
    "augochlorini.jpg": "Augochlorini",
    "bombini.jpg": "Bombini",
    "halictini.jpg": "Halictini",
    "osmiini.jpg": "Osmiini",
    "xylocopa.jpg": "Xylocopini",
}

def demo(filename, ground_truth):

    demo_image_path = "{}/../data/demo_img/{}".format(os.getcwd(), filename)
    demo_img = Image.open(demo_image_path).resize((128,128), resample=0)

    # Display image
    display(demo_img)

    # Outputs a vector displaying the models predictions
    demo_img = TF.to_tensor(demo_img)
    demo_img.unsqueeze_(0) #need to provide the batch dimension at dim0
    demo_img = demo_img.to(device)
    predictions = VanillaViT(demo_img)

    # Get index of max value (most confident prediction)
    tribe = torch.argmax(predictions)

    print("The VanillaViT model believes this bee is a {}".format(classes[tribe]))
    print("Ground truth: This bee is a {}".format(ground_truth))

for key, value in demo_images.items(): 
    demo(key, value)