# Install requirements

In [1]:
!pip install faiss-cpu transformers

Collecting faiss-cpu
  Downloading faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.8 kB)
Downloading faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl (31.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.3/31.3 MB[0m [31m59.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.11.0


# Import library

In [2]:
import os
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
import faiss
import gradio as gr
import pickle
from transformers import ViTFeatureExtractor, ViTModel

2025-07-14 15:09:53.031391: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752505793.226193      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752505793.281006      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
# Settings
IMAGE_FOLDER = "/kaggle/input/2017-2017/train2017/train2017"
FEATURES_PATH = "vit_image_features.npy"
PATHS_PATH = "vit_image_paths.pkl"
BATCH_SIZE = 16
TOP_K = 5

device = "cuda" if torch.cuda.is_available() else "cpu"

The optimal batch size can differ between CNN-based and ViT-based models due to differences in their architecture, memory usage, and computational requirements.
CNNs (e.g., ResNet): Typically have fewer parameters and require less memory per image. They can often handle larger batch sizes on the same hardware.
ViTs (Vision Transformers): Use self-attention, which has memory and compute requirements that scale quadratically with image size and linearly with batch size. Each image is split into many patches, and attention is computed between all pairs of patches, consuming more memory.
As a result, ViTs often require smaller batch sizes to avoid running out of GPU/CPU memory.
Moreover, ViTs often use larger input sizes (e.g., 224x224 or 384x384) and more complex preprocessing, which can further increase memory usage.

# Load ViT model and feature extractor

In [4]:
# Load ViT model and feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k').to(device)
vit_model.eval()

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): ViTOutput(
          (d

# Feature extraction function

In [5]:

image_paths = [os.path.join(IMAGE_FOLDER, fname) for fname in os.listdir(IMAGE_FOLDER)
               if fname.lower().endswith(('.png', '.jpg', '.jpeg'))]

def extract_vit_features(image_paths, batch_size=BATCH_SIZE):
    features = []
    for i in tqdm(range(0, len(image_paths), batch_size), desc="Extracting ViT features"):
        batch_paths = image_paths[i:i+batch_size]
        images = [Image.open(p).convert("RGB") for p in batch_paths]
        inputs = feature_extractor(images=images, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = vit_model(**inputs)
            batch_features = outputs.last_hidden_state[:, 0, :]  # CLS token
            batch_features = batch_features / batch_features.norm(dim=-1, keepdim=True)
        features.append(batch_features.cpu().numpy())
    return np.concatenate(features, axis=0).astype("float32")

This function iterates through your collection of images in batches, preprocesses them using feature_extractor, feeds them through the vit_model to get their feature vectors, and then normalizes these features.

In [6]:
# Load or compute features
if os.path.exists(FEATURES_PATH) and os.path.exists(PATHS_PATH):
    image_features = np.load(FEATURES_PATH)
    with open(PATHS_PATH, "rb") as f:
        saved_paths = pickle.load(f)
    if set(saved_paths) != set(image_paths):
        print("Image set changed, re-extracting features...")
        image_features = extract_vit_features(image_paths)
        np.save(FEATURES_PATH, image_features)
        with open(PATHS_PATH, "wb") as f:
            pickle.dump(image_paths, f)
else:
    image_features = extract_vit_features(image_paths)
    np.save(FEATURES_PATH, image_features)
    with open(PATHS_PATH, "wb") as f:
        pickle.dump(image_paths, f)

Extracting ViT features: 100%|██████████| 7393/7393 [51:26<00:00,  2.39it/s]


# FAISS index
Initializes a FAISS index (IndexFlatIP) designed for cosine similarity search (because your features are L2-normalized) and then populates it with all the extracted features from image collection. 

In [7]:
# Build FAISS index
index = faiss.IndexFlatIP(image_features.shape[1])
index.add(image_features)

# Search functions and Gradio-based demo

In [8]:
def search_by_image_vit(query_image, top_k=TOP_K):
    image = query_image.convert("RGB")
    inputs = feature_extractor(images=image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = vit_model(**inputs)
        image_features_query = outputs.last_hidden_state[:, 0, :]
        image_features_query = image_features_query / image_features_query.norm(dim=-1, keepdim=True)
    image_features_query = image_features_query.cpu().numpy().astype("float32")
    D, I = index.search(image_features_query, top_k)
    return [image_paths[i] for i in I[0]]

def visual_search_vit(image_query):
    if image_query is not None:
        results = search_by_image_vit(image_query)
        return [Image.open(p) for p in results]
    else:
        return []

with gr.Blocks() as demo:
    gr.Markdown("# ViT Visual Search Engine (Image-to-Image)")
    image_input = gr.Image(type="pil", label="Upload an image to search")
    output_gallery = gr.Gallery(label="Top Results", columns=5, height="auto")
    search_btn = gr.Button("Search")
    search_btn.click(
        fn=visual_search_vit,
        inputs=image_input,
        outputs=output_gallery
    )

demo.launch()

* Running on local URL:  http://127.0.0.1:7860
It looks like you are running Gradio on a hosted a Jupyter notebook. For the Gradio app to work, sharing must be enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

* Running on public URL: https://5e58bc1aae84fc220f.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


