In [67]:
from __future__ import print_function

import os
import time
import glob
import random
import zipfile
from itertools import chain

import timm
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from collections import OrderedDict
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
from torchvision import models
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from torchvision.transforms import functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset

from LATransformer.model import ClassBlock, LATransformer, LATransformerTest
from LATransformer.utils import save_network, update_summary, get_id
from LATransformer.metrics import rank1, rank5, rank10, calc_map

os.environ['CUDA_VISIBLE_DEVICES']='0'
device = 'cuda'

## Config Parameters

In [68]:
batch_size = 8
gamma = 0.7
seed = 42

## Load Model

In [72]:
# Load ViT
vit_base = timm.create_model('vit_small_patch16_224', pretrained=True, num_classes=751)
vit_base= vit_base.to(device)

# Create La-Transformer
model = LATransformerTest(vit_base, lmbd=8).to(device)

# Load LA-Transformer
name = "occluded_duke_wo_augm"
save_path = os.path.join('./weights',name,'model_state_dict_small.pth')
model.load_state_dict(torch.load(save_path), strict=False)
model.eval()

LATransformerTest(
  (model): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU()
          (d

# Calculation of GFLOPs

In [77]:
import torch
import torchvision.models as models
from thop import profile

# Load the pretrained ResNet-50 model
#model = models.resnet101(pretrained=True)

# Set the input size (224x224) and the batch size
input_size = 224
batch_size = 1

# Create a random input tensor with the appropriate size
input_tensor = torch.randn(batch_size, 3, input_size, input_size).to(device)

# Calculate the number of FLOPs and parameters in the model
flops, params = profile(model, inputs=(input_tensor,))
print("Parameters: ",params)
# Convert FLOPs to GFLOPS (billions of floating-point operations per second)
gflops = flops / 1e9

print(f"Total FLOPs: {flops:.2f} FLOPs")
print(f"GFLOPS: {gflops:.2f} GFLOPS")

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
Parameters:  21589632.0
Total FLOPs: 12745222272.00 FLOPs
GFLOPS: 12.75 GFLOPS




### DataLoader

In [4]:
transform_query_list = [
    transforms.Resize((224,224), interpolation=3),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.RandomErasing(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]
transform_gallery_list = [
    transforms.Resize(size=(224,224),interpolation=3), #Image.BICUBIC
    transforms.ToTensor(),
    transforms.RandomErasing(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]
data_transforms = {
'query': transforms.Compose( transform_query_list ),
'gallery': transforms.Compose(transform_gallery_list),
}



In [5]:
image_datasets = {}
data_dir = " " #import data folder

image_datasets['query'] = datasets.ImageFolder(os.path.join(data_dir, 'query'),
                                          data_transforms['query'])
image_datasets['gallery'] = datasets.ImageFolder(os.path.join(data_dir, 'gallery'),
                                          data_transforms['gallery'])
query_loader = DataLoader(dataset = image_datasets['query'], batch_size=batch_size, shuffle=False )
gallery_loader = DataLoader(dataset = image_datasets['gallery'], batch_size=batch_size, shuffle=False)

class_names = image_datasets['query'].classes
print(len(class_names))

750


###  Extract Features

In [6]:
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

In [7]:
def extract_feature(model,dataloaders):
    
    features =  torch.FloatTensor()
    count = 0
    idx = 0
    for data in tqdm(dataloaders):
        img, label = data
        img, label = img.to(device), label.to(device)

        output = model(img)

        n, c, h, w = img.size()
        
        count += n
        features = torch.cat((features, output.detach().cpu()), 0)
        idx += 1
    return features

In [8]:
# Extract Query Features
query_feature= extract_feature(model, query_loader)

# Extract Gallery Features
gallery_feature = extract_feature(model, gallery_loader)

  0%|          | 0/421 [00:00<?, ?it/s]

  0%|          | 0/1990 [00:00<?, ?it/s]

In [9]:
# Retrieve labels
gallery_path = image_datasets['gallery'].imgs
query_path = image_datasets['query'].imgs

gallery_cam,gallery_label = get_id(gallery_path)
query_cam,query_label = get_id(query_path)

## Concat Averaged GELTs

In [10]:
concatenated_query_vectors = []
for query in tqdm(query_feature):
   
    fnorm = torch.norm(query, p=2, dim=1, keepdim=True)*np.sqrt(14)
   
    query_norm = query.div(fnorm.expand_as(query))
    
    concatenated_query_vectors.append(query_norm.view((-1))) # 14*768 -> 10752

concatenated_gallery_vectors = []
for gallery in tqdm(gallery_feature):
   
    fnorm = torch.norm(gallery, p=2, dim=1, keepdim=True) *np.sqrt(14)
   
    gallery_norm = gallery.div(fnorm.expand_as(gallery))
    
    concatenated_gallery_vectors.append(gallery_norm.view((-1))) # 14*768 -> 10752
  

  0%|          | 0/3368 [00:00<?, ?it/s]

  0%|          | 0/15913 [00:00<?, ?it/s]

## Calculate Similarity using FAISS

In [11]:
import faiss
import numpy as np

In [12]:
#index = faiss.IndexIDMap(faiss.IndexFlatIP(10752))

index = faiss.IndexIDMap(faiss.IndexFlatIP(10752))
index2 = faiss.IndexIDMap(index)


encoded_data = np.asarray([t.numpy() for t in concatenated_gallery_vectors]).astype('float32')
print(encoded_data.shape)
ids = np.array(gallery_label)
print(ids.shape)
ids = np.asarray(ids.astype('int64'))
print(ids.shape)
#"IDMap,Flat"
index.add_with_ids(encoded_data,ids)
#print("encoded_data:", ids)

(15913, 2688)
(15913,)
(15913,)


AssertionError: 

In [14]:
#np.asarray(encoded_data).astype('float32')
def search(query: str, k=1):
    encoded_query = query.unsqueeze(dim=0).numpy()
    top_k = index.search(encoded_query, k)
    return top_k

In [15]:
rank1_score = 0
rank5_score = 0
rank10_score = 0
ap = 0
count = 0
for query, label in zip(concatenated_query_vectors, query_label):
    count += 1
    label = label
    output = search(query, k=10)
    rank1_score += rank1(label, output) 
    rank5_score += rank5(label, output) 
    rank10_score += rank10(label, output) 
    print("Correct: {}, Total: {}, Incorrect: {}".format(rank1_score, count, count-rank1_score), end="\r")
    ap += calc_map(label, output)

print("Rank1: {}, Rank5: {}, Rank10: {}, mAP: {}".format(rank1_score/len(query_feature), 
                                                         rank5_score/len(query_feature), 
                                                         rank10_score/len(query_feature), ap/len(query_feature)))    

AssertionError: 