# Install requirements

In [1]:
!pip install faiss-cpu

Collecting faiss-cpu
  Downloading faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.8 kB)
Downloading faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl (31.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.3/31.3 MB[0m [31m57.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.11.0


# Import library

In [2]:
import os
import torch
import torchvision.models as models
import torchvision.transforms as T
from PIL import Image
import numpy as np
import faiss
from tqdm import tqdm
import gradio as gr
import pickle

In [3]:
IMAGE_FOLDER = "/kaggle/input/2017-2017/train2017/train2017"
FEATURES_PATH = "cnn_image_features.npy"
PATHS_PATH = "cnn_image_paths.pkl"
BATCH_SIZE = 32
TOP_K = 5

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load CNN model(resnet50)
Remove the last layer(fully connected layer)
And process images 

In [4]:
# Load CNN
cnn_model = models.resnet50(pretrained=True)
cnn_model = torch.nn.Sequential(*(list(cnn_model.children())[:-1]))
cnn_model.eval()
cnn_model.to(device)

cnn_preprocess = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 187MB/s] 


# Feature extraction function
Iterates through image_paths in batches, loads and preprocesses each image, feeds it through the cnn_model to get its feature vector, and then normalizes these features.

In [5]:

image_paths = [os.path.join(IMAGE_FOLDER, fname) for fname in os.listdir(IMAGE_FOLDER)
               if fname.lower().endswith(('.png', '.jpg', '.jpeg'))]

def extract_cnn_features(image_paths, batch_size=BATCH_SIZE):
    features = []
    for i in tqdm(range(0, len(image_paths), batch_size), desc="Extracting CNN features"):
        batch_paths = image_paths[i:i+batch_size]
        images = [cnn_preprocess(Image.open(p).convert("RGB")).unsqueeze(0) for p in batch_paths]
        images = torch.cat(images).to(device)
        with torch.no_grad():
            batch_features = cnn_model(images).squeeze(-1).squeeze(-1)
            batch_features = batch_features / batch_features.norm(dim=-1, keepdim=True)
        features.append(batch_features.cpu().numpy())
    return np.concatenate(features, axis=0).astype("float32")


In [6]:
# Load or compute features
if os.path.exists(FEATURES_PATH) and os.path.exists(PATHS_PATH):
    image_features = np.load(FEATURES_PATH)
    with open(PATHS_PATH, "rb") as f:
        saved_paths = pickle.load(f)
    if set(saved_paths) != set(image_paths):
        print("Image set changed, re-extracting features...")
        image_features = extract_cnn_features(image_paths)
        np.save(FEATURES_PATH, image_features)
        with open(PATHS_PATH, "wb") as f:
            pickle.dump(image_paths, f)
else:
    image_features = extract_cnn_features(image_paths)
    np.save(FEATURES_PATH, image_features)
    with open(PATHS_PATH, "wb") as f:
        pickle.dump(image_paths, f)

Extracting CNN features: 100%|██████████| 3697/3697 [35:10<00:00,  1.75it/s]


# FAISS index
Initializes a FAISS index (IndexFlatIP) designed for cosine similarity search (because your features are L2-normalized) and then populates it with all the extracted features from image collection. 

In [7]:
# Build FAISS index
index = faiss.IndexFlatIP(image_features.shape[1])
index.add(image_features)

# Search functions and Gradio-based demo

In [8]:
def search_by_image_cnn(query_image, top_k=TOP_K):
    image = cnn_preprocess(query_image.convert("RGB")).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features_query = cnn_model(image).squeeze(-1).squeeze(-1)
        image_features_query = image_features_query / image_features_query.norm(dim=-1, keepdim=True)
    image_features_query = image_features_query.cpu().numpy().astype("float32")
    D, I = index.search(image_features_query, top_k)
    return [image_paths[i] for i in I[0]]

def visual_search_cnn(image_query):
    if image_query is not None:
        results = search_by_image_cnn(image_query)
        return [Image.open(p) for p in results]
    else:
        return []

with gr.Blocks() as demo:
    gr.Markdown("# CNN Visual Search Engine (Image-to-Image)")
    image_input = gr.Image(type="pil", label="Upload an image to search")
    output_gallery = gr.Gallery(label="Top Results", columns=5, height="auto")
    search_btn = gr.Button("Search")
    search_btn.click(
        fn=visual_search_cnn,
        inputs=image_input,
        outputs=output_gallery
    )

demo.launch()

* Running on local URL:  http://127.0.0.1:7860
It looks like you are running Gradio on a hosted a Jupyter notebook. For the Gradio app to work, sharing must be enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

* Running on public URL: https://abb2c4a19eeebb14b1.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


