# Visual Transformer with Linformer

Training Visual Transformer on *Lion, Lizard and Toucan dataset*

In [208]:
import sys
!pip -q install vit_pytorch linformer

## Import Libraries

In [209]:
from __future__ import print_function

import glob
from itertools import chain
import os
import random
import zipfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from linformer import Linformer
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

import shutil

from vit_pytorch.efficient import ViT


In [210]:
print(f"Torch: {torch.__version__}")

Torch: 1.8.2+cu111


In [211]:
# Training settings
batch_size = 32  # was 64
epochs = 10  # was 20
lr = 3e-5
gamma = 0.7
seed = 42

In [212]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

In [213]:
device = 'cuda'

## Load Data

In [214]:
# JW: Remove data if exists
shutil.rmtree("data")

os.makedirs('data', exist_ok=True)
os.makedirs('data/train', exist_ok=True)
os.makedirs('data/test', exist_ok=True)

In [215]:
train_dir = 'data/train'
test_dir = 'data/test'

In [216]:
# JW: Copy images from chosen dataset to folders, 2/3 train/test split. Add label to train name/path
split = 2/3
dataset = "images-outline"
folders = [x.replace("\\", "/") for x in glob.glob(f"{dataset}/*")]
for folder in folders:
    paths = [x.replace("\\", "/") for x in glob.glob(f"{folder}/*.*")]
    index = round(split * len(paths))
    for path in paths[:index]:
        items = path.split('/')
        new_name = f"{items[1].lower()}.{items[-1]}"
        new_path = f"{train_dir}/{new_name}"
        shutil.copy(path, new_path)
    for path in paths[index:]:
        new_name = path.split('/')[-1]
        new_path = f"{test_dir}/{new_name}"
        shutil.copy(path, new_path)

In [217]:
train_list = [x.replace("\\", "/") for x in glob.glob(f"{train_dir}/*.*")]
test_list = [x.replace("\\", "/") for x in glob.glob(f"{test_dir}/*.*")]

In [218]:
print(f"Train Data: {len(train_list)}")
print(f"Test Data: {len(test_list)}")

Train Data: 6667
Test Data: 3330


In [219]:
labels = [path.split('/')[-1].split('.')[0] for path in train_list]

## Split

In [220]:
train_list, valid_list = train_test_split(train_list, 
                                          test_size=0.2,
                                          stratify=labels,
                                          random_state=seed)

In [221]:
print(f"Train Data: {len(train_list)}")
print(f"Validation Data: {len(valid_list)}")
print(f"Test Data: {len(test_list)}")

Train Data: 5333
Validation Data: 1334
Test Data: 3330


## Image Augmentation

In [222]:
train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ]
)


test_transforms = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ]
)


## Load Datasets

In [223]:
class LionLizardToucanDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path)
        img_transformed = self.transform(img)

        label = img_path.split("/")[-1].split(".")[0]
        if label == "lion":
            label = 0
        elif label == "lizard":
            label = 1
        else:
            label = 2

        return img_transformed, label


In [224]:
train_data = LionLizardToucanDataset(train_list, transform=train_transforms)
valid_data = LionLizardToucanDataset(valid_list, transform=test_transforms)
test_data = LionLizardToucanDataset(test_list, transform=test_transforms)

In [225]:
train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True )
valid_loader = DataLoader(dataset = valid_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset = test_data, batch_size=batch_size, shuffle=True)

In [226]:
print(len(train_data), len(train_loader))

5333 84


In [227]:
print(len(valid_data), len(valid_loader))

1334 21


## Efficient Attention

### Linformer

In [228]:
efficient_transformer = Linformer(
    dim=128,
    seq_len=49+1,  # 7x7 patches + 1 cls-token
    depth=12,
    heads=8,
    k=64
)

### Visual Transformer

In [229]:
model = ViT(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=3,
    transformer=efficient_transformer,
    channels=3,
).to(device)

### Training

In [230]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [231]:
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )


  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 1 - loss : 1.1056 - acc: 0.3332 - val_loss : 1.0970 - val_acc: 0.3462



  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 2 - loss : 1.0942 - acc: 0.3738 - val_loss : 1.0742 - val_acc: 0.3452



  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 3 - loss : 0.9015 - acc: 0.5904 - val_loss : 0.2814 - val_acc: 0.9303



  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 4 - loss : 0.4421 - acc: 0.8356 - val_loss : 0.1818 - val_acc: 0.9993



  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 5 - loss : 0.3352 - acc: 0.8876 - val_loss : 0.1336 - val_acc: 0.9768



  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 6 - loss : 0.2913 - acc: 0.8932 - val_loss : 0.0336 - val_acc: 1.0000



  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 7 - loss : 0.2392 - acc: 0.8923 - val_loss : 0.0269 - val_acc: 1.0000



  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 8 - loss : 0.1954 - acc: 0.9064 - val_loss : 0.0180 - val_acc: 0.9985



  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 9 - loss : 0.1825 - acc: 0.9073 - val_loss : 0.0039 - val_acc: 1.0000



  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 10 - loss : 0.1620 - acc: 0.9200 - val_loss : 0.0076 - val_acc: 1.0000



  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 11 - loss : 0.1879 - acc: 0.9042 - val_loss : 0.0085 - val_acc: 1.0000



  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 12 - loss : 0.1891 - acc: 0.8990 - val_loss : 0.0430 - val_acc: 0.9844



  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 13 - loss : 0.1801 - acc: 0.9046 - val_loss : 0.0018 - val_acc: 1.0000



  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 14 - loss : 0.1672 - acc: 0.9107 - val_loss : 0.0019 - val_acc: 1.0000



  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 15 - loss : 0.1637 - acc: 0.9137 - val_loss : 0.0017 - val_acc: 1.0000



  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 16 - loss : 0.1811 - acc: 0.9113 - val_loss : 0.0021 - val_acc: 1.0000



  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 17 - loss : 0.1547 - acc: 0.9185 - val_loss : 0.0020 - val_acc: 0.9993



  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 18 - loss : 0.1675 - acc: 0.9070 - val_loss : 0.0032 - val_acc: 0.9991



  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 19 - loss : 0.1680 - acc: 0.9057 - val_loss : 0.0047 - val_acc: 0.9993



  0%|          | 0/84 [00:00<?, ?it/s]

Epoch : 20 - loss : 0.1681 - acc: 0.9096 - val_loss : 0.0021 - val_acc: 1.0000

