In [None]:
import torch
from datasets import Dataset
from qwen_vl_utils import process_vision_info
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
from transformers import (
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq,
    Qwen2_5_VLForConditionalGeneration,
    AutoProcessor,
    AutoTokenizer
)
import json
from sklearn.metrics import classification_report, confusion_matrix
from bert_score import score
import evaluate


def process_func(example):
    MAX_LENGTH = 8192
    input_ids, attention_mask, labels = [], [], []
    conversation = example["conversations"]
    input_content = conversation[0]["value"]
    output_content = conversation[1]["value"]
    file_path = input_content.split("<|vision_start|>")[1].split("<|vision_end|>")[0]  # 获取图像路径
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": f"{file_path}",
                    "resized_height": 448,
                    "resized_width": 448,
                },
                {"type": "text", "text": "Is this image manipulated or synthesized?"},
            ],
        }
    ]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )  
    image_inputs, video_inputs = process_vision_info(messages) 
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = {key: value.tolist() for key, value in inputs.items()} 
    instruction = inputs

    response = tokenizer(f"{output_content}", add_special_tokens=False)

    input_ids = (
            instruction["input_ids"][0] + response["input_ids"] + [tokenizer.pad_token_id]
    )

    attention_mask = instruction["attention_mask"][0] + response["attention_mask"] + [1]
    labels = (
            [-100] * len(instruction["input_ids"][0])
            + response["input_ids"]
            + [tokenizer.pad_token_id]
    )
    if len(input_ids) > MAX_LENGTH: 
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]

    input_ids = torch.tensor(input_ids)
    attention_mask = torch.tensor(attention_mask)
    labels = torch.tensor(labels)
    inputs['pixel_values'] = torch.tensor(inputs['pixel_values'])
    inputs['image_grid_thw'] = torch.tensor(inputs['image_grid_thw']).squeeze(0)  
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels,
            "pixel_values": inputs['pixel_values'], "image_grid_thw": inputs['image_grid_thw']}


def predict(messages, model):
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")

    generated_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    
    return output_text[0]

def extract_label(text):
    text = text.lower()
    if "has been manipulated" in text:
        return "fake"
    elif "has not been manipulated" in text:
        return "real"
    else:
        return "unknown"


if __name__ == "__main__":

    model_id = "Qwen/Qwen2.5-VL-3B-Instruct"

    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False, trust_remote_code=True)
    processor = AutoProcessor.from_pretrained(model_id)

    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    model.enable_input_require_grads() 

    # ====================Test Mode===================
    val_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        inference_mode=True, 
        r=64, 
        lora_alpha=16,  
        lora_dropout=0.05,  
        bias="none",
    )

    val_peft_model = PeftModel.from_pretrained(model, "./checkpoint-600", config=val_config)

    with open("./train/train.json", "r") as f:
        test_dataset = json.load(f)

    rouge = evaluate.load("rouge")

    test_image_list = []
    y_true = []
    y_pred = []
    preds = []
    refs = []
    for item in test_dataset:
        gt_label = item["conversations"][1]["value"].strip().lower()  # ground truth
        gt_response = item["conversations"][1]["value"].strip()
        input_image_prompt = item["conversations"][0]["value"]
        origin_image_path = input_image_prompt.split("<|vision_start|>")[1].split("<|vision_end|>")[0]
        
        messages = [{
            "role": "user", 
            "content": [
                {
                "type": "image", 
                "image": origin_image_path
                },
                {
                "type": "text",
                "text": "Is this image manipulated or synthesized?"
                }
            ]}]
        
        response = predict(messages, val_peft_model)
        pred_label = extract_label(response)
        y_true.append(extract_label(gt_label)) 
        y_pred.append(pred_label)
        refs.append(gt_response)
        preds.append(response)
        messages.append({"role": "assistant", "content": f"{response}"})
        print(messages[-1])
        print(f"GT: {gt_label} | Pred: {response} -> Label: {pred_label}")
        print(f"\nGT: {gt_response}\nPred: {response}")
    print("\nClassification Report:")
    print(classification_report(y_true, y_pred, labels=["real", "fake"]))
    scores = rouge.compute(predictions=preds, references=refs)
    print("\nROUGE Scores:")
    for key in scores:
        print(f"{key}: {scores[key].mid.fmeasure:.4f}")
    P, R, F1 = score(preds, refs, lang="en")
    print(f"\nBERTScore - F1: {F1.mean().item():.4f}")