# Phase 0-4: Feature Extraction

Extract CLIP features for all images in the dataset and save them to features.pkl.

In [None]:
import os
import pickle
from pathlib import Path
from datetime import datetime

import numpy as np
from PIL import Image
from tqdm import tqdm
import torch

print(f"Imports OK. CWD: {Path().resolve()}")
print(f"GPU available: {torch.cuda.is_available()}")

In [None]:
# CLIP import helper
try:
    import clip  # OpenAI CLIP: pip install git+https://github.com/openai/CLIP.git
except ImportError:
    clip = None
    print("\n[warning] `clip` package not found. Install with:")
    print("  pip install git+https://github.com/openai/CLIP.git")

In [None]:
# Paths are machine-specific. Adjust PROJECT_ROOT before running on a new machine.
PROJECT_ROOT = "/Users/tyreecruse/Desktop/CS230/Project/Data/Original"
# Or just use the current directory:
# PROJECT_ROOT = os.getcwd()

CONFIG = {
    "dataset_path": os.path.join(PROJECT_ROOT, "master_dataset_pool"),
    "output_path": os.path.join(PROJECT_ROOT, "features.pkl"),

    "model_name": "ViT-B/32",
    "batch_size": 512,
    "normalize": True,
}

device = "cuda" if torch.cuda.is_available() else "cpu"

print("Config")
print("------")
print(f"dataset_path: {CONFIG['dataset_path']}")
print(f"output_path : {CONFIG['output_path']}")
print(f"model_name  : {CONFIG['model_name']}")
print(f"batch_size  : {CONFIG['batch_size']}")
print(f"device      : {device}")

dataset_path = Path(CONFIG["dataset_path"])
images_dir = dataset_path / "images"

if not images_dir.exists():
    print(f"\n[warning] images/ not found under {dataset_path}")
else:
    n_images = sum(1 for p in images_dir.iterdir() if p.is_file())
    print(f"\nFound {n_images:,} image files under {images_dir}")

In [None]:
def list_image_files(images_dir):
    """Return a sorted list of image file paths under images_dir."""
    images_dir = Path(images_dir)
    exts = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
    files = [p for p in images_dir.iterdir() if p.suffix.lower() in exts and p.is_file()]
    return sorted(files)

In [None]:
def load_clip_model(model_name, device="cpu"):
    """Load CLIP model + preprocess transform."""
    if clip is None:
        raise ImportError("clip package is not installed.")
    model, preprocess = clip.load(model_name, device=device)
    model.eval()
    return model, preprocess

In [None]:
def extract_features_for_batch(model, preprocess, device, paths):
    """Extract features for a list of image paths (single batch)."""
    images = []
    valid_paths = []
    for p in paths:
        try:
            img = Image.open(p).convert("RGB")
        except Exception:
            continue
        images.append(preprocess(img))
        valid_paths.append(str(p))

    if not images:
        return None, []

    batch = torch.stack(images).to(device)
    with torch.no_grad():
        feats = model.encode_image(batch)
        if feats.ndim > 2:
            feats = feats.mean(dim=1)
        feats = feats.cpu().numpy()

    return feats, valid_paths

In [None]:
def process_all_images(image_files, model, preprocess, device, batch_size, normalize=True):
    """Process all images in batches and return (features, paths)."""
    all_features = []
    all_paths = []

    n = len(image_files)
    for start in tqdm(range(0, n, batch_size), desc="batches"):
        end = min(start + batch_size, n)
        batch_paths = image_files[start:end]
        feats, valid_paths = extract_features_for_batch(model, preprocess, device, batch_paths)
        if feats is None:
            continue
        all_features.append(feats)
        all_paths.extend(valid_paths)

    if not all_features:
        return np.zeros((0, 0), dtype=np.float32), []

    features = np.concatenate(all_features, axis=0).astype(np.float32)

    if normalize and features.size > 0:
        norms = np.linalg.norm(features, axis=1, keepdims=True) + 1e-8
        features = features / norms

    return features, all_paths

In [None]:
def run_feature_extraction():
    """Orchestrate CLIP feature extraction and save to features.pkl."""
    if not images_dir.exists():
        raise FileNotFoundError(f"images/ directory not found at {images_dir}")

    image_files = list_image_files(images_dir)
    if not image_files:
        raise RuntimeError(f"No image files found under {images_dir}")

    print(f"\nPreparing to extract features for {len(image_files):,} images.")

    model, preprocess = load_clip_model(CONFIG["model_name"], device=device)
    print("\nModel loaded.")
    print("----------------")
    print(model)

    features, paths = process_all_images(
        image_files, model, preprocess, device, CONFIG["batch_size"], normalize=CONFIG["normalize"]
    )

    print(f"\nFeature matrix shape: {features.shape}")
    if features.size > 0:
        norms = np.linalg.norm(features[:5], axis=1)
        print("Sample norms (first 5 rows):", [f"{n:.3f}" for n in norms])

    data = {
        "features": features,
        "paths": paths,
        "model_name": CONFIG["model_name"],
        "normalized": CONFIG["normalize"],
        "device": device,
        "n_images": len(paths),
        "feature_dim": int(features.shape[1]) if features.size > 0 else 0,
        "created_at": datetime.now().isoformat(timespec="seconds"),
    }

    output_path = Path(CONFIG["output_path"])
    output_path.parent.mkdir(parents=True, exist_ok=True)

    with open(output_path, "wb") as f:
        pickle.dump(data, f)

    print(f"\nSaved features to: {output_path}")
    return data

In [None]:
# Run feature extraction when executed as a script
if __name__ == "__main__":
    feature_data = run_feature_extraction()

In [None]:
# Verification / quick inspection cell
output_path = Path(CONFIG["output_path"])

if output_path.exists():
    with open(output_path, "rb") as f:
        loaded = pickle.load(f)

    feats = np.asarray(loaded.get("features"))
    paths = loaded.get("paths", [])

    print("\nVerification")
    print("------------")
    print(f"features shape : {feats.shape}")
    print(f"#paths         : {len(paths)}")

    if feats.size > 0:
        print(f"mean: {feats.mean():.4f}, std: {feats.std():.4f}")
        norms = np.linalg.norm(feats[:5], axis=1)
        print("sample norms   :", [f"{n:.3f}" for n in norms])

    print("\nSample image paths:")
    for p in paths[:3]:
        print(" -", p)
else:
    print(f"\nNo features file found at {output_path}")