In [None]:
# %%
# VCD-based Open-Ended Caption Generation Script
# This notebook loads a VCD-enabled LVLM, processes a list of image IDs,
# and generates open-ended captions using the optimal settings reported in the VCD paper.

import os
import json
from PIL import Image
from tqdm import tqdm
import torch

# %%
# Add paths to the VCD repository and submodules
import sys
repo_root = os.path.abspath(os.getcwd())

sys.path.insert(0, os.path.join(repo_root, "experiments"))
sys.path.insert(0, os.path.join(repo_root, "vcd_utils"))

# %%
# Import LVLM and VCD utilities
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN

from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from vcd_utils.vcd_add_noise import add_diffusion_noise
from vcd_utils.vcd_sample import evolve_vcd_sampling

evolve_vcd_sampling()  # initialize VCD sampling functions

# %%
# Configuration

txt_file = "../../auto_cir/selected_images-160.txt"
img_dir = "../../img-set/val2014"
output_jsonl = "gen_captions.jsonl"
model_path = "/root/.cache/huggingface/hub/models--liuhaotian--llava-v1.5-7b/snapshots/4481d270cc22fd5c4d1bb5df129622006ccd9234"     
model_base = None                          # if using LoRA weights, set base model here
conv_mode = "llava_v1"                   # conversational template for captioning
use_cd = True                              # enable VCD

# VCD hyperparameters (as per Appendix A)
noise_step = 500   # T for LLaVA-Bench-style open-ended generation
cd_alpha = 1.0    # α
cd_beta = 0.1     # β

# Sampling parameters
temperature = 1.0
top_p = 1.0
top_k = None

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"



In [None]:
# %%
# Load the pretrained LVLM with VCD capabilities
disable_torch_init()

model_name = "llava-1.5-7b"
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path,
    model_base,
    model_name,
    load_8bit=False,
    load_4bit=False,
    device=device
)
model.eval()

# %%
# Read image IDs from the text file
with open(txt_file, "r", encoding="utf-8") as f:
    image_files = [line.strip() for line in f if line.strip()]
print(f"Total images: {len(image_files)}")



In [None]:
from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN

with open(output_jsonl, "w", encoding="utf-8") as out_file:
    for image_file in tqdm(image_files, desc="Generating captions"):
        # 从文件名提取 ID
        stem = os.path.splitext(image_file)[0]     
        image_id = int(stem.split("_")[-1])        

        # 加载图像
        path = os.path.join(img_dir, image_file)
        image = Image.open(path).convert("RGB")

        # 构造对话 prompt（只用 image token）
        if model.config.mm_use_im_start_end:
            img_token = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n"
        else:
            img_token = DEFAULT_IMAGE_TOKEN + "\n"
        conv = conv_templates[conv_mode].copy()
        conv.append_message(conv.roles[0], img_token)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        # Tokenize + 图像预处理
        input_ids = tokenizer_image_token(
            prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
        ).unsqueeze(0).to(device)
        pixel_values = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]

        # VCD 噪声图
        noisy = add_diffusion_noise(pixel_values, noise_step) if use_cd else None

        # 推理
        with torch.inference_mode():
            outputs = model.generate(
                input_ids,
                images=pixel_values.unsqueeze(0).half().to(device),
                images_cd=noisy.unsqueeze(0).half().to(device) if noisy is not None else None,
                cd_alpha=cd_alpha,
                cd_beta=cd_beta,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
                max_new_tokens=128,
            )

        # 解码并清理尾部 stop_str（如果还需要）
        gen_ids = outputs[0, input_ids.shape[1]:].cpu().tolist()
        gen = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()

        # # DEBUG 打印
        # print("PROMPT:", prompt)
        # print("INPUT IDS:", input_ids[0].cpu().tolist())
        # print("GENERATED IDS:", gen_ids)
        # print("DECODED:", gen)

        # 写入 JSONL
        out_file.write(
            json.dumps({"image_id": image_id, "caption": gen}, ensure_ascii=False)
            + "\n"
        )

print("Done! Saved to", output_jsonl)