In [5]:
import json
import os
from datetime import datetime, timedelta

import cv2
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
from torchvision import models
from scipy.spatial.distance import cdist
from matplotlib import pyplot as plt

In [2]:
class VeriWildModel(nn.Module):
    def __init__(self):
        super(VeriWildModel, self).__init__()
        self.backbone = models.resnet50(pretrained=False)
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(num_features, 1000) 
    
    def forward(self, x):
        x = self.backbone(x)
        return x

model_embeddings = VeriWildModel()
model_weights_path = '/home/abhijithganesh/bmc/bmc24/models/veriwild.pth'
checkpoint = torch.load(model_weights_path)
state_dict = checkpoint['model']
model_embeddings.load_state_dict(state_dict, strict=False)
model_embeddings.eval()

  checkpoint = torch.load(model_weights_path)


VeriWildModel(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
   

In [7]:
def open_image(image_path):
    image = Image.open(image_path)
    image = image.convert('RGB')
    return np.asarray(image)

folder_path = '/home/abhijithganesh/bmc/cars/car_data/train/Chevrolet Silverado 1500 Extended Cab 2012/'
query_image_path = '/home/abhijithganesh/bmc/cars/car_data/test/Mercedes-Benz S-Class Sedan 2012/00008.jpg'

In [16]:
def get_vehicle_embeddings(image):
    img_tensor = torch.from_numpy(image).float().permute(2, 0, 1)
    with torch.no_grad():
        embeddings = model_embeddings(img_tensor)
    return embeddings.cpu().numpy().flatten().tolist()

In [19]:
get_vehicle_embeddings(open_image(query_image_path))

ValueError: expected 4D input (got 3D input)

In [None]:
def find_similar_images(query_image_path, folder_path, model, top_k=5):
    query_features = model.extract_features(query_image_path)

    similarities = []
    for filename in os.listdir(folder_path):
        if filename.endswith(('.jpg', '.jpeg', '.png')):
            image_path = os.path.join(folder_path, filename)
            image_features = model.extract_features(open_image(image_path))
            similarity = 1 - cdist(query_features, image_features, 'cosine')[0][0]
            similarities.append((filename, similarity))

    similarities.sort(key=lambda x: x[1], reverse=True)
    return similarities[:top_k]

# Function to display images
def display_images(query_image_path, similar_images, folder_path):
    plt.figure(figsize=(15, 5))
    plt.subplot(1, len(similar_images) + 1, 1)
    plt.imshow(query_image_path)
    plt.title("Query Image")
    plt.axis('off')

    for i, (image, similarity) in enumerate(similar_images, start=2):
        image_path = os.path.join(folder_path, image)
        img = Image.open(image_path)
        plt.subplot(1, len(similar_images) + 1, i)
        plt.imshow(img)
        plt.title(f"Sim: {similarity:.2f}")
        plt.axis('off')

    plt.show()

In [None]:
similar_images = find_similar_images(query_image_path, folder_path, model)
print("Top similar images:")
for image, similarity in similar_images:
    print(f"{image}: {similarity:.4f}")

In [None]:
display_images(query_image_path, similar_images, folder_path)