# Swin Transformers

This notebook trains a  Swin Transformer on the Butterfly dataset.

In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load
# We use a butterfly dataset of 50 species to demonstrate the classification method

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import torch
import torchvision
from torchvision import datasets
from torchvision import transforms as T # for simplifying the transforms
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, sampler, random_split
from torchvision import models

In [3]:
## Now, we import timm, torchvision image models
!pip install timm # kaggle doesnt have it installed by default
import timm
from timm.loss import LabelSmoothingCrossEntropy # This is better than normal nn.CrossEntropyLoss

Collecting timm
  Downloading timm-0.6.12-py3-none-any.whl (549 kB)
     |████████████████████████████████| 549 kB 575 kB/s            
Installing collected packages: timm
Successfully installed timm-0.6.12


In [5]:
# remove warnings
import warnings
warnings.filterwarnings("ignore")

In [8]:
import matplotlib.pyplot as plt
%matplotlib inline

In [9]:
import sys
from tqdm import tqdm
import time
import copy

In [10]:
def get_classes(data_dir):
    all_data = datasets.ImageFolder(data_dir)
    return all_data.classes

In [15]:
def get_data_loader(data_dir, batch_size, train=False):
    if train:
        # train
        transform = T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            T.RandomApply(torch.nn.ModuleList([T.ColorJitter()]), p=0.25),
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(timm.data.IMAGENET_DEFAULT_MEAN, timm.data.IMAGENET_DEFAULT_STD),
            T.RandomErasing(p=0.1, value='random')
        ])
        train_data = datasets.ImageFolder(os.path.join(data_dir, "train/"), transform=transform)
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
        
        return train_loader, len(train_data)
    else:
        # val/test
        transform = T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(timm.data.IMAGENET_DEFAULT_MEAN, timm.data.IMAGENET_DEFAULT_STD)
        ])
        val_data = datasets.ImageFolder(os.path.join(data_dir, "valid/"), transform=transform)
        test_data = datasets.ImageFolder(os.path.join(data_dir, "test/"), transform=transform)
        val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True, num_workers=4)
        test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=4)
        
        return val_loader, test_loader, len(val_data), len(test_data)

In [13]:
dataset_path = "/kaggle/input/butterfly-images40-species"

In [19]:
(train_loader, train_data_len) = get_data_loader(dataset_path, 128, train=True)
(val_loader, test_loader, valid_data_len, test_data_len) = get_data_loader(dataset_path, 32, train=False)

In [17]:
classes = get_classes("/kaggle/input/butterfly-images40-species/train/")
print(classes, len(classes))

['ADONIS', 'AFRICAN GIANT SWALLOWTAIL', 'AMERICAN SNOOT', 'AN 88', 'APPOLLO', 'ARCIGERA FLOWER MOTH', 'ATALA', 'ATLAS MOTH', 'BANDED ORANGE HELICONIAN', 'BANDED PEACOCK', 'BANDED TIGER MOTH', 'BECKERS WHITE', 'BIRD CHERRY ERMINE MOTH', 'BLACK HAIRSTREAK', 'BLUE MORPHO', 'BLUE SPOTTED CROW', 'BROOKES BIRDWING', 'BROWN ARGUS', 'BROWN SIPROETA', 'CABBAGE WHITE', 'CAIRNS BIRDWING', 'CHALK HILL BLUE', 'CHECQUERED SKIPPER', 'CHESTNUT', 'CINNABAR MOTH', 'CLEARWING MOTH', 'CLEOPATRA', 'CLODIUS PARNASSIAN', 'CLOUDED SULPHUR', 'COMET MOTH', 'COMMON BANDED AWL', 'COMMON WOOD-NYMPH', 'COPPER TAIL', 'CRECENT', 'CRIMSON PATCH', 'DANAID EGGFLY', 'EASTERN COMA', 'EASTERN DAPPLE WHITE', 'EASTERN PINE ELFIN', 'ELBOWED PIERROT', 'EMPEROR GUM MOTH', 'GARDEN TIGER MOTH', 'GIANT LEOPARD MOTH', 'GLITTERING SAPPHIRE', 'GOLD BANDED', 'GREAT EGGFLY', 'GREAT JAY', 'GREEN CELLED CATTLEHEART', 'GREEN HAIRSTREAK', 'GREY HAIRSTREAK', 'HERCULES MOTH', 'HUMMING BIRD HAWK MOTH', 'INDRA SWALLOW', 'IO MOTH', 'Iphiclus si

In [20]:
dataloaders = {
    "train": train_loader,
    "val": val_loader
}
dataset_sizes = {
    "train": train_data_len,
    "val": valid_data_len
}

In [22]:
print(len(train_loader), len(val_loader), len(test_loader))

99 16 16


In [24]:
print(train_data_len, valid_data_len, test_data_len)

12639 500 500


In [25]:
# for the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [26]:
HUB_URL = "SharanSMenon/swin-transformer-hub:main"
MODEL_NAME = "swin_tiny_patch4_window7_224"
# check hubconf for more models.
model = torch.hub.load(HUB_URL, MODEL_NAME, pretrained=True) # load from torch hub

Downloading: "https://github.com/SharanSMenon/swin-transformer-hub/archive/main.zip" to /root/.cache/torch/hub/main.zip
Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth" to /root/.cache/torch/hub/checkpoints/swin_tiny_patch4_window7_224.pth


  0%|          | 0.00/109M [00:00<?, ?B/s]

In [27]:
for param in model.parameters(): # freeze model
    param.requires_grad = False
    
n_inputs = model.head.in_features
model.head = nn.Sequential(
    nn.Linear(n_inputs, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, len(classes))
)
model = model.to(device)
print(model.head)

Sequential(
  (0): Linear(in_features=768, out_features=512, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.3, inplace=False)
  (3): Linear(in_features=512, out_features=100, bias=True)
)


In [28]:
print(model)

SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0): BasicLayer(
      dim=96, input_resolution=(56, 56), depth=2
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          dim=96, input_resolution=(56, 56), num_heads=3, window_size=7, shift_size=0, mlp_ratio=4.0
          (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=96, window_size=(7, 7), num_heads=3
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=96, out_features=96, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2): LayerNo

In [29]:
criterion = LabelSmoothingCrossEntropy()
criterion = criterion.to(device)
optimizer = optim.AdamW(model.head.parameters(), lr=0.001)

In [30]:
# lr scheduler
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.97)

In [36]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=10):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print("-" * 10)
        
        for phase in ['train', 'val']: # do training and validation phase per epoch
            if phase == 'train':
                model.train()
            else:
                model.eval()
            
            running_loss = 0.0
            running_corrects = 0.0
            
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'): # no autograd makes validation faster
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1) # used for accuracy
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
            if phase == 'train':
                scheduler.step()
                
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            print("{} Loss: {:.4f} Acc: {:.4f}".format(phase, epoch_loss, epoch_acc))
            
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict()) # keep the best validation accuracy model
                
        print()
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print("Best Val Acc: {:.4f}".format(best_acc))
    
    model.load_state_dict(best_model_wts)
    return model

In [37]:
model_ft = train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=7)

Epoch 0/6
----------


100%|██████████| 99/99 [00:51<00:00,  1.93it/s]


train Loss: 1.6606 Acc: 0.7858


100%|██████████| 16/16 [00:02<00:00,  5.87it/s]


val Loss: 1.4572 Acc: 0.8500

Epoch 1/6
----------


100%|██████████| 99/99 [00:50<00:00,  1.97it/s]


train Loss: 1.4470 Acc: 0.8528


100%|██████████| 16/16 [00:03<00:00,  4.96it/s]


val Loss: 1.3314 Acc: 0.8960

Epoch 2/6
----------


100%|██████████| 99/99 [00:50<00:00,  1.97it/s]


train Loss: 1.3367 Acc: 0.8854


100%|██████████| 16/16 [00:02<00:00,  6.13it/s]


val Loss: 1.2644 Acc: 0.8920

Epoch 3/6
----------


100%|██████████| 99/99 [00:50<00:00,  1.95it/s]


train Loss: 1.2820 Acc: 0.9009


100%|██████████| 16/16 [00:02<00:00,  6.45it/s]


val Loss: 1.2269 Acc: 0.9140

Epoch 4/6
----------


100%|██████████| 99/99 [00:50<00:00,  1.96it/s]


train Loss: 1.2341 Acc: 0.9135


100%|██████████| 16/16 [00:02<00:00,  6.26it/s]


val Loss: 1.1826 Acc: 0.9340

Epoch 5/6
----------


100%|██████████| 99/99 [00:50<00:00,  1.97it/s]


train Loss: 1.2075 Acc: 0.9203


100%|██████████| 16/16 [00:02<00:00,  6.42it/s]


val Loss: 1.1605 Acc: 0.9280

Epoch 6/6
----------


100%|██████████| 99/99 [00:50<00:00,  1.94it/s]


train Loss: 1.1815 Acc: 0.9267


100%|██████████| 16/16 [00:02<00:00,  6.39it/s]

val Loss: 1.1401 Acc: 0.9360

Training complete in 6m 13s
Best Val Acc: 0.9360





In [38]:
test_loss = 0.0
class_correct = list(0 for i in range(len(classes)))
class_total = list(0 for i in range(len(classes)))
model_ft.eval()

for data, target in tqdm(test_loader):
    data, target = data.to(device), target.to(device)
    with torch.no_grad(): # turn off autograd for faster testing
        output = model_ft(data)
        loss = criterion(output, target)
    test_loss = loss.item() * data.size(0)
    _, pred = torch.max(output, 1)
    correct_tensor = pred.eq(target.data.view_as(pred))
    correct = np.squeeze(correct_tensor.cpu().numpy())
    if len(target) == 32:
        for i in range(32):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

test_loss = test_loss / test_data_len
print('Test Loss: {:.4f}'.format(test_loss))
for i in range(len(classes)):
    if class_total[i] > 0:
        print("Test Accuracy of %5s: %2d%% (%2d/%2d)" % (
            classes[i], 100*class_correct[i]/class_total[i], np.sum(class_correct[i]), np.sum(class_total[i])
        ))
    else:
        print("Test accuracy of %5s: NA" % (classes[i]))
print("Test Accuracy of %2d%% (%2d/%2d)" % (
            100*np.sum(class_correct)/np.sum(class_total), np.sum(class_correct), np.sum(class_total)
        ))

100%|██████████| 16/16 [00:02<00:00,  5.37it/s]

Test Loss: 0.0419
Test Accuracy of ADONIS: 100% ( 4/ 4)
Test Accuracy of AFRICAN GIANT SWALLOWTAIL: 100% ( 5/ 5)
Test Accuracy of AMERICAN SNOOT: 80% ( 4/ 5)
Test Accuracy of AN 88: 100% ( 5/ 5)
Test Accuracy of APPOLLO: 100% ( 5/ 5)
Test Accuracy of ARCIGERA FLOWER MOTH: 100% ( 5/ 5)
Test Accuracy of ATALA: 100% ( 5/ 5)
Test Accuracy of ATLAS MOTH: 80% ( 4/ 5)
Test Accuracy of BANDED ORANGE HELICONIAN: 100% ( 5/ 5)
Test Accuracy of BANDED PEACOCK: 100% ( 5/ 5)
Test Accuracy of BANDED TIGER MOTH: 75% ( 3/ 4)
Test Accuracy of BECKERS WHITE: 100% ( 4/ 4)
Test Accuracy of BIRD CHERRY ERMINE MOTH: 80% ( 4/ 5)
Test Accuracy of BLACK HAIRSTREAK: 100% ( 5/ 5)
Test Accuracy of BLUE MORPHO: 80% ( 4/ 5)
Test Accuracy of BLUE SPOTTED CROW: 100% ( 5/ 5)
Test Accuracy of BROOKES BIRDWING: 100% ( 5/ 5)
Test Accuracy of BROWN ARGUS: 100% ( 5/ 5)
Test Accuracy of BROWN SIPROETA: 100% ( 5/ 5)
Test Accuracy of CABBAGE WHITE: 100% ( 5/ 5)
Test Accuracy of CAIRNS BIRDWING: 100% ( 5/ 5)
Test Accuracy of CH




In [39]:
# our model earns 93% test accuracy, which is very high. lets save it
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model.cpu(), example)
traced_script_module.save("butterfly_swin_transformer.pt")

In [None]:
# That's it for this video, see you next time