In [1]:
import torch, torch.nn.functional as F
import open_clip
from torchvision import transforms
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"

# ---- CoCa bits (replace with your integration) ----
# Assume you have:
#   coca.generate(image) -> List[ { "tokens": LongTensor, "logprob": float } ]  # K candidates
#   coca.tokenizer.decode(ids) -> str
#   coca.preprocess(Image) -> Tensor (B, C, H, W)

# ---- OpenCLIP for CLIPScore ----
clip_model, clip_preprocess, clip_tokenizer = open_clip.create_model_and_transforms(
    "ViT-B-32", pretrained="openai"
)
clip_model = clip_model.to(device).eval()

@torch.no_grad()
def clip_image_embed(pil_img):
    t = clip_preprocess(pil_img).unsqueeze(0).to(device)
    img_feat = clip_model.encode_image(t)
    img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
    return img_feat.squeeze(0)  # (D,)

@torch.no_grad()
def clip_text_embed(texts):
    tok = clip_tokenizer(texts).to(device)
    txt_feat = clip_model.encode_text(tok)
    txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
    return txt_feat  # (B, D)

def clipscore(img_feat, captions):
    txt_feat = clip_text_embed(captions)     # (K, D)
    sims = (txt_feat @ img_feat.unsqueeze(1)).squeeze(1)  # (K,)
    return sims  # cosine similarity in [-1, 1]

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def length_normalize(logprob, length, length_penalty=0.7):
    # same spirit as GNMT: lp = ((5 + L)^lp) / ((5 + 1)^lp)
    lp = ((5 + length)**length_penalty) / ((5 + 1)**length_penalty)
    return logprob / lp

In [3]:
@torch.no_grad()
def generate_candidates_with_scores(coca, pil_img, K=10, beam_size=5, top_p=None, temperature=1.0):
    # 1) preprocess for CoCa
    img_tensor = coca.preprocess(pil_img).unsqueeze(0).to(device)
    # 2) run your generator (implement this in your CoCa wrapper)
    cands = coca.generate(
        images=img_tensor,
        num_candidates=K,
        beam_size=beam_size,
        top_p=top_p,
        temperature=temperature,
        return_logprobs=True,   # ensure you can get per-seq logprob
        include_eos=True
    )
    # cands: list of dicts: { "tokens": LongTensor[L], "logprob": float }
    return cands

In [4]:
import numpy as np

def zscore(arr):
    arr = np.asarray(arr)
    mu, sd = arr.mean(), arr.std() + 1e-6
    return (arr - mu) / sd

def rerank(cands, img_feat, alpha=1.0, len_pen=0.7):
    captions = []
    lls = []
    for d in cands:
        cap = coca.tokenizer.decode(d["tokens"])
        captions.append(cap)
        L = len(d["tokens"])
        lls.append(length_normalize(d["logprob"], L, len_pen))
    lls = np.array(lls, dtype=np.float32)

    clips = clipscore(img_feat, captions).detach().float().cpu().numpy()

    # Per-image z-normalization
    lls_z   = zscore(lls)
    clips_z = zscore(clips)

    hybrid = lls_z + alpha * clips_z
    idx = int(np.argmax(hybrid))
    return captions[idx], {
        "chosen_idx": idx,
        "captions": captions,
        "lls": lls.tolist(),
        "clips": clips.tolist(),
        "hybrid": hybrid.tolist()
    }


In [5]:
from pathlib import Path
import json

def run_inference(dataset, coca, alpha=0.8, K=10, beam_size=5, out_jsonl="preds.jsonl"):
    out = []
    with open(out_jsonl, "w", encoding="utf-8") as f:
        for img_id, img_path in dataset.iter_images():
            pil = Image.open(img_path).convert("RGB")

            # cache CLIP image embedding
            img_feat = clip_image_embed(pil)

            cands = generate_candidates_with_scores(coca, pil, K=K, beam_size=beam_size)
            best_caption, dbg = rerank(cands, img_feat, alpha=alpha, len_pen=0.7)

            out_rec = {
                "image_id": img_id,
                "caption": best_caption,
                "dbg": dbg
            }
            f.write(json.dumps(out_rec, ensure_ascii=False) + "\n")
            out.append(out_rec)
    return out


In [6]:
def to_coco_dump(preds, out_path):
    # preds is list of {"image_id": ..., "caption": ...}
    dump = [{"image_id": p["image_id"], "caption": p["caption"]} for p in preds]
    with open(out_path, "w") as f:
        json.dump(dump, f)

# Then call the nocaps official evaluator on your json vs references.
