
# NoCaps Validation — CoCa Custom Sampling + **Hybrid Reranking** (OpenCLIP)

Pipeline:
1. Load **NoCaps validation** annotations + images.  
2. Load **OpenCLIP CoCa** (caption generation).  
3. **Custom sampling** loop (temperature, top-k, top-p, no-repeat-ngram) for *N* diverse candidates.  
4. Load **CLIP ViT-B/32** (scoring).  
5. **Hybrid reranking**:  
   \[ Score(c) = \log P_{\text{CoCa}}(c\mid I) + \alpha \cdot \text{CLIPScore}(I,c) \]  
6. Evaluate with BLEU, METEOR, ROUGE_L, CIDEr (SPICE skipped).  
7. Show qualitative examples.  


In [None]:

%pip install --upgrade pip
%pip install open_clip_torch pillow tqdm torchvision pycocotools
%pip install git+https://github.com/salaniz/pycocoevalcap


In [1]:

import os, json, random
from pathlib import Path
from collections import defaultdict

import torch
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm

import open_clip

from pycocotools.coco import COCO
from pycocoevalcap.eval import COCOEvalCap
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider

# === paths ===
ANN_PATH = "data/nocap_val_4500_captions.json"
IMG_DIR  = "data/validation"

# hyperparams
N_CANDIDATES = 5
SEQ_LEN = 28
TEMP = 1.1
TOP_K = 50
TOP_P = 0.9
NO_REPEAT_N = 3

ALPHA = 2.0
LEN_NORM = True

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


  from .autonotebook import tqdm as notebook_tqdm


Device: cuda


In [2]:

with open(ANN_PATH, "r") as f:
    ann = json.load(f)

id2file = {img["id"]: img["file_name"] for img in ann["images"]}
caps_by_id = defaultdict(list)
for a in ann["annotations"]:
    caps_by_id[a["image_id"]].append(a["caption"])

print("Images:", len(ann["images"]))
print("Example refs:", caps_by_id[ann["images"][0]["id"]][:3])


Images: 4500
Example refs: ['A baby is standing in front of a house.', 'A little girl in a white jacket and sandals.', 'A young child stands in front of a house.']


In [3]:

coca_model, _, coca_preprocess = open_clip.create_model_and_transforms("coca_ViT-L-14", pretrained="mscoco_finetuned_laion2b_s13b_b90k")
coca_model = coca_model.to(device).eval()
coca_tokenizer = open_clip.get_tokenizer("coca_ViT-L-14")

clip_model, _, clip_preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
clip_tokenizer = open_clip.get_tokenizer("ViT-B-32")
clip_model = clip_model.to(device).eval()

from open_clip import tokenizer as openclip_tok_mod




In [4]:

def _detect_special_ids(tok):
    bos_id, eos_id, pad_id = 49406, 49407, 0
    return bos_id, eos_id, pad_id

BOS_ID, EOS_ID, PAD_ID = _detect_special_ids(coca_tokenizer)
print("Special IDs:", BOS_ID, EOS_ID, PAD_ID)


Special IDs: 49406 49407 0


In [5]:

def top_k_top_p_filtering(logits, top_k=None, top_p=None):
    filtered = logits.clone()
    if top_k and top_k > 0:
        kth_vals, _ = torch.topk(filtered, top_k)
        thresh = kth_vals[..., -1].unsqueeze(-1)
        filtered[filtered < thresh] = -float("inf")
    if top_p and 0.0 < top_p < 1.0:
        sorted_vals, sorted_idx = torch.sort(filtered, descending=True)
        probs = torch.softmax(sorted_vals, dim=-1)
        cumprobs = torch.cumsum(probs, dim=-1)
        mask = cumprobs > top_p
        mask[..., 0] = False
        sorted_vals[mask] = -float("inf")
        unsort = torch.empty_like(sorted_idx)
        unsort.scatter_(0, sorted_idx, torch.arange(sorted_idx.numel(), device=sorted_idx.device))
        filtered = sorted_vals[unsort]
    return filtered

def violates_no_repeat(ids, next_id, n):
    if not n or n <= 0 or len(ids) < n-1:
        return False
    ngram = ids[-(n-1):] + [next_id]
    for i in range(len(ids) - (n-1)):
        if ids[i:i+n] == ngram:
            return True
    return False


In [6]:

@torch.no_grad()
def coca_sample_once(pil_img, max_len=30, temperature=1.0, top_k=50, top_p=0.9, no_repeat_ngram_size=3):
    img = coca_preprocess(pil_img).unsqueeze(0).to(device)
    seq = torch.tensor([[BOS_ID]], device=device)
    for _ in range(max_len):
        logits = coca_model(image=img, text=seq)
        next_logits = logits[:, -1, :].squeeze(0)
        next_logits = next_logits / temperature
        next_logits = top_k_top_p_filtering(next_logits, top_k=top_k, top_p=top_p)
        ids_list = seq.squeeze(0).tolist()
        if no_repeat_ngram_size:
            for tok_id in range(next_logits.numel()):
                if violates_no_repeat(ids_list, tok_id, no_repeat_ngram_size):
                    next_logits[tok_id] = -float("inf")
        probs = torch.softmax(next_logits, dim=-1)
        next_id = torch.multinomial(probs, 1)
        seq = torch.cat([seq, next_id.view(1,1)], dim=1)
        if int(next_id.item()) == EOS_ID:
            break
    ids = seq.squeeze(0).tolist()
    if ids and ids[0] == BOS_ID: ids = ids[1:]
    if EOS_ID in ids: ids = ids[:ids.index(EOS_ID)]
    return openclip_tok_mod.decode(torch.tensor(ids))


In [7]:

def generate_n_candidates(pil_img, N=5):
    caps = [coca_sample_once(pil_img, max_len=SEQ_LEN, temperature=TEMP, top_k=TOP_K, top_p=TOP_P, no_repeat_ngram_size=NO_REPEAT_N) for _ in range(N)]
    uniq = list(dict.fromkeys(caps))
    return uniq


In [8]:

@torch.no_grad()
def coca_loglik(pil_img, caption):
    img = coca_preprocess(pil_img).unsqueeze(0).to(device)
    toks = coca_tokenizer([caption]).to(device)
    input_ids, target_ids = toks[:, :-1], toks[:, 1:]
    logits = coca_model(image=img, text=input_ids)
    logp = F.log_softmax(logits, dim=-1)
    token_ll = logp.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
    mask = (target_ids != 0)
    ll_sum = (token_ll * mask).sum().item()
    tokens_kept = int(mask.sum().item())
    return ll_sum / tokens_kept if LEN_NORM else ll_sum

@torch.no_grad()
def hybrid_rerank(pil_img, candidates, alpha=2.0):
    img = clip_preprocess(pil_img).unsqueeze(0).to(device)
    img_feat = clip_model.encode_image(img); img_feat /= img_feat.norm(dim=-1, keepdim=True)
    toks = clip_tokenizer(candidates).to(device)
    txt_feat = clip_model.encode_text(toks); txt_feat /= txt_feat.norm(dim=-1, keepdim=True)
    clip_sims = (img_feat @ txt_feat.T).squeeze(0).cpu().tolist()
    rows = []
    for c, cs in zip(candidates, clip_sims):
        ll = coca_loglik(pil_img, c)
        score = ll + alpha * cs
        rows.append((c, ll, cs, score))
    rows.sort(key=lambda x: x[-1], reverse=True)
    return rows[0][0], rows


In [9]:

test_path = Path(IMG_DIR) / id2file[ann["images"][0]["id"]]
pil = Image.open(test_path).convert("RGB")
cands = generate_n_candidates(pil, N=5)
best, ranked = hybrid_rerank(pil, cands, alpha=ALPHA)
print("Candidates:", cands)
print("Best:", best)


KeyError: (slice(None, None, None), -1, slice(None, None, None))

In [None]:

preds = []
for img_info in tqdm(ann["images"][:50], desc="Demo subset"):
    fpath = Path(IMG_DIR) / img_info["file_name"]
    if not fpath.exists(): continue
    pil = Image.open(fpath).convert("RGB")
    cands = generate_n_candidates(pil, N=N_CANDIDATES)
    best, _ = hybrid_rerank(pil, cands, alpha=ALPHA)
    preds.append({"image_id": img_info["id"], "caption": best})

OUT_JSON = "preds_demo.json"
with open(OUT_JSON, "w") as f:
    json.dump(preds, f)

coco = COCO(ANN_PATH)
cocoRes = coco.loadRes(OUT_JSON)
evaluator = COCOEvalCap(coco, cocoRes)
evaluator.scorers = [(Bleu(4), ["Bleu_1","Bleu_2","Bleu_3","Bleu_4"]), (Meteor(),"METEOR"), (Rouge(),"ROUGE_L"), (Cider(),"CIDEr")]
evaluator.evaluate()
