## Import Libraries

In [1]:
from __future__ import print_function

import os
import time
import random
import zipfile
from itertools import chain

import timm
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
from collections import OrderedDict

import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
from torchvision import models
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset

from LATransformer.model import ClassBlock, LATransformer
from LATransformer.utils import save_network, update_summary

os.environ['CUDA_VISIBLE_DEVICES']='1'
device = "cuda"

### Set Config Parameters

In [2]:
batch_size = 32
num_epochs = 30
lr = 3e-4
gamma = 0.7
unfreeze_after=2
lr_decay=.8
lmbd = 8

## Load Data

In [3]:
transform_train_list = [
    transforms.Resize((224,224), interpolation=3),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]
transform_val_list = [
    transforms.Resize(size=(224,224),interpolation=3), #Image.BICUBIC
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]
data_transforms = {
'train': transforms.Compose( transform_train_list ),
'val': transforms.Compose(transform_val_list),
}

In [4]:
image_datasets = {}
data_dir = "data/Market-Pytorch/Market/"

image_datasets['train'] = datasets.ImageFolder(os.path.join(data_dir, 'train'),
                                          data_transforms['train'])
image_datasets['val'] = datasets.ImageFolder(os.path.join(data_dir, 'val'),
                                          data_transforms['val'])
train_loader = DataLoader(dataset = image_datasets['train'], batch_size=batch_size, shuffle=True )
valid_loader = DataLoader(dataset = image_datasets['val'], batch_size=batch_size, shuffle=True)
# dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
#                                              shuffle=True, num_workers=8, pin_memory=True) # 8 workers may work faster
#               for x in ['train', 'val']}
# dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
print(len(class_names))

751


## Load Model

In [5]:
# Load pre-trained ViT
vit_base = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=751)
vit_base= vit_base.to(device)
vit_base.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (



###  Train

In [6]:
class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [7]:
def validate(model, loader, loss_fn):
    batch_time_m = AverageMeter()
    losses_m = AverageMeter()
    top1_m = AverageMeter()
    top5_m = AverageMeter()

    model.eval()
    epoch_accuracy = 0
    epoch_loss = 0
    end = time.time()
    last_idx = len(loader) - 1
    
    running_loss = 0.0
    running_corrects = 0.0

    with torch.no_grad():
        for input, target in tqdm(loader):

            input, target = input.to(device), target.to(device)
            
            output = model(input)
            
            score = 0.0
            sm = nn.Softmax(dim=1)
            for k, v in output.items():
                score += sm(output[k])
            _, preds = torch.max(score.data, 1)

            loss = 0.0
            for k,v in output.items():
                loss += loss_fn(output[k], target)


            batch_time_m.update(time.time() - end)
            acc = (preds == target.data).float().mean()
            epoch_loss += loss/len(loader)
            epoch_accuracy += acc / len(loader)
            
            print(f"Epoch : {epoch+1} - val_loss : {epoch_loss:.4f} - val_acc: {epoch_accuracy:.4f}", end="\r")
    print()    
    metrics = OrderedDict([('val_loss', epoch_loss.data.item()), ("val_accuracy", epoch_accuracy.data.item())])


    return metrics

In [8]:
def train_one_epoch(
        epoch, model, loader, optimizer, loss_fn,
        lr_scheduler=None, saver=None, output_dir='', 
        loss_scaler=None, model_ema=None, mixup_fn=None):

 

    
    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()

    model.train()
    epoch_accuracy = 0
    epoch_loss = 0
    end = time.time()
    last_idx = len(loader) - 1
    num_updates = epoch * len(loader)
    running_loss = 0.0
    running_corrects = 0.0

    for data, target in tqdm(loader):
        data, target = data.to(device), target.to(device)

            
        data_time_m.update(time.time() - end)

        optimizer.zero_grad()
        output = model(data)
        score = 0.0
        sm = nn.Softmax(dim=1)
        for k, v in output.items():
            score += sm(output[k])
        _, preds = torch.max(score.data, 1)
        
        loss = 0.0
        for k,v in output.items():
            loss += loss_fn(output[k], target)
        loss.backward()

        optimizer.step()

        batch_time_m.update(time.time() - end)
        
#         print(preds, target.data)
        acc = (preds == target.data).float().mean()
        
#         print(acc)
        epoch_loss += loss/len(loader)
        epoch_accuracy += acc / len(loader)
#         if acc:
#             print(acc, epreds, target.data)
        print(
    f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f}"
, end="\r")

    print()

    return OrderedDict([('train_loss', epoch_loss.data.item()), ("train_accuracy", epoch_accuracy.data.item())])


In [9]:
def freeze_all_blocks(model):
    frozen_blocks = 12
    for block in model.model.blocks[:frozen_blocks]:
        for param in block.parameters():
            param.requires_grad=False
    

In [10]:
def unfreeze_blocks(model, amount= 1):
    
    for block in model.model.blocks[11-amount:]:
        for param in block.parameters():
            param.requires_grad=True
    return model

## Training Loop

In [11]:
# Create LA Transformer
model = LATransformer(vit_base, lmbd).to(device)
print(model.eval())

# loss function
criterion = nn.CrossEntropyLoss()

# optimizer
optimizer = optim.Adam(model.parameters(),weight_decay=5e-4, lr=lr)

# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
freeze_all_blocks(model)

LATransformer(
  (model): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): Block(
        (norm1): LayerNorm((768

In [None]:
best_acc = 0.0
y_loss = {} # loss history
y_loss['train'] = []
y_loss['val'] = []
y_err = {}
y_err['train'] = []
y_err['val'] = []
print("training...")
output_dir = ""
best_acc = 0
name = "la_with_lmbd_{}".format(lmbd)

try:
    os.mkdir("model/" + name)

except:
    pass
output_dir = "model/" + name
unfrozen_blocks = 0

for epoch in range(num_epochs):

    if epoch%unfreeze_after==0:
        unfrozen_blocks += 1
        model = unfreeze_blocks(model, unfrozen_blocks)
        optimizer.param_groups[0]['lr'] *= lr_decay 
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print("Unfrozen Blocks: {}, Current lr: {}, Trainable Params: {}".format(unfrozen_blocks, 
                                                                             optimizer.param_groups[0]['lr'], 
                                                                             trainable_params))

    train_metrics = train_one_epoch(
        epoch, model, train_loader, optimizer, criterion,
        lr_scheduler=None, saver=None)

    eval_metrics = validate(model, valid_loader, criterion)


    # update summary
    update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
                   write_header=True)

    # deep copy the model
    last_model_wts = model.state_dict()
    if eval_metrics['val_accuracy'] > best_acc:
        best_acc = eval_metrics['val_accuracy']
        save_network(model, epoch,name)
        print("SAVED!")

training...
Unfrozen Blocks: 1, Current lr: 0.00023999999999999998, Trainable Params: 20962817


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 1 - loss : 82.7351 - acc: 0.0880



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 1 - val_loss : 77.1901 - val_acc: 0.0497

SAVED!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 2 - loss : 59.0334 - acc: 0.2364



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 2 - val_loss : 58.8111 - val_acc: 0.1918

SAVED!
Unfrozen Blocks: 2, Current lr: 0.000192, Trainable Params: 28050689


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 3 - loss : 41.1694 - acc: 0.4632



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 3 - val_loss : 47.2650 - val_acc: 0.3353

SAVED!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 4 - loss : 28.3517 - acc: 0.6674



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 4 - val_loss : 33.9487 - val_acc: 0.5391

SAVED!
Unfrozen Blocks: 3, Current lr: 0.00015360000000000002, Trainable Params: 35138561


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 5 - loss : 18.7140 - acc: 0.8141



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 5 - val_loss : 25.3060 - val_acc: 0.6617

SAVED!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 6 - loss : 12.2253 - acc: 0.9050



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 6 - val_loss : 19.0367 - val_acc: 0.7506

SAVED!
Unfrozen Blocks: 4, Current lr: 0.00012288000000000002, Trainable Params: 42226433


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 7 - loss : 8.0031 - acc: 0.9542



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 7 - val_loss : 14.0309 - val_acc: 0.8325

SAVED!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 8 - loss : 5.4122 - acc: 0.9771



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 8 - val_loss : 11.0224 - val_acc: 0.8602

SAVED!
Unfrozen Blocks: 5, Current lr: 9.830400000000001e-05, Trainable Params: 49314305


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 9 - loss : 3.7149 - acc: 0.9906



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 9 - val_loss : 8.5832 - val_acc: 0.8944

SAVED!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 10 - loss : 2.7142 - acc: 0.9950



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 10 - val_loss : 7.6481 - val_acc: 0.9033

SAVED!
Unfrozen Blocks: 6, Current lr: 7.864320000000001e-05, Trainable Params: 56402177


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 11 - loss : 2.0092 - acc: 0.9965



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 11 - val_loss : 6.7372 - val_acc: 0.9137

SAVED!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 12 - loss : 1.5912 - acc: 0.9977



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 12 - val_loss : 6.0404 - val_acc: 0.9189

SAVED!
Unfrozen Blocks: 7, Current lr: 6.291456000000001e-05, Trainable Params: 63490049


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 13 - loss : 1.3100 - acc: 0.9984



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 13 - val_loss : 5.8097 - val_acc: 0.9230

SAVED!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 14 - loss : 1.0894 - acc: 0.9991



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 14 - val_loss : 5.1302 - val_acc: 0.9321

SAVED!
Unfrozen Blocks: 8, Current lr: 5.0331648000000016e-05, Trainable Params: 70577921


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 15 - loss : 0.9347 - acc: 0.9992



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 15 - val_loss : 5.5233 - val_acc: 0.9217



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 16 - loss : 0.9086 - acc: 0.9996



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 16 - val_loss : 4.4655 - val_acc: 0.9362

SAVED!
Unfrozen Blocks: 9, Current lr: 4.026531840000002e-05, Trainable Params: 77665793


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 17 - loss : 0.7159 - acc: 0.9999



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 17 - val_loss : 4.2927 - val_acc: 0.9414

SAVED!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 18 - loss : 0.6362 - acc: 0.9998



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 18 - val_loss : 4.2925 - val_acc: 0.9453

SAVED!
Unfrozen Blocks: 10, Current lr: 3.221225472000002e-05, Trainable Params: 84753665


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 19 - loss : 0.6389 - acc: 0.9997



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 19 - val_loss : 4.5622 - val_acc: 0.9319



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 20 - loss : 0.5667 - acc: 0.9998



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 20 - val_loss : 4.6590 - val_acc: 0.9254

Unfrozen Blocks: 11, Current lr: 2.5769803776000016e-05, Trainable Params: 91841537


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 21 - loss : 0.5401 - acc: 0.9998



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 21 - val_loss : 3.8805 - val_acc: 0.9401



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 22 - loss : 0.6303 - acc: 0.9991



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 22 - val_loss : 4.4941 - val_acc: 0.9375

Unfrozen Blocks: 12, Current lr: 2.0615843020800013e-05, Trainable Params: 91841537


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 23 - loss : 0.5186 - acc: 0.9997



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 23 - val_loss : 4.0348 - val_acc: 0.9435



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 24 - loss : 0.4421 - acc: 0.9999



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 24 - val_loss : 3.6783 - val_acc: 0.9464

SAVED!
Unfrozen Blocks: 13, Current lr: 1.649267441664001e-05, Trainable Params: 91841537


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 25 - loss : 0.4184 - acc: 1.0000



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 25 - val_loss : 3.9668 - val_acc: 0.9425



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 26 - loss : 0.4113 - acc: 1.0000



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 26 - val_loss : 3.9590 - val_acc: 0.9398

Unfrozen Blocks: 14, Current lr: 1.319413953331201e-05, Trainable Params: 91841537


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 27 - loss : 0.3976 - acc: 1.0000



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 27 - val_loss : 3.8370 - val_acc: 0.9414



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 28 - loss : 0.3917 - acc: 1.0000



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 28 - val_loss : 3.8097 - val_acc: 0.9422

Unfrozen Blocks: 15, Current lr: 1.0555311626649608e-05, Trainable Params: 91841537


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))

Epoch : 29 - loss : 0.3875 - acc: 1.0000



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))

Epoch : 29 - val_loss : 0.5044 - val_acc: 0.1576