In [None]:
# ==============================================================================
#Caption Loading Utility
# ==============================================================================
from pathlib import Path
import json

def _norm(s: str) -> str:
    """Normalize for dedupe: collapse spaces + lowercase."""
    return " ".join(s.split()).strip().casefold()

def load_captions(source: str | None = None, folder: str | None = None) -> list[str]:
    """
    Load captions from a single JSON file (source) or from all JSON files in a folder.
    Dedupe while preserving original text & order of first occurrence.
    """
    if not source and not folder:
        raise ValueError("Provide either `source` (file) or `folder` (directory).")

    captions: list[str] = []

    if source:
        p = Path(source)
        if not p.exists():
            raise FileNotFoundError(f"Captions file not found: {p.resolve()}")
        with p.open("r", encoding="utf-8") as f:
            captions.extend(json.load(f))

    if folder:
        p = Path(folder)
        if not p.is_dir():
            raise FileNotFoundError(f"Folder not found: {p.resolve()}")
        for jf in sorted(p.glob("*.json")):
            with jf.open("r", encoding="utf-8") as f:
                captions.extend(json.load(f))

    # Dedupe by normalized key
    seen = set()
    unique = []
    for c in captions:
        if not isinstance(c, str):
            continue
        key = _norm(c)
        if key not in seen:
            seen.add(key)
            unique.append(c.strip())

    if not unique:
        raise ValueError("No captions loaded (file(s) empty?).")

    return unique

In [None]:
# ==============================================================================
# Define Path to Your Captions File
# ==============================================================================
# Point this to your JSON file containing the list of candidate captions.
CAPTIONS_PATH = "captions_set_01.json"




In [None]:
# ==============================================================================
#Core Functions for Embedding and Randomized Diverse Selection
# ==============================================================================
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import torch
import random

def pick_caption_weighted_topk(image_features_np, caption_features_np, captions, top_k=5, temperature=1.0):
   
    #Picks a caption from the top_k most similar captions using weighted random sampling.
    
    sims = cosine_similarity(image_features_np, caption_features_np)[0]
    top_indices = np.argsort(sims)[-top_k:][::-1]
    top_scores = sims[top_indices]
    if temperature <= 0:
        chosen_idx = top_indices[0]
    else:
        scaled_scores = top_scores / temperature
        scaled_scores -= np.max(scaled_scores)
        exp_scores = np.exp(scaled_scores)
        probs = exp_scores / np.sum(exp_scores)
        chosen_idx = np.random.choice(top_indices, p=probs)
    return captions[chosen_idx], top_indices.tolist(), top_scores.tolist()

def pick_diverse_caption(image_features_np, caption_features_np, captions, top_k=5, temperature=1.5, pool_size=30, diversity_threshold=0.97):
    """
    Selects a randomized, diverse set of captions before making a weighted random choice.
    """
    sims = cosine_similarity(image_features_np, caption_features_np)[0]
    
    # 1. Get a large pool of initial candidates
    initial_indices = np.argsort(sims)[-pool_size:][::-1]
    
    if len(initial_indices) == 0:
        return "No captions found", [], []

    # 2. Start the diverse list with the absolute best caption
    diverse_indices = [initial_indices[0]]
    
    # 3. Shuffle the rest of the candidates to introduce randomness
    remaining_indices = list(initial_indices[1:])
    random.shuffle(remaining_indices)
    
    # 4. Iterate through the shuffled list to find other diverse candidates
    for idx in remaining_indices:
        if len(diverse_indices) >= top_k:
            break
            
        current_embedding = caption_features_np[idx].reshape(1, -1)
        selected_embeddings = caption_features_np[diverse_indices]
        
        # Check similarity against captions already selected
        similarity_to_selected = cosine_similarity(current_embedding, selected_embeddings)[0]
        
        if np.max(similarity_to_selected) < diversity_threshold:
            diverse_indices.append(idx)

    # 5. Re-sort the final diverse list by similarity for display
    final_scores = sims[diverse_indices]
    sorted_order = np.argsort(final_scores)[::-1]
    final_diverse_indices = np.array(diverse_indices)[sorted_order]
    final_diverse_scores = final_scores[sorted_order]
    
    if len(final_diverse_indices) == 0:
         return "Could not find any diverse captions.", [], []

    # 6. Perform temperature-based weighted choice on the final sorted DIVERSE set
    scaled_scores = final_diverse_scores / temperature
    scaled_scores -= np.max(scaled_scores)
    exp_scores = np.exp(scaled_scores)
    probs = exp_scores / np.sum(exp_scores)
    
    chosen_idx = np.random.choice(final_diverse_indices, p=probs)
    
    return captions[chosen_idx], final_diverse_indices.tolist(), final_diverse_scores.tolist()

def compute_caption_embeddings(captions, clip_model, processor):
    """Compute CLIP text embeddings for all captions and return a numpy array."""
    clip_model.eval()
    text_inputs = processor(text=captions, return_tensors="pt", padding=True, truncation=True)
    text_inputs = {k: v.to(clip_model.device) for k, v in text_inputs.items()}
    with torch.no_grad():
        text_features = clip_model.get_text_features(**text_inputs)
    return text_features.detach().cpu().numpy()

def compute_image_embedding(image, clip_model, processor):
    """Compute CLIP image embedding for a PIL image and return a numpy array."""
    clip_model.eval()
    inputs = processor(images=image, return_tensors="pt").to(clip_model.device)
    inputs['pixel_values'] = inputs['pixel_values'].to(clip_model.dtype)
    with torch.no_grad():
        image_features = clip_model.get_image_features(**inputs)
    return image_features.detach().cpu().numpy()

In [None]:
# ==============================================================================
#Main Demo - With Diversity Filtering
# ==============================================================================
from PIL import Image
from transformers import CLIPProcessor, CLIPModel

# 1) Define the model ID
MODEL_ID = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"

# 2) Set up device and data type
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32

# 3) Load the model and processor
print(f"Loading model: {MODEL_ID} onto {device.upper()}")
clip_model = CLIPModel.from_pretrained(MODEL_ID, torch_dtype=torch_dtype).to(device)
processor = CLIPProcessor.from_pretrained(MODEL_ID)
print("Model loaded successfully.")

# 4) Load captions
candidate_captions = load_captions(source=CAPTIONS_PATH)

# 5) Compute caption embeddings
print("Computing caption embeddings...")
caption_embeds = compute_caption_embeddings(candidate_captions, clip_model, processor)
print(f"{len(caption_embeds)} caption embeddings computed.")

# 6) Load your image
image_path = "tree.jpg"  # <-- Make sure this is your image file
try:
    image = Image.open(image_path).convert("RGB")
    image_embed = compute_image_embedding(image, clip_model, processor)

    # 7) Pick a caption using the NEW DIVERSITY-AWARE function
    # You can tune temperature and diversity_threshold
    chosen, top_idx, top_scores = pick_diverse_caption(
        image_embed, 
        caption_embeds, 
        candidate_captions, 
        top_k=9, 
        temperature=1.5,
        diversity_threshold=0.95 # Lower this (e.g., 0.95) to force even more diversity
    )

    # 8) Display results
    print(f"\nChosen caption (diverse selection):\n-> {chosen}")
    
    print("\nTop-5 DIVERSE candidates (sorted by similarity):")
    # Note: The indices in top_idx are now from the original list of 600 captions
    diverse_captions = [candidate_captions[i] for i in top_idx]
    for rank, (caption, score) in enumerate(zip(diverse_captions, top_scores), start=1):
        print(f"{rank}. {caption} (sim={score:.4f})")

except FileNotFoundError:
    print(f"\nERROR: Image file not found at '{image_path}'. Please check the path and try again.")

Loading model: laion/CLIP-ViT-H-14-laion2B-s32B-b79K onto CPU
Model loaded successfully.
Computing caption embeddings...
600 caption embeddings computed.

Chosen caption (diverse selection):
-> Energy speaks.

Top-5 DIVERSE candidates (sorted by similarity):
1. Living between to‑do and ta‑da. (sim=0.2189)
2. Peace looks like this. (sim=0.1930)
3. Little by little becomes a lot. (sim=0.1868)
4. Keeping it real and really kind. (sim=0.1863)
5. Sky is the limit (sim=0.1853)
6. Energy speaks. (sim=0.1851)
7. Proof of life: this post. (sim=0.1845)
8. Creating space for serendipity. (sim=0.1808)
9. Optimism looks good on everyone. (sim=0.1806)
