### In this part, we will finetune the ViT model to Stanford Dogs Dataset

In [1]:
import PIL
import numpy as np
from tqdm import tqdm
import torch
import torch.optim as optim
from torchvision import transforms
from models.models_to_finetune import deit_small_patch16_224
from datasets import CUBDataset, DOGDataset
import math

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

## Setup dataset

In [3]:
# Set train and test set
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
data_transform = transforms.Compose([  # Accuracy:87.622%

        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
data_transform1 = transforms.Compose([ # Accuracy:83.263%

        transforms.CenterCrop(224),
        transforms.ColorJitter(brightness=0.2, contrast=0.3, saturation=0.1, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
data_transform2 = transforms.Compose([ # Accuracy:85.524%
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
data_transform3 = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomRotation(20),
        transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

data_transform4 = transforms.Compose([  # 

        transforms.Resize((224, 224)),
        transforms.RandomRotation(20),
        transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])


In [9]:
data_root = "./dog/"

for tr in [data_transform]:
    train_dataset = DOGDataset(image_root_path=f"{data_root}", transform=tr, split="train")
    test_dataset = DOGDataset(image_root_path=f"{data_root}", transform=data_transform, split="test")
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, drop_last=True, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, drop_last=False)
    
    
    model = deit_small_patch16_224(pretrained=True, use_top_n_heads=12,use_patch_outputs=False).cuda()
    
# freeze backbone and add linear classifier on top that is for 120 classes
    for param in model.parameters():
        param.requires_grad = False
        
    model.head = torch.nn.Linear(in_features=model.head.in_features, out_features=120) # dogs dataset has 120 classes
    
    model.head.apply(model._init_weights)
    for param in model.head.parameters():
        param.requires_grad = True

    model = model.to(device)
    for param in model.parameters():
        param.requires_grad = True
        
    criterion = torch.nn.CrossEntropyLoss()
    
    my_list = ['head.weight', 'head.bias']
    params = list(filter(lambda kv: kv[0] in my_list, model.named_parameters()))
    base_params = list(filter(lambda kv: kv[0] not in my_list, model.named_parameters()))

    param_list = []
    for name in base_params:
        param_list.append(name[0])

    block11 = [name   for name in param_list if 'blocks.11'  in name]
#     block11.extend(['norm.weight',
#      'norm.bias'])
    block10 = [name   for name in param_list if 'blocks.10'  in name]
    block9 = [name   for name in param_list if 'blocks.9'  in name]
    block8 = [name   for name in param_list if 'blocks.8'  in name]
    
    checkpoint = torch.load('model_best.pth.tar')
    model.load_state_dict(checkpoint['state_dict'])
    #optimizer.load_state_dict(checkpoint['optimizer'])
    
    #optimizer = optim.Adam(model.parameters(), lr=0.01, betas=(0.5, 0.999))
#     optimizer = optim.Adam([{'params': block11, 'lr':0.001, 'betas': (0.5, 0.999)},
#                             {'params': block10, 'lr':0.0001, 'betas': (0.5, 0.999)},
#                             {'params': block9, 'lr': 0.0001, 'betas': (0.5, 0.999)},
#                             {'params': block8, 'lr': 0.00001, 'betas': (0.5, 0.999)},
#                             {'params':  [i[1]for i in params], 'lr': 0.001, 'betas': (0.5, 0.999)}])
#     epochs = 50
#     print('Training....')
#     for epoch in range(epochs):
#         with tqdm(train_loader) as p_bar:
#             for samples, targets in p_bar:
#                 p_bar.set_description(f"Epoch {epoch}")
#                 samples = samples.to(device)
#                 targets = targets.to(device)

#                 outputs = model(samples, fine_tune=True)
#                 loss = criterion(outputs, targets)

#                 loss_value = loss.item()
#                 if not math.isfinite(loss_value):
#                     print("Loss is {}, stopping training".format(loss_value))
#                     sys.exit(1)
#                 optimizer.zero_grad()
#                 loss.backward()
#                 optimizer.step()
#                 p_bar.set_postfix({'loss': loss_value})
#     print('Testing....')
    acc=0
    with tqdm(test_loader) as p_bar:
        for samples, targets in p_bar:
            samples = samples.to(device)
            targets = targets.to(device)

            outputs = model(samples, fine_tune=True)
            acc+=torch.sum(outputs.argmax(dim=-1) == targets).item()

    print('Accuracy:{0:.3%}'.format(acc/len(test_dataset)))

    

_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=[])


100%|█████████████████████████████████████████| 537/537 [01:17<00:00,  6.94it/s]

Accuracy:87.634%





In [82]:
# optimiser = torch.optim.SGD([{"params": Net.fc1.parameters(), 'lr' : 0.001, "momentum" : 0.99},
#                              {"params": Net.fc2.parameters()}], lr = 0.01, momentum = 0.9)


my_list = ['head.weight', 'head.bias']
params = list(filter(lambda kv: kv[0] in my_list, model.named_parameters()))
base_params = list(filter(lambda kv: kv[0] not in my_list, model.named_parameters()))

param_list = []
for name in base_params:
    param_list.append((name[0], name[1]))

block11 = [name[1]   for name in param_list if 'blocks.11'  in name[0]]
    block11.extend([param_list[-2][1],
     param_list[-1][1]])
block10 = [name[1]   for name in param_list if 'blocks.10'  in name[0]]
block9 = [name[1]   for name in param_list if 'blocks.9'  in name[0]]
block8 = [name[1]   for name in param_list if 'blocks.8'  in name[0]]

#optimizer = optim.Adam(model.parameters(), lr=0.01, betas=(0.5, 0.999))
optimizer = optim.Adam([{'params': block11, 'lr':0.001, 'betas': (0.5, 0.999)},
                {'params': block10, 'lr':0.0001, 'betas': (0.5, 0.999)},
                {'params': block9, 'lr': 0.0001, 'betas': (0.5, 0.999)},
                {'params': block8, 'lr': 0.00001, 'betas': (0.5, 0.999)},
                {'params':  [i[1]for i in params], 'lr': 0.001, 'betas': (0.5, 0.999)}])


In [76]:
params[0][1]

Parameter containing:
tensor([[ 0.0243, -0.0053,  0.0073,  ...,  0.0148,  0.0108, -0.0129],
        [ 0.0030, -0.0227,  0.0011,  ..., -0.0187,  0.0345, -0.0152],
        [-0.0171, -0.0013,  0.0056,  ..., -0.0152, -0.0105, -0.0100],
        ...,
        [ 0.0226, -0.0216, -0.0181,  ..., -0.0209,  0.0218,  0.0052],
        [ 0.0410, -0.0425, -0.0115,  ...,  0.0199, -0.0172, -0.0212],
        [-0.0351,  0.0171,  0.0083,  ..., -0.0037,  0.0226, -0.0055]],
       device='cuda:0', requires_grad=True)

In [80]:
param_list = []
for name in base_params:
    param_list.append((name[0], name[1]))
param_list

[('cls_token',
  Parameter containing:
  tensor([[[-3.9859e-02, -4.7006e-02,  3.0603e-02, -8.0471e-01, -5.1156e-01,
            -6.2574e-01,  2.4592e-01, -6.9993e-01,  4.2092e-02, -1.0225e+00,
            -4.3548e-02,  1.1119e-02,  1.6613e-02, -6.3589e-01, -1.1603e+00,
             5.0059e-01, -3.2553e-01,  1.4249e-01,  5.1884e-01,  2.4932e-02,
            -1.6074e-01,  1.6326e-01,  5.4418e-03, -1.3323e-02, -2.1773e-01,
            -2.3380e-01, -1.1907e-01,  1.6013e+00,  3.7736e-01, -4.4444e-02,
             1.5786e-02, -2.7299e-02,  1.1802e-02,  4.5447e-01,  3.9871e-03,
            -9.9099e-01, -3.7284e-02, -2.3835e-02,  8.9959e-02,  3.1242e-01,
            -3.9857e-02,  1.5268e-01,  8.9504e-03,  1.3713e+00,  1.3731e-01,
             1.9936e-01, -2.9797e-01,  3.1435e-01,  1.0969e-01,  5.7547e-01,
             1.0796e-01, -5.1769e-02,  8.6791e-01,  5.4784e-01,  7.1095e-02,
             3.5920e-02,  2.7323e-02,  3.0601e-01,  4.0461e-01,  3.1754e-01,
            -2.1408e-01, -2.3908e-03,

In [85]:
param_list[-2:]

[('norm.weight',
  Parameter containing:
  tensor([1.4399, 1.2805, 1.3020, 1.4227, 1.3177, 1.3913, 1.3326, 1.1943, 1.4930,
          1.1751, 1.3554, 1.2137, 1.2393, 1.2155, 1.3559, 1.3057, 1.2435, 1.2911,
          1.3832, 1.3979, 1.2391, 1.3050, 1.2539, 1.3349, 1.2524, 1.3131, 1.3080,
          1.2083, 1.3156, 1.2425, 1.3523, 1.1901, 1.3903, 1.2259, 1.3353, 1.2022,
          1.3274, 1.2503, 1.1622, 1.2007, 1.3240, 1.2621, 1.3331, 1.3821, 1.2320,
          1.2288, 1.3627, 1.4508, 1.2987, 1.3425, 1.1675, 1.2485, 1.2395, 1.1680,
          1.2873, 1.3542, 1.3866, 1.2758, 1.2942, 1.2809, 1.1959, 1.2301, 1.2234,
          1.3758, 1.2916, 1.2703, 1.3651, 1.4469, 1.2882, 1.5075, 1.2381, 1.2058,
          1.2732, 1.2979, 1.2687, 1.3001, 1.2198, 1.2952, 1.3041, 1.2082, 1.2485,
          1.2935, 1.1888, 1.2435, 1.3419, 1.2739, 1.4323, 1.3538, 1.2715, 1.4460,
          1.3190, 1.2579, 1.3563, 1.2388, 1.2577, 1.2088, 1.2554, 1.1991, 1.3488,
          1.3131, 1.1981, 1.3003, 1.1936, 1.2625, 1.2583,

In [87]:
param_list[-1]

('norm.bias',
 Parameter containing:
 tensor([ 7.7214e-03,  2.4394e-01, -1.0672e-01, -2.9516e-01, -1.0609e-01,
          3.2042e-02, -1.0961e-01,  2.6324e-01,  2.4436e-01, -3.8204e-01,
          1.0117e-01, -2.1206e-02,  1.2525e-01, -8.9974e-02,  9.6729e-02,
          1.7676e-01, -3.1723e-02, -1.2963e-01, -3.1566e-02, -1.8053e-02,
         -3.3319e-02, -1.1454e-01, -1.3713e-01, -1.3234e-01,  1.8526e-01,
         -2.5303e-02,  1.2068e-01, -3.8178e-01,  6.4466e-02, -5.2894e-02,
          2.1485e-02, -5.0624e-01, -1.3696e-01, -1.8588e-02,  1.1969e-01,
         -3.7775e-01, -1.3086e-01, -1.5258e-02,  1.1189e-01, -1.4449e-01,
          2.3175e-01,  8.4147e-02, -8.4270e-02,  2.3193e-01,  1.3954e-01,
          8.0393e-02,  1.7572e-02, -1.0284e-01, -1.4451e-01, -3.1911e-01,
          3.5000e-01, -2.2037e-01,  2.6579e-01,  1.1214e-02,  5.5485e-02,
         -2.2960e-03, -5.3662e-02, -3.8542e-02, -1.7256e-02,  2.9221e-01,
         -3.2865e-01, -8.0416e-02, -5.9403e-02,  1.8221e-01, -9.7802e-03,
 