In [1]:
import os
import torch
from PIL import Image
from tqdm.notebook import tqdm
import json
from torchvision import transforms
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.models import load_preprocess

# ========= 配置 ==========
MODEL_EVAL_CONFIG_PATH = {
    "minigpt4": "eval_configs/minigpt4_eval.yaml",
    "instructblip": "eval_configs/instructblip_eval.yaml",
    "lrv_instruct": "eval_configs/lrv_instruct_eval.yaml",
    "shikra": "eval_configs/shikra_eval.yaml",
    "llava-1.5": "eval_configs/llava-1.5_eval.yaml",
}
INSTRUCTION_TEMPLATE = {
    "minigpt4": "###Human: <Img><ImageHere></Img> <question> ###Assistant:",
    "instructblip": "<ImageHere><question>",
    "lrv_instruct": "###Human: <Img><ImageHere></Img> <question> ###Assistant:",
    "shikra": "USER: <im_start><ImageHere><im_end> <question> ASSISTANT:",
    "llava-1.5": "USER: <ImageHere> <question> ASSISTANT:"
}

def get_image_id(fname):
    # eg: COCO_val2014_000000382584.jpg
    return int(os.path.splitext(fname)[0].split('_')[-1])

# ====== 变量直接指定 =======
model_name = "llava-1.5"
gpu_id = "0"
txt_file = "../../auto_cir/selected_images-160.txt"
img_dir = "../../img-set/val2014"
output_jsonl = "gen_captions.jsonl"
use_opera = False  # 是否用OPERA解码

os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id

# ========== 加载模型 ============
cfg_path = MODEL_EVAL_CONFIG_PATH[model_name]
class Args: pass
args = Args()
args.model = model_name
args.cfg_path = cfg_path
args.options = None
args.run_cfg = type('', (), {})()
args.run_cfg.seed = 42
cfg = Config(args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_config = cfg.model_cfg
model_config.device_8bit = gpu_id
model_cls = registry.get_model_class(model_config.arch)

model = model_cls.from_config(model_config)
model.eval()
processor_cfg = cfg.get_config().preprocess
processor_cfg.vis_processor.eval.do_normalize = False
vis_processors, txt_processors = load_preprocess(processor_cfg)

# ----- norm (与官方一致) -----
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)

norm = transforms.Normalize(mean, std)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of the model checkpoint at openai/clip-vit-large-patch14-336 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.5.self_attn.q_proj.bias', 'text_model.encoder.layers.3.self_attn.q_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.encoder.layers.4.mlp.fc1.weight', 'text_model.encoder.layers.6.layer_norm2.bias', 'text_model.encoder.layers.10.mlp.fc1.weight', 'text_model.encoder.layers.6.self_attn.q_proj.bias', 'text_model.encoder.layers.11.mlp.fc1.weight', 'text_model.encoder.layers.3.layer_norm1.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.5.self_attn.q_proj.weight', 'text_model.encoder.layers.5.mlp.fc1.bias', 'text_model.encoder.layers.7.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.6.self_attn.out_proj.bias', 'text_model.encoder.layers.11.self_attn.out_proj.weight', 'text_model.encoder.layers.1.self_attn.v_proj.bias', 'text_mo

In [3]:
# ========== 读图片名 ============
with open(txt_file) as f:
    img_names = [x.strip() for x in f if x.strip()]

template = INSTRUCTION_TEMPLATE[model_name]
question = "Please describe this image in detail."

# ========== 推理并写入jsonl ==========
with open(output_jsonl, "w", encoding="utf-8") as fout:
    for img_name in tqdm(img_names, desc="Generating captions"):
        image_path = os.path.join(img_dir, img_name)
        try:
            raw_image = Image.open(image_path).convert("RGB")
        except Exception as e:
            print(f"Error loading {image_path}: {e}")
            continue
        image = vis_processors["eval"](raw_image).unsqueeze(0)
        image = norm(image)  # ←★ 和官方范例对齐
        image = image.to(device)

        prompt = template.replace("<question>", question)

        with torch.inference_mode(), torch.no_grad():
            generate_kwargs = dict(
                use_nucleus_sampling=False,
                num_beams=5,
                max_new_tokens=128,
                do_sample=False,
                penalty_weights=1,
            )
            if use_opera:
                generate_kwargs.update({
                    "output_attentions": True,
                    "opera_decoding": True,
                    "scale_factor": 50,
                    "threshold": 15.0,
                    "num_attn_candidates": 5,
                })
            output = model.generate(
                {"image": image, "prompt": prompt},
                **generate_kwargs
            )
        if isinstance(output, dict) and "text" in output:
            caption = output["text"]
        else:
            caption = output[0] if isinstance(output, (list, tuple)) else str(output)
        entry = {
            "image_id": get_image_id(img_name),
            "caption": caption.strip()
        }
        fout.write(json.dumps(entry, ensure_ascii=False) + "\n")


Generating captions:   0%|          | 0/160 [00:00<?, ?it/s]