# Plant Image Retrieval Indexing (CLIP ViT-B/32)

Name: Zihan Yin

## Step 1 — Install Dependencies

In [None]:
# %pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu121
# %pip install -q transformers pillow numpy tqdm safetensors

## Step 2 - Imports and Config

In [None]:
from pathlib import Path
import re, json, time, random, math
import numpy as np
from PIL import Image, ImageOps
from tqdm import tqdm

import torch
from torch import nn
from torchvision import transforms
from transformers import CLIPModel, CLIPProcessor

# ======== Paths & Params ========
NOTEBOOK_DIR = Path(__file__).parent if "__file__" in globals() else Path().resolve()
DATA_DIR = (NOTEBOOK_DIR / "../../01_data_wrangling/01_raw_data/05_thumbnail_image").resolve()
OUT_DIR    = Path("index")                            # output embeddings & meta
MODEL_ID   = "openai/clip-vit-base-patch32"          # CLIP ViT-B/32
MODEL_DIR  = Path("clip-vit-b32")                    # local model cache
NUM_AUGS   = 20                                      # augmentations per class (excluding original)
BATCH_SIZE = 32                                      # encoding batch size
SEED       = 42

OUT_DIR.mkdir(parents=True, exist_ok=True)
MODEL_DIR.mkdir(parents=True, exist_ok=True)

# Filename regex: capture plant_id from filename
FILE_RE = re.compile(r"plant_species_thumbnail_image_(\d+)\.jpg", re.IGNORECASE)

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

set_seed(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

Quick Dataset Peek

In [2]:
print("Directory exists?", DATA_DIR.exists())
print("Number of JPGs found:", len(list(DATA_DIR.glob("*.jpg"))))
print("First 5 files:", [p.name for p in DATA_DIR.glob("*.jpg")][:5])

Directory exists? True
Number of JPGs found: 1319
First 5 files: ['plant_species_thumbnail_image_1.jpg', 'plant_species_thumbnail_image_10.jpg', 'plant_species_thumbnail_image_100.jpg', 'plant_species_thumbnail_image_1000.jpg', 'plant_species_thumbnail_image_1001.jpg']


## Step 3 — Utils Tools (EXIF fix, light augment, file scan)

In [3]:
def exif_correct(img: Image.Image) -> Image.Image:
    # Normalize orientation; convert RGBA to RGB (white background)
    img = ImageOps.exif_transpose(img)
    if img.mode == "RGBA":
        bg = Image.new("RGB", img.size, (255, 255, 255))
        bg.paste(img, mask=img.split()[-1])
        return bg
    if img.mode != "RGB":
        img = img.convert("RGB")
    return img

def light_aug_pipeline():
    """
    Light augmentations to avoid prototype drift:
    - Random horizontal flip
    - Mild brightness/contrast/saturation/hue jitter
    - Small rotation/translation + moderate scale jitter
    - Low-probability mild blur
    """
    return transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15, hue=0.02),
        transforms.RandomAffine(degrees=15, translate=(0.10, 0.10), scale=(0.9, 1.1)),
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0))], p=0.2),
    ])

def parse_plant_images(data_dir: Path):
    pairs = []
    for p in sorted(data_dir.glob("*")):
        m = FILE_RE.match(p.name)
        if m:
            plant_id = int(m.group(1))
            pairs.append((plant_id, p))
    return pairs

pairs = parse_plant_images(DATA_DIR)
print(f"Found {len(pairs)} classes (one thumbnail each)")
assert len(pairs) > 0, "No images found matching the naming pattern; check directory and filenames."

Found 1319 classes (one thumbnail each)


## Step 4 — Load CLIP Model and Processor (cache locally)

In [4]:
print("Loading CLIP model and processor…")
model = CLIPModel.from_pretrained(MODEL_ID)
processor = CLIPProcessor.from_pretrained(MODEL_ID)
model.eval().to(device)

# Optionally cache locally for offline/deployment
# model.save_pretrained(MODEL_DIR) ###############
processor.save_pretrained(MODEL_DIR)

embedding_dim = model.config.projection_dim  # 512
embedding_dim

Loading CLIP model and processor…


512

## Step 5 — Batch Encoding (AMP on GPU)

In [5]:
@torch.no_grad()
def encode_images(pil_list, batch_size=32, amp_dtype=torch.float16):
    """
    Use CLIPProcessor for canonical preprocessing (resize/center crop/normalize).
    Return L2-normalized image embeddings (B, D).
    """
    embs = []
    for i in range(0, len(pil_list), batch_size):
        batch = pil_list[i:i+batch_size]
        inputs = processor(images=batch, return_tensors="pt", do_center_crop=True, padding=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"),
                            dtype=amp_dtype, enabled=(device=="cuda")):
            feats = model.get_image_features(**inputs)  # (B, D)
        feats = nn.functional.normalize(feats, p=2, dim=-1)
        embs.append(feats.cpu())
    return torch.cat(embs, dim=0)  # (N, D)

## Step 6 — Build Mean Prototypes and Cache (main loop)

In [6]:
aug = light_aug_pipeline()

# Lists for robust collection (skip corrupted images)
proto_list = []
plant_id_list = []

pbar = tqdm(pairs, desc="Encoding classes (with augmentation)")

for plant_id, img_path in pbar:
    # Open & correct
    try:
        img = Image.open(img_path)
    except Exception as e:
        print(f"[WARN] Skip corrupted image {img_path}: {e}")
        continue
    img = exif_correct(img)

    # Build augmented samples (include original)
    pil_list = [img] + [aug(img) for _ in range(NUM_AUGS)]

    # Encode
    feats = encode_images(pil_list, batch_size=BATCH_SIZE)  # (N, D)
    mean_proto = feats.mean(dim=0, keepdim=True)
    mean_proto = nn.functional.normalize(mean_proto, p=2, dim=-1).cpu().numpy()[0]  # (D,)

    # Cache
    proto_list.append(mean_proto.astype(np.float32))
    plant_id_list.append(plant_id)

# Assemble matrices
embeddings = np.vstack(proto_list).astype(np.float32)  # (C, D)
plant_ids  = np.array(plant_id_list, dtype=np.int64)

print("Embeddings shape:", embeddings.shape)
print("Example L2 norm (≈1):", np.linalg.norm(embeddings[0]))

Encoding classes (with augmentation): 100%|██████████| 1319/1319 [11:35<00:00,  1.90it/s]

Embeddings shape: (1319, 512)
Example L2 norm (≈1): 0.99983567





## Step 7 — Save Index (fp16) and Meta

In [7]:
# Step 8 - Persist Artifacts
# Store embeddings in fp16 (space-efficient); queries can upcast to fp32
emb_fp16 = embeddings.astype(np.float16)
np.savez_compressed(OUT_DIR / "embeddings_fp16.npz",
                    embeddings=emb_fp16,
                    plant_ids=plant_ids)

meta = {
    "model_id": MODEL_ID,
    "model_local_dir": str(MODEL_DIR.as_posix()),
    "embedding_dim": int(embedding_dim),
    "num_classes": int(embeddings.shape[0]),
    "num_augs_per_class": NUM_AUGS,
    "built_at": time.strftime("%Y-%m-%d %H:%M:%S"),
    "device_used": device,
    "seed": SEED,
    "preprocess": {
        "center_crop": True,
        "note": "Aligned with CLIPProcessor defaults; online queries must match"
    },
    "similarity": "cosine (via dot on L2-normalized vectors)"
}
with open(OUT_DIR / "meta.json", "w", encoding="utf-8") as f:
    json.dump(meta, f, ensure_ascii=False, indent=2)

print("Saved:", (OUT_DIR / "embeddings_fp16.npz").as_posix())
print("Saved:", (OUT_DIR / "meta.json").as_posix())
print("Model cached at:", MODEL_DIR.as_posix())

Saved: index/embeddings_fp16.npz
Saved: index/meta.json
Model cached at: clip-vit-b32


## Step 8 - Reload and Sanity Check

In [8]:

# Load the saved index and check shapes/norms
npz = np.load(OUT_DIR / "embeddings_fp16.npz")
E = npz["embeddings"].astype(np.float32)  # (C, D)
P = npz["plant_ids"]

print("Loaded embeddings:", E.shape, " Loaded plant_ids:", P.shape)
print("Mean L2 norm:", np.linalg.norm(E, axis=1).mean())
print("First 5 plant_ids:", P[:5])

Loaded embeddings: (1319, 512)  Loaded plant_ids: (1319,)
Mean L2 norm: 1.0000056
First 5 plant_ids: [   1   10  100 1000 1001]
