#SWIN Transformer


In [13]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import os
from PIL import Image
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
#from linformer import Linformer

In [14]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [15]:
%pip install einops
%pip install wandb



In [16]:
%cd drive/MyDrive/Colab Notebooks/ViT

[Errno 2] No such file or directory: 'drive/MyDrive/Colab Notebooks/ViT'
/content/drive/MyDrive/Colab Notebooks/ViT


In [17]:
from model.swin import SwinTransformer
from dataloader import set_transforms
from utils.utils import seed_everything

seed_everything(42)

In [18]:
import torch.backends.cudnn as cudnn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
use_cuda = torch.cuda.is_available()
cudnn.benchmark = True
print('Use CUDA:', use_cuda)

Use CUDA: True


In [19]:
config={
    "epochs": 20, 
    "batch_size": 64,
    "lr" : 3e-5,
    "gamma" : 0.7,
    "image_size" : 64,
    "hidden_dim" : 96,
    "channels" : 3,
    "num_classes" : 100,
    "head_dim" : 32,
    "window_size" : 4,
    }

# Training settings
import wandb
wandb.init(config=config, project="Cifar100_ViT", entity='NoguNogu',name='Swin')

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

In [20]:
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms


train_transforms = transforms.Compose([transforms.Resize((config['image_size'], config['image_size'])),
                                       transforms.RandomResizedCrop(config['image_size']),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor()
                                       ])

val_transforms = transforms.Compose([
                                     transforms.Resize((config['image_size'], config['image_size'])),
                                     transforms.RandomResizedCrop(config['image_size']),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor()
                                     ])


trainset = datasets.CIFAR100(
    root='./data/Cifar100',
    train=True,
    download=False,
    transform=train_transforms)
    
testset = datasets.CIFAR100(
    root='./data/Cifar100',
    train=False,
    download=False,
    transform=val_transforms)

trainloader = DataLoader(trainset,
                         batch_size=config['batch_size'],
                         shuffle=True,
                         num_workers=2)

testloader = DataLoader(testset,
                         batch_size=config['batch_size'],
                         shuffle=False,
                         num_workers=2)

In [21]:
model = SwinTransformer(
    hidden_dim= config['hidden_dim'],
    layers=(2, 2, 6, 2),
    heads=(3, 6, 12, 24),
    channels=config['channels'],
    num_classes=config['num_classes'],
    head_dim=config['head_dim'],
    window_size=config['window_size'],
    downscaling_factors=(2, 2, 2, 2),
    relative_pos_embedding=True).to(device)

In [31]:
%pip install torchinfo

from torchinfo import summary

summary(model, input_size=(config['batch_size'], 3, 64, 64))
#summary(model, (3, 64, 64))



Layer (type:depth-idx)                   Output Shape              Param #
SwinTransformer                          --                        --
├─StageModule: 1-1                       [64, 96, 32, 32]          --
│    └─PatchMerging: 2-1                 [64, 32, 32, 96]          --
│    │    └─Unfold: 3-1                  [64, 12, 1024]            --
│    │    └─Linear: 3-2                  [64, 32, 32, 96]          1,248
├─StageModule: 1-2                       [64, 192, 16, 16]         --
│    └─PatchMerging: 2-2                 [64, 16, 16, 192]         --
│    │    └─Unfold: 3-3                  [64, 384, 256]            --
│    │    └─Linear: 3-4                  [64, 16, 16, 192]         73,920
├─StageModule: 1-3                       [64, 384, 8, 8]           --
│    └─PatchMerging: 2-3                 [64, 8, 8, 384]           --
│    │    └─Unfold: 3-5                  [64, 768, 64]             --
│    │    └─Linear: 3-6                  [64, 8, 8, 384]           295,296
├─S

In [30]:
from torchsummary import summary

summary(model, (3, 64, 64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Unfold-1             [-1, 12, 1024]               0
            Linear-2           [-1, 32, 32, 96]           1,248
      PatchMerging-3           [-1, 32, 32, 96]               0
         LayerNorm-4           [-1, 32, 32, 96]             192
            Linear-5          [-1, 32, 32, 288]          27,648
            Linear-6           [-1, 32, 32, 96]           9,312
   WindowAttention-7           [-1, 32, 32, 96]               0
           PreNorm-8           [-1, 32, 32, 96]               0
          Residual-9           [-1, 32, 32, 96]               0
        LayerNorm-10           [-1, 32, 32, 96]             192
           Linear-11          [-1, 32, 32, 384]          37,248
             GELU-12          [-1, 32, 32, 384]               0
           Linear-13           [-1, 32, 32, 96]          36,960
      FeedForward-14           [-1, 32,

In [23]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=config['lr'])
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=config['gamma'])

In [24]:
# 3. Log gradients and model parameters
wandb.watch(model, log_freq=100)
for epoch in range(config['epochs']):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(trainloader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(trainloader)
        epoch_loss += loss / len(trainloader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in testloader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(testloader)
            epoch_val_loss += val_loss / len(testloader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

    # 4. Log metrics to visualize performance
    wandb.log({
        "Epoch": epoch+1,
        "loss": epoch_loss,
        "acc": epoch_accuracy,
        "val_loss" : epoch_val_loss,
        "val_acc": epoch_val_accuracy
        })

HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 1 - loss : 4.0753 - acc: 0.0752 - val_loss : 3.8348 - val_acc: 0.1088



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 2 - loss : 3.6936 - acc: 0.1325 - val_loss : 3.6315 - val_acc: 0.1348



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 3 - loss : 3.4836 - acc: 0.1687 - val_loss : 3.3640 - val_acc: 0.1929



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 4 - loss : 3.2950 - acc: 0.2043 - val_loss : 3.2202 - val_acc: 0.2214



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 5 - loss : 3.1641 - acc: 0.2291 - val_loss : 3.1228 - val_acc: 0.2364



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 6 - loss : 3.0407 - acc: 0.2542 - val_loss : 3.0419 - val_acc: 0.2527



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 7 - loss : 2.9428 - acc: 0.2725 - val_loss : 2.9397 - val_acc: 0.2787



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 8 - loss : 2.8449 - acc: 0.2937 - val_loss : 2.8689 - val_acc: 0.2939



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 9 - loss : 2.7651 - acc: 0.3089 - val_loss : 2.7980 - val_acc: 0.3046



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 10 - loss : 2.6902 - acc: 0.3233 - val_loss : 2.7381 - val_acc: 0.3163



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 11 - loss : 2.6237 - acc: 0.3401 - val_loss : 2.7272 - val_acc: 0.3184



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 12 - loss : 2.5622 - acc: 0.3516 - val_loss : 2.7139 - val_acc: 0.3225



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 13 - loss : 2.4932 - acc: 0.3650 - val_loss : 2.6177 - val_acc: 0.3446



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 14 - loss : 2.4444 - acc: 0.3778 - val_loss : 2.6024 - val_acc: 0.3501



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 15 - loss : 2.3919 - acc: 0.3879 - val_loss : 2.5653 - val_acc: 0.3561



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 16 - loss : 2.3285 - acc: 0.3994 - val_loss : 2.4851 - val_acc: 0.3785



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 17 - loss : 2.2850 - acc: 0.4100 - val_loss : 2.4676 - val_acc: 0.3804



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 18 - loss : 2.2299 - acc: 0.4237 - val_loss : 2.4386 - val_acc: 0.3866



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 19 - loss : 2.1814 - acc: 0.4345 - val_loss : 2.4090 - val_acc: 0.3888



HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


Epoch : 20 - loss : 2.1277 - acc: 0.4461 - val_loss : 2.3944 - val_acc: 0.3957



In [25]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Sat May 15 08:43:50 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   71C    P0    50W / 250W |   4081MiB / 16280MiB |     83%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces