In [1]:
import sys
sys.path.append("..")
import json
import torch
from visualize.font_settings import FontSettings
from transformers import AutoModelForCausalLM, AutoTokenizer
from visualize.visualizer import DiscreteVisualizer
from visualize.legend_settings import DiscreteLegendSettings
from visualize.page_layout_settings import PageLayoutSettings
from visualize.color_scheme import ColorSchemeForDiscreteVisualization
from visualize.data_for_visualization import DataForVisualization

data_list = []

with open("./example.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        data = json.loads(line)
        data_list.append(data)

tokenizer = AutoTokenizer.from_pretrained(
        "GSAI-ML/LLaDA-8B-Instruct",
        cache_dir="./LLaDA-8B-Instruct",
        padding_side='left'
    )

visualizer = DiscreteVisualizer(color_scheme=ColorSchemeForDiscreteVisualization(),
                                font_settings=FontSettings(), 
                                page_layout_settings=PageLayoutSettings(),
                                legend_settings=DiscreteLegendSettings())

processed_data_list = []

for data in data_list:
    completion = data["completion"]
    encoded_completion = tokenizer(
                completion, return_tensors="pt", padding=True
            )["input_ids"]
    decoded_tokens = tokenizer.batch_decode(
                encoded_completion[0], skip_special_tokens=False
            )[:150]
    data = DataForVisualization(
        decoded_tokens = decoded_tokens,
        highlight_values = data["token_color"]
    )
    processed_data_list.append(data)

watermarked_img = visualizer.visualize(data=processed_data_list[0], 
                                       show_text=True, 
                                       visualize_weight=True, 
                                       display_legend=True)

unwatermarked_img = visualizer.visualize(data=processed_data_list[1],
                                         show_text=True, 
                                         visualize_weight=True, 
                                         display_legend=True)

watermarked_img.save("watermarked.png")
unwatermarked_img.save("unwatermarked.png")