In [1]:
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
import torch.nn.functional as F

import matplotlib
import matplotlib.pyplot as plt

import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms

import numpy

from IPython.display import Image
from tqdm.notebook import tqdm
from math import *


In [2]:
# Set dataset and model

# Dataset
dataset = torchvision.datasets.Flowers102

# Pre-trained models of ViT: 
# torchvision.models.vit_b_16/vit_b_32/vit_l_16/vit_l_32
from torchvision.models import vit_b_16
model = vit_b_16(pretrained=True)

model_filename = "flowers102-vitb16"

warmup_epoch = 10
num_epoch = 30

lr_warmup = 1e-3
lr_base = 1e-2
lr_min = 1e-3

device = "cuda:2" if torch.cuda.is_available() else "cpu"

In [3]:
device

'cuda:2'

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(224),
    transforms.CenterCrop(224)
])
train_set = torchvision.datasets.Flowers102("./data", split= "train", download = True, transform=transform)
test_set = torchvision.datasets.Flowers102("./data", split= "test", download = True, transform=transform)
val_set = torchvision.datasets.Flowers102("./data", split= "val", download = True, transform=transform)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=True)
val_loader = DataLoader(val_set, batch_size=64, shuffle=True)

In [6]:
in_features = model.heads[0].in_features
in_features

768

In [7]:
model.heads = nn.Linear(in_features, 102) # set in_features and out_features
model.train()

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (linear_1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU()
          (dropout_1): Dropout(p=0.0, inplace=False)
          (linear_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout_2): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
 

In [8]:
# Freeze other layers unless it is self_attention or mlp_head
for name, param in model.named_parameters():
    if ('self_attention' in name) or ('head' in name):
        print("Unfreeze " + name)
        param.requires_grad = True
    else:
        print("Freeze " + name)
        param.requires_grad = False


Freeze class_token
Freeze conv_proj.weight
Freeze conv_proj.bias
Freeze encoder.pos_embedding
Freeze encoder.layers.encoder_layer_0.ln_1.weight
Freeze encoder.layers.encoder_layer_0.ln_1.bias
Unfreeze encoder.layers.encoder_layer_0.self_attention.in_proj_weight
Unfreeze encoder.layers.encoder_layer_0.self_attention.in_proj_bias
Unfreeze encoder.layers.encoder_layer_0.self_attention.out_proj.weight
Unfreeze encoder.layers.encoder_layer_0.self_attention.out_proj.bias
Freeze encoder.layers.encoder_layer_0.ln_2.weight
Freeze encoder.layers.encoder_layer_0.ln_2.bias
Freeze encoder.layers.encoder_layer_0.mlp.linear_1.weight
Freeze encoder.layers.encoder_layer_0.mlp.linear_1.bias
Freeze encoder.layers.encoder_layer_0.mlp.linear_2.weight
Freeze encoder.layers.encoder_layer_0.mlp.linear_2.bias
Freeze encoder.layers.encoder_layer_1.ln_1.weight
Freeze encoder.layers.encoder_layer_1.ln_1.bias
Unfreeze encoder.layers.encoder_layer_1.self_attention.in_proj_weight
Unfreeze encoder.layers.encoder_laye

In [9]:
# Reading whole dataloader into memory can improve the speed of training
train_loader = list(train_loader)
test_loader = list(test_loader)

In [10]:
train_set

Dataset Flowers102
    Number of datapoints: 1020
    Root location: ./data
    split=train
    StandardTransform
Transform: Compose(
               ToTensor()
               Resize(size=224, interpolation=bilinear, max_size=None, antialias=None)
               CenterCrop(size=(224, 224))
           )

In [12]:
loss_function = nn.CrossEntropyLoss()

In [13]:
lr = lr_base
def adjust_learning_rate(optimizer, current_epoch, max_epoch, lr_min=lr_min, lr_max=lr_base, warmup=True):
    if current_epoch < warmup_epoch:
        lr = lr_max * (current_epoch+1) / (warmup_epoch+1)
    else:
        lr = lr_min + (lr_max-lr_min)*(1 + cos(pi * (current_epoch - warmup_epoch) / (max_epoch - warmup_epoch))) / 2
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    print("Learning rate is set to "+str(lr))

optimiser = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), 
                      lr=lr,
                      momentum=0.9) #only optimse non-frozen layers
model.to(device)

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (linear_1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU()
          (dropout_1): Dropout(p=0.0, inplace=False)
          (linear_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout_2): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
 

In [14]:
results = []

for epoch in range(warmup_epoch+num_epoch):
    running_loss = 0
    train_acc = 0
    
    adjust_learning_rate(optimizer=optimiser,
                        current_epoch=epoch,
                        max_epoch=warmup_epoch+num_epoch)
    with tqdm(train_loader, desc='Train(epoch'+str(epoch)+')') as t:
        total = 0
        correct = 0
        for data in t:
            model.train()
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimiser.zero_grad()

            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            running_loss += loss.item()
            loss.backward()
            optimiser.step()
            
            pred = torch.argmax(F.softmax(outputs), dim=1)
            total += len(labels)
            correct += sum(pred == labels)
            
        train_acc = (100.0 * correct) / total

        t.set_postfix(running_loss=running_loss,
                      runing_acc=train_acc)
            
    print("epoch %d/%d:(tr)loss=%.4f" % (epoch, warmup_epoch+num_epoch, running_loss))
    print("epoch %d/%d:(tr)acc=%.4f%%" % (epoch, warmup_epoch+num_epoch, train_acc))
    
    test_running_loss = 0
    test_acc = 0
    
    if epoch%10==9:
        
        with tqdm(test_loader, desc='test'+str(epoch)) as t:
            with torch.no_grad():
                total = 0
                correct = 0
                for data in t:
                    model.eval()
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)

                    outputs = model(inputs)

                    loss = loss_function(outputs, labels)
                    test_running_loss += loss.item()

                    pred = torch.argmax(F.softmax(outputs), dim=1)
                    total += len(labels)
                    correct += sum(pred == labels)
                test_acc = (100.0 * correct) / total

                t.set_postfix(running_loss=test_running_loss,
                              runing_acc=test_acc)

        print("epoch %d/%d:(te)loss=%.4f" % (epoch, warmup_epoch+num_epoch, test_running_loss))
        print("epoch %d/%d:(te)acc=%.4f%%" % (epoch, warmup_epoch+num_epoch, test_acc))
          
    results.append({'running_loss':running_loss,
                   'train_acc':train_acc,
                   'test_running_loss':test_running_loss,
                   'test_acc':test_acc})


Learning rate is set to 0.0009090909090909091


Train(epoch0):   0%|          | 0/16 [00:00<?, ?it/s]



epoch 0/40:(tr)loss=74.4035
epoch 0/40:(tr)acc=1.4706%
Learning rate is set to 0.0018181818181818182


Train(epoch1):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 1/40:(tr)loss=71.9159
epoch 1/40:(tr)acc=4.5098%
Learning rate is set to 0.002727272727272727


Train(epoch2):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 2/40:(tr)loss=66.3191
epoch 2/40:(tr)acc=19.4118%
Learning rate is set to 0.0036363636363636364


Train(epoch3):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 3/40:(tr)loss=56.1171
epoch 3/40:(tr)acc=53.3333%
Learning rate is set to 0.004545454545454546


Train(epoch4):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 4/40:(tr)loss=41.2294
epoch 4/40:(tr)acc=79.7059%
Learning rate is set to 0.005454545454545454


Train(epoch5):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 5/40:(tr)loss=25.6905
epoch 5/40:(tr)acc=92.5490%
Learning rate is set to 0.006363636363636364


Train(epoch6):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 6/40:(tr)loss=14.0369
epoch 6/40:(tr)acc=97.7451%
Learning rate is set to 0.007272727272727273


Train(epoch7):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 7/40:(tr)loss=7.3315
epoch 7/40:(tr)acc=99.2157%
Learning rate is set to 0.00818181818181818


Train(epoch8):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 8/40:(tr)loss=4.0013
epoch 8/40:(tr)acc=99.8039%
Learning rate is set to 0.009090909090909092


Train(epoch9):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 9/40:(tr)loss=2.3598
epoch 9/40:(tr)acc=100.0000%


test9:   0%|          | 0/97 [00:00<?, ?it/s]



epoch 9/40:(te)loss=84.2538
epoch 9/40:(te)acc=87.0548%
Learning rate is set to 0.010000000000000002


Train(epoch10):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 10/40:(tr)loss=1.5797
epoch 10/40:(tr)acc=100.0000%
Learning rate is set to 0.00997534852915723


Train(epoch11):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 11/40:(tr)loss=1.1446
epoch 11/40:(tr)acc=100.0000%
Learning rate is set to 0.009901664203302126


Train(epoch12):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 12/40:(tr)loss=0.9097
epoch 12/40:(tr)acc=100.0000%
Learning rate is set to 0.009779754323328192


Train(epoch13):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 13/40:(tr)loss=0.7457
epoch 13/40:(tr)acc=100.0000%
Learning rate is set to 0.009610954559391705


Train(epoch14):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 14/40:(tr)loss=0.6223
epoch 14/40:(tr)acc=100.0000%
Learning rate is set to 0.009397114317029977


Train(epoch15):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 15/40:(tr)loss=0.5380
epoch 15/40:(tr)acc=100.0000%
Learning rate is set to 0.009140576474687266


Train(epoch16):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 16/40:(tr)loss=0.4796
epoch 16/40:(tr)acc=100.0000%
Learning rate is set to 0.008844151714648276


Train(epoch17):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 17/40:(tr)loss=0.4342
epoch 17/40:(tr)acc=100.0000%
Learning rate is set to 0.008511087728614863


Train(epoch18):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 18/40:(tr)loss=0.3985
epoch 18/40:(tr)acc=100.0000%
Learning rate is set to 0.008145033635316129


Train(epoch19):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 19/40:(tr)loss=0.3691
epoch 19/40:(tr)acc=100.0000%


test19:   0%|          | 0/97 [00:00<?, ?it/s]

epoch 19/40:(te)loss=60.9277
epoch 19/40:(te)acc=88.9413%
Learning rate is set to 0.007750000000000001


Train(epoch20):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 20/40:(tr)loss=0.3448
epoch 20/40:(tr)acc=100.0000%
Learning rate is set to 0.007330314893841103


Train(epoch21):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 21/40:(tr)loss=0.3245
epoch 21/40:(tr)acc=100.0000%
Learning rate is set to 0.006890576474687264


Train(epoch22):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 22/40:(tr)loss=0.3074
epoch 22/40:(tr)acc=100.0000%
Learning rate is set to 0.006435602608679917


Train(epoch23):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 23/40:(tr)loss=0.2930
epoch 23/40:(tr)acc=100.0000%
Learning rate is set to 0.0059703780847044415


Train(epoch24):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 24/40:(tr)loss=0.2806
epoch 24/40:(tr)acc=100.0000%
Learning rate is set to 0.005500000000000001


Train(epoch25):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 25/40:(tr)loss=0.2701
epoch 25/40:(tr)acc=100.0000%
Learning rate is set to 0.0050296219152955604


Train(epoch26):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 26/40:(tr)loss=0.2611
epoch 26/40:(tr)acc=100.0000%
Learning rate is set to 0.0045643973913200835


Train(epoch27):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 27/40:(tr)loss=0.2534
epoch 27/40:(tr)acc=100.0000%
Learning rate is set to 0.004109423525312737


Train(epoch28):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 28/40:(tr)loss=0.2468
epoch 28/40:(tr)acc=100.0000%
Learning rate is set to 0.0036696851061589


Train(epoch29):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 29/40:(tr)loss=0.2412
epoch 29/40:(tr)acc=100.0000%


test29:   0%|          | 0/97 [00:00<?, ?it/s]

epoch 29/40:(te)loss=57.4091
epoch 29/40:(te)acc=89.1202%
Learning rate is set to 0.003250000000000001


Train(epoch30):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 30/40:(tr)loss=0.2364
epoch 30/40:(tr)acc=100.0000%
Learning rate is set to 0.0028549663646838717


Train(epoch31):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 31/40:(tr)loss=0.2323
epoch 31/40:(tr)acc=100.0000%
Learning rate is set to 0.0024889122713851394


Train(epoch32):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 32/40:(tr)loss=0.2288
epoch 32/40:(tr)acc=100.0000%
Learning rate is set to 0.0021558482853517268


Train(epoch33):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 33/40:(tr)loss=0.2259
epoch 33/40:(tr)acc=100.0000%
Learning rate is set to 0.0018594235253127371


Train(epoch34):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 34/40:(tr)loss=0.2234
epoch 34/40:(tr)acc=100.0000%
Learning rate is set to 0.0016028856829700259


Train(epoch35):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 35/40:(tr)loss=0.2213
epoch 35/40:(tr)acc=100.0000%
Learning rate is set to 0.0013890454406082957


Train(epoch36):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 36/40:(tr)loss=0.2195
epoch 36/40:(tr)acc=100.0000%
Learning rate is set to 0.001220245676671809


Train(epoch37):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 37/40:(tr)loss=0.2180
epoch 37/40:(tr)acc=100.0000%
Learning rate is set to 0.0010983357966978745


Train(epoch38):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 38/40:(tr)loss=0.2167
epoch 38/40:(tr)acc=100.0000%
Learning rate is set to 0.0010246514708427696


Train(epoch39):   0%|          | 0/16 [00:00<?, ?it/s]

epoch 39/40:(tr)loss=0.2155
epoch 39/40:(tr)acc=100.0000%


test39:   0%|          | 0/97 [00:00<?, ?it/s]

epoch 39/40:(te)loss=56.5754
epoch 39/40:(te)acc=89.1039%


In [15]:
results

[{'running_loss': 74.40354537963867,
  'train_acc': tensor(1.4706, device='cuda:2'),
  'test_running_loss': 0,
  'test_acc': 0},
 {'running_loss': 71.9158844947815,
  'train_acc': tensor(4.5098, device='cuda:2'),
  'test_running_loss': 0,
  'test_acc': 0},
 {'running_loss': 66.31913661956787,
  'train_acc': tensor(19.4118, device='cuda:2'),
  'test_running_loss': 0,
  'test_acc': 0},
 {'running_loss': 56.11714959144592,
  'train_acc': tensor(53.3333, device='cuda:2'),
  'test_running_loss': 0,
  'test_acc': 0},
 {'running_loss': 41.22936463356018,
  'train_acc': tensor(79.7059, device='cuda:2'),
  'test_running_loss': 0,
  'test_acc': 0},
 {'running_loss': 25.690459728240967,
  'train_acc': tensor(92.5490, device='cuda:2'),
  'test_running_loss': 0,
  'test_acc': 0},
 {'running_loss': 14.036860942840576,
  'train_acc': tensor(97.7451, device='cuda:2'),
  'test_running_loss': 0,
  'test_acc': 0},
 {'running_loss': 7.331454664468765,
  'train_acc': tensor(99.2157, device='cuda:2'),
  'te

In [16]:
torch.save(model.state_dict(), "model/" + model_filename)

In [17]:
train_loss_list = []
train_acc_list = []
for result in results:
    train_loss_list.append(result["running_loss"])
    train_acc_list.append(torch.Tensor.cpu(result["train_acc"]))

In [None]:
test_loss_list = []
test_acc_list = []
for result in results:
    test_loss_list.append(result["test_running_loss"])
    test_acc_list.append(torch.Tensor.cpu(result["test_acc"]))

In [None]:
val_loss_list = []
val_acc_list = []
for result in results:
    val_loss_list.append(result["val_loss"])
    val_acc_list.append(result["val_acc"])

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10,4))

ax[0].plot(train_loss_list)
ax[0].plot(test_loss_list)
ax[0].legend(['train','test'])
ax[0].grid()
ax[0].set_title("Loss")

ax[1].plot(train_acc_list)
ax[1].plot(test_acc_list)
ax[1].legend(['train','test'])
ax[1].grid()
ax[1].set_title("Accuracy")
plt.show()