In [2]:
import os
import time
import random

import timm
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
# from tqdm import tqdm
from collections import OrderedDict

import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
from torchvision import models
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset

from model import ClassBlock, LATransformer, LATransformerTest
from utils import save_network, update_summary, get_id

## Hyper parameters

In [3]:
seed = 42

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [4]:
# os.environ['CUDA_VISIBLE_DEVICES']='0'
# device = "cpu" 
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 8
lr = 3e-4
gamma = 0.7
lmbd = 8

## DataLoader

In [5]:
transform_query_list = [
    transforms.Resize((224,224), interpolation=3),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]
transform_gallery_list = [
    transforms.Resize(size=(224,224),interpolation=3), #Image.BICUBIC
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]
data_transforms = {
    'query': transforms.Compose( transform_query_list ),
    'gallery': transforms.Compose(transform_gallery_list),
}

  "Argument interpolation should be of type InterpolationMode instead of int. "


In [6]:
image_datasets = {}
data_dir = "/home/shubham/CVP/data/val"
# data_dir = "/home/shubham/CVP/test"

image_datasets['query'] = datasets.ImageFolder(os.path.join(data_dir, 'query'),
                                          data_transforms['query'])
image_datasets['gallery'] = datasets.ImageFolder(os.path.join(data_dir, 'gallery'),
                                          data_transforms['gallery'])
query_loader = DataLoader(dataset = image_datasets['query'], batch_size=batch_size, shuffle=False )
gallery_loader = DataLoader(dataset = image_datasets['gallery'], batch_size=batch_size, shuffle=False)

class_names = image_datasets['query'].classes
print(len(class_names))

12


## Model

In [7]:
class LATransformer(nn.Module):
    def __init__(self, ViT, lmbd, num_classes=751, test=False):
        super(LATransformer, self).__init__()
        self.test = test
        self.class_num = num_classes # output number of classes
        
        # ViT model
        self.model = ViT
        self.model.head.requires_grad_ = False 
        self.cls_token = self.model.cls_token # 1, 1, 768
        self.pos_embed = self.model.pos_embed # 1, 197, 768

        # these are ViT model internal hyper-parameters (FIXED) 
        # self.num_blocks = 12 # number of sequential blocks in ViT
        
        # there are 196 patches in each image; thus, we split them into 14 x 14 grid
        self.num_rows = 14 
        self.num_cols = 14

        # Locally aware network
        self.avgpool = nn.AdaptiveAvgPool2d((self.num_rows,768))
        self.lmbd = lmbd

        if not self.test:
            # ensemble of classifiers
            for i in range(self.num_rows):
                name = 'classifier'+str(i)
                setattr(self, name, FC_Classifier(input_dim=768, num_classes=self.class_num, droprate=0.5, num_bottleneck=256, return_features=False))

    def forward(self, x):
        # x shape = 32, 3, 224, 224
        
        # Divide input image into patch embeddings and add position embeddings
        x = self.model.patch_embed(x) # 32, 196, 768
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # 32, 1, 768
        x = torch.cat((cls_token, x), dim=1) # 32, 197, 768
        trnsfrmr_inp = self.model.pos_drop(x + self.pos_embed) # dropout with p = 0; idk!
        
        # Feed forward the x = (patch_embeddings+position_embeddings) through transformer blocks
        # for i in range(self.num_blocks):
        # x = self.model.blocks[i](x)
        x = self.model.blocks(trnsfrmr_inp)
        x_trnsfrmr_encdd = self.model.norm(x) # layer normalization; shape = 32, 197, 768
        
        # extract the cls token
        cls_token_out = x_trnsfrmr_encdd[:, 0].unsqueeze(1)
        
        # Average pool
        Q = x_trnsfrmr_encdd[:, 1:]
        L = self.avgpool(Q) # 32, 14, 768
        
        if self.test:
            return L
        
        # Add global cls token to each local token 
        for i in range(self.num_rows):
            out = torch.mul(L[:, i, :], self.lmbd)
            L[:,i,:] = torch.div(torch.add(cls_token_out.squeeze(),out), 1+self.lmbd)
        
        # Locally aware network
        part = {}
        predict = {}
        for i in range(self.num_rows):
            part[i] = L[:,i,:] # 32, 768
            name = 'classifier'+str(i)
            c = getattr(self, name)
            predict[i] = c(part[i]) # 32, 751
        return predict

## Load Model

In [8]:
# Load ViT
vit_base = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=751)
vit_base = vit_base.to(device)

# Create La-Transformer
model = LATransformer(vit_base, lmbd=lmbd, num_classes=123, test=True).to(device)

# Load LA-Transformer
save_path = "/home/shubham/CVP/model/la-tranformer_best.pth"
model.load_state_dict(torch.load(save_path), strict=False)
model.eval()

LATransformer(
  (model): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): Block(
      

##  Extract Features

In [9]:
def extract_feature(model, dataloaders):
    features = torch.FloatTensor()
    for data in tqdm(dataloaders):
        img, label = data
        img, label = img.to(device), label.to(device)

        output = model(img)
        features = torch.cat((features, output.detach().cpu()), 0)
    return features

In [10]:
# Extract Query Features
query_feature = extract_feature(model, query_loader)

# Extract Gallery Features
gallery_feature = extract_feature(model, gallery_loader)

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/23 [00:00<?, ?it/s]

In [11]:
# Retrieve labels
gallery_path = image_datasets['gallery'].imgs
query_path = image_datasets['query'].imgs

In [12]:
# gallery_path

In [13]:
def get_id(img_path):
    camera_id = []
    labels = []
    for path, label in img_path:
        cam_id = int(path.split("/")[-1].split("_")[0])
#         filename = os.path.basename(path)
#         camera = filename.split('_')[0]
        labels.append(int(label))
        camera_id.append(cam_id)
    return camera_id, labels

gallery_cam, gallery_label = get_id(gallery_path)
query_cam, query_label = get_id(query_path)

## Concat Averaged GELTs

In [14]:
concatenated_query_vectors = []
for query in tqdm(query_feature):
    fnorm = torch.norm(query, p=2, dim=1, keepdim=True)*np.sqrt(14)
    query_norm = query.div(fnorm.expand_as(query))
    concatenated_query_vectors.append(query_norm.view((-1))) # 14*768 -> 10752

concatenated_gallery_vectors = []
for gallery in tqdm(gallery_feature):
    fnorm = torch.norm(gallery, p=2, dim=1, keepdim=True) *np.sqrt(14)
    gallery_norm = gallery.div(fnorm.expand_as(gallery))
    concatenated_gallery_vectors.append(gallery_norm.view((-1))) # 14*768 -> 10752

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/181 [00:00<?, ?it/s]

## Calculate Similarity using FAISS

In [15]:
import faiss
import numpy as np

index = faiss.IndexIDMap(faiss.IndexFlatIP(10752))
index.add_with_ids(np.array([t.numpy() for t in concatenated_gallery_vectors]),np.array(gallery_label))

def search(query: str, k=1):
    encoded_query = query.unsqueeze(dim=0).numpy()
    top_k = index.search(encoded_query, k)
    return top_k

## Evaluation Metrics

In [16]:
def rank1(label, output):
    if label==output[1][0][0]:
        return True
    return False

def rank5(label, output):
    if label in output[1][0][:5]:
        return True
    return False

def rank10(label, output):
    if label in output[1][0][:10]:
        return True
    return False

def calc_map(label, output):
    count = 0
    score = 0
    good = 0
    for out in output[1][0]:
        count += 1
        if out==label:
            good += 1            
            score += (good/count)
    if good==0:
        return 0
    return score/good

In [17]:
rank1_score = 0
rank5_score = 0
rank10_score = 0
ap = 0
count = 0
for query, label in zip(concatenated_query_vectors, query_label):
    count += 1
    label = label
    output = search(query, k=10)
#     print(output)
    rank1_score += rank1(label, output) 
    rank5_score += rank5(label, output) 
    rank10_score += rank10(label, output) 
    print("Correct: {}, Total: {}, Incorrect: {}".format(rank1_score, count, count-rank1_score), end="\r")
    ap += calc_map(label, output)

print("Rank1: {}, Rank5: {}, Rank10: {}, mAP: {}".format(rank1_score/len(query_feature), 
                                                         rank5_score/len(query_feature), 
                                                         rank10_score/len(query_feature), ap/len(query_feature)))    

Correct: 0, Total: 1, Incorrect: 1Correct: 1, Total: 2, Incorrect: 1Correct: 2, Total: 3, Incorrect: 1Correct: 3, Total: 4, Incorrect: 1Correct: 4, Total: 5, Incorrect: 1Correct: 5, Total: 6, Incorrect: 1Correct: 6, Total: 7, Incorrect: 1Correct: 7, Total: 8, Incorrect: 1Correct: 8, Total: 9, Incorrect: 1Correct: 9, Total: 10, Incorrect: 1Correct: 10, Total: 11, Incorrect: 1Correct: 11, Total: 12, Incorrect: 1Correct: 12, Total: 13, Incorrect: 1Correct: 13, Total: 14, Incorrect: 1Correct: 14, Total: 15, Incorrect: 1Correct: 15, Total: 16, Incorrect: 1Correct: 16, Total: 17, Incorrect: 1Correct: 17, Total: 18, Incorrect: 1Correct: 18, Total: 19, Incorrect: 1Correct: 19, Total: 20, Incorrect: 1Correct: 20, Total: 21, Incorrect: 1Correct: 21, Total: 22, Incorrect: 1Correct: 22, Total: 23, Incorrect: 1Correct: 22, Total: 24, Incorrect: 2Correct: 22, Total: 25, Incorrect: 3Correct: 23, Total: 26, Incorrect: 3Correct: 24, Total: 27, Incorrect: 3Correct: 25, Total: 

### experimental results

1. seed:42, epochs:10, Rank1: 0.9285714285714286, Rank5: 0.9285714285714286, Rank10: 0.9642857142857143, mAP: 0.9170920335726883
2. seed:42, epochs:15, Rank1: 0.9285714285714286, Rank5: 0.9642857142857143, Rank10: 0.9642857142857143, mAP: 0.9152374822283411
3. seed:0, epochs:15, Rank1: 0.8928571428571429, Rank5: 0.9285714285714286, Rank10: 1.0, mAP: 0.8852698817622288
