# Watermarking Experiment Setup

In [1]:
import torch

if torch.cuda.is_available():
        print(f"Cuda is available! {torch.cuda.get_device_name()} || VRAM: {torch.cuda.mem_get_info()[0] // 1e9: .0f} GB")
else:
        print("Cuda is not available")

Cuda is available! NVIDIA GeForce RTX 3070 Ti Laptop GPU || VRAM:  7 GB


In [2]:
import os
import argparse
from argparse import Namespace
from pprint import pprint
from functools import partial

import numpy # for gradio hot reload

from transformers import (AutoTokenizer,
                          AutoModelForSeq2SeqLM,
                          AutoModelForCausalLM,
                          LogitsProcessorList)

from extended_watermark_processor import WatermarkLogitsProcessor, WatermarkDetector, EntropiesLogitsProcessor
from sweet import SweetDetector
from html_css_base import CSS_BASE, SPINNER
import ipywidgets as widgets
from IPython.display import display, HTML
import copy
import json

## Watermark Hyper-Parameters

In [3]:
model_name = widgets.Text(
    value='microsoft/Phi-3-mini-4k-instruct',
    placeholder='Enter something',
    description='Model Name:',
    disabled=False,
)

max_new_tokens = widgets.IntSlider(
    value=200,
    min=0,
    max=8096,
    step=1,
    description='Max New Tokens:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

gen_seed = widgets.IntSlider(
    value=123,
    min=0,
    max=1e5,
    step=1,
    description='Seed:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

delta = widgets.FloatSlider(
    value=2.,
    min=0,
    max=4.0,
    step=0.1,
    description='Delta:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
)

gamma = widgets.FloatSlider(
    value=.25,
    min=0,
    max=1.0,
    step=0.1,
    description='Gamma:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
)

z_threshold = widgets.FloatSlider(
    value=4.,
    min=0,
    max=10.0,
    step=0.1,
    description='Z:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
)

seeding_scheme = widgets.Text(
    value='simple_1',
    placeholder='Enter something',
    description='Seeding Scheme:',
    disabled=False,
)

sample_temp = widgets.FloatSlider(
    value=.7,
    min=0,
    max=2.,
    step=0.1,
    description='Temp:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)
display(model_name, max_new_tokens, gen_seed, delta, gamma, z_threshold,seeding_scheme,sample_temp)

Text(value='microsoft/Phi-3-mini-4k-instruct', description='Model Name:', placeholder='Enter something')

IntSlider(value=200, continuous_update=False, description='Max New Tokens:', max=8096)

IntSlider(value=123, continuous_update=False, description='Seed:', max=100000)

FloatSlider(value=2.0, continuous_update=False, description='Delta:', max=4.0)

FloatSlider(value=0.25, continuous_update=False, description='Gamma:', max=1.0)

FloatSlider(value=4.0, continuous_update=False, description='Z:', max=10.0)

Text(value='simple_1', description='Seeding Scheme:', placeholder='Enter something')

FloatSlider(value=0.7, continuous_update=False, description='Temp:', max=2.0, readout_format='.1f')

In [4]:
def load_model():
    """Load and return the model and tokenizer"""
    print(f"Loading {model_name.value}")
    model = model = AutoModelForCausalLM.from_pretrained(model_name.value, token=os.getenv("HUG_TOKEN"))
    
    device = "cpu" if torch.cuda.is_available() else "cpu"
    model = model.to(device)
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(model_name.value, token=os.getenv("HUG_TOKEN"))
    print(f"Tokenizer Vocab Size: {tokenizer.vocab_size}")
    return model, tokenizer, device

model, tokenizer, device = load_model()

Loading microsoft/Phi-3-mini-4k-instruct


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Tokenizer Vocab Size: 32000


In [5]:
def generate(prompt, model=None, device=None, tokenizer=None, watermark=False, system_prompt=None):
    """Instatiate the WatermarkLogitsProcessor according to the watermark parameters
       and generate watermarked text by passing it to the generate method of the model
       as a logits processor. """

    watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
                                                    gamma=gamma.value,
                                                    delta=delta.value,
                                                    seeding_scheme=seeding_scheme.value,
                                                    select_green_tokens=True)
    entropy_processor = EntropiesLogitsProcessor()
    
    gen_kwargs = dict(max_new_tokens=max_new_tokens.value)

    gen_kwargs.update(dict(
        do_sample=True, 
        top_k=0,
        temperature=sample_temp.value
    ))


    generate_without_watermark = partial(
        model.generate,
        logits_processor=LogitsProcessorList([entropy_processor]), 
        **gen_kwargs
    )
    generate_with_watermark = partial(
        model.generate,
        logits_processor=LogitsProcessorList([entropy_processor, watermark_processor]), 
        **gen_kwargs
    )

    if hasattr(model.config,"max_position_embedding"):
        prompt_max_length = model.config.max_position_embeddings-max_new_tokens.value
    else:
        prompt_max_length = 2048-max_new_tokens.value

    if system_prompt:
        full_prompt = f"{system_prompt}\n\n{prompt}"
    else:
        full_prompt = prompt
    
    tokd_input = tokenizer(full_prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=prompt_max_length).to(device)
    truncation_warning = True if tokd_input["input_ids"].shape[-1] == prompt_max_length else False
    redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=False)[0]

    torch.manual_seed(gen_seed.value)

    if not watermark:
        output = generate_without_watermark(**tokd_input)
    else:
        output = generate_with_watermark(**tokd_input)


    # need to isolate the newly generated tokens
    output = output[:,tokd_input["input_ids"].shape[-1]:]

    decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]

    return (redecoded_input,
            int(truncation_warning),
            decoded_output,
            output,
            entropy_processor.entropies,
            entropy_processor.topks
            ) 

def color_decode_html(output, tokenizer, mask_map, sep=""):
    cursor, colored = 0, copy.copy(output)
    green_color, red_color = "green", "red"
    
    for token_id, green in mask_map:
        decoded_token = tokenizer.decode(token_id, skip_special_tokens=True)

        token_start = colored.find(decoded_token, cursor)

        if token_start > -1:
            mod_token = f"<mark style=\"background-color:{green_color if green else red_color}\">{decoded_token}</mark>"
            colored = colored[:token_start] + mod_token + colored[token_start + len(decoded_token):]
            cursor = token_start + len(mod_token)
    
    return colored

def color_decode_html_entropy(output, tokens, tokenizer, entropies, threshold = None, sep=""):
    cursor, colored = 0, copy.copy(output)

    if not threshold:
        scaled_ents = [0.1 + 0.9 * (ent / (max(entropies) - min(entropies)))**2 for ent in entropies]
    else:
        scaled_ents = [1. if ent >= threshold else 0. for ent in entropies]
        
    for idx, entropy in enumerate(scaled_ents):
        decoded_token = tokenizer.decode(tokens[0][idx], skip_special_tokens=True)

        token_start = colored.find(decoded_token, cursor)
        
        if token_start > -1:
            mod_token = f"<mark style=\"background-color:rgba(255,255,0,{entropy:.2f})\">{decoded_token}</mark>"
            colored = colored[:token_start] + mod_token + colored[token_start + len(decoded_token):]
            cursor = token_start + len(mod_token)
    
    return colored
    
def detect(input_text, device=None, tokenizer=None, window_size=None, window_stride=None):
    """Instantiate the WatermarkDetection object and call detect on
        the input text returning the scores and outcome of the test"""

    watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
                                        gamma=gamma.value, # should match original setting
                                        seeding_scheme=seeding_scheme.value, # should match original setting
                                        device=device, # must match the original rng device type
                                        tokenizer=tokenizer,
                                        z_threshold=z_threshold.value,
                                        normalizers='',
                                        ignore_repeated_ngrams=True,
                                        select_green_tokens=True)
    score_dict = watermark_detector.detect(input_text, window_size=window_size, window_stride=window_stride)
    # output = str_format_scores(score_dict, watermark_detector.z_threshold)
    output = list_format_scores(score_dict, watermark_detector.z_threshold)
    return output, watermark_detector.g_masks

def format_names(s):
    """Format names for the gradio demo interface"""
    s=s.replace("num_tokens_scored","Tokens Counted (T)")
    s=s.replace("num_green_tokens","# Tokens in Greenlist")
    s=s.replace("green_fraction","Fraction of T in Greenlist")
    s=s.replace("z_score","z-score")
    s=s.replace("p_value","p value")
    s=s.replace("prediction","Prediction")
    s=s.replace("confidence","Confidence")
    return s
    
def list_format_scores(score_dict, detection_threshold):
    """Format the detection metrics into a gradio dataframe input format"""
    lst_2d = []
    # lst_2d.append(["z-score threshold", f"{detection_threshold}"])
    for k,v in score_dict.items():
        if k=='green_fraction': 
            lst_2d.append([format_names(k), f"{v:.1%}"])
        elif k=='confidence': 
            lst_2d.append([format_names(k), f"{v:.3%}"])
        elif isinstance(v, float): 
            lst_2d.append([format_names(k), f"{v:.3g}"])
        elif isinstance(v, bool):
            lst_2d.append([format_names(k), ("Watermarked" if v else "Human/Unwatermarked")])
        else: 
            lst_2d.append([format_names(k), f"{v}"])
    if "confidence" in score_dict:
        lst_2d.insert(-2,["z-score Threshold", f"{detection_threshold}"])
    else:
        lst_2d.insert(-1,["z-score Threshold", f"{detection_threshold}"])
    return lst_2d

def change_to_dict_w_keys(ll, suffix):
    """"""
    return {
        f"{k[0]} ({suffix})":k[1] for k in ll
    }

def display_process_msg(msg):
    display(HTML(f"{CSS_BASE}\n<hr><div style=\"text-align: center\">{SPINNER}</div><b><div style=\"text-align: center\">{msg}</b></div>"))


## Prompt

In [6]:
from extended_watermark_processor import WatermarkLogitsProcessor, WatermarkDetector, WatermarkBase

In [8]:
prompt = widgets.Textarea(
    value='Who is Robert Downey Jr?',
    placeholder='Prompt here...',
    description='Prompt:',
    layout=widgets.Layout(width='70%', height='150px', margin='10px 0px 10px 0px')
)

watermark = widgets.Checkbox(
    value=False,
    description='Watermark',
    disabled=False,
    indent=False
)

# Create a Button widget with layout properties
button = widgets.Button(
    description='Generate',
    button_style='primary',
    layout=widgets.Layout(width='200px', margin='10px 0px 10px 0px')
)

output = widgets.Output()

CSS_BASE = "<style>" + CSS_BASE + "</style>"
def on_button_click(b):
    with output:
        output.clear_output()
        display_process_msg(f"Generating using {model_name.value} {'(Watermarked)' if watermark.value else ''}")
        inp, trunc, out, tokens, entropies, topks = generate(prompt.value, model, device,tokenizer, watermark.value)
        entropies = [ent.item() for ent in entropies]
        
        topks = [
            [entropies[idx], tokenizer.decode(tokens[0][idx], skip_special_tokens=True), [
                (ind, val, tokenizer.decode(ind, skip_special_tokens=True)) for ind, val in token
            ]] for idx, token in enumerate(topks)
        ]
        
        open("data.json", "w").write("const RAW_DATA = " + json.dumps({"text": out, "probs":topks}) + ";")
        
        output.clear_output()
        display_process_msg("Running Detection")
        result_wm, g_masks = detect(out, device, tokenizer)
        # generate
        
        color_coded = color_decode_html(out, tokenizer, g_masks)
        entropy_coded = color_decode_html_entropy(out, tokens, tokenizer, entropies)
        entropy_coded_th = color_decode_html_entropy(out, tokens, tokenizer, entropies, threshold=0.5)
        
        # sweet detector
        watermark_detector = SweetDetector(
                                        device=device,
                                        vocab=list(tokenizer.get_vocab().values()),
                                        gamma=gamma.value,
                                        tokenizer=tokenizer,
                                        z_threshold=z_threshold.value,
                                        ignore_repeated_bigrams=False,
                                        entropy_threshold=0.5)

        tokd_input = tokenizer(inp, return_tensors="pt", add_special_tokens=False)["input_ids"].to(device)
        
        detection_result = watermark_detector.detect( # no batch
                        tokenized_prefix=tokd_input[0],
                        tokenized_text=tokens[0],
                        entropy=entropies,
                    )

        output.clear_output()
        display(HTML(f"<pre>{color_coded}</pre><hr><pre>{out}</pre><hr>{result_wm}<hr><pre>{entropy_coded}</pre><hr><pre>{entropy_coded_th}</pre><hr>{detection_result}"))
        
button.on_click(on_button_click)

display(prompt, watermark, button, output)

Textarea(value='Who is Robert Downey Jr?', description='Prompt:', layout=Layout(height='150px', margin='10px 0…

Checkbox(value=False, description='Watermark', indent=False)

Button(button_style='primary', description='Generate', layout=Layout(margin='10px 0px 10px 0px', width='200px'…

Output()

In [None]:
test_text = widgets.Textarea(
    value='',
    placeholder='Text here...',
    description='Text:',
    layout=widgets.Layout(width='70%', height='150px', margin='10px 0px 10px 0px')
)

output2 = widgets.Output()

button2 = widgets.Button(
    description='Detect',
    button_style='primary',
    layout=widgets.Layout(width='200px', margin='10px 0px 10px 0px')
)

def on_button_click2(b):
    with output2:
        output2.clear_output()
        display_process_msg("Running Detection")
        result_wm, g_masks = detect(test_text.value, device, tokenizer)
        # generate
        output2.clear_output()
        color_coded = color_decode_html(test_text.value, tokenizer, g_masks)

        display(HTML(f"<pre>{color_coded}</pre><hr>{result_wm}"))

button2.on_click(on_button_click2)

display(test_text, button2, output2)