# 植物检索建库（CLIP ViT-B/32）

## Cell 1 — 安装依赖（已装可跳过）

In [None]:
# # 如果你的环境里已经有这些库，可以跳过本单元
# # 注意：PyTorch 请按照你的 CUDA 版本安装（https://pytorch.org/），这里给出通用命令
# %pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu121
# %pip install -q transformers pillow numpy tqdm safetensors

## Cell 2 — 导入与配置

In [44]:
from pathlib import Path
import re, json, time, random, math
import numpy as np
from PIL import Image, ImageOps
from tqdm import tqdm

import torch
from torch import nn
from torchvision import transforms
from transformers import CLIPModel, CLIPProcessor

# ======== 路径与参数 ========
NOTEBOOK_DIR = Path(__file__).parent if "__file__" in globals() else Path().resolve()
DATA_DIR = (NOTEBOOK_DIR / "../../01_data_wrangling/01_raw_data/05_thumbnail_image").resolve()
OUT_DIR    = Path("index")                            # 输出 embeddings 与 meta
MODEL_ID   = "openai/clip-vit-base-patch32"          # CLIP ViT-B/32
MODEL_DIR  = Path("clip-vit-b32")             # 离线保存模型
NUM_AUGS   = 20                                      # 每类增强次数（不含原图）
BATCH_SIZE = 32                                      # 编码 batch
SEED       = 42

OUT_DIR.mkdir(parents=True, exist_ok=True)
MODEL_DIR.mkdir(parents=True, exist_ok=True)

# 文件名正则：从文件名里抓 plant_id
FILE_RE = re.compile(r"plant_species_thumbnail_image_(\d+)\.jpg", re.IGNORECASE)

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

set_seed(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [45]:
print("目录存在？", DATA_DIR.exists())
print("找到的jpg数：", len(list(DATA_DIR.glob("*.jpg"))))
print("前5个文件：", [p.name for p in DATA_DIR.glob("*.jpg")][:5])


目录存在？ True
找到的jpg数： 541
前5个文件： ['plant_species_thumbnail_image_1.jpg', 'plant_species_thumbnail_image_10.jpg', 'plant_species_thumbnail_image_100.jpg', 'plant_species_thumbnail_image_101.jpg', 'plant_species_thumbnail_image_102.jpg']


## Cell 3 — 工具函数（EXIF 修正、轻量增强、文件扫描）

In [48]:
def exif_correct(img: Image.Image) -> Image.Image:
    # 统一方向；将带 alpha 的 PNG 转为 RGB（白底）
    img = ImageOps.exif_transpose(img)
    if img.mode == "RGBA":
        bg = Image.new("RGB", img.size, (255, 255, 255))
        bg.paste(img, mask=img.split()[-1])
        return bg
    if img.mode != "RGB":
        img = img.convert("RGB")
    return img

def light_aug_pipeline():
    """
    轻量增强（避免“过猛”导致原型漂移）：
    - 随机水平翻转
    - 亮度/对比/饱和度/色相 小幅抖动
    - 轻旋转与平移，适度缩放
    - 低概率轻模糊
    """
    return transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15, hue=0.02),
        transforms.RandomAffine(degrees=15, translate=(0.10, 0.10), scale=(0.9, 1.1)),
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0))], p=0.2),
    ])

def parse_plant_images(data_dir: Path):
    pairs = []
    for p in sorted(data_dir.glob("*")):
        m = FILE_RE.match(p.name)
        if m:
            plant_id = int(m.group(1))
            pairs.append((plant_id, p))
    return pairs

pairs = parse_plant_images(DATA_DIR)
print(f"发现 {len(pairs)} 个类别（每类 1 张缩略图）")
assert len(pairs) > 0, "未发现符合命名规则的图片，请检查目录与文件名格式。"

发现 541 个类别（每类 1 张缩略图）


## Cell 4 — 加载 CLIP 模型与处理器（并离线保存）

In [49]:
print("加载 CLIP 模型与处理器…")
model = CLIPModel.from_pretrained(MODEL_ID)
processor = CLIPProcessor.from_pretrained(MODEL_ID)
model.eval().to(device)

# 保存一份到本地，便于离线/部署
# model.save_pretrained(MODEL_DIR)
processor.save_pretrained(MODEL_DIR)

embedding_dim = model.config.projection_dim  # 512
embedding_dim

加载 CLIP 模型与处理器…


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


512

## Cell 5 — 批量编码函数（GPU 混合精度）

In [50]:
@torch.no_grad()
def encode_images(pil_list, batch_size=32, amp_dtype=torch.float16):
    """
    使用 CLIPProcessor 保持与官方一致的预处理（resize/center crop/normalize 等）
    输出为 L2 归一化后的图像 embedding（B, D）
    """
    embs = []
    for i in range(0, len(pil_list), batch_size):
        batch = pil_list[i:i+batch_size]
        inputs = processor(images=batch, return_tensors="pt", do_center_crop=True, padding=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"),
                            dtype=amp_dtype, enabled=(device=="cuda")):
            feats = model.get_image_features(**inputs)  # (B, D)
        feats = nn.functional.normalize(feats, p=2, dim=-1)
        embs.append(feats.cpu())
    return torch.cat(embs, dim=0)  # (N, D)

## Cell 6 — 构建均值原型并缓存（主循环）

In [51]:
aug = light_aug_pipeline()

# 预分配：可能有坏图跳过，先用 list 收集更稳妥
proto_list = []
plant_id_list = []

pbar = tqdm(pairs, desc="编码各类别（含增强）")

for plant_id, img_path in pbar:
    # 打开 & 纠正
    try:
        img = Image.open(img_path)
    except Exception as e:
        print(f"[WARN] 跳过损坏图片 {img_path}: {e}")
        continue
    img = exif_correct(img)

    # 生成增强样本（含原图）
    pil_list = [img] + [aug(img) for _ in range(NUM_AUGS)]

    # 编码
    feats = encode_images(pil_list, batch_size=BATCH_SIZE)  # (N, D)
    mean_proto = feats.mean(dim=0, keepdim=True)
    mean_proto = nn.functional.normalize(mean_proto, p=2, dim=-1).cpu().numpy()[0]  # (D,)

    # 缓存
    proto_list.append(mean_proto.astype(np.float32))
    plant_id_list.append(plant_id)

# 组装矩阵
embeddings = np.vstack(proto_list).astype(np.float32)  # (C, D)
plant_ids  = np.array(plant_id_list, dtype=np.int64)

print("embeddings 形状：", embeddings.shape)
print("示例 L2 范数（应接近 1）：", np.linalg.norm(embeddings[0]))

编码各类别（含增强）: 100%|██████████| 541/541 [04:40<00:00,  1.93it/s]

embeddings 形状： (541, 512)
示例 L2 范数（应接近 1）： 0.99983567





## Cell 7 — 保存索引（fp16 压缩）与 Meta

In [52]:
# 使用 fp16 存盘（节省空间），查询时会转回 fp32 计算
emb_fp16 = embeddings.astype(np.float16)
np.savez_compressed(OUT_DIR / "embeddings_fp16.npz",
                    embeddings=emb_fp16,
                    plant_ids=plant_ids)

meta = {
    "model_id": MODEL_ID,
    "model_local_dir": str(MODEL_DIR.as_posix()),
    "embedding_dim": int(embedding_dim),
    "num_classes": int(embeddings.shape[0]),
    "num_augs_per_class": NUM_AUGS,
    "built_at": time.strftime("%Y-%m-%d %H:%M:%S"),
    "device_used": device,
    "seed": SEED,
    "preprocess": {
        "center_crop": True,
        "note": "与 CLIPProcessor 默认预处理保持一致；在线查询需完全一致"
    },
    "similarity": "cosine (via dot on L2-normalized vectors)"
}
with open(OUT_DIR / "meta.json", "w", encoding="utf-8") as f:
    json.dump(meta, f, ensure_ascii=False, indent=2)

print("✅ 已保存：", (OUT_DIR / "embeddings_fp16.npz").as_posix())
print("✅ 已保存：", (OUT_DIR / "meta.json").as_posix())
print("✅ 模型保存在：", MODEL_DIR.as_posix())

✅ 已保存： embeddings_fp16.npz
✅ 已保存： meta.json
✅ 模型保存在： clip-vit-b32


## Cell 8 — 快速自检（可选）

In [53]:
# 加载刚保存的索引，做一次形状与范数检查
npz = np.load(OUT_DIR / "embeddings_fp16.npz")
E = npz["embeddings"].astype(np.float32)  # (C, D)
P = npz["plant_ids"]

print("载入 embeddings：", E.shape, " 载入 plant_ids：", P.shape)
print("平均范数：", np.linalg.norm(E, axis=1).mean())
print("展示前 5 个 plant_id：", P[:5])

载入 embeddings： (541, 512)  载入 plant_ids： (541,)
平均范数： 0.999998
展示前 5 个 plant_id： [  1  10 100 101 102]
