# Swin Transformers

This notebook trains a  Vision Transformer on the Butterfly dataset.

## Loading Dataset

---

In [10]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [12]:
!cp '/content/gdrive/MyDrive/ML Project/archive.zip' /content/

In [13]:
!cd /content/
!unzip archive.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: train/JULIA/042.jpg     
  inflating: train/JULIA/043.jpg     
  inflating: train/JULIA/044.jpg     
  inflating: train/JULIA/045.jpg     
  inflating: train/JULIA/046.jpg     
  inflating: train/JULIA/047.jpg     
  inflating: train/JULIA/048.jpg     
  inflating: train/JULIA/049.jpg     
  inflating: train/JULIA/050.jpg     
  inflating: train/JULIA/051.jpg     
  inflating: train/JULIA/052.jpg     
  inflating: train/JULIA/053.jpg     
  inflating: train/JULIA/054.jpg     
  inflating: train/JULIA/055.jpg     
  inflating: train/JULIA/056.jpg     
  inflating: train/JULIA/057.jpg     
  inflating: train/JULIA/058.jpg     
  inflating: train/JULIA/059.jpg     
  inflating: train/JULIA/060.jpg     
  inflating: train/JULIA/061.jpg     
  inflating: train/JULIA/062.jpg     
  inflating: train/JULIA/063.jpg     
  inflating: train/JULIA/064.jpg     
  inflating: train/JULIA/065.jpg     
  inflating: train/JULI

## Main Code

---



In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load
# We use a butterfly dataset of 50 species to demonstrate the classification method

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import torch
import torchvision
from torchvision import datasets
from torchvision import transforms as T # for simplifying the transforms
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, sampler, random_split
from torchvision import models

In [3]:
## Now, we import timm, torchvision image models
!pip install timm # kaggle doesnt have it installed by default
import timm
from timm.loss import LabelSmoothingCrossEntropy # This is better than normal nn.CrossEntropyLoss

Collecting timm
  Downloading timm-0.5.4-py3-none-any.whl (431 kB)
[?25l[K     |▊                               | 10 kB 27.3 MB/s eta 0:00:01[K     |█▌                              | 20 kB 34.4 MB/s eta 0:00:01[K     |██▎                             | 30 kB 26.3 MB/s eta 0:00:01[K     |███                             | 40 kB 20.3 MB/s eta 0:00:01[K     |███▉                            | 51 kB 21.1 MB/s eta 0:00:01[K     |████▋                           | 61 kB 24.0 MB/s eta 0:00:01[K     |█████▎                          | 71 kB 19.6 MB/s eta 0:00:01[K     |██████                          | 81 kB 21.3 MB/s eta 0:00:01[K     |██████▉                         | 92 kB 23.3 MB/s eta 0:00:01[K     |███████▋                        | 102 kB 23.6 MB/s eta 0:00:01[K     |████████▍                       | 112 kB 23.6 MB/s eta 0:00:01[K     |█████████▏                      | 122 kB 23.6 MB/s eta 0:00:01[K     |█████████▉                      | 133 kB 23.6 MB/s eta 0:00:01

In [4]:
# remove warnings
import warnings
warnings.filterwarnings("ignore")

In [5]:
import matplotlib.pyplot as plt
%matplotlib inline

In [6]:
import sys
from tqdm import tqdm
import time
import copy

In [7]:
def get_classes(data_dir):
    all_data = datasets.ImageFolder(data_dir)
    return all_data.classes

In [8]:
def get_data_loaders(data_dir, batch_size, train = False):
    if train:
        #train
        transform = T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            T.RandomApply(torch.nn.ModuleList([T.ColorJitter()]), p=0.25),
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(timm.data.IMAGENET_DEFAULT_MEAN, timm.data.IMAGENET_DEFAULT_STD), # imagenet means
            T.RandomErasing(p=0.1, value='random')
        ])
        train_data = datasets.ImageFolder(os.path.join(data_dir, "train/"), transform = transform)
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
        return train_loader, len(train_data)
    else:
        # val/test
        transform = T.Compose([ # We dont need augmentation for test transforms
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(timm.data.IMAGENET_DEFAULT_MEAN, timm.data.IMAGENET_DEFAULT_STD), # imagenet means
        ])
        val_data = datasets.ImageFolder(os.path.join(data_dir, "valid/"), transform=transform)
        test_data = datasets.ImageFolder(os.path.join(data_dir, "test/"), transform=transform)
        val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True, num_workers=4)
        test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=4)
        return val_loader, test_loader, len(val_data), len(test_data)

In [14]:
dataset_path = "/content/"

In [15]:
(train_loader, train_data_len) = get_data_loaders(dataset_path, 128, train=True)
(val_loader, test_loader, valid_data_len, test_data_len) = get_data_loaders(dataset_path, 32, train=False)

In [16]:
classes = get_classes("/content/train/")
print(classes, len(classes))

['ADONIS', 'AFRICAN GIANT SWALLOWTAIL', 'AMERICAN SNOOT', 'AN 88', 'APPOLLO', 'ATALA', 'BANDED ORANGE HELICONIAN', 'BANDED PEACOCK', 'BECKERS WHITE', 'BLACK HAIRSTREAK', 'BLUE MORPHO', 'BLUE SPOTTED CROW', 'BROWN SIPROETA', 'CABBAGE WHITE', 'CAIRNS BIRDWING', 'CHECQUERED SKIPPER', 'CHESTNUT', 'CLEOPATRA', 'CLODIUS PARNASSIAN', 'CLOUDED SULPHUR', 'COMMON BANDED AWL', 'COMMON WOOD-NYMPH', 'COPPER TAIL', 'CRECENT', 'CRIMSON PATCH', 'DANAID EGGFLY', 'EASTERN COMA', 'EASTERN DAPPLE WHITE', 'EASTERN PINE ELFIN', 'ELBOWED PIERROT', 'GOLD BANDED', 'GREAT EGGFLY', 'GREAT JAY', 'GREEN CELLED CATTLEHEART', 'GREY HAIRSTREAK', 'INDRA SWALLOW', 'IPHICLUS SISTER', 'JULIA', 'LARGE MARBLE', 'MALACHITE', 'MANGROVE SKIPPER', 'MESTRA', 'METALMARK', 'MILBERTS TORTOISESHELL', 'MONARCH', 'MOURNING CLOAK', 'ORANGE OAKLEAF', 'ORANGE TIP', 'ORCHARD SWALLOW', 'PAINTED LADY', 'PAPER KITE', 'PEACOCK', 'PINE WHITE', 'PIPEVINE SWALLOW', 'POPINJAY', 'PURPLE HAIRSTREAK', 'PURPLISH COPPER', 'QUESTION MARK', 'RED ADMIRA

In [17]:
dataloaders = {
    "train": train_loader,
    "val": val_loader
}
dataset_sizes = {
    "train": train_data_len,
    "val": valid_data_len
}

In [18]:
print(len(train_loader), len(val_loader), len(test_loader))

73 12 12


In [19]:
print(train_data_len, valid_data_len, test_data_len)

9285 375 375


In [21]:
# now, for the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [35]:
HUB_URL = "SharanSMenon/swin-transformer-hub:main"
MODEL_NAME = "swin_tiny_patch4_window7_224"
# check hubconf for more models.
model = torch.hub.load(HUB_URL, MODEL_NAME, pretrained=True) # load from torch hub

Using cache found in /root/.cache/torch/hub/SharanSMenon_swin-transformer-hub_main


In [36]:
for param in model.parameters(): #freeze model
    param.requires_grad = False

n_inputs = model.head.in_features
model.head = nn.Sequential(
    nn.Linear(n_inputs, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, len(classes))
)
model = model.to(device)
print(model.head)

Sequential(
  (0): Linear(in_features=768, out_features=512, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.3, inplace=False)
  (3): Linear(in_features=512, out_features=75, bias=True)
)


In [37]:
criterion = LabelSmoothingCrossEntropy()
criterion = criterion.to(device)
optimizer = optim.AdamW(model.head.parameters(), lr=0.001)

In [38]:
# lr scheduler
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.97)

In [39]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=10):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print("-"*10)
        
        for phase in ['train', 'val']: # We do training and validation phase per epoch
            if phase == 'train':
                model.train() # model to training mode
            else:
                model.eval() # model to evaluate
            
            running_loss = 0.0
            running_corrects = 0.0
            
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'): # no autograd makes validation go faster
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1) # used for accuracy
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
            if phase == 'train':
                scheduler.step() # step at end of epoch
            
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc =  running_corrects.double() / dataset_sizes[phase]
            
            print("{} Loss: {:.4f} Acc: {:.4f}".format(phase, epoch_loss, epoch_acc))
            
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict()) # keep the best validation accuracy model
        print()
    time_elapsed = time.time() - since # slight error
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print("Best Val Acc: {:.4f}".format(best_acc))
    
    model.load_state_dict(best_model_wts)
    return model

In [40]:
model_ft = train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=20) # now it is a lot faster
# I will come back after 10 epochs

Epoch 0/19
----------


100%|██████████| 73/73 [00:52<00:00,  1.39it/s]


train Loss: 2.8263 Acc: 0.4178


100%|██████████| 12/12 [00:02<00:00,  4.14it/s]


val Loss: 1.8997 Acc: 0.6907

Epoch 1/19
----------


100%|██████████| 73/73 [00:53<00:00,  1.36it/s]


train Loss: 1.7114 Acc: 0.7417


100%|██████████| 12/12 [00:02<00:00,  4.18it/s]


val Loss: 1.4909 Acc: 0.8187

Epoch 2/19
----------


100%|██████████| 73/73 [00:53<00:00,  1.37it/s]


train Loss: 1.4673 Acc: 0.8240


100%|██████████| 12/12 [00:02<00:00,  4.36it/s]


val Loss: 1.3352 Acc: 0.8640

Epoch 3/19
----------


100%|██████████| 73/73 [00:53<00:00,  1.37it/s]


train Loss: 1.3412 Acc: 0.8654


100%|██████████| 12/12 [00:02<00:00,  4.30it/s]


val Loss: 1.2591 Acc: 0.8747

Epoch 4/19
----------


100%|██████████| 73/73 [00:53<00:00,  1.36it/s]


train Loss: 1.2751 Acc: 0.8835


100%|██████████| 12/12 [00:02<00:00,  4.32it/s]


val Loss: 1.2143 Acc: 0.9067

Epoch 5/19
----------


100%|██████████| 73/73 [00:53<00:00,  1.36it/s]


train Loss: 1.2312 Acc: 0.8974


100%|██████████| 12/12 [00:02<00:00,  4.27it/s]


val Loss: 1.1762 Acc: 0.9067

Epoch 6/19
----------


100%|██████████| 73/73 [00:53<00:00,  1.37it/s]


train Loss: 1.1887 Acc: 0.9103


100%|██████████| 12/12 [00:02<00:00,  4.28it/s]


val Loss: 1.1323 Acc: 0.9173

Epoch 7/19
----------


100%|██████████| 73/73 [00:53<00:00,  1.36it/s]


train Loss: 1.1579 Acc: 0.9217


100%|██████████| 12/12 [00:02<00:00,  4.19it/s]


val Loss: 1.1058 Acc: 0.9227

Epoch 8/19
----------


100%|██████████| 73/73 [00:53<00:00,  1.36it/s]


train Loss: 1.1362 Acc: 0.9250


100%|██████████| 12/12 [00:02<00:00,  4.34it/s]


val Loss: 1.0893 Acc: 0.9227

Epoch 9/19
----------


100%|██████████| 73/73 [00:53<00:00,  1.37it/s]


train Loss: 1.1189 Acc: 0.9348


100%|██████████| 12/12 [00:02<00:00,  4.29it/s]


val Loss: 1.0712 Acc: 0.9520

Epoch 10/19
----------


100%|██████████| 73/73 [00:53<00:00,  1.37it/s]


train Loss: 1.0965 Acc: 0.9407


100%|██████████| 12/12 [00:02<00:00,  4.38it/s]


val Loss: 1.0651 Acc: 0.9227

Epoch 11/19
----------


100%|██████████| 73/73 [00:53<00:00,  1.36it/s]


train Loss: 1.0909 Acc: 0.9393


100%|██████████| 12/12 [00:02<00:00,  4.33it/s]


val Loss: 1.0650 Acc: 0.9360

Epoch 12/19
----------


100%|██████████| 73/73 [00:53<00:00,  1.37it/s]


train Loss: 1.0740 Acc: 0.9428


100%|██████████| 12/12 [00:02<00:00,  4.24it/s]


val Loss: 1.0567 Acc: 0.9387

Epoch 13/19
----------


100%|██████████| 73/73 [00:53<00:00,  1.36it/s]


train Loss: 1.0604 Acc: 0.9478


100%|██████████| 12/12 [00:02<00:00,  4.32it/s]


val Loss: 1.0483 Acc: 0.9387

Epoch 14/19
----------


100%|██████████| 73/73 [00:53<00:00,  1.37it/s]


train Loss: 1.0556 Acc: 0.9491


100%|██████████| 12/12 [00:02<00:00,  4.31it/s]


val Loss: 1.0474 Acc: 0.9360

Epoch 15/19
----------


100%|██████████| 73/73 [00:53<00:00,  1.37it/s]


train Loss: 1.0480 Acc: 0.9496


100%|██████████| 12/12 [00:02<00:00,  4.24it/s]


val Loss: 1.0408 Acc: 0.9307

Epoch 16/19
----------


100%|██████████| 73/73 [00:53<00:00,  1.36it/s]


train Loss: 1.0375 Acc: 0.9564


100%|██████████| 12/12 [00:02<00:00,  4.25it/s]


val Loss: 1.0293 Acc: 0.9440

Epoch 17/19
----------


100%|██████████| 73/73 [00:53<00:00,  1.36it/s]


train Loss: 1.0312 Acc: 0.9558


100%|██████████| 12/12 [00:02<00:00,  4.25it/s]


val Loss: 1.0272 Acc: 0.9493

Epoch 18/19
----------


100%|██████████| 73/73 [00:53<00:00,  1.36it/s]


train Loss: 1.0223 Acc: 0.9596


100%|██████████| 12/12 [00:02<00:00,  4.27it/s]


val Loss: 1.0293 Acc: 0.9440

Epoch 19/19
----------


100%|██████████| 73/73 [00:53<00:00,  1.36it/s]


train Loss: 1.0151 Acc: 0.9621


100%|██████████| 12/12 [00:02<00:00,  4.24it/s]

val Loss: 1.0285 Acc: 0.9333

Training complete in 18m 46s
Best Val Acc: 0.9520





## Testing

Ok, now we finished training. Lets run the dataset on the test loader and calculate accuracy

In [41]:
test_loss = 0.0
class_correct = list(0 for i in range(len(classes)))
class_total = list(0 for i in range(len(classes)))
model_ft.eval()

for data, target in tqdm(test_loader):
    data, target = data.to(device), target.to(device)
    with torch.no_grad(): # turn off autograd for faster testing
        output = model_ft(data)
        loss = criterion(output, target)
    test_loss = loss.item() * data.size(0)
    _, pred = torch.max(output, 1)
    correct_tensor = pred.eq(target.data.view_as(pred))
    correct = np.squeeze(correct_tensor.cpu().numpy())
    if len(target) == 32:
        for i in range(32):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

test_loss = test_loss / test_data_len
print('Test Loss: {:.4f}'.format(test_loss))
for i in range(len(classes)):
    if class_total[i] > 0:
        print("Test Accuracy of %5s: %2d%% (%2d/%2d)" % (
            classes[i], 100*class_correct[i]/class_total[i], np.sum(class_correct[i]), np.sum(class_total[i])
        ))
    else:
        print("Test accuracy of %5s: NA" % (classes[i]))
print("Test Accuracy of %2d%% (%2d/%2d)" % (
            100*np.sum(class_correct)/np.sum(class_total), np.sum(class_correct), np.sum(class_total)
        ))

100%|██████████| 12/12 [00:02<00:00,  4.19it/s]

Test Loss: 0.0626
Test Accuracy of ADONIS: 100% ( 4/ 4)
Test Accuracy of AFRICAN GIANT SWALLOWTAIL: 100% ( 5/ 5)
Test Accuracy of AMERICAN SNOOT: 75% ( 3/ 4)
Test Accuracy of AN 88: 100% ( 5/ 5)
Test Accuracy of APPOLLO: 100% ( 5/ 5)
Test Accuracy of ATALA: 100% ( 4/ 4)
Test Accuracy of BANDED ORANGE HELICONIAN: 100% ( 5/ 5)
Test Accuracy of BANDED PEACOCK: 100% ( 4/ 4)
Test Accuracy of BECKERS WHITE: 100% ( 3/ 3)
Test Accuracy of BLACK HAIRSTREAK: 100% ( 5/ 5)
Test Accuracy of BLUE MORPHO: 80% ( 4/ 5)
Test Accuracy of BLUE SPOTTED CROW: 100% ( 5/ 5)
Test Accuracy of BROWN SIPROETA: 100% ( 5/ 5)
Test Accuracy of CABBAGE WHITE: 100% ( 4/ 4)
Test Accuracy of CAIRNS BIRDWING: 100% ( 5/ 5)
Test Accuracy of CHECQUERED SKIPPER: 100% ( 4/ 4)
Test Accuracy of CHESTNUT: 100% ( 5/ 5)
Test Accuracy of CLEOPATRA: 80% ( 4/ 5)
Test Accuracy of CLODIUS PARNASSIAN: 80% ( 4/ 5)
Test Accuracy of CLOUDED SULPHUR: 100% ( 5/ 5)
Test Accuracy of COMMON BANDED AWL: 100% ( 5/ 5)
Test Accuracy of COMMON WOOD-N




In [42]:
# our model earns 93% test accuracy, which is very high. lets save it
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model.cpu(), example)
traced_script_module.save("butterfly_swin_transformer_20.pt")

In [32]:
!cp '/content/butterfly_swin_transformer_.pt' '/content/gdrive/MyDrive/ML Project'