## Step 2: Image Embedding with CLIP
This notebook covers the process of generating image embeddings using the CLIP model. We will load images, preprocess them, and extract embeddings for downstream tasks such as similarity search and clustering.

- Load image paths and metadata
- Preprocess images for CLIP
- Generate and save image embeddings
- Discuss best practices and troubleshooting tips

## CLIP Model Setup & Image Selection
Import CLIP model, set up device, and define helper functions to select the best image URL from metadata.

In [4]:
import io, time, math, requests, sys, os
from PIL import Image
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from transformers import CLIPModel, CLIPProcessor
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath('__file__'))))
from src.data_utils import preprocess_df
from src.clip_processing import pick_best_image_from_images_field

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

df_review = pd.read_parquet("../data/CDs_and_Vinyl_reviews.parquet")
df_meta = pd.read_parquet("../data/CDs_and_Vinyl_meta.parquet")
df = preprocess_df(df_review, df_meta)

df['img_url'] = df['images_y'].apply(pick_best_image_from_images_field)

Using device: cuda


## Image Downloading, Embedding, and Saving
Download images, convert to PIL, embed with CLIP, and save embeddings in chunks. Merge all parts and update the main DataFrame.

In [None]:
import os, io, gc, requests
from concurrent.futures import ThreadPoolExecutor, as_completed
from PIL import Image
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath('__file__'))))
from src.clip_processing import get_session, fetch_image_bytes, to_pil, embed_pil_batch

# ================================================================
# CONFIGURATION
# ================================================================
SAMPLE = 50000
CHUNK = 5000
N_WORKERS = 24
CLIP_BATCH = 32
IMG_CACHE = "../img_cache"
SAVE_DIR = "../emb_parts"
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(IMG_CACHE, exist_ok=True)
CATEGORY = "CDs_and_Vinyl"

# ================================================================
# REQUEST SESSION
# ================================================================
SESSION = get_session()

# ================================================================
# PREPARE DATAFRAME SUBSET
# ================================================================
df = pd.read_parquet(f"../data_{CATEGORY}_prepared.parquet")
mask = df["img_url"].notna()
work = df.loc[mask, ["parent_asin", "img_url"]].drop_duplicates("parent_asin").copy()
if len(work) > SAMPLE:
    work = work.sample(SAMPLE, random_state=42).reset_index(drop=True)
print("Total images to process:", len(work))

# ================================================================
# MAIN PROCESS LOOP
# ================================================================
def part_path(pi):
    return os.path.join(SAVE_DIR, f"clip_img_emb_parent.part{pi:03d}.parquet")

num_parts = (len(work) + CHUNK - 1) // CHUNK

for pi in range(num_parts):
    out_path = part_path(pi)
    if os.path.exists(out_path):
        print(f"[SKIP] Part {pi} exists -> {out_path}")
        continue

    start, stop = pi * CHUNK, min(len(work), (pi + 1) * CHUNK)
    sub = work.iloc[start:stop].reset_index(drop=True)
    print(f"\n[Part {pi}] rows {start}:{stop} ({len(sub)})")

    # ---------- Download ----------
    bytes_list = [None] * len(sub)
    with ThreadPoolExecutor(max_workers=N_WORKERS) as ex:
        futs = {ex.submit(fetch_image_bytes, url, session=SESSION, fname=sub.loc[i, "parent_asin"], cache_dir=IMG_CACHE): i
                for i, url in enumerate(sub["img_url"])}
        for fut in tqdm(as_completed(futs), total=len(futs), desc=f"Downloading P{pi}"):
            i = futs[fut]
            bytes_list[i] = fut.result()

    # ---------- Convert to PIL ----------
    pil_list = [to_pil(b) for b in bytes_list]
    ok_rate = np.mean([p is not None for p in pil_list])
    print(f"[Part {pi}] PIL OK rate: {ok_rate:.1%}")

    # ---------- Embed ----------
    embs = embed_pil_batch(pil_list, processor, model, device, batch_size=CLIP_BATCH)

    # ---------- Save partial parquet ----------
    rows = [(str(sub.loc[i, "parent_asin"]), e.tolist()) for i, e in enumerate(embs) if e is not None]
    part_df = pd.DataFrame(rows, columns=["parent_asin", "clip_img_emb"])
    part_df.to_parquet(out_path, index=False, compression="zstd")
    print(f"[Part {pi}] Saved {len(part_df)} embeddings -> {out_path}")

    # ---------- Cleanup ----------
    del bytes_list, pil_list, embs, part_df, sub
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

print("\n✅ All parts done (or skipped).")

# ================================================================
# MERGE ALL PARTS
# ================================================================
parts = [part_path(pi) for pi in range(num_parts) if os.path.exists(part_path(pi))]
if parts:
    merged_emb = pd.concat([pd.read_parquet(p) for p in parts], ignore_index=True)
    merged_emb = merged_emb.drop_duplicates("parent_asin")
    merged_emb.to_parquet("../clip_img_emb_parent.parquet", index=False, compression="zstd")
    print("Final merged embeddings:", len(merged_emb))

    # ================================================================
    # MERGE BACK TO MAIN DF
    # ================================================================
    df["parent_asin"] = df["parent_asin"].astype(str)
    merged_emb["parent_asin"] = merged_emb["parent_asin"].astype(str)
    df = df.drop(columns=["clip_img_emb"], errors="ignore").merge(merged_emb, on="parent_asin", how="left")

    df["has_img_emb"] = df["clip_img_emb"].notna()
    df["clip_img_emb"] = df["clip_img_emb"].apply(
        lambda v: np.array(v, dtype=np.float32) if isinstance(v, list) else np.zeros(model.config.projection_dim, dtype=np.float32)
    )

    df.to_parquet(f"../reviews_with_img_emb_{CATEGORY}.parquet", index=False, engine="pyarrow", compression="zstd")
    print(" Saved final df:", f"../reviews_with_img_emb_{CATEGORY}.parquet")
else:
    print("No embedding parts were created.")