## Extract Features from Dataset

In [1]:
import os
import json
import torch
from tqdm import tqdm
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from torchvision import transforms

# === CONFIG ===
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATASET_DIR = "dataset_cap"
SPLITS = ["train", "val"]
OUTPUT_DIR = "clip_features"
os.makedirs(OUTPUT_DIR, exist_ok=True)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# === LOAD CLIP ===
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [3]:
def load_image(path):
    image = Image.open(path).convert("RGB")
    return image

In [4]:
for split in SPLITS:
    print(f"Processing split: {split}")
    image_dir = os.path.join(DATASET_DIR, split, "images")
    caption_path = os.path.join(DATASET_DIR, split, "captions.json")

    with open(caption_path, "r") as f:
        captions_data = json.load(f)

    img_features = []
    txt_features = []
    labels = []
    filenames = []

    for fname, info in tqdm(captions_data.items()):
        img_path = os.path.join(image_dir, fname)
        if not os.path.exists(img_path):
            continue

        image = load_image(img_path)

        caption = info["captions"][0] if isinstance(info["captions"], list) else info["captions"]
        category_id = info.get("category_id", -1)

        inputs = processor(text=caption, images=image, return_tensors="pt", padding=True).to(DEVICE)

        with torch.no_grad():
            outputs = model(**inputs)
            img_feat = outputs.image_embeds[0].cpu()
            txt_feat = outputs.text_embeds[0].cpu()

        img_features.append(img_feat)
        txt_features.append(txt_feat)
        labels.append(category_id)
        filenames.append(fname)

    # Save tensors
    torch.save({
        "image_features": torch.stack(img_features),
        "text_features": torch.stack(txt_features),
        "labels": torch.tensor(labels),
        "filenames": filenames
    }, os.path.join(OUTPUT_DIR, f"{split}_features.pt"))

    print(f"Saved features for {split} to {os.path.join(OUTPUT_DIR, f'{split}_features.pt')}")

Processing split: train


100%|██████████| 7919/7919 [02:27<00:00, 53.63it/s]


Saved features for train to clip_features/train_features.pt
Processing split: val


100%|██████████| 1985/1985 [00:36<00:00, 54.34it/s]

Saved features for val to clip_features/val_features.pt



