## Import Libraries

In [None]:
import os
import time
import argparse
import random
import timm
import numpy as np
from PIL import Image
# from tqdm.notebook import tqdm
from tqdm import tqdm
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
from torchvision import models
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from utils import *
from model import *

In [None]:
seed = 42

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

## Hyper parameters

In [None]:
# os.environ['CUDA_VISIBLE_DEVICES']='0'
# device = "cpu" 
device = "cuda" if torch.cuda.is_available() else "cpu"
num_epochs = 30
batch_size = 32
lr = 3e-4
gamma = 0.7
unfreeze_after = 2 # unfreeze transformer blocks after 2 epochs
lr_decay = .8
lmbd = 8

## Load Data

In [None]:
transform_train_list = [
    transforms.Resize((224,224), interpolation=3),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]
transform_val_list = [
    transforms.Resize(size=(224,224),interpolation=3), #Image.BICUBIC
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]
data_transforms = {
    'train': transforms.Compose( transform_train_list ),
    'val': transforms.Compose(transform_val_list),
}

In [None]:
# data_dir = "/home/shubham/CVP/TrainData_split/"
# data_dir = "/home/shubham/CVP/data/"

# image_datasets['train'] = datasets.ImageFolder(os.path.join(data_dir, 'train'),
#                                           data_transforms['train'])
# image_datasets['val'] = datasets.ImageFolder(os.path.join(data_dir, 'val'),
#                                           data_transforms['val'])

In [None]:
train_dir = "/home/shubham/CVP/data/train/"
# val_dir = "/home/shubham/CVP/data/val/all_imgs/"

image_datasets = {}
image_datasets['train'] = datasets.ImageFolder(train_dir, data_transforms['train'])
# image_datasets['val'] = datasets.ImageFolder(val_dir, data_transforms['val'])

train_loader = DataLoader(dataset = image_datasets['train'], batch_size=batch_size, shuffle=True )
# valid_loader = DataLoader(dataset = image_datasets['val'], batch_size=batch_size, shuffle=True)
class_names = image_datasets['train'].classes # '001','003', etc
print(len(class_names))
# print(len(image_datasets['val'].classes)) # '001','003', etc

## Model

## Load Model

In [None]:
# Load pre-trained ViT
vit_base = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=751)
vit_base = vit_base.to(device)

In [None]:
# Create LA Transformer
model = LATransformer(ViT=vit_base, lmbd=lmbd, num_classes=62).to(device) # len(class_names)

### Utilities

In [None]:
def freeze_all_blocks(model):
    # frozen_blocks = 12
    assert len(model.model.blocks) == 12
    for block in model.model.blocks: # [:frozen_blocks]
        for param in block.parameters():
            param.requires_grad=False

def unfreeze_block(model, block_num = 1):
    # unfreeze transformer blocks from last
    for block in model.model.blocks[11-block_num :]:
        for param in block.parameters():
            param.requires_grad=True
    return model

def save_network(network, model_dir, name):
    save_path = os.path.join(model_dir, name + ".pth")
    torch.save(network.cpu().state_dict(), save_path)
    
    if torch.cuda.is_available():
        network.cuda()

##  Train

In [None]:
def train_one_epoch(epoch, model, loader, optimizer, loss_fn):
    model.train()
    epoch_accuracy, epoch_loss = 0, 0
    total_samples, correct_predictions = 0, 0
    for data, target in tqdm(loader):
        data, target = data.to(device), target.to(device)

        # predictions
        optimizer.zero_grad()
        output = model(data)
        score = 0.0
        sm = nn.Softmax(dim=1)
        for k, v in output.items():
            score += sm(output[k])
        _, preds = torch.max(score.data, 1)
        
        # backpropagation through ensemble
        # loss = 0.0
        # for k,v in output.items():
        #     loss += loss_fn(output[k], target)
        loss = 0.0
        for loss_function in loss_fn:
            for k,v in output.items():
                loss += loss_function(output[k], target)
        
        loss.backward()
        optimizer.step()
        
        # print(preds, target.data)
        # acc = (preds == target.data).float().mean()
        # print(acc)
        
        # print(acc)
        # epoch_loss += loss/len(loader)
        # epoch_accuracy += acc / len(loader)
        # if acc:
        #     print(acc, epreds, target.data)
        
        epoch_loss += (loss.item()/data.shape[0])
        correct_predictions += (preds.eq(target.data).sum().item())
        total_samples += data.size(0)
        epoch_accuracy = correct_predictions/total_samples
        # print(f"Epoch : {epoch}; loss : {epoch_loss:.4f}; acc: {epoch_accuracy:.4f}", end="\r")

    # print("total_samples", total_samples, "correct", correct_predictions)
    epoch_loss /= len(loader)
    return OrderedDict([('train_loss', epoch_loss), ("train_accuracy", epoch_accuracy)])

In [None]:
# def eval_one_epoch(epoch, model, loader, loss_fn):
#     model.eval()
#     epoch_accuracy, epoch_loss = 0, 0
#     total_samples, correct_predictions = 0, 0
#     with torch.no_grad():
#         for data, target in tqdm(loader):
#             data, target = data.to(device), target.to(device)

#             # predictions
#             output = model(data)
#             score = 0.0
#             sm = nn.Softmax(dim=1)
#             for k, v in output.items():
#                 score += sm(output[k])
#             _, preds = torch.max(score.data, 1)

#             # backpropagation through ensemble
#             loss = 0.0
#             for k,v in output.items():
#                 loss += loss_fn(output[k], target)

#             epoch_loss += (loss.item()/data.shape[0])
#             correct_predictions += (preds.eq(target.data).sum().item())
#             total_samples += data.size(0)
#             epoch_accuracy = correct_predictions/total_samples
#             # print(f"Epoch : {epoch}; loss : {epoch_loss:.4f}; acc: {epoch_accuracy:.4f}", end="\r")

#     # print("total_samples", total_samples, "correct", correct_predictions)
#     epoch_loss /= len(loader)
#     return OrderedDict([('val_loss', epoch_loss), ("val_accuracy", epoch_accuracy)])

In [None]:
model_name = "la-tf++_final"
model_dir = "/home/shubham/CVP/model/"
if not os.path.exists(model_dir):
    os.mkdir(model_dir)

In [None]:
freeze_all_blocks(model)
unfreeze_block_id = 0

In [None]:
class TripletLoss(nn.Module):
    """Triplet loss with hard positive/negative mining.
    Reference:
    Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
    Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.
    Args:
        margin (float): margin for triplet.
    """
    def __init__(self, margin=0.3, mutual_flag = False):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.ranking_loss = nn.MarginRankingLoss(margin=margin)
        self.mutual = mutual_flag

    def forward(self, inputs, targets):
        """
        Args:
            inputs: feature matrix with shape (batch_size, feat_dim)
            targets: ground truth labels with shape (num_classes)
        """
        n = inputs.size(0)
        # inputs = 1. * inputs / (torch.norm(inputs, 2, dim=-1, keepdim=True).expand_as(inputs) + 1e-12)
        # Compute pairwise distance, replace by the official when merged
        dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
        dist = dist + dist.t()
        dist.addmm_(1, -2, inputs, inputs.t())
        dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
        # For each anchor, find the hardest positive and negative
        mask = targets.expand(n, n).eq(targets.expand(n, n).t())
        dist_ap, dist_an = [], []
        for i in range(n):
            dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
            dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
        dist_ap = torch.cat(dist_ap)
        dist_an = torch.cat(dist_an)
        # Compute ranking hinge loss
        y = torch.ones_like(dist_an)
        loss = self.ranking_loss(dist_an, dist_ap, y)
        if self.mutual:
            return loss, dist
        return loss

In [None]:
class CrossEntropyLabelSmooth(nn.Module):
    """Cross entropy loss with label smoothing regularizer.
    Reference:
    Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
    Equation: y = (1 - epsilon) * y + epsilon / K.
    Args:
        num_classes (int): number of classes.
        epsilon (float): weight.
    """
    def __init__(self, num_classes=62, epsilon=0.1, use_gpu=True):
        super(CrossEntropyLabelSmooth, self).__init__()
        self.num_classes = num_classes
        self.epsilon = epsilon
        self.use_gpu = use_gpu
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs, targets):
        """
        Args:
            inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
            targets: ground truth labels with shape (num_classes)
        """
        log_probs = self.logsoftmax(inputs)
        targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
        if self.use_gpu: targets = targets.cuda()
        targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
        loss = (- targets * log_probs).mean(0).sum()
        return loss

In [None]:
# loss function
# criterion = nn.CrossEntropyLoss()
# criterion = TripletLoss()
# criterion = CrossEntropyLabelSmooth()
criterion = [CrossEntropyLabelSmooth(), TripletLoss()]

# optimizer
optimizer = optim.Adam(model.parameters(),weight_decay=5e-4, lr=lr)

# # scheduler
# scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [None]:
print("training...")
# num_eps = 10
# pbar = tqdm(np.arange(num_eps).tolist())
for epoch in range(num_epochs):
    # if epoch == num_epochs//2:
    #    criterion = TripletLoss()

    if epoch % unfreeze_after == 0: # and epoch != 0:
        unfreeze_block_id += 1
        model = unfreeze_block(model, unfreeze_block_id)
        optimizer.param_groups[0]['lr'] *= lr_decay 
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        # print(f"Unfrozen Blocks: {unfreeze_block_id}, Current lr: {optimizer.param_groups[0]['lr']}, Trainable Params: {trainable_params}")

    train_metrics = train_one_epoch(epoch, model, train_loader, optimizer, criterion)
    # val_metrics = eval_one_epoch(epoch, model, valid_loader, criterion)
    ta = train_metrics['train_accuracy']
    tl = train_metrics['train_loss']
    # va = val_metrics['val_accuracy']
    # vl = val_metrics['val_loss']
    # pbar.set_description(f"Train Acc : {ta}, Train Loss : {tl}, Val Acc : {va}, Val Loss : {vl}")
    
    print(f"Epoch : {epoch}; trainacc : {ta:.4f}")
    # print(f"Epoch : {epoch}; trainacc : {ta:.4f}; valacc: {va:.4f}", end="\r")

In [None]:
save_network(model, model_dir, model_name) 
print(model_name +" saved at " + model_dir)

### Appendix

In [None]:
# vit_base.head.requires_grad_

In [None]:
# x,y = next(iter(train_loader))
# print(x.shape, y.shape)

In [None]:
# print(x.shape)
# x = vit_base.patch_embed(x)
# print(x.shape)
# print()

# print(vit_base.cls_token.shape, vit_base.pos_embed.shape)
# cls_token = vit_base.cls_token.expand(x.shape[0], -1, -1) 
# print(cls_token.shape)
# x = torch.cat((cls_token, x), dim=1)
# print(x.shape)
# x = vit_base.pos_drop(x + vit_base.pos_embed)
# print(x.shape)
# print()

# # Feed forward the x = (patch_embeddings+position_embeddings) through transformer blocks
# # for i in range(12):
# x = vit_base.blocks(x)
# x = vit_base.norm(x) # layer normalization
# print(x.shape)

In [None]:
# # extract the cls token
# cls_token_out = x[:, 0].unsqueeze(1)
# print(cls_token_out.shape)

# # Average pool
# avgpool = nn.AdaptiveAvgPool2d(output_size = (14, 768))
# print(x.shape)
# x = avgpool(x[:, 1:]) # input is 32,196,768
# print(x.shape)

In [None]:
# t = torch.tensor([1.0,2])
# print(t)
# t.requires_grad_ = True
# print(t)
# t.requires_grad = True
# print(t)