Just a script to embed all of the textual and image data in CLIP and save the embeddings in the drive folder

In [2]:
import os, numpy as np, torch
from PIL import Image

# Helper
def mean_pool(folder, model, preprocess):
    jpgs = [f for f in os.listdir(folder) if f.endswith(".jpg")]
    if not jpgs:
        return np.zeros(512, dtype=np.float32)
    batch = torch.stack([preprocess(Image.open(folder / f)) for f in jpgs])
    with torch.no_grad():
        vecs = model.encode_image(batch)
    vecs = vecs / vecs.norm(dim=-1, keepdim=True)
    mean = vecs.mean(dim=0)
    return (mean / mean.norm()).cpu().numpy()


In [None]:
from pathlib import Path
import numpy as np, torch, clip
from tqdm import tqdm

# -------------------------------------------------------------------------
drive_root = Path(r"G:/")
PROJ_DIR   = drive_root / ".shortcut-targets-by-id/1CwmFOsYFnq6t33KAzpvw0gaOTQXbcozs/brain-decoder-files"

assert PROJ_DIR.is_dir(), f"{PROJ_DIR} does not exist – check the path!"

images_dir = PROJ_DIR / "experiment-images"
text_npz   = PROJ_DIR / "clip_text_embeddings.npz"
image_npz  = PROJ_DIR / "clip_image_embeddings.npz"
# -------------------------------------------------------------------------

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-L/14@336px", device=device)
concepts = np.genfromtxt(PROJ_DIR / "concepts.txt", dtype=str)

# ---- text embeddings -----------------------------------------------------
with torch.no_grad():
    toks = clip.tokenize([f"A photo of {c}" for c in concepts]).to(device)
    txt  = model.encode_text(toks)
    txt  = txt / txt.norm(dim=-1, keepdim=True)
np.savez_compressed(text_npz, data=txt.cpu().numpy().astype(np.float32))
print("Text vectors saved to", text_npz)

# ---- image embeddings ----------------------------------------------------
imgs = [mean_pool(images_dir / c, model, preprocess) for c in tqdm(concepts, desc="images")]
np.savez_compressed(image_npz, data=np.stack(imgs, dtype=np.float32))
print("Image vectors saved to", image_npz)
