In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [1]:
! pip install open_clip_torch matplotlib

Collecting pyparsing>=2.3.1 (from matplotlib)
  Using cached pyparsing-3.2.5-py3-none-any.whl.metadata (5.0 kB)
Using cached pyparsing-3.2.5-py3-none-any.whl (113 kB)
Installing collected packages: pyparsing
Successfully installed pyparsing-3.2.5


In [2]:
import open_clip
import torch
import torch.nn.functional as F
from PIL import Image
import json, os
from pathlib import Path
from tqdm import tqdm
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# --- Device / helpers ---
device = "cuda" if torch.cuda.is_available() else "cpu"
amp_enabled = (device == "cuda")

In [None]:
# --- Setup (paths) ---
ANN_PATH = "/nocap_val_4500_captions.json"
IMG_DIR = "/data/images"
CAPTION_JSON = "../data/caption.json"
CAPTION_CLIPSCORE_JSON = "../data/captions_with_clipscores.json"
CAPTION_HYBRID_JSON = "../data/captions_hybrid_scored.json"

In [10]:
def clean_caption(tokens):
    # open_clip.decode -> string with special tokens; strip them
    s = open_clip.decode(tokens).split("<end_of_text>")[0]
    s = s.replace("<start_of_text>", "").strip()
    # normalize whitespace
    return " ".join(s.split())

# --- Model variants you requested ---
VARIANTS = [
    ("coca_ViT-B-32", "laion2b_s13b_b90k"),
    ("coca_ViT-B-32", "mscoco_finetuned_laion2b_s13b_b90k"),
    ("coca_ViT-L-14", "laion2b_s13b_b90k"),
    ("coca_ViT-L-14", "mscoco_finetuned_laion2b_s13b_b90k"),
]

# --- Load all models & transforms once (to avoid reloading per image) ---
models = []
for model_name, ckpt in VARIANTS:
    model, _, transform = open_clip.create_model_and_transforms(
        model_name=model_name,
        pretrained=ckpt
    )
    model = model.to(device)
    model.eval()
    models.append((model_name, ckpt, model, transform))

print(f"Loaded {len(models)} CoCa variants on {device}.")

Loaded 4 CoCa variants on cuda.


In [11]:
# --- Read nocaps json ---
LIMIT = 10

with open(ANN_PATH, "r") as f:
    nocaps = json.load(f)

images = nocaps.get("images", [])
if LIMIT is not None:
    images = images[:LIMIT]

print(f"Will caption {len(images)} image(s).")

Will caption 10 image(s).


In [12]:
# --- Generate captions ---
results = []
missing_files = 0
failed = 0

for img_info in tqdm(images, desc="Captioning"):
    fname = img_info["file_name"]
    fpath = Path(IMG_DIR) / fname
    if not fpath.exists():
        missing_files += 1
        continue

    captions = []
    try:
        with Image.open(fpath).convert("RGB") as pil:
            for (model_name, ckpt, model, transform) in models:
                # transform and move to device
                pixel = transform(pil).unsqueeze(0).to(device)
                with torch.no_grad(), torch.cuda.amp.autocast(enabled=amp_enabled):
                    tokens = model.generate(pixel)
                cap = clean_caption(tokens[0])
                captions.append(cap)
                print(cap)
    except Exception as e:
        failed += 1
        # keep going; skip this image
        continue

    results.append({
        "file_name": fname,
        "captions": captions
    })

  with torch.no_grad(), torch.cuda.amp.autocast(enabled=amp_enabled):
Captioning: 100%|██████████| 10/10 [00:00<00:00, 76.24it/s]


In [None]:
# --- Save ---
with open(CAPTION_JSON, "w") as f:
    json.dump(results, f, indent=2, ensure_ascii=False)

print(f"\nDone. Saved {len(results)} items to: {CAPTION_JSON}")
print(f"Missing files: {missing_files} | Failed during generation: {failed}")



Done. Saved 0 items to: ../data/caption.json
Missing files: 0 | Failed during generation: 10


: 

In [None]:
# Pick a CLIP backbone + weights (common options shown below)
CLIP_MODEL   = "ViT-B-32"
CLIP_CKPT    = "openai"  # e.g., "openai" or "laion2b_s34b_b79k"

In [None]:
# ===== Load CLIP (NOT CoCa) =====
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
    model_name=CLIP_MODEL,
    pretrained=CLIP_CKPT
)
clip_model = clip_model.to(device)
clip_model.eval()

In [None]:
# ===== Helpers =====
def clipscore_image_text(pil_img, captions, batch_size=8):
    """
    Returns cosine similarities (list of floats) between one image and N captions.
    """
    # Encode image
    with torch.no_grad():
        img = clip_preprocess(pil_img).unsqueeze(0).to(device)
        image_features = clip_model.encode_image(img)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

    # Encode text (batched)
    sims = []
    for i in range(0, len(captions), batch_size):
        batch_caps = captions[i:i+batch_size]
        with torch.no_grad():
            tok = open_clip.tokenize(batch_caps).to(device)
            text_features = clip_model.encode_text(tok)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            # cosine similarity = dot product since both are L2-normalized
            sim = (image_features @ text_features.T).squeeze(0)  # shape [batch]
            sims.extend(sim.tolist())
    return sims  # each in roughly [-1, 1]


In [None]:
# ===== Read previous captions =====
with open(CAPTION_JSON, "r") as f:
    items = json.load(f)

results = []
missing, failed = 0, 0

for it in tqdm(items, desc="Scoring with CLIP"):
    fname = it["file_name"]
    fpath = Path(IMG_DIR) / fname
    if not fpath.exists():
        missing += 1
        continue

    caps = it.get("captions", [])
    if not caps:
        continue

    try:
        with Image.open(fpath).convert("RGB") as pil:
            scores = clipscore_image_text(pil, caps, batch_size=8)
    except Exception as e:
        failed += 1
        continue

    # Find best caption
    best_idx   = int(np.argmax(scores))
    best_cap   = caps[best_idx]
    best_score = float(scores[best_idx])

    results.append({
        "file_name": fname,
        "captions": caps,
        "clipscores": [float(s) for s in scores],  # cosine similarity per caption
        "best_index": best_idx,
        "best_caption": best_cap,
        "best_score": best_score
    })

In [None]:
# ===== Save =====
with open(CAPTION_CLIPSCORE_JSON, "w") as f:
    json.dump(results, f, indent=2, ensure_ascii=False)

print(f"Saved {len(results)} items -> {CAPTION_CLIPSCORE_JSON}")
print(f"Missing images: {missing} | Failed during scoring: {failed}")

# (Optional) quick peek
if results:
    print("\nExample:")
    print(json.dumps(results[0], indent=2, ensure_ascii=False))

In [None]:
# ==== Hybrid scoring (CoCa log-likelihood + α * CLIP cosine) ====

# --- Imports (safe to re-run) ---
import json
from pathlib import Path
from contextlib import nullcontext

import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import open_clip
from tqdm import tqdm
coca_ViT-L-14
# --- Setup (paths) ---
ANN_PATH = "/content/drive/MyDrive/data/nocap_val_4500_captions.json"
IMG_DIR = "/content/drive/MyDrive/data/selected_images"
CAPTION_JSON = "/content/drive/MyDrive/data/caption.json"
CAPTION_CLIPSCORE_JSON = "/content/drive/MyDrive/data/captions_with_clipscores.json"
CAPTION_HYBRID_JSON = "/content/drive/MyDrive/data/captions_hybrid_scored.json"

# --- Models/config ---
# CoCa used as the *scoring* LM
COCA_MODEL = "coca_ViT-L-14"
COCA_CKPT  = "mscoco_finetuned_laion2b_s13b_b90k"

# CLIP used for alignment score
CLIP_MODEL = "ViT-B-32"
CLIP_CKPT  = "openai"

ALPHA      = 15.0   # weight for CLIPScore in hybrid score
LEN_NORM   = True   # normalize CoCa log-prob by token count
BATCH_CAPS = 16     # batch size for CLIP & CoCa scoring

device  = "cuda" if torch.cuda.is_available() else "cpu"
USE_AMP = torch.cuda.is_available()

print(f"Device: {device} | AMP: {USE_AMP}")
print(f"open_clip version: {getattr(open_clip, '__version__', 'unknown')}")

# ===== Load models / preprocessors =====
# CoCa as LM scorer
coca_model, _, coca_preproc = open_clip.create_model_and_transforms(
    model_name=COCA_MODEL,
    pretrained=COCA_CKPT
)
coca_model = coca_model.to(device).eval()
# ✅ Use CoCa's tokenizer (do NOT use open_clip.tokenize here)
coca_tokenize = open_clip.get_tokenizer(COCA_MODEL)

# CLIP for alignment score
clip_model, _, clip_preproc = open_clip.create_model_and_transforms(
    model_name=CLIP_MODEL,
    pretrained=CLIP_CKPT
)
clip_model = clip_model.to(device).eval()

# ===== Helpers =====
def _extract_coca_lm_logits(forward_out: object) -> torch.Tensor:
    """
    Return LM logits as float tensor. Works across open_clip return variants.
    """
    if isinstance(forward_out, torch.Tensor):
        logits = forward_out
    elif isinstance(forward_out, (list, tuple)) and len(forward_out) > 0 and isinstance(forward_out[0], torch.Tensor):
        logits = forward_out[0]
    elif isinstance(forward_out, dict):
        for k in ("logits", "lm_logits", "text_logits", "decoder_logits"):
            if k in forward_out and isinstance(forward_out[k], torch.Tensor):
                logits = forward_out[k]
                break
        else:
            raise RuntimeError("Could not find CoCa LM logits in model output (dict).")
    else:
        raise RuntimeError(f"Unexpected CoCa forward output type: {type(forward_out)}")
    return logits.float()

def _to_B_T_V(logits: torch.Tensor, tok_len: int) -> torch.Tensor:
    """
    Ensure logits are [B, T, V]. Some builds return [B, V, T].
    Heuristic: if last dim equals tok_len, transpose.
    """
    if logits.dim() != 3:
        raise RuntimeError(f"Expected 3D logits, got shape {tuple(logits.shape)}")
    B, A, C = logits.shape
    if C == tok_len and A != tok_len:   # [B, V, T] -> [B, T, V]
        return logits.transpose(1, 2).contiguous()
    return logits  # assume already [B, T, V] or already aligned

def _sanitize_captions(caps):
    return [str(c).strip() for c in caps if c is not None and str(c).strip()]

def coca_logprob_for_many(pil_img, captions, batch_size=BATCH_CAPS):
    """
    Teacher-forced log-likelihoods for a list of captions given one image.
    Returns:
      sum_logprobs: [N] floats (sum over non-pad tokens)
      token_counts: [N] ints   (non-pad tokens)
    """
    captions = _sanitize_captions(captions)
    if not captions:
        return [], []

    with torch.no_grad():
        img = coca_preproc(pil_img).unsqueeze(0).to(device)

    sum_logprobs, token_counts = [], []
    amp_ctx = torch.amp.autocast("cuda", enabled=USE_AMP) if device == "cuda" else nullcontext()

    for i in range(0, len(captions), batch_size):
        caps = captions[i:i+batch_size]
        with torch.no_grad():
            tok = coca_tokenize(caps).to(device)  # ✅ CoCa tokenizer
            tok_len = tok.shape[1]

            with amp_ctx:
                out = coca_model(image=img, text=tok)

            logits = _extract_coca_lm_logits(out)          # [B, T, V] or [B, V, T]
            logits = _to_B_T_V(logits, tok_len)            # -> [B, T*, V]

            # Predict next token; align logits with target = tokens[:, 1:]
            target = tok[:, 1:]                            # [B, T-1]
            T_logits = logits.shape[1]
            if T_logits == tok_len:
                logits_aligned = logits[:, :-1, :]         # drop last step to match T-1
            elif T_logits == tok_len - 1:
                logits_aligned = logits                    # already shifted
            else:
                # Fallback: crop to min length
                min_T = min(T_logits, tok_len - 1)
                logits_aligned = logits[:, :min_T, :]
                target = target[:, :min_T]

            log_probs = F.log_softmax(logits_aligned, dim=-1)                         # [B, T-1, V]
            token_lp  = torch.gather(log_probs, 2, target.unsqueeze(-1)).squeeze(-1)  # [B, T-1]

            pad_id = 0  # CoCa pad token id in open_clip tokenizers
            mask = (target != pad_id)                                                 # [B, T-1]
            # Avoid NaNs if a row is all pads
            mask_row_sum = mask.sum(dim=1)
            safe_mask = torch.where(mask_row_sum[:, None] > 0, mask, torch.zeros_like(mask))

            sum_lp  = (token_lp * safe_mask).sum(dim=1)  # [B]
            tok_cnt = safe_mask.sum(dim=1)               # [B]

        sum_logprobs.extend(sum_lp.detach().cpu().float().tolist())
        token_counts.extend(tok_cnt.detach().cpu().long().tolist())

    return sum_logprobs, token_counts

def clip_scores_for_many(pil_img, captions, batch_size=BATCH_CAPS):
    """
    Cosine similarities (≈[-1,1]) between one image and N captions.
    """
    captions = _sanitize_captions(captions)
    if not captions:
        return []

    with torch.no_grad():
        img = clip_preproc(pil_img).unsqueeze(0).to(device)
        amp_ctx = torch.amp.autocast("cuda", enabled=USE_AMP) if device == "cuda" else nullcontext()
        with amp_ctx:
            img_feats = clip_model.encode_image(img)
        img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True)

    sims = []
    for i in range(0, len(captions), batch_size):
        caps = captions[i:i+batch_size]
        with torch.no_grad():
            tok = open_clip.tokenize(caps).to(device)  # CLIP tokenizer is correct for CLIP
            with amp_ctx:
                txt_feats = clip_model.encode_text(tok)
            txt_feats = txt_feats / txt_feats.norm(dim=-1, keepdim=True)
            sim = (img_feats @ txt_feats.T).squeeze(0)  # [batch]
            sims.extend(sim.detach().cpu().float().tolist())
    return sims

# ===== Load captions =====
src_json = CAPTION_CLIPSCORE_JSON if Path(CAPTION_CLIPSCORE_JSON).exists() else CAPTION_JSON
print(f"Reading captions from: {src_json}")
with open(src_json, "r") as f:
    items = json.load(f)

results = []
missing, failed = 0, 0

for it in tqdm(items, desc="Hybrid scoring"):
    fname = it.get("file_name")
    if not fname:
        continue
    fpath = Path(IMG_DIR) / fname
    if not fpath.exists():
        missing += 1
        continue

    caps = it.get("captions", [])
    caps = _sanitize_captions(caps)
    if not caps:
        continue

    try:
        with Image.open(fpath).convert("RGB") as pil:
            # 1) CLIP cosine scores (recompute here to be consistent with chosen CLIP backbone)
            clip_s = clip_scores_for_many(pil, caps)

            # 2) CoCa log-likelihoods (teacher-forced)
            coca_sum, coca_len = coca_logprob_for_many(pil, caps)

            # 3) Length-normalized LL if enabled
            if LEN_NORM:
                coca_ll = []
                for s, L in zip(coca_sum, coca_len):
                    L = int(L)
                    coca_ll.append(s / L if L > 0 else float("-inf"))
            else:
                coca_ll = coca_sum

            # 4) Hybrid score
            hybrid = [ ll + ALPHA * cs for ll, cs in zip(coca_ll, clip_s) ]

            # 5) Pick best (handle all -inf case)
            if all([not np.isfinite(h) for h in hybrid]):
                best_idx = 0
            else:
                best_idx = int(np.nanargmax(hybrid))

            results.append({
                "file_name": fname,
                "captions": caps,
                "clipscores": [float(x) for x in clip_s],
                "coca_logprobs_sum": [float(x) for x in coca_sum],
                "coca_token_counts": [int(x) for x in coca_len],
                "coca_logprobs_norm": [float(x) for x in coca_ll],  # avg log prob per token if LEN_NORM
                "alpha": float(ALPHA),
                "hybrid_scores": [float(x) for x in hybrid],
                "best_index": best_idx,
                "best_caption": caps[best_idx],
                "best_score": float(hybrid[best_idx]),
                "len_normalized": bool(LEN_NORM),
                "coca_model": f"{COCA_MODEL}::{COCA_CKPT}",
                "clip_model": f"{CLIP_MODEL}::{CLIP_CKPT}"
            })

    except torch.cuda.OutOfMemoryError:
        torch.cuda.empty_cache()
        failed += 1
        print(f"[OOM] Skipped: {fname}")
        continue
    except Exception as e:
        failed += 1
        print(f"[Error] {fname}: {e}")
        continue

with open(CAPTION_HYBRID_JSON, "w") as f:
    json.dump(results, f, indent=2, ensure_ascii=False)

print(f"Saved {len(results)} items -> {CAPTION_HYBRID_JSON}")
print(f"Missing images: {missing} | Failed: {failed}")

# Quick peek
if results:
    from pprint import pprint
    print("\nExample:")
    pprint(results[0])


In [None]:
print(results)