# A Watermark for Large Language Models

This is the Python notebook for our project. It generally follows the flow of the original paper, A Watermark for Large Language Models (Kirchenbauer et al. 2023). Much of the text and images in each section is pulled directly from the paper, though we have appended our own code to implement the paper's watermarking strategy and analysis.

In [None]:
# Install block - put any necessary pip installs here
!pip install datasets
!pip install torch

In [None]:
# Import block - put any necessary imports here
from datasets import load_dataset, Dataset
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
from functools import partial
import json
import math

# Set up device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(F"Device set to {device}")

# Load the Dataset

The paper uses the C4 dataset’s RealNewsLike subset

In [None]:
dataset_name = "c4"
dataset_config_name = "realnewslike"
dataset = load_dataset(dataset_name, dataset_config_name, split="train", streaming=True)

# Load the Opt-1.3b tokenizer and model

In [None]:
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b")

# Load the model
model = AutoModelForCausalLM.from_pretrained("facebook/opt-1.3b")
model = model.to(device)

# Helpers
def tokenize(sequence):
    return tokenizer(sequence, return_tensors="pt").to(model.device)

def seed_rng(seed=42):
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

# Algorithm 1: Text Generation with Hard Red List

![images/algorithm_1.png](images/algorithm_1.png)

## Define the Hard Red List Logits Processor

In [None]:
class HardRedListWatermark(LogitsProcessor):

    def __init__(self, tokenizer, device, hash_key=15485863):
        self.tokenizer = tokenizer
        self.device = device
        self.vocab_size = tokenizer.vocab_size
        self.green_list_size = self.vocab_size // 2
        # Large prime number to be used for seed
        self.hash_key = hash_key
        self.generator = torch.Generator(device=device)

    def __call__(self, input_ids, scores):
        # Compute hash of previous token and set it as seed
        prev_token = int(input_ids[0, -1].item())
        self.generator.manual_seed(self.hash_key * prev_token)

        # Shuffle the vocabulary and get red list ids
        permuted_vocab = torch.randperm(self.vocab_size, generator=self.generator, device=input_ids.device)
        red_list = permuted_vocab[self.green_list_size:] # continue where green list left off

        # Set red list logits to -infinity
        scores[:, red_list] = -float("inf")
        return scores

    def compute_green_and_red_lists(self):
        permuted_vocab = torch.randperm(self.vocab_size, generator=self.generator, device=self.device)
        green_list = permuted_vocab[:self.green_list_size]
        red_list = permuted_vocab[self.green_list_size:]

        return green_list, red_list

    def produce_color_map(self, prompt_ids, generated_token_ids):
        # Set seed for reproducibility
        seed_rng()

        # Get last prompt_id
        prev_token_id = int(prompt_ids['input_ids'][0, -1].item())

        color_map = []

        for i, token_id in enumerate(generated_token_ids['input_ids'][0]):
            token_id = int(token_id.item())
            self.generator.manual_seed(self.hash_key * prev_token_id)

            green_list, _ = self.compute_green_and_red_lists()

            if token_id in green_list:
                color_map.append(1) # green
            else:
                color_map.append(0) # red

            # Update previous token before next iteration
            prev_token_id = token_id

        return color_map

## Implement Algorithm 1 using Hard Red List Logits Processor

In [None]:
delta = 2.0
gamma = 0.5

hard_watermark = HardRedListWatermark(tokenizer, model.device)

def algorithm_1(tokenizer, model, prompt_ids):
    # Set seed for reproducibility
    seed_rng()

    # Instantiate the hard red list logits processor
    hard_red_list_lp = LogitsProcessorList([hard_watermark])

    # Generate using the hard red list logits processor
    algorithm_1_generate = partial(
        model.generate,
        logits_processor=hard_red_list_lp,
        max_new_tokens=200,
        min_new_tokens=1,
        do_sample=True,
        top_k=0,
        temperature=0.7
    )

    # Generate output ids
    output_ids = algorithm_1_generate(**prompt_ids)

    # Decode and return the string
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

## Running Algorithm 1

In [None]:
results = []
for i, prompts in enumerate(dataset):
  prompt = prompts["text"]
  # Tokenize prompt into ids
  prompt_ids = tokenize(prompt)
  # Generate output
  output = algorithm_1(tokenizer, model, prompt_ids)
  continuation = output[len(prompt):]
  generated_token_ids = tokenize(continuation)
  # Given a continuation, compute which tokens are in the red list and which are in the green list
  color_map = hard_watermark.produce_color_map(
    prompt_ids = prompt_ids,
    generated_token_ids = generated_token_ids,
  )
  green_token_count = len(list(filter(lambda x: x, color_map)))

  # Collect results
  result_dict = {
    "prompt": prompt,
    "continuation": continuation,
    "full_output": output,
    "continuation_token_count": continuation_token_count,
    "green_token_count": green_token_count,
  }
  results.append(result_dict)

  # For now just break after generating 1 result since it takes a while
  if i >= 0:
      break

## Detecting the Hard Watermark

While producing watermarked text requires access to the language model, detecting the watermark does not. A third party with knowledge of the hash function and random number generator can re-produce the red list for each token and count how many times the red list rule is violated. We can detect the watermark by testing the following null hypothesis:

$H_0$: The text sequence is generated with no knowledge of the red list rule.

Because the red list is chosen at random, a natural writer is expected to violate the red list rule with half of their tokens, while the watermarked model produces no violations. The probability that a natural source produces $T$ tokens without violating the red list rule is only $1/(2^T)$, which is vanishingly small even for short text fragments with a dozen words. This enables detection of the watermark (rejection of $H_0$) for, e.g., a synthetic tweet.

A more robust detection approach uses a one proportion z-test to evaluate the null hypothesis. If the null hypothesis is true, then the number of green list tokens, denoted $|s|_G$, has expected value $T/2$ and variance $T/4$. The z-statistic for this test is:

![images/eq_2.png](images/eq_2.png)

We reject the null hypothesis and detect the watermark if z is above a chosen threshold. Suppose we choose to reject the null hypothesis if $z > 4$. In this case, the probability of a false positive is $3 × 10^{−5}$, which is the one-sided p-value corresponding to $z > 4$. At the same time, we will detect any watermarked sequence with 16 or more tokens (the minimum value of $T$ that produces $z = 4$ when $|s|_G=T$).

In [None]:
def compute_z_score(green_token_count, total_length, gamma):
    numerator = green_token_count - (total_length / 2)
    denominator = math.sqrt(total_length)
    return 2 * numerator / denominator

# Calculate z-score for a given result
first_result = results[0]
z_score = compute_z_score(first_result["green_token_count"], first_result["continuation_token_count"], gamma)
first_result["z_score"] = z_score
print(json.dumps(first_result, indent=1))

# Algorithm 2: Text Generation with Soft Red List

![images/algorithm_2.png](images/algorithm_2.png)

## Define the Soft Red List Logits Processor

In [None]:
class SoftRedListWatermark(LogitsProcessor):
    def __init__(self, tokenizer, device, gamma=0.5, hash_key=15485863, delta=2.0):
        super().__init__()
        self.tokenizer = tokenizer
        self.device = device
        self.vocab_size = tokenizer.vocab_size
        self.gamma = gamma
        self.hash_key = hash_key
        self.delta = delta
        self.green_list_size = int(self.gamma * self.vocab_size)
        self.generator = torch.Generator(device=device)

    def __call__(self, input_ids, scores):
        # Compute hash of previous token and set it as seed
        prev_token = int(input_ids[0, -1].item())
        self.generator.manual_seed(self.hash_key * prev_token)

        # Shuffle the vocabulary and get green list ids
        permuted_vocab = torch.randperm(self.vocab_size, generator=self.generator, device=input_ids.device)
        green_list = permuted_vocab[:self.green_list_size]

        # Add delta to green list logits
        scores[:, green_list] += self.delta
        return scores

    def compute_green_and_red_lists(self):
        permuted_vocab = torch.randperm(self.vocab_size, generator=self.generator, device=self.device)
        green_list = permuted_vocab[:self.green_list_size]
        red_list = permuted_vocab[self.green_list_size:]

        return green_list, red_list

    def produce_color_map(self, prompt_ids, generated_token_ids):
        # Set seed for reproducibility
        seed_rng()

        # Get last prompt_id
        prev_token_id = int(prompt_ids['input_ids'][0, -1].item())

        color_map = []

        for i, token_id in enumerate(generated_token_ids['input_ids'][0]):
            token_id = int(token_id.item())
            self.generator.manual_seed(self.hash_key * prev_token_id)

            green_list, _ = self.compute_green_and_red_lists()

            if token_id in green_list:
                color_map.append(1) # green
            else:
                color_map.append(0) # red

            # Update previous token before next iteration
            prev_token_id = token_id

        return color_map

## Implement Algorithm 2 using Soft Red List Logits Processor

In [None]:
delta = 2.0
gamma = 0.5

soft_watermark = SoftRedListWatermark(tokenizer, model.device)

def algorithm_2(tokenizer, model, prompt_ids):
    # Set seed for reproducibility
    seed_rng()

    # Instantiate the soft red list logits processor
    soft_red_list_lp = LogitsProcessorList([soft_watermark])

    # Generate using the soft red list logits processor
    algorithm_2_generate = partial(
        model.generate,
        logits_processor=soft_red_list_lp,
        max_new_tokens=200,
        min_new_tokens=1,
        do_sample=True,
        top_k=0,
        temperature=0.7
    )

    # Generate output ids
    output_ids = algorithm_2_generate(**prompt_ids)

    # Decode and return the string
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

## Running Algorithm 2

In [None]:
# Helpers

def print_colored_terminal(text, color_map, tokenizer):
    tokens = tokenizer.encode(text, add_special_tokens=False)
    token_texts = [tokenizer.decode([token]) for token in tokens]

    for i, token_text in enumerate(token_texts):
        if i < len(color_map):
            if color_map[i] == 1:
                # Green text in terminal
                print(f"\033[92m{token_text}\033[0m", end="")
            else:
                # Red text in terminal
                print(f"\033[91m{token_text}\033[0m", end="")
        else:
            print(token_text, end="")
    print()

In [None]:
results = []
for i, prompts in enumerate(dataset):
  prompt = prompts["text"]
  # Tokenize prompt into ids
  prompt_ids = tokenize(prompt)
  # Generate output
  output = algorithm_2(tokenizer, model, prompt_ids)
  continuation = output[len(prompt):]
  generated_token_ids = tokenize(continuation)
  # Given a continuation, compute which tokens are in the red list and which are in the green list
  color_map = soft_watermark.produce_color_map(
    prompt_ids = prompt_ids,
    generated_token_ids = generated_token_ids,
  )
  green_token_count = len(list(filter(lambda x: x, color_map)))

  # Collect results
  result_dict = {
    "prompt": prompt,
    "continuation": continuation,
    "full_output": output,
    "continuation_token_count": continuation_token_count,
    "green_token_count": green_token_count,
  }
  results.append(result_dict)

  # For now just break after generating 1 result since it takes a while
  if i >= 0:
      break

# Add this line to print colored text in terminal
# print("\nColored continuation:")
# print_colored_terminal(continuation, color_map, tokenizer)

## Detecting the Soft Watermark

The process for detecting the soft watermark is identical to that for the hard watermark. We assume the null hypothesis:

$H_0$: The text sequence is generated with no knowledge of the red list rule.

and compute a z-statistic using:

![images/eq_2.png](images/eq_2.png)

We reject the null hypothesis and detect the watermark if z is greater than a threshold. For arbitrary $\gamma$ we have

![images/eq_3.png](images/eq_3.png)

Consider again the case in which we detect the watermark for $z > 4$. Just like in the case of the hard watermark, we get false positives with rate $3 × 10^{−5}$. In the case of the hard watermark, we could detect any watermarked sequence of length 16 tokens or more, regardless of the properties of the text. However, in the case of the soft watermark our ability to detect synthetic text depends on the entropy of the sequence. High entropy sequences are detected with relatively few tokens, while low entropy sequences require more tokens for detection. Below, we rigorously analyze the detection sensitivity of the soft watermark, and its dependence on entropy.

In [None]:
def compute_z_score(green_token_count, total_length, gamma):
    numerator = green_token_count - (gamma * total_length)
    denominator = math.sqrt(total_length * gamma * (1 - gamma))
    return numerator / denominator

# Calculate z-score for a given result
first_result = results[0]
z_score = compute_z_score(first_result["green_token_count"], first_result["continuation_token_count"], gamma)
first_result["z_score"] = z_score
print(json.dumps(first_result, indent=1))

## Define Spike Entropy

![images/spike_entropy.png](images/spike_entropy.png)

In [None]:
import torch

def compute_spike_entropy(p, z):
    denom = 1.0 + z * p
    denom = torch.clamp(denom, min=1e-9) # Prevent division by zero/very small numbers
    return torch.sum(p / denom, dim=-1) # Sum across the vocabulary dimension

# Algorithm 3: Robust Private Watermarking

![images/algorithm_3.png](images/algorithm_3.png)

In [None]:
# This is the code block for Algorithm 3

def algorithm_3():
    # TODO: Implement Algorithm 3
    raise NotImplementedError()