# Instance Search Project

In [None]:
%load_ext autoreload
%autoreload 2

import os

from src.helper.InstanceSearch import InstanceSearch
from src.helper.FeatureExtractor import FeatureExtractor

import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from scipy.spatial import distance

import cv2

In [None]:
config = {
    "query_path": "src/files/datasets/query/",
    "query_box_path": "src/files/datasets/query_txt/",
    "gallery_path": "src/files/datasets/gallery_4186/",
    "feature_path": "src/files/features/",
}

In [None]:
instanceSearch = InstanceSearch(config)

In [None]:
print(f"Loaded query files: {len(instanceSearch.query_ids)}")
print(f"Loaded query boxes: {len(instanceSearch.query_boxes)}")

## VGG16

In [None]:
from src.helper.ModelWrapper import VGG16Extractor

In [None]:
fe_vgg16 = FeatureExtractor(VGG16Extractor)

In [None]:
fe_vgg16.extract_batch(source_path=config['gallery_path'], target_path=config['feature_path'])

In [None]:
instanceSearch.set_feature_extractor(fe_vgg16)

In [None]:
res = instanceSearch.search(query_id = '27', k=40, plot=True)

In [None]:
res = instanceSearch.search_all()

## VGG19

In [None]:
from src.helper.ModelWrapper import VGG19Extractor

In [None]:
fe_vgg19 = FeatureExtractor(VGG19Extractor)

In [None]:
fe_vgg19.extract_batch(source_path=config['gallery_path'], target_path=config['feature_path'])

In [None]:
instanceSearch.set_feature_extractor(fe_vgg19)

In [None]:
instanceSearch.feature_extractor.model.name

In [None]:
res = instanceSearch.search(query_id = '27', k=40, plot=True)

In [None]:
res = instanceSearch.search_all()

## CLIP

In [None]:
from src.helper.ModelWrapper import CLIPExtractor

In [None]:
fe_clip = FeatureExtractor(CLIPExtractor)

In [None]:
fe_clip.extract_batch(source_path=config['gallery_path'], target_path=config['feature_path'])

In [None]:
instanceSearch.set_feature_extractor(fe_clip)

In [None]:
instanceSearch.feature_extractor.model.name

In [None]:
res = instanceSearch.search(query_id = '27', k=150, plot=True)

In [None]:
res = instanceSearch.search_all()

## CLIP 336

In [None]:
from src.helper.ModelWrapper import CLIPExtractor_336

In [None]:
fe_clip_336 = FeatureExtractor(CLIPExtractor_336)

In [None]:
fe_clip_336.extract_batch(source_path=config['gallery_path'], target_path=config['feature_path'])

In [None]:
instanceSearch.set_feature_extractor(fe_clip_336)

In [None]:
instanceSearch.feature_extractor.model.name

In [None]:
res = instanceSearch.search(query_id = '27', k=40, distance='euclidean', plot=True)

In [None]:
res = instanceSearch.search_all()

## Google VIT

In [None]:
from src.helper.ModelWrapper import ViTExtractor

In [None]:
fe_vit = FeatureExtractor(ViTExtractor)

In [None]:
fe_vit.extract_batch(source_path=config['gallery_path'], target_path=config['feature_path'])

In [None]:
instanceSearch.set_feature_extractor(fe_vit)

In [None]:
instanceSearch.feature_extractor.model.name

In [None]:
res = instanceSearch.search(query_id = '27', k=40, plot=True, distance='euclidean', query_expansion=True)

In [None]:
res = instanceSearch.search_all()

## Dino V2

In [None]:
from src.helper.ModelWrapper import DinoV2Extractor

In [None]:
fe_dino = FeatureExtractor(DinoV2Extractor)

In [None]:
fe_dino.extract_batch(source_path=config['gallery_path'], target_path=config['feature_path'])

In [None]:
instanceSearch.set_feature_extractor(fe_dino)

In [None]:
instanceSearch.feature_extractor.model.name

In [None]:
res = instanceSearch.search(query_id = '27', k=40, plot=True, distance='cosine', query_augmentation=True)

In [None]:
res = instanceSearch.search_all()

## Reranking using local feature matching LightGlue + SuperPoint

In [None]:
from lightglue import LightGlue, SuperPoint, DISK
from lightglue.utils import load_image, rbd
from lightglue import viz2d
import torch

In [None]:
device = torch.device("mps" if torch.cuda.is_available() else "cpu") 

extractor = SuperPoint(max_num_keypoints=1024).eval().to(device)
matcher = LightGlue(features='superpoint', depth_confidence=0.9, width_confidence=0.95).eval().to(device)

In [None]:
from tqdm import tqdm
# For each query image, show the top-20 most similar images
for query_filename in tqdm(instanceSearch.query_filenames):
    print("Query", query_filename, "...")
    
    query_id = query_filename.split('.')[0]
    
    query = fe_clip.extract(config['query_path'] + query_id + '.jpg', bounding_box=instanceSearch.query_boxes[query_id])
    
    distances = {}
    for i in tqdm(os.listdir('src/files/features_clip_L14/')):
        if i.split('.')[-1] != 'npy':
            continue
        feature = np.load('src/files/features_clip_L14/' + i)
        distances[i.replace('.npy', '.jpg')] = distance.euclidean(query, feature)

    distances = dict(sorted(distances.items(), key=lambda item: item[1]))
    
    # Rerank using number of matches between the query and the gallery images
    
    reranked_distances = {}
    for image in tqdm(list(distances.keys())[:40]):
        img = load_image(config['gallery_path'] + image)
        query_image = load_image(config['query_path'] + query_id + '.jpg')
        
        x, y, w, h = instanceSearch.query_boxes[query_id] # x, y, w, h
        cropped_query_image = query_image[:, y:y+h, x:x+w]
        
        # Extract keypoints and descriptors
        feats0 = extractor.extract(img.to(device))
        feats1 = extractor.extract(cropped_query_image.to(device))
        matches01 = matcher({"image0": feats0, "image1": feats1})
        feats0, feats1, matches01 = [
            rbd(x) for x in [feats0, feats1, matches01]
        ]  # remove batch dimension

        kpts0, kpts1, matches = feats0["keypoints"], feats1["keypoints"], matches01["matches"]
        reranked_distances[image] = len(matches)
        
    reranked_distances = dict(sorted(reranked_distances.items(), key=lambda item: item[1], reverse=True))
    
    # Show the top-20 most similar images
    plt.figure(figsize=(23, 4))
    plt.subplot(2, 12, 1)
    plt.imshow(Image.open(config['query_path'] + query_id + '.jpg'))
    plt.title(f'Query Image {query_id}')
    plt.axis('off')
    
    for i, (filename, dist) in enumerate(reranked_distances.items()):
        if i >= 20:
            break
        img = Image.open(config['gallery_path'] + filename)
        plt.subplot(2, 12, i+3)
        plt.imshow(img)
        plt.title(f'{filename} ({dist:.4f})')
        plt.axis('off')
        
    plt.show()

## Reranking using local feature matching SIFT

In [None]:
# For each query image, show the top-20 most similar images
for query_id in tqdm(instanceSearch.query_ids):
    print("Query", query_id, "...")
    
    query = fe_clip.extract(config['query_path'] + query_id + '.jpg', bounding_box=instanceSearch.query_boxes[query_id])
    
    distances = {}
    for i in tqdm(os.listdir('src/files/features/features_clip_L14/')):
        if i.split('.')[-1] != 'npy':
            continue
        feature = np.load('src/files/features/features_clip_L14/' + i)
        distances[i.replace('.npy', '.jpg')] = distance.euclidean(query, feature)

    distances = dict(sorted(distances.items(), key=lambda item: item[1]))
    
    # Rerank using number of matches using SIFT
    
    reranked_distances = {}
    for image in tqdm(list(distances.keys())[:40]):
        img = Image.open(config['gallery_path'] + image)
        query_image = Image.open(config['query_path'] + query_id + '.jpg')
        
        x, y, w, h = instanceSearch.query_boxes[query_id]
        cropped_query_image = query_image.crop([x, y, x + w, y + h])
        
        img = np.array(img)
        cropped_query_image = np.array(cropped_query_image)
        
        # Extract keypoints and descriptors
        sift = cv2.SIFT_create()
        kp1, des1 = sift.detectAndCompute(img, None)
        kp2, des2 = sift.detectAndCompute(cropped_query_image, None)
        
        # Match descriptors
        bf = cv2.BFMatcher()
        matches = bf.knnMatch(des1, des2, k=2)
        
        # Apply ratio test
        good = []
        for m, n in matches:
            if m.distance < 0.75 * n.distance:
                good.append([m])
                
        reranked_distances[image] = len(good)/len(kp1)
        
    reranked_distances = dict(sorted(reranked_distances.items(), key=lambda item: item[1], reverse=True))
    
    # Show the top-20 most similar images
    plt.figure(figsize=(23, 4))
    plt.subplot(2, 12, 1)
    plt.imshow(Image.open(config['query_path'] + query_id + '.jpg'))
    plt.title(f'Query Image {query_id}')
    plt.axis('off')
    
    for i, (filename, dist) in enumerate(reranked_distances.items()):
        if i >= 20:
            break
        img = Image.open(config['gallery_path'] + filename)
        plt.subplot(2, 12, i+3)
        plt.imshow(img)
        plt.title(f'{filename} ({dist:.4f})')
        plt.axis('off')
        
    plt.show()

In [None]:
def data_aug(img):
    # perspective change to augment query data
    transformed_imgs = [img]

    height, width, channel = img.shape
    pts = np.float32([[0, 0], [width, 0], [width, height], [0, height]])

    # eight different perspective transformations
    pts1 = np.float32([[0, 0], [0.8 * width, height * 0.025], [0.8 * width, 0.975 * height], [0, height]])
    pts2 = np.float32([[0.2 * width, 0.025 * height], [width, 0], [width, height], [0.2 * width, 0.975 * height]])
    pts3 = np.float32([[0, 0], [width, 0], [0.975 * width, 0.8 * height], [0.025 * width, 0.8 * height]])
    pts4 = np.float32([[0.025 * width, 0.2 * height], [0.975 * width, 0.2 * height], [width, height], [0, height]])

    pts5 = np.float32([[0, 0], [0.6 * width, height * 0.1], [0.6 * width, 0.9 * height], [0, height]])
    pts6 = np.float32([[0.4 * width, 0.1 * height], [width, 0], [width, height], [0.4 * width, 0.9 * height]])
    pts7 = np.float32([[0, 0], [width, 0], [0.9 * width, 0.6 * height], [0.1 * width, 0.6 * height]])
    pts8 = np.float32([[0.1 * width, 0.4 * height], [0.9 * width, 0.4 * height], [width, height], [0, height]])
    
    

    all_target_pts = [pts1, pts2, pts3, pts4, pts5, pts6, pts7, pts8]

    for idx, target in enumerate(all_target_pts):
        # compute the perspective transform matrix and then apply it
        M = cv2.getPerspectiveTransform(pts, target)
        transformed = cv2.warpPerspective(img, M, (width, height))
        transformed_imgs.append(transformed)
        
    return transformed_imgs

In [None]:
img = Image.open(config['query_path'] + '27.jpg')

x, y, w, h = instanceSearch.query_boxes['27']
cropped_img = img.crop([x, y, x + w, y + h])

In [None]:
aug_imgs = data_aug(np.array(cropped_img))

for i in aug_imgs:
    plt.imshow(i)
    plt.axis('off')
    plt.show()

## CLIP + Query Expansion

In [None]:
from src.helper.ModelWrapper import CLIPExtractor

In [None]:
fe_clip = FeatureExtractor(CLIPExtractor)

In [None]:
instanceSearch.set_feature_extractor(fe_clip)

In [None]:
instanceSearch.feature_extractor.model.name

In [None]:
res = instanceSearch.search(query_id = '27', k=150, plot=True, query_expansion=True)

In [None]:
res = instanceSearch.search_all()

In [None]:
# Min
res = instanceSearch.search_all(query_expansion=True)

In [None]:
# Mean
res = instanceSearch.search_all(query_expansion=True)

In [None]:
res = instanceSearch.search_all(query_augmentation=True, query_expansion=True)

In [None]:
res = instanceSearch.search_all(query_augmentation=True, query_expansion=True)

## CLIP 336 + QA + QE

In [None]:
from src.helper.ModelWrapper import CLIPExtractor_336

In [None]:
fe_clip_336 = FeatureExtractor(CLIPExtractor_336)

In [None]:
instanceSearch.set_feature_extractor(fe_clip_336)

In [None]:
instanceSearch.feature_extractor.model.name

In [None]:
res = instanceSearch.search(query_id = '27', k=40, distance='euclidean', plot=True, query_expansion=True, query_augmentation=True)

In [None]:
res = instanceSearch.search_all(query_expansion=True, query_augmentation=True)

## CLIP 336 + QA 2.0 + QE

In [None]:
from src.helper.ModelWrapper import CLIPExtractor_336

In [None]:
fe_clip_336 = FeatureExtractor(CLIPExtractor_336)

In [None]:
instanceSearch.set_feature_extractor(fe_clip_336)

In [None]:
instanceSearch.feature_extractor.model.name

In [None]:
res = instanceSearch.search(query_id = '27', k=40, distance='euclidean', plot=True, query_augmentation=True, query_expansion=True, query_expansion_2=True)

In [None]:
res = instanceSearch.search_all(query_expansion=True, query_augmentation=True)

## VGG19 with Query Expansion

In [None]:
from src.helper.ModelWrapper import VGG19Extractor

In [None]:
fe_vgg19 = FeatureExtractor(VGG19Extractor)

In [None]:
fe_vgg19.extract_batch(source_path=config['gallery_path'], target_path=config['feature_path'])

In [None]:
instanceSearch.set_feature_extractor(fe_vgg19)

In [None]:
instanceSearch.feature_extractor.model.name

In [None]:
res = instanceSearch.search(query_id = '27', k=40, plot=True, query_expansion=True, query_augmentation=True)

In [None]:
res = instanceSearch.search_all(query_augmentation=True)

## Combined

In [None]:
from tqdm import tqdm

In [None]:
from src.helper.ModelWrapper import VGG19Extractor, CLIPExtractor_336
fe_clip_336 = FeatureExtractor(CLIPExtractor_336)
fe_vgg19 = FeatureExtractor(VGG19Extractor)

class InstanceSearchCombined(InstanceSearch):
    
    def search(self, 
               query_id: str, 
               k: int = 10, 
               plot = False, 
               distance='euclidean', 
               query_augmentation: bool = False,
               query_expansion: bool = False,
               query_expansion_2: bool = False) -> list:
        
        

        queries_vgg19 = fe_vgg19.extract(img_path=self.config['query_path'] + query_id + '.jpg', 
                                                 bounding_box=self.query_boxes[query_id],
                                                 aug=query_augmentation)
        
        queries_clip = fe_clip_336.extract(img_path=self.config['query_path'] + query_id + '.jpg', 
                                                 bounding_box=self.query_boxes[query_id],
                                                 aug=query_augmentation)
        
        distance_calculator = self.get_distance_calculator(distance)
        
        feature_path_vgg19 = os.path.join(self.config['feature_path'], ('features_' + fe_vgg19.model.name))
        feature_path_clip = os.path.join(self.config['feature_path'], ('features_' + fe_clip_336.model.name))
        
        distances = {}
        for i, j in tqdm(zip(os.listdir(feature_path_vgg19), os.listdir(feature_path_clip))):
            
            feature_vgg19 = np.load(os.path.join(feature_path_vgg19, i))
            
            d = []
            for query in queries_vgg19: 
                d.append(distance_calculator(query, feature_vgg19))
            distance_vgg19 = np.mean(d)
            
            feature_clip = np.load(os.path.join(feature_path_clip, j))
            
            d = []
            for query in queries_clip:
                d.append(distance_calculator(query, feature_clip))
            distance_clip = np.mean(d)
                                
            distances[i.split('.')[0]] = (distance_vgg19 + distance_clip) / 2

        distances = dict(sorted(distances.items(), key=lambda item: item[1]))
        
        feature_path = os.path.join(self.config['feature_path'], ('features_' + self.feature_extractor.model.name))
        
        new_distances = {}
        
        if query_expansion:
            
            # Get the second closest image
            second_closest_ids = list(distances.keys())[:10]
            
            # Query using the second closest image
            second_query = []
            for id in second_closest_ids:
                second_query.extend(self.feature_extractor.extract(img_path=self.config['gallery_path'] + id + '.jpg'))
            
            for i in tqdm(os.listdir(feature_path)):
                
                if i.split('.')[-1] != 'npy':
                    continue
                feature = np.load(os.path.join(feature_path, i))
                
                d = []
                for query in second_query:
                    d.append(distance_calculator(query, feature))
                
                # Get mean distance
                d.append(distances[i.split('.')[0]])
                new_distances[i.split('.')[0]] = np.mean(d)
                
            new_distances = dict(sorted(new_distances.items(), key=lambda item: item[1]))
            distances = new_distances


        if plot:
            
            num_rows = (k + 2) // 12 + 1
            num_cols = 12

            figsize = (num_cols * 2.25, num_rows * 2)
            
            plt.figure(figsize=figsize)
            plt.subplot(num_rows, num_cols, 1)
            plt.imshow(Image.open(self.config['query_path'] + query_id + '.jpg'))
            plt.title(f'Query Image {query_id}')
            plt.axis('off')

            for i, (query_id, dist) in enumerate(distances.items()):
                if i >= k:
                    break
                img = Image.open(self.config['gallery_path'] + query_id + '.jpg')
                plt.subplot(num_rows, 12, i+3)
                plt.imshow(img)
                plt.title(f'{query_id} ({dist:.3f})')
                plt.axis('off')

            plt.show()
            
        return list(distances.keys())[:k] 

In [None]:
instanceSearchCombined = InstanceSearchCombined(config)

In [None]:
fe_clip_336 = FeatureExtractor(CLIPExtractor_336)
instanceSearchCombined.set_feature_extractor(fe_clip_336)

In [None]:
res = instanceSearchCombined.search(query_id = '27', k=40, distance='cosine', plot=True, query_expansion=True, query_augmentation=True)

In [None]:
res = instanceSearchCombined.search_all(query_augmentation=True)

In [None]:
res = instanceSearchCombined.search_all(query_augmentation=True, query_expansion=True)