In [None]:
import os
import torch
import json
from tqdm import tqdm
from attention import llama_modify
from constants import INSTRUCTION_TEMPLATE, SYSTEM_MESSAGE
from llava.utils import disable_torch_init
from model_loader import ModelLoader
from transformers.generation.logits_process import LogitsProcessorList
from PIL import Image

class FileListCOCODataSet(torch.utils.data.Dataset):
    def __init__(self, file_list, img_dir, trans):
        with open(file_list, "r") as f:
            self.img_names = [x.strip() for x in f if x.strip()]
        self.img_dir = img_dir
        self.trans = trans

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img_id = int(os.path.splitext(img_name)[0].split('_')[-1])
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        if self.trans:
            image = self.trans(image)
        return {"img_id": img_id, "image": image}


# ========== 变量直接指定 ==========
model_name = "llava-1.5"
txt_file = "../../auto_cir/selected_images-160.txt"
img_dir = "../../img-set/val2014"
output_jsonl = "pai_captions.jsonl"
gpu_id = "0"
use_cfg = False
use_attn = True   

alpha = 0.5
gamma = 1.2
beam = 5
sample = False
batch_size = 1
max_tokens = 128
start_layer = 2
end_layer = 32



os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
disable_torch_init()

# ========== 新增: 针对图片清单的 DataSet ==========
class FileListCOCODataSet(torch.utils.data.Dataset):
    def __init__(self, file_list, img_dir, trans):
        with open(file_list, "r") as f:
            self.img_names = [x.strip() for x in f if x.strip()]
        self.img_dir = img_dir
        self.trans = trans

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img_id = int(os.path.splitext(img_name)[0].split('_')[-1])
        from PIL import Image
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        if self.trans:
            image = self.trans(image)
        return {"img_id": img_id, "image": image}

# ========== 加载模型与预处理 ==========
model_loader = ModelLoader(model_name)
dataset = FileListCOCODataSet(txt_file, img_dir, model_loader.image_processor)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# ========== Prompt 模板 ==========
template = INSTRUCTION_TEMPLATE[model_name]
if model_name in ["llava-1.5", "shikra"]:
    template = SYSTEM_MESSAGE + template
question = "Please help me describe the image in detail."

# ========== 推理主循环 ==========
for batch in tqdm(loader, desc="Generating captions"):
    img_ids = batch["img_id"]
    images = batch["image"]

    questions, kwargs = model_loader.prepare_inputs_for_model(template, [question]*len(img_ids), images)
    llama_modify(
        model_loader.llm_model,
        start_layer,
        end_layer,
        use_attn,
        alpha,
        use_cfg,
        model_loader.img_start_idx,
        model_loader.img_end_idx,
    )

    logits_processor = (
        model_loader.init_cfg_processor(questions, gamma, beam, start_layer, end_layer)
        if use_cfg else None
    )
    if logits_processor is not None:
        kwargs["logits_processor"] = LogitsProcessorList([logits_processor])

    with torch.inference_mode():
        outputs = model_loader.llm_model.generate(
            do_sample=sample,
            max_new_tokens=max_tokens,
            use_cache=True,
            num_beams=beam,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True,
            **kwargs,
        )

    output_text = model_loader.decode(outputs)

    with open(output_jsonl, "a", encoding="utf-8") as fout:
        for img_id, caption in zip(img_ids, output_text):
            fout.write(json.dumps({"image_id": int(img_id), "caption": caption.strip()}, ensure_ascii=False) + "\n")
