In [174]:
import torch

In [175]:
from torchvision import transforms

In [176]:
from PIL import Image

In [177]:
import numpy as np

In [178]:
import matplotlib.pyplot as plt

In [179]:
import pickle

In [180]:
from timm import models

In [181]:
from collections import OrderedDict

In [182]:
device = torch.device("cuda")
print(device)

cuda


In [183]:
# Load model checkpoint
checkpoint_path = "/home/tun78940/tcam/tcam_training/traffickcam_model_training/models/latest_checkpoint.pth.tar"
model = models.vit_base_patch16_224_in21k(pretrained=True)
model.head = torch.nn.Identity()
model.cuda()

model = torch.nn.DataParallel(model)

checkpoint = torch.load(checkpoint_path)
state_dict = checkpoint['state_dict']

model.load_state_dict(state_dict)

model.eval()

DataParallel(
  (module): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')
          (drop1): Dropout(p=0.0, inplace=False)
          (fc2): Linear(in_features=3072

In [184]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [185]:
with open("/home/tun78940/tcam/tcam_training/traffickcam_model_training/src/gallery_imgs.dat", 'rb') as f:
    gallery_paths = pickle.load(f)
    f.close()

In [None]:
gallery_embeddings = []
for path in gallery_paths:
    img = Image.open(path).convert("RGB")
    img_tensor = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        embedding = model(img_tensor)
    gallery_embeddings.append(embedding.squeeze().cpu().numpy())

gallery_embeddings = np.array(gallery_embeddings)

In [None]:
# Take a query image in the validation set and embed
query_path = "/shared/data/Traffickcam/full-apr20/03/0374/2619950_37452.jpg"

In [None]:
query_img = Image.open(query_path).convert('RGB')
query_img_tensor = transform(query_img).unsqueeze(0).to(device)
with torch.no_grad():
    query_embedding = model(query_img_tensor).squeeze().cpu().numpy()

In [None]:
# Perform nearest neighbor search
distances = np.linalg.norm(gallery_embeddings - query_embedding, axis=1)
nearest_indices = np.argsort(distances)[:5]  # Change the number 5 to the desired number of neighbors

In [None]:
# Visual inspection of query and nearest neighbors
fig, axes = plt.subplots(1, 6, figsize=(15, 5))
axes[0].imshow(query_img)
axes[0].set_title("Query")

In [None]:
for i, index in enumerate(nearest_indices):
    img = Image.open(gallery_paths[index])
    axes[i+1].imshow(img)
    axes[i+1].set_title(f"Neighbor {i+1}")

plt.tight_layout()
plt.show()