In [None]:
import json

from tqdm import tqdm
from json_repair import repair_json

import torch
from datasets import load_from_disk
from torchmetrics import MeanAbsoluteError
from torchmetrics.text import EditDistance, BERTScore
from torchmetrics.classification import BinaryAccuracy
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

from dataset_formatter import format_data

In [None]:
dataset_id = "~/.cache/huggingface/hub/my_tmp_dataset_test"
test_dataset = load_from_disk(dataset_id)

In [None]:
test_dataset_modified = [format_data(sample) for sample in tqdm(test_dataset)]

In [None]:
test_dataset_modified[0]

In [None]:
model_id = "Qwen/Qwen2.5-VL-7B-Instruct"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_id,
    device_map=0,
    torch_dtype=torch.bfloat16,
)

processor = AutoProcessor.from_pretrained(model_id)

In [None]:
adapter_path = "qwen2.5-7b-instruct-trl-watermarks"
model.load_adapter(adapter_path)

In [None]:
model.set_adapter('default')

In [None]:
#model.disable_adapters()
model.enable_adapters()

In [None]:
def generate_text_from_sample(model, processor, sample, max_new_tokens=1024, device="cuda"):
    # Prepare the text input by applying the chat template
    text_input = processor.apply_chat_template(
        sample["messages"], 
        tokenize=False,
        add_generation_prompt=True,
    )
    # Process the visual input from the sample
    image_inputs, _ = process_vision_info(sample["messages"])

    # Prepare the inputs for the model
    model_inputs = processor(
        text=[text_input],
        images=image_inputs,
        return_tensors="pt",
    ).to(
        device
    )  # Move inputs to the specified device

    # Generate text with the model
    generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)

    # Trim the generated ids to remove the input ids
    trimmed_generated_ids = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)]

    # Decode the output text
    output_text = processor.batch_decode(
        trimmed_generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    return json.loads(repair_json(output_text[0]))  # Return the first decoded output text

In [None]:
output = generate_text_from_sample(model, processor, test_dataset_modified[238] )
output

In [None]:
test_dataset_modified[1]

In [None]:
from PIL import Image

image_path = "assets/should-you-watermark-images-5.jpg"
img = Image.open(image_path)
img = {"images": [img.resize((512, 341))], "texts": [{"assistant": ""}]}
img1 = img['images'][0]
img1.show()

In [None]:
accuracy = BinaryAccuracy()
mae_error = MeanAbsoluteError()
levenshtein = EditDistance()
bertscore_object = BERTScore(model_name_or_path="roberta-base")
bertscore_style = BERTScore(model_name_or_path="roberta-base")

In [None]:
corrupted = 0
keys = {'watermarks', "text", "main object", "style"}
for sample in tqdm(test_dataset_modified[:1000]):
    output = generate_text_from_sample(model, processor, sample)
    target = json.loads(repair_json(sample["messages"][2]["content"][0]["text"]))
    if not isinstance(output, dict) or not set(output.keys()) == keys or not set(target.keys()) == keys:
        corrupted += 1
        continue
    accuracy(torch.tensor([int(output['watermarks'] == target["watermarks"])]), torch.tensor([1]))
    mae_error(torch.tensor([output['watermarks']]), torch.tensor([target["watermarks"]]))
    levenshtein([str(output["text"])], str(target['text']))
    bertscore_object(preds=[output["main object"]], target=target["main object"])
    bertscore_style(preds=[output["style"]], target=target["style"])
acc = accuracy.compute()
mae = mae_error.compute()
lev = levenshtein.compute()
bscore_ob = bertscore_object.compute()
bscore_st= bertscore_style.compute()
print("Watermarks accuracy:", acc)
print("Watermarks MAE:", mae)
print("Found text Levenshtein edit distance:", lev)
print("Main object BERTScore:", bscore_ob['f1'].mean())
print("Style BERTScore:", bscore_st['f1'].mean())

In [None]:
print("Watermarks accuracy:", "{:.3f}".format(acc.item()))
print("Watermarks MAE:", "{:.3f}".format(mae.item()))
print("Found text Levenshtein edit distance:", "{:.3f}".format(lev.item()))
print("Main object BERTScore:", "{:.3f}".format(bscore_ob['f1'].mean().item()))
print("Style BERTScore:", "{:.3f}".format(bscore_st['f1'].mean().item()))

In [None]:
accuracy.reset()
mae_error.reset()
levenshtein.reset()
bertscore_object.reset()
bertscore_style.reset()

In [None]:
corrupted