##### Vision Transformers

This notebook trains a  Vision Transformer on the Butterfly dataset.

Code for the following video:
https://youtu.be/0tjuRnkFHKg

In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os

In [2]:
import torch
import torchvision
from torchvision import datasets
from torchvision import transforms as T # for simplifying the transforms
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, sampler, random_split
from torchvision import models


In [3]:
## Now, we import timm, torchvision image models
!pip install timm # kaggle doesnt have it installed by default
import timm
from timm.loss import LabelSmoothingCrossEntropy # This is better than normal nn.CrossEntropyLoss
import timm.optim
from timm.scheduler import CosineLRScheduler

Collecting timm
  Downloading timm-0.6.12-py3-none-any.whl (549 kB)
     |████████████████████████████████| 549 kB 901 kB/s            
Installing collected packages: timm
Successfully installed timm-0.6.12


In [4]:
# remove warnings
import warnings
warnings.filterwarnings("ignore")

In [5]:
import matplotlib.pyplot as plt
%matplotlib inline

In [6]:
import sys
from tqdm import tqdm
import time
import copy

In [7]:
def get_classes(data_dir):
    all_data = datasets.ImageFolder(data_dir)
    return all_data.classes

In [8]:
def get_data_loaders(data_dir, batch_size, train = False):
    if train:
        #train
        transform = T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            T.RandomApply(torch.nn.ModuleList([T.ColorJitter()]), p=0.25),
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # imagenet means
            T.RandomErasing(p=0.2, value='random')
        ])
        train_data = datasets.ImageFolder(os.path.join(data_dir, "train/"), transform = transform)
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
        return train_loader, len(train_data)
    else:
        # val/test
        transform = T.Compose([ # We dont need augmentation for test transforms
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # imagenet means
        ])
        val_data = datasets.ImageFolder(os.path.join(data_dir, "valid/"), transform=transform)
        test_data = datasets.ImageFolder(os.path.join(data_dir, "test/"), transform=transform)
        val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True, num_workers=4)
        test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=4)
        return val_loader, test_loader, len(val_data), len(test_data)

In [9]:
dataset_path = "/kaggle/input/butterfly-images40-species/"

In [10]:
(train_loader, train_data_len) = get_data_loaders(dataset_path, 128, train=True)
(val_loader, test_loader, valid_data_len, test_data_len) = get_data_loaders(dataset_path, 32, train=False)

In [11]:
classes = get_classes("/kaggle/input/butterfly-images40-species/train/")
print(classes, len(classes))

['ADONIS', 'AFRICAN GIANT SWALLOWTAIL', 'AMERICAN SNOOT', 'AN 88', 'APPOLLO', 'ARCIGERA FLOWER MOTH', 'ATALA', 'ATLAS MOTH', 'BANDED ORANGE HELICONIAN', 'BANDED PEACOCK', 'BANDED TIGER MOTH', 'BECKERS WHITE', 'BIRD CHERRY ERMINE MOTH', 'BLACK HAIRSTREAK', 'BLUE MORPHO', 'BLUE SPOTTED CROW', 'BROOKES BIRDWING', 'BROWN ARGUS', 'BROWN SIPROETA', 'CABBAGE WHITE', 'CAIRNS BIRDWING', 'CHALK HILL BLUE', 'CHECQUERED SKIPPER', 'CHESTNUT', 'CINNABAR MOTH', 'CLEARWING MOTH', 'CLEOPATRA', 'CLODIUS PARNASSIAN', 'CLOUDED SULPHUR', 'COMET MOTH', 'COMMON BANDED AWL', 'COMMON WOOD-NYMPH', 'COPPER TAIL', 'CRECENT', 'CRIMSON PATCH', 'DANAID EGGFLY', 'EASTERN COMA', 'EASTERN DAPPLE WHITE', 'EASTERN PINE ELFIN', 'ELBOWED PIERROT', 'EMPEROR GUM MOTH', 'GARDEN TIGER MOTH', 'GIANT LEOPARD MOTH', 'GLITTERING SAPPHIRE', 'GOLD BANDED', 'GREAT EGGFLY', 'GREAT JAY', 'GREEN CELLED CATTLEHEART', 'GREEN HAIRSTREAK', 'GREY HAIRSTREAK', 'HERCULES MOTH', 'HUMMING BIRD HAWK MOTH', 'INDRA SWALLOW', 'IO MOTH', 'Iphiclus si

In [12]:
dataloaders = {
    "train": train_loader,
    "val": val_loader
}
dataset_sizes = {
    "train": train_data_len,
    "val": valid_data_len
}

In [13]:
print(len(train_loader), len(val_loader), len(test_loader))

99 16 16


In [14]:
print(train_data_len, valid_data_len, test_data_len)

12639 500 500


In [15]:
# now, for the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [16]:
#model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)
model = timm.create_model("vit_base_patch16_224", pretrained=True)

In [17]:
for param in model.parameters(): #freeze model
    param.requires_grad = False

n_inputs = model.head.in_features
model.head = nn.Sequential(
    nn.Linear(n_inputs, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, len(classes))
)
model = model.to(device)
print(model.head)

Sequential(
  (0): Linear(in_features=768, out_features=512, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.3, inplace=False)
  (3): Linear(in_features=512, out_features=100, bias=True)
)


In [18]:
criterion = LabelSmoothingCrossEntropy()
criterion = criterion.to(device)
#optimizer = optim.Adam(model.head.parameters(), lr=0.001)

In [19]:
# lr scheduler
#exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.97)

In [20]:
def train_model(model, criterion, num_epochs=10):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    optimizer = timm.optim.AdamP(model.head.parameters(), lr=0.001)
    scheduler = timm.scheduler.CosineLRScheduler(optimizer, t_initial=num_epochs)
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print("-"*10)
        
        for phase in ['train', 'val']: # We do training and validation phase per epoch
            if phase == 'train':
                model.train() # model to training mode
            else:
                model.eval() # model to evaluate
            
            running_loss = 0.0
            running_corrects = 0.0
            
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'): # no autograd makes validation go faster
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1) # used for accuracy
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
            if phase == 'train':
                scheduler.step(epoch) # step at end of epoch
            
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc =  running_corrects.double() / dataset_sizes[phase]
            
            print("{} Loss: {:.4f} Acc: {:.4f}".format(phase, epoch_loss, epoch_acc))
            
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict()) # keep the best validation accuracy model
        print()
    time_elapsed = time.time() - since # slight error
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print("Best Val Acc: {:.4f}".format(best_acc))
    
    model.load_state_dict(best_model_wts)
    return model

In [21]:
model_ft = train_model(model, criterion) # now it is a lot faster
# I will come back after 10 epochs

Epoch 0/9
----------


100%|██████████| 99/99 [01:26<00:00,  1.15it/s]


train Loss: 1.8870 Acc: 0.7060


100%|██████████| 16/16 [00:04<00:00,  3.86it/s]


val Loss: 1.1799 Acc: 0.9140

Epoch 1/9
----------


100%|██████████| 99/99 [01:18<00:00,  1.26it/s]


train Loss: 1.2159 Acc: 0.9057


100%|██████████| 16/16 [00:03<00:00,  4.09it/s]


val Loss: 1.0904 Acc: 0.9360

Epoch 2/9
----------


100%|██████████| 99/99 [01:18<00:00,  1.26it/s]


train Loss: 1.1327 Acc: 0.9318


100%|██████████| 16/16 [00:03<00:00,  4.00it/s]


val Loss: 1.0675 Acc: 0.9580

Epoch 3/9
----------


100%|██████████| 99/99 [01:18<00:00,  1.26it/s]


train Loss: 1.0822 Acc: 0.9477


100%|██████████| 16/16 [00:03<00:00,  4.10it/s]


val Loss: 1.0272 Acc: 0.9580

Epoch 4/9
----------


100%|██████████| 99/99 [01:19<00:00,  1.25it/s]


train Loss: 1.0430 Acc: 0.9585


100%|██████████| 16/16 [00:03<00:00,  4.13it/s]


val Loss: 1.0206 Acc: 0.9540

Epoch 5/9
----------


100%|██████████| 99/99 [01:18<00:00,  1.26it/s]


train Loss: 1.0179 Acc: 0.9656


100%|██████████| 16/16 [00:04<00:00,  3.83it/s]


val Loss: 1.0189 Acc: 0.9560

Epoch 6/9
----------


100%|██████████| 99/99 [01:18<00:00,  1.26it/s]


train Loss: 0.9934 Acc: 0.9734


100%|██████████| 16/16 [00:03<00:00,  4.09it/s]


val Loss: 1.0053 Acc: 0.9560

Epoch 7/9
----------


100%|██████████| 99/99 [01:18<00:00,  1.26it/s]


train Loss: 0.9717 Acc: 0.9767


100%|██████████| 16/16 [00:03<00:00,  4.08it/s]


val Loss: 0.9892 Acc: 0.9640

Epoch 8/9
----------


100%|██████████| 99/99 [01:18<00:00,  1.26it/s]


train Loss: 0.9583 Acc: 0.9801


100%|██████████| 16/16 [00:03<00:00,  4.04it/s]


val Loss: 0.9874 Acc: 0.9640

Epoch 9/9
----------


100%|██████████| 99/99 [01:19<00:00,  1.25it/s]


train Loss: 0.9499 Acc: 0.9828


100%|██████████| 16/16 [00:04<00:00,  3.97it/s]

val Loss: 0.9818 Acc: 0.9700

Training complete in 13m 55s
Best Val Acc: 0.9700





## Testing

Ok, now we finished training. Lets run the dataset on the test loader and calculate accuracy

In [22]:
# anyways, we can use model, for it is also updated witht he latest weights

In [23]:
test_loss = 0.0
class_correct = list(0 for i in range(len(classes)))
class_total = list(0 for i in range(len(classes)))
model_ft.eval()

for data, target in tqdm(test_loader):
    data, target = data.to(device), target.to(device)
    with torch.no_grad(): # turn off autograd for faster testing
        output = model_ft(data)
        loss = criterion(output, target)
    test_loss = loss.item() * data.size(0)
    _, pred = torch.max(output, 1)
    correct_tensor = pred.eq(target.data.view_as(pred))
    correct = np.squeeze(correct_tensor.cpu().numpy())
    if len(target) == 32:
        for i in range(32):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

test_loss = test_loss / test_data_len
print('Test Loss: {:.4f}'.format(test_loss))
for i in range(len(classes)):
    if class_total[i] > 0:
        print("Test Accuracy of %5s: %2d%% (%2d/%2d)" % (
            classes[i], 100*class_correct[i]/class_total[i], np.sum(class_correct[i]), np.sum(class_total[i])
        ))
    else:
        print("Test accuracy of %5s: NA" % (classes[i]))
print("Test Accuracy of %2d%% (%2d/%2d)" % (
            100*np.sum(class_correct)/np.sum(class_total), np.sum(class_correct), np.sum(class_total)
        ))

100%|██████████| 16/16 [00:04<00:00,  3.93it/s]

Test Loss: 0.0456
Test Accuracy of ADONIS: 100% ( 5/ 5)
Test Accuracy of AFRICAN GIANT SWALLOWTAIL: 100% ( 5/ 5)
Test Accuracy of AMERICAN SNOOT: 100% ( 5/ 5)
Test Accuracy of AN 88: 100% ( 4/ 4)
Test Accuracy of APPOLLO: 100% ( 5/ 5)
Test Accuracy of ARCIGERA FLOWER MOTH: 100% ( 5/ 5)
Test Accuracy of ATALA: 100% ( 5/ 5)
Test Accuracy of ATLAS MOTH: 100% ( 5/ 5)
Test Accuracy of BANDED ORANGE HELICONIAN: 100% ( 5/ 5)
Test Accuracy of BANDED PEACOCK: 100% ( 5/ 5)
Test Accuracy of BANDED TIGER MOTH: 100% ( 4/ 4)
Test Accuracy of BECKERS WHITE: 100% ( 5/ 5)
Test Accuracy of BIRD CHERRY ERMINE MOTH: 100% ( 4/ 4)
Test Accuracy of BLACK HAIRSTREAK: 100% ( 5/ 5)
Test Accuracy of BLUE MORPHO: 100% ( 5/ 5)
Test Accuracy of BLUE SPOTTED CROW: 100% ( 4/ 4)
Test Accuracy of BROOKES BIRDWING: 100% ( 5/ 5)
Test Accuracy of BROWN ARGUS: 100% ( 5/ 5)
Test Accuracy of BROWN SIPROETA: 100% ( 5/ 5)
Test Accuracy of CABBAGE WHITE: 100% ( 5/ 5)
Test Accuracy of CAIRNS BIRDWING: 100% ( 5/ 5)
Test Accuracy 




In [24]:
# our model earns 93% test accuracy, which is very high. lets save it
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model.cpu(), example)
traced_script_module.save("butterfly_deit_video.pt")

In [25]:
# That's it for this video, see you next time