In [1]:
import torch
import timm
import numpy as np
import os
from PIL import Image
from torchvision import transforms
import faiss  # Efficient similarity search
import random
from sklearn.decomposition import PCA

In [5]:
# ✅ Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ✅ Load a smaller ViT model for efficiency
model = timm.create_model("vit_tiny_r_s16_p8_224", pretrained=True)  # Smaller ViT
model.head = torch.nn.Identity()  # Remove classifier layer
model = model.to(device)
model.eval()

model.safetensors:   0%|          | 0.00/25.4M [00:00<?, ?B/s]

VisionTransformer(
  (patch_embed): HybridEmbed(
    (backbone): Sequential(
      (conv): StdConv2dSame(3, 64, kernel_size=(7, 7), stride=(2, 2), bias=False)
      (norm): GroupNormAct(
        32, 64, eps=1e-05, affine=True
        (drop): Identity()
        (act): ReLU(inplace=True)
      )
      (pool): MaxPool2dSame(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0), dilation=(1, 1), ceil_mode=False)
    )
    (proj): Conv2d(64, 192, kernel_size=(8, 8), stride=(8, 8))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace

In [6]:
# ✅ Define image transformations (Keep at 224x224 to match ViT requirements)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

def extract_features(image_folder, limit=5000, batch_size=32):
    """Extracts features from a subset of images using ViT."""
    features = []
    image_names = []
    
    image_list = os.listdir(image_folder)
    random.shuffle(image_list)  # Shuffle images for randomness
    image_list = image_list[:limit]  # Limit number of images
    
    with torch.no_grad(): 
         for i in range(0, len(image_list), batch_size):
            batch_imgs = []
            batch_names = []
            for img_name in image_list[i:i + batch_size]:
                img_path = os.path.join(image_folder, img_name)
                image = Image.open(img_path).convert("RGB")
                image = transform(image).unsqueeze(0)  # Process one image at a time
                batch_imgs.append(image)
                batch_names.append(img_name)
            
            batch_imgs = torch.cat(batch_imgs).to(device)  # Convert batch to tensor
            batch_feats = model(batch_imgs).cpu().numpy()  # Extract feature vectors
            features.append(batch_feats)
            image_names.extend(batch_names)
    
    return np.vstack(features), image_names # No gradient computation to save memory
       

In [7]:
gallery_features, gallery_names = extract_features("gallery_set", limit=5000)
query_features, query_names = extract_features("query_set", limit=500)

In [8]:
# ✅ Apply PCA to reduce dimensionality (Speeds up FAISS search)
pca = PCA(n_components=64)
gallery_features = pca.fit_transform(gallery_features)
query_features = pca.transform(query_features)


In [9]:
# ✅ Use FAISS IVF index for better efficiency
d = gallery_features.shape[1]  # Feature vector dimension
nlist = 100  # Number of clusters
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)
index.train(gallery_features)
index.add(gallery_features)

In [10]:
# ✅ Find top-5 matches for each query
_, indices = index.search(query_features, k=5)

# ✅ Print some results
for i, query in enumerate(query_names[:5]):
    print(f"\nQuery Image: {query}")
    print("Top 5 Matches:", [gallery_names[idx] for idx in indices[i]])



Query Image: img_245.png
Top 5 Matches: ['img_16730.png', 'img_18322.png', 'img_8572.png', 'img_13384.png', 'img_46213.png']

Query Image: img_5098.png
Top 5 Matches: ['img_21122.png', 'img_27710.png', 'img_2616.png', 'img_20535.png', 'img_38021.png']

Query Image: img_7677.png
Top 5 Matches: ['img_34488.png', 'img_25734.png', 'img_7232.png', 'img_15097.png', 'img_42348.png']

Query Image: img_5248.png
Top 5 Matches: ['img_25510.png', 'img_4248.png', 'img_41735.png', 'img_41681.png', 'img_36160.png']

Query Image: img_545.png
Top 5 Matches: ['img_21685.png', 'img_32256.png', 'img_23148.png', 'img_5184.png', 'img_23003.png']
