# Plant Image Retrieval Query (CLIP ViT-B/32)

Name: Zihan Yin

## Step 1 - Imports and Path Config


In [7]:
from pathlib import Path
import json
import numpy as np
from PIL import Image, ImageOps

import torch
from torch import nn
from torchvision import transforms
from transformers import CLIPModel, CLIPProcessor

# ======== Path Config ========
INDEX_PATH = Path("index/embeddings_fp16.npz")   # vector file generated by indexing
META_PATH  = Path("index/meta.json")             # meta generated by indexing
MODEL_DIR  = Path("clip-vit-b32")                # local CLIP model directory
MODEL_ID   = "openai/clip-vit-base-patch32"      # fallback if local not found
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

Device: cuda


## Step 2 — Utilities (image preprocessing + TTA)

In [8]:
def exif_correct(img: Image.Image) -> Image.Image:
    """Fix orientation via EXIF and convert to RGB."""
    img = ImageOps.exif_transpose(img)
    if img.mode == "RGBA":
        bg = Image.new("RGB", img.size, (255, 255, 255))
        bg.paste(img, mask=img.split()[-1])
        return bg
    if img.mode != "RGB":
        img = img.convert("RGB")
    return img

def tta_pipeline():
    """Light test-time augmentation (avoid heavy transforms)."""
    return transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomAffine(degrees=10, translate=(0.06, 0.06), scale=(0.95, 1.05)),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.01),
    ])

## Step 3 - Load Index and CLIP Model

In [9]:
# Load index
npz = np.load(INDEX_PATH)
db_emb = npz["embeddings"].astype(np.float32)  # (N, D)
plant_ids = npz["plant_ids"]

with open(META_PATH, "r", encoding="utf-8") as f:
    meta = json.load(f)

print("Number of classes in index:", db_emb.shape[0])
print("Embedding dimension:", db_emb.shape[1])

# Load model (prefer on-line model)
try:
    model = CLIPModel.from_pretrained(MODEL_DIR)
    processor = CLIPProcessor.from_pretrained(MODEL_DIR)
    print("Loaded local model successfully.")
except Exception:
    print("Failed to load local model; falling back to online.")
    model = CLIPModel.from_pretrained(MODEL_ID)
    processor = CLIPProcessor.from_pretrained(MODEL_ID)

model.eval().to(DEVICE)

Number of classes in index: 1319
Embedding dimension: 512
Failed to load local model; falling back to online.


CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPSdpaAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e

## Step 4 — Encode Query Image (From User)

In [10]:
@torch.no_grad()
def encode_query(img: Image.Image, tta_num=8):
    """Encode the uploaded image with TTA and return a single vector."""
    img = exif_correct(img)
    tta = tta_pipeline()
    pil_list = [img] + [tta(img) for _ in range(max(tta_num-1, 0))]

    inputs = processor(images=pil_list, return_tensors="pt", do_center_crop=True, padding=True)
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
    with torch.autocast(device_type=("cuda" if DEVICE=="cuda" else "cpu"),
                        dtype=torch.float16, enabled=(DEVICE=="cuda")):
        feats = model.get_image_features(**inputs)
    feats = nn.functional.normalize(feats, p=2, dim=-1)
    q = feats.mean(dim=0, keepdim=False)  # (D,)
    return q.cpu().numpy().astype(np.float32)

## Step 5 — Compute Cosine Similarity and Take Top-K

In [11]:
# Cosine Similarity Top-K
def topk_cosine(query_emb, db_emb, plant_ids, k=10):
    """
    query_emb: (D,)
    db_emb: (N, D) already L2-normalized
    Return Top-K as (plant_id, score)
    """
    sims = db_emb @ query_emb  # cosine similarity
    idx = np.argpartition(-sims, kth=min(k, len(sims)-1))[:k]
    idx = idx[np.argsort(-sims[idx])]
    return [(int(plant_ids[i]), float(sims[i])) for i in idx]

## Step 6 — Example Query

In [12]:
# Allowed suffixes: .jpg, .webp, .JPG

# Replace with the image path you want to query
QUERY_IMG = Path("test_images/fdghjmhgferghthjkhgf.jpg")  

img = Image.open(QUERY_IMG)
q_emb = encode_query(img, tta_num=8)

results = topk_cosine(q_emb, db_emb, plant_ids, k=10)
for pid, score in results:
    print(f"plant_id={pid}\tscore={score:.4f}")

plant_id=342	score=0.8903
plant_id=343	score=0.8866
plant_id=341	score=0.8801
plant_id=250	score=0.8756
plant_id=540	score=0.8733
plant_id=340	score=0.8713
plant_id=402	score=0.8682
plant_id=348	score=0.8681
plant_id=742	score=0.8677
plant_id=236	score=0.8664
