In [None]:
# !pip install datasets vit-pytorch torch Linformer torchvision matplotlib torch numpy

In [None]:
from vit_pytorch.efficient import ViT
import vit_pytorch
from datasets import load_dataset
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F 
import torch
import torch.nn as nn
import torch.optim as optim
from linformer import Linformer

In [None]:
print("Hello World")
device = "cuda" if torch.cuda.is_available() else 'cpu'
print(f"Using {device}")

## LOAD DATA

In [None]:
# Load, split data
food = load_dataset("food101", split="train[:1000]")
food = food.train_test_split(test_size=0.2)
# Map label ids with label names
labels = food["test"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label
print(f"Train: {len(food['train'])}")
print(f"Test: {len(food['test'])}")

## Image Augmentation

In [None]:
patch_size = 16
image_size = 256
half_size = 112

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((half_size, half_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
test_transform = transforms.Compose([
        transforms.Resize((half_size, half_size)),
        transforms.ToTensor(),
])

In [None]:
# Make my Special Input Matrix
def mk_special_matrix(tmp):
    ret = torch.cat([tmp[:,patch_size//2:,:], tmp[:,-patch_size//2:,:]], dim=1) # Cut the TOP half-patch-sized slice and cat to the BOTTOM
    res = torch.cat([ret[:,:,-patch_size//2:], ret[:,:,patch_size//2:]], dim=2) # Cut the LEFT half-patch-sized slice and cat to the RIGHT
    return res


def mk_big_matrix(tmp, is_special=False):
    """
        Make an (image_size x image_size) sized matrix that contains
        is_special == FALSE)):
            WHETHER contains a 
              2x2 Matrix of tmp /                 \ 
                                | TMP      SPECIAL| 
                                |                 |
                                | SPECIAL  TMP    |
                                \                 / 2x2
        is_special == TRUE)):
            OR      contains a
                                /                 \  
                                | TMP      SPECIAL| 
                                |                 |
                                | SPECIAL  TMP    |
                                \                 / 2x2                        
    """
    special = tmp
    if is_special:
        special = mk_special_matrix(tmp)
    up = torch.cat([tmp, special], dim=1)
    down = torch.cat([special, tmp], dim=1)
    big_matrix = torch.cat([up, down], dim=2)
    return big_matrix

In [None]:
# Normal Train and Test Data
train_data, test_data = dict(), dict()
train_data["img"] = [mk_big_matrix(train_transform(f)) for f in food["train"]["image"]]
train_data["label"] = food["train"]["label"]
test_data["img"] = [mk_big_matrix(test_transform(f)) for f in food["test"]["image"]]
test_data["label"] = food["test"]["label"]
print(f"INPUT SIZE: {train_data['img'][0].size()}")
print(f"Train: {len(train_data['label'])}")
print(f"Test: {len(test_data['label'])}")

# Special Train and Test Data
train_data_S, test_data_S = dict(), dict()
train_data_S["img"] = [mk_big_matrix(train_transform(f), True) for f in food["train"]["image"]]
train_data_S["label"] = food["train"]["label"]
test_data_S["img"] = [mk_big_matrix(test_transform(f), True) for f in food["test"]["image"]]
test_data_S["label"] = food["test"]["label"]
print(f"INPUT SIZE: {train_data_S['img'][0].size()}")
print(f"Train: {len(train_data_S['label'])}")
print(f"Test: {len(test_data_S['label'])}")

## Plots

In [None]:
fig, axes = plt.subplots(3,3,figsize=(16,12))
for idx, ax in enumerate(axes.ravel()):
    label = train_data["label"][idx]
    img = train_data["img"][idx]
    ax.set_title(id2label[str(label)])
    image = transforms.ToPILImage()(img).convert('RGB')
    ax.imshow(image)


## DataLoader Wrapping

In [None]:
class MyDataset(Dataset):
    def __init__(self, File, transform=None):
        self.File = File 

    # Get Current File Length
    def __len__(self):
        self.filelength=len(self.File['img'])
        return self.filelength
  
    def __getitem__(self, idx):
        label = self.File["label"][idx]
        img = self.File["img"][idx]
        return img, label

In [None]:
train_my_data = MyDataset(train_data)
test_my_data = MyDataset(test_data)
train_my_data_S = MyDataset(train_data_S)
test_my_data_S = MyDataset(test_data_S)

In [None]:
train_loader = DataLoader(dataset=train_my_data, batch_size = 15, shuffle=True)
test_loader = DataLoader(dataset=test_my_data, batch_size = 15, shuffle=True)
train_loader_S = DataLoader(dataset=train_my_data_S, batch_size = 15, shuffle=True)
test_loader_S = DataLoader(dataset=test_my_data_S, batch_size = 15, shuffle=True)

## Model Definition

In [None]:
ModelDict = dict()

In [None]:
# Normal ViT Model
efficient_transformer = Linformer(
    dim=128,
    seq_len=49+1,  # 7x7 patches + 1 cls-token
    depth=12,
    heads=8,
    k=64,
)
model_vit_efficient = ViT(
    dim=128,
    image_size=image_size,
    patch_size=patch_size,
    num_classes=len(labels),
    transformer=efficient_transformer,
    channels=3,
).to(device)
# ModelDict["efficientViT"] = model_vit_efficient

In [None]:
# Normal ViT
model_vit = vit_pytorch.ViT(
    dim=128,
    image_size=image_size,
    patch_size=patch_size,
    num_classes=len(labels),
    depth=12,
    heads=8,
    mlp_dim=128,
).to(device)
ModelDict["ViT"] = model_vit

In [None]:
# Deep Vit
from vit_pytorch.deepvit import DeepViT
model_deepvit = DeepViT(
    image_size = image_size,
    patch_size = patch_size,
    num_classes = len(labels),
    dim = 128,
    depth = 12,
    heads = 8,
    mlp_dim = 128,
)
# ModelDict["Deepvit"] = model_deepvit

In [None]:
# CaiT Model
from vit_pytorch.cait import CaiT
model_cait = CaiT(
    image_size = image_size,
    patch_size = patch_size,
    num_classes = len(labels),
    dim = 128,
    depth = 12,             # depth of transformer for patch to patch attention only
    cls_depth = 2,          # depth of cross attention of CLS tokens to patch
    heads = 8,
    mlp_dim = 128,
)
ModelDict["cait"] = model_cait

##  Training

In [None]:
print("Hello World")

In [None]:
def trainit(model, train_loader, test_loader, name="DefaultName",lr=3e-5,EPOCH=30):
    """
        Train Model on dataset
            INPUT: 
                - model:           Defined model
                - train_loader:    training data     ->  DataLoader(MyDataset, batchSize, Shutffled)
                - test_loader:     test data         ->  DataLoader(MyDataset, batchSize, Shutffled)
                - name:            Name of the model ->  (default "DefaultName")
                - lr:              Learning Reate    ->  (default 1e-5)
                - EPOCH:           Total Epoches     ->  (default 10) 
            OUTPUT: 
                - model:           After Training 
                - val_loss_list:   Validation Loss respect to Epoch 
                - train_loss_list: Training Loss respect to Epoch
    """
    criterio = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    train_loss_list = []
    val_loss_list = []
    for epoch in range(EPOCH):
        epoch_loss = 0
        epoch_accuracy = 0
        i = 0
        for data, label in train_loader:
            i += 1
            data = data.to(device)
            label = label.to(device) 
            output = model(data) 
            loss = criterio(output, label) 

            optimizer.zero_grad() 
            loss.backward() 
            optimizer.step() 

            acc = (output.argmax(dim=1)==label).float().mean() 
            epoch_accuracy += acc / len(train_loader)
            epoch_loss += loss / len(train_loader) 
            print(f"Epoch: {epoch}, {i}/{len(train_loader)} | Acc: {epoch_accuracy:.4f} | Los: {epoch_loss:.4f}", end='\r')
        with torch.no_grad(): 
            epoch_val_acc = 0 
            epoch_val_loss = 0
            for data, label in test_loader:
                data = data.to(device)
                label = label.to(device)
                val_output = model(data)
                val_loss = criterio(val_output, label) 
                acc = (val_output.argmax(dim=1)==label).float().mean() 
                epoch_val_acc += acc / len(test_loader)
                epoch_val_loss += val_loss / len(test_loader)
        val_loss_list.append(epoch_val_loss)
        train_loss_list.append(epoch_loss)
        print(f"{name:15s} > Epoch: {epoch+1:2d} | Loss: {epoch_loss:.4f} | acc: {epoch_accuracy:4f} | val_loss: {epoch_val_loss: .4f} | val_acc: {epoch_val_acc: .4f}")
        return model, train_loss_list, val_loss_list

In [None]:
for key in ModelDict:
    print(f"{key} ->")

In [None]:
TrainLossDict, ValLossDict, ModelOK = dict(), dict(), dict()
for key in ModelDict:
    print(key)
    ModelOK[key], TrainLossDict[key], ValLossDict[key] = trainit(ModelDict[key], train_loader, test_loader, key, EPOCH=2)
key = "ViT_INPUTVARY"
ModelOK[key], TrainLossDict[key], ValLossDict[key] = trainit(ModelDict[key], train_loader_S, test_loader_S, key, EPOCH=2)