# **Dictionary Similarity + Perfect Matching**

In [3]:
# coding=utf-8
# Copyright 2023 Authors of "A Watermark for Large Language Models"
# available at https://arxiv.org/abs/2301.10226
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations
import collections
from math import sqrt
import scipy.stats
import torch
from torch import Tensor
from transformers import LogitsProcessor
import nltk
import ssl

try:
    _create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
    pass
else:
    ssl._create_default_https_context = _create_unverified_https_context

nltk.download('wordnet')
from nltk.util import ngrams
import networkx as nx

# Import WordNet from NLTK (make sure you have run: nltk.download('wordnet'))
from nltk.corpus import wordnet as wn
from normalizers import normalization_strategy_lookup

###############################################
# Helper Functions for Vocabulary & Matching
###############################################

def get_vocabulary(tokenizer) -> list[str]:
    """
    Extract the vocabulary from a Hugging Face tokenizer.
    Returns a list of tokens where the index corresponds to the token ID.
    """
    vocab_dict = tokenizer.get_vocab()  # mapping: token -> id
    vocab_list = [None] * len(vocab_dict)
    for token, idx in vocab_dict.items():
        vocab_list[idx] = token
    return vocab_list


from nltk.corpus import wordnet as wn
from functools import lru_cache
from tqdm import tqdm


@lru_cache(maxsize=None)
def get_lemma_set(token: str) -> set:
    """
    Given a token, first strip off any leading special characters (e.g. the "Ġ" in GPT-2 tokenization)
    to obtain the underlying word, and then return a set of all lemma names (in lowercase)
    from all its synsets. The result is cached to avoid repeated WordNet lookups.
    """
    synsets = wn.synsets(token)
    return {lemma.lower() for s in synsets for lemma in s.lemma_names()}


def are_synonyms(token1: str, token2: str) -> bool:
    """
    Determines whether two tokens are synonyms by checking if token1 appears in the lemma set
    of token2 and vice versa.
    """
    # Remove common BPE prefix marker (like "Ġ") used by some tokenizers.
    word1 = token1.lstrip("Ġ")
    word2 = token2.lstrip("Ġ")
    lemmas1 = get_lemma_set(word1)
    lemmas2 = get_lemma_set(word2)
    if not lemmas1 or not lemmas2:
        return False
    return (word1.lower() in lemmas2) and (word2.lower() in lemmas1)


def filter_tokens_with_synonyms(vocab_list: list[str]) -> (list[int], list[int]):
    """
    Splits the vocabulary (list of tokens) into two sets:
      - unique_indices: indices for tokens that have no synonym in the vocabulary.
      - paired_indices: indices for tokens that have at least one synonym.

    This implementation precomputes the lemma sets for all tokens and uses tqdm to show a progress bar.
    """
    unique_indices = []
    paired_indices = []
    n = len(vocab_list)

    # Precompute lemma sets for all tokens in the vocabulary.
    lemma_sets = [get_lemma_set(token.lstrip("Ġ")) for token in vocab_list]

    for i in tqdm(range(n), desc="Processing vocabulary"):
        token_i = vocab_list[i]
        lemmas_i = lemma_sets[i]
        has_synonym = False
        for j in range(n):
            if i == j:
                continue
            # Check if token_i is in token_j's lemma set and vice versa.
            if (token_i.lstrip("Ġ")).lower() in lemma_sets[j] and vocab_list[j].lower() in lemmas_i:
                has_synonym = True
                break
        if has_synonym:
            paired_indices.append(i)
        else:
            unique_indices.append(i)
    return unique_indices, paired_indices


import itertools
from concurrent.futures import ThreadPoolExecutor, as_completed
import networkx as nx


# Assume are_synonyms is defined elsewhere (and ideally optimized or cached)

def _check_pair(a: int, b: int, vocab_list: list[str], indices: list[int]) -> tuple[int, int, float]:
    """
    Checks a pair of tokens (using the indices into vocab_list provided by indices)
    and returns (a, b, weight) where weight is 1.0 if the tokens are synonyms, 0.0 otherwise.
    """
    token_a = vocab_list[indices[a]]
    token_b = vocab_list[indices[b]]
    weight = 1.0 if are_synonyms(token_a, token_b) else 0.0
    return (a, b, weight)

def construct_similarity_matrix(vocab_list: list[str], indices: list[int]) -> list[list[float]]:
    """
    Constructs an m x m similarity matrix for tokens specified by indices.
    Each entry C[i][j] is 1.0 if vocab_list[indices[i]] and vocab_list[indices[j]] are synonyms,
    0 otherwise.
    Uses a ThreadPoolExecutor and tqdm progress bars.
    """
    m = len(indices)
    C = [[0.0 for _ in range(m)] for _ in range(m)]
    
    # Precompute lemma sets for tokens in 'indices'
    lemma_dict = {i: get_lemma_set(vocab_list[i].lstrip("Ġ")) for i in indices}
    
    # Iterate over pairs (a, b) with a < b
    for a in tqdm(range(m), desc="Constructing similarity matrix (outer loop)"):
        for b in range(a+1, m):
            token_a = vocab_list[indices[a]].lstrip("Ġ")
            token_b = vocab_list[indices[b]].lstrip("Ġ")
            lemmas_a = lemma_dict[indices[a]]
            lemmas_b = lemma_dict[indices[b]]
            weight = 1.0 if (token_a.lower() in lemmas_b and token_b.lower() in lemmas_a) else 0.0
            C[a][b] = weight
            C[b][a] = weight
    return C


def find_perfect_matching(similarity_matrix: list[list[float]]) -> list[tuple[int, int]]:
    """
    Given a similarity matrix, builds an undirected graph (nodes correspond to indices in the matrix)
    and finds a maximum–weight matching (i.e. a pairing).

    This version uses itertools.combinations to iterate over all unique pairs and wraps the loop with tqdm.
    Only edges with weight > 0 are added.
    Returns a list of tuples (i, j) representing matched indices (relative to the input list).
    """
    m = len(similarity_matrix)
    G = nx.Graph()
    G.add_nodes_from(range(m))

    # Create all unique pairs (i, j) for i < j.
    pairs = list(itertools.combinations(range(m), 2))

    for i, j in tqdm(pairs, total=len(pairs), desc="Building graph for matching"):
        weight = similarity_matrix[i][j]
        if weight > 0:
            G.add_edge(i, j, weight=weight)

    matching = nx.max_weight_matching(G, maxcardinality=True)
    pairing = [tuple(sorted(pair)) for pair in matching]
    return pairing

###############################################
# Revised Watermark Classes
###############################################

class WatermarkBase:
    def __init__(
        self,
        vocab: list[int] = None,
        gamma: float = 0.5,
        delta: float = 2.0,
        seeding_scheme: str = "simple_1",
        hash_key: int = 15485863,
        select_green_tokens: bool = True,
        precomputed_pairing: list[tuple[int, int]] = None,
        unique_tokens: list[int] = None,
    ):
        self.vocab = vocab                      # list of token IDs (usually 0,...,n-1)
        self.vocab_size = len(vocab)
        self.gamma = gamma                      # fraction of tokens to designate as green (target size)
        self.delta = delta                      # bias to add to green tokens' logits
        self.seeding_scheme = seeding_scheme
        self.rng = None
        self.hash_key = hash_key
        self.select_green_tokens = select_green_tokens
        self.pairing = precomputed_pairing      # perfect matching on tokens with synonyms (indices in vocab)
        self.unique_tokens = unique_tokens        # list of token IDs that have no synonyms

    def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None:
        """
        Seeds the RNG deterministically using the last token in input_ids.
        For the "simple_1" scheme, seed = hash_key * (last token id).
        """
        if seeding_scheme is None:
            seeding_scheme = self.seeding_scheme
        # Ensure the RNG is initialized.
        if self.rng is None:
            self.rng = torch.Generator(device=input_ids.device)
        if seeding_scheme == "simple_1":
            assert input_ids.shape[-1] >= 1, "Input must have at least one token."
            prev_token = input_ids[-1].item()
            self.rng.manual_seed(self.hash_key * prev_token)
        else:
            raise NotImplementedError(f"Seeding scheme {seeding_scheme} not implemented.")
        return

    def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
        """
        Returns a list of token IDs that form the green list.
        If a precomputed pairing exists, then:
          - All tokens from the unique set (those with no synonyms) are automatically in the green list.
          - For each pair in the perfect matching (on tokens with synonyms), a fair coin flip (using the seeded RNG)
            selects one token from the pair.
        Otherwise, falls back to a random permutation method.
        Optionally, the list may be truncated to a target size (gamma * vocab_size).
        """
        self._seed_rng(input_ids)
        if self.pairing is None or self.unique_tokens is None:
            # Fallback: use random permutation.
            greenlist_size = int(self.vocab_size * self.gamma)
            vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
            if self.select_green_tokens:
                return vocab_permutation[:greenlist_size].tolist()
            else:
                return vocab_permutation[-greenlist_size:].tolist()
        else:
            # Start with all unique tokens.
            greenlist_ids = self.unique_tokens.copy()
            # For tokens that have synonyms (precomputed pairing), randomly assign one from each pair.
            for pair in self.pairing:
                coin_flip = (torch.rand(1, generator=self.rng) < 0.5).item()
                chosen = pair[0] if coin_flip == 1 else pair[1]
                greenlist_ids.append(chosen)
            # Optionally, enforce a maximum size (gamma * vocab_size)
            desired_size = int(self.vocab_size * self.gamma)
            if len(greenlist_ids) > desired_size:
                indices = torch.randperm(len(greenlist_ids), generator=self.rng)[:desired_size].tolist()
                greenlist_ids = [greenlist_ids[i] for i in indices]
            return greenlist_ids


class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
        green_tokens_mask = torch.zeros_like(scores)
        for b_idx in range(len(greenlist_token_ids)):
            green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1
        return green_tokens_mask.bool()

    def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor:
        scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
        return scores

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if self.rng is None:
            self.rng = torch.Generator(device=input_ids.device)
        batched_greenlist_ids = [None for _ in range(input_ids.shape[0])]
        for b_idx in range(input_ids.shape[0]):
            greenlist_ids = self._get_greenlist_ids(input_ids[b_idx])
            batched_greenlist_ids[b_idx] = greenlist_ids
        green_tokens_mask = self._calc_greenlist_mask(scores, batched_greenlist_ids)
        scores = self._bias_greenlist_logits(scores, green_tokens_mask, self.delta)
        return scores


class WatermarkDetector(WatermarkBase):
    def __init__(
        self,
        *args,
        device: torch.device = None,
        tokenizer: Tokenizer = None,
        z_threshold: float = 4.0,
        normalizers: list[str] = ["unicode"],
        ignore_repeated_bigrams: bool = True,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        assert device, "Device must be provided."
        assert tokenizer, "A tokenizer is required for detection."
        self.tokenizer = tokenizer
        self.device = device
        self.z_threshold = z_threshold
        self.rng = torch.Generator(device=self.device)
        if self.seeding_scheme == "simple_1":
            self.min_prefix_len = 1
        else:
            raise NotImplementedError(f"Seeding scheme {self.seeding_scheme} not implemented.")
        self.normalizers = [normalization_strategy_lookup(norm) for norm in normalizers]
        self.ignore_repeated_bigrams = ignore_repeated_bigrams
        if self.ignore_repeated_bigrams:
            assert self.seeding_scheme == "simple_1", "Repeated bigram variant requires simple_1 seeding."

    def _compute_z_score(self, observed_count, T):
        expected_count = self.gamma
        numer = observed_count - expected_count * T
        denom = sqrt(T * expected_count * (1 - expected_count))
        return numer / denom

    def _compute_p_value(self, z):
        return scipy.stats.norm.sf(z)

    def _score_sequence(
        self,
        input_ids: Tensor,
        return_num_tokens_scored: bool = True,
        return_num_green_tokens: bool = True,
        return_green_fraction: bool = True,
        return_green_token_mask: bool = False,
        return_z_score: bool = True,
        return_p_value: bool = True,
    ):
        if self.ignore_repeated_bigrams:
            bigram_table = {}
            token_bigram_generator = ngrams(input_ids.cpu().tolist(), 2)
            freq = collections.Counter(token_bigram_generator)
            num_tokens_scored = len(freq.keys())
            for bigram in freq.keys():
                prefix = torch.tensor([bigram[0]], device=self.device)
                greenlist_ids = self._get_greenlist_ids(prefix)
                bigram_table[bigram] = True if bigram[1] in greenlist_ids else False
            green_token_count = sum(bigram_table.values())
        else:
            num_tokens_scored = len(input_ids) - self.min_prefix_len
            if num_tokens_scored < 1:
                raise ValueError("Not enough tokens to score.")
            green_token_count = 0
            green_token_mask = []
            for idx in range(self.min_prefix_len, len(input_ids)):
                curr_token = input_ids[idx]
                greenlist_ids = self._get_greenlist_ids(input_ids[:idx])
                if curr_token in greenlist_ids:
                    green_token_count += 1
                    green_token_mask.append(True)
                else:
                    green_token_mask.append(False)
        score_dict = {}
        if return_num_tokens_scored:
            score_dict["num_tokens_scored"] = num_tokens_scored
        if return_num_green_tokens:
            score_dict["num_green_tokens"] = green_token_count
        if return_green_fraction:
            score_dict["green_fraction"] = green_token_count / num_tokens_scored
        if return_z_score:
            score_dict["z_score"] = self._compute_z_score(green_token_count, num_tokens_scored)
        if return_p_value:
            z = score_dict.get("z_score", self._compute_z_score(green_token_count, num_tokens_scored))
            score_dict["p_value"] = self._compute_p_value(z)
        if return_green_token_mask:
            score_dict["green_token_mask"] = green_token_mask
        return score_dict

    def detect(
        self,
        text: str = None,
        tokenized_text: list[int] = None,
        return_prediction: bool = True,
        return_scores: bool = True,
        z_threshold: float = None,
        **kwargs,
    ) -> dict:
        assert (text is not None) ^ (tokenized_text is not None), "Provide either raw or tokenized text."
        if return_prediction:
            kwargs["return_p_value"] = True
        for normalizer in self.normalizers:
            text = normalizer(text)
        if tokenized_text is None:
            tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device)
            if tokenized_text[0] == self.tokenizer.bos_token_id:
                tokenized_text = tokenized_text[1:]
        else:
            if self.tokenizer is not None and tokenized_text[0] == self.tokenizer.bos_token_id:
                tokenized_text = tokenized_text[1:]
        output_dict = {}
        score_dict = self._score_sequence(tokenized_text, **kwargs)
        if return_scores:
            output_dict.update(score_dict)
        if return_prediction:
            z_threshold = z_threshold if z_threshold is not None else self.z_threshold
            output_dict["prediction"] = score_dict["z_score"] > z_threshold
            if output_dict["prediction"]:
                output_dict["confidence"] = 1 - score_dict["p_value"]
        return output_dict


  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package wordnet to /Users/zhzhu1/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [2]:
###############################################
# Outline of the New Partitioning Method
###############################################
# 1. Obtain the vocabulary from the model's tokenizer using get_vocabulary(tokenizer).
# 2. Use filter_tokens_with_synonyms(vocab_list) to split the vocabulary indices into:
#       - unique_indices: tokens with no synonyms (set A)
#       - paired_indices: tokens with at least one synonym (set B)
# 3. Construct the similarity matrix for tokens in set B using construct_similarity_matrix.
# 4. Compute a perfect matching on set B using find_perfect_matching.
# 5. In _get_greenlist_ids, use the precomputed pairing (mapped back to original token IDs)
#    and return all tokens from set A plus one token per pair (chosen at random by a coin flip).
###############################################
# Full Code Example for the Revised Watermark Partition
###############################################

if __name__ == "__main__":
    # For testing purposes, use a small model like 'distilgpt2' which can run on CPU.
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b")
    vocab_list = get_vocabulary(tokenizer)
    print(f"Vocabulary size: {len(vocab_list)}")

    # Filter tokens: separate those with no synonyms (set A) and with synonyms (set B)
    unique_indices, paired_indices = filter_tokens_with_synonyms(vocab_list)
    print(f"Number of unique tokens (no synonyms, set A): {len(unique_indices)}")
    print(f"Number of tokens with synonyms (set B): {len(paired_indices)}")

    # Construct similarity matrix for tokens in set B.
    similarity_matrix = construct_similarity_matrix(vocab_list, paired_indices)
    # Find a perfect matching on the tokens in set B.
    matching = find_perfect_matching(similarity_matrix)
    # Map matching indices (relative to paired_indices) back to the original vocabulary indices.
    mapped_pairing = [(paired_indices[i], paired_indices[j]) for (i, j) in matching]
    print("Computed perfect matching (pairs) on tokens with synonyms (set B):")
    print(mapped_pairing)

Vocabulary size: 50265


Processing vocabulary: 100%|██████████| 50265/50265 [03:20<00:00, 251.04it/s]


Number of unique tokens (no synonyms, set A): 33446
Number of tokens with synonyms (set B): 16819


Constructing similarity matrix (outer loop): 100%|██████████| 16819/16819 [00:35<00:00, 472.41it/s] 
Building graph for matching: 100%|██████████| 141430971/141430971 [00:11<00:00, 11848921.94it/s]


Computed perfect matching (pairs) on tokens with synonyms (set B):
[(2852, 22073), (7050, 25182), (41796, 47294), (20707, 29783), (15729, 22877), (1166, 18107), (1786, 48025), (21284, 43870), (5494, 34017), (2435, 44178), (4152, 18175), (17267, 20118), (1457, 47587), (7266, 11433), (5400, 36154), (1178, 37695), (7526, 30181), (2692, 38384), (5823, 38818), (4453, 36994), (19451, 44565), (130, 46814), (6929, 36870), (25255, 44137), (2703, 41483), (7094, 44165), (860, 42876), (1515, 41259), (15696, 19562), (275, 38060), (22636, 27957), (11644, 24483), (9394, 24777), (5654, 39088), (25512, 36142), (4615, 9149), (1911, 34571), (7368, 38566), (25715, 47711), (3017, 17577), (7960, 33166), (10465, 16583), (3151, 11316), (958, 45918), (23022, 34502), (9867, 45532), (8781, 39249), (11689, 34329), (7331, 29886), (2237, 4208), (465, 10118), (2228, 16111), (8780, 29530), (3360, 32764), (11005, 39205), (20078, 45375), (24670, 33092), (2365, 34799), (11350, 29383), (2198, 16123), (17751, 30639), (182

In [3]:
# Initialize WatermarkLogitsProcessor with precomputed pairing and unique tokens.
wm_processor = WatermarkLogitsProcessor(
    vocab=list(range(len(vocab_list))),
    gamma=0.25,
    delta=2.0,
    seeding_scheme="simple_1",
    select_green_tokens=True,
    precomputed_pairing=mapped_pairing,
    unique_tokens=unique_indices
)

# Test _get_greenlist_ids with a sample prompt.
sample_prompt = "This is a good day."
input_ids = torch.tensor(tokenizer.encode(sample_prompt))
greenlist_ids = wm_processor._get_greenlist_ids(input_ids)
print("Greenlist token IDs for the prompt:")
print(greenlist_ids)
green_tokens = [vocab_list[tok] for tok in greenlist_ids]
print("Greenlist tokens (strings):")
print(green_tokens)

Greenlist token IDs for the prompt:
[15185, 46264, 39650, 44828, 40670, 15632, 40129, 45840, 15779, 37640, 30965, 8303, 39467, 31106, 17943, 11256, 7108, 10209, 44006, 27350, 22024, 20391, 49546, 26531, 43495, 41035, 4476, 43981, 17662, 29447, 36385, 8894, 24933, 32512, 11704, 48145, 21803, 16961, 15445, 19327, 7039, 44816, 18806, 24152, 35653, 38971, 46473, 38229, 14252, 5697, 42934, 7935, 20950, 28401, 41414, 30126, 6685, 31169, 10676, 10282, 19921, 40443, 39258, 1678, 12816, 49017, 43420, 21837, 39443, 10197, 21408, 22711, 3139, 49441, 49040, 2622, 34993, 49463, 13244, 20581, 47848, 39850, 42913, 17319, 16195, 20452, 41167, 15560, 46678, 8921, 26357, 44937, 25671, 9938, 44758, 13100, 36638, 8622, 29835, 21281, 31849, 10280, 5644, 3257, 48682, 46209, 10741, 3907, 11965, 31522, 43898, 40392, 6913, 41771, 33157, 48573, 37805, 9294, 12789, 19148, 2469, 35144, 47050, 30343, 24026, 43006, 32169, 33287, 29604, 10851, 34576, 10717, 28419, 619, 5521, 48864, 31336, 46407, 36902, 35639, 7706, 

## Watermark Full Pipeline Example 

In [7]:
# coding=utf-8
# Copyright 2023 Authors of "A Watermark for Large Language Models"
# available at https://arxiv.org/abs/2301.10226
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations
import os
import argparse
from pprint import pprint
from functools import partial
from normalizers import normalization_strategy_lookup
import numpy  # for gradio hot reload
import gradio as gr
import torch

import collections
from math import sqrt
import scipy.stats
import torch
from torch import Tensor
from transformers import LogitsProcessor
import nltk
import ssl

try:
    _create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
    pass
else:
    ssl._create_default_https_context = _create_unverified_https_context

nltk.download('wordnet')
from nltk.util import ngrams
import networkx as nx

from transformers import LogitsProcessor
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    AutoModelForCausalLM,
    LogitsProcessorList,
)

# Import our revised watermark processor which now accepts precomputed pairing info.
from watermark_processor import WatermarkLogitsProcessor_kgw, WatermarkDetector_kgw


#########################
# Helper Functions
#########################

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def parse_args():
    parser = argparse.ArgumentParser(
        description="A demo for watermarking using a revised green/red partition based on WordNet synonyms."
    )
    parser.add_argument("--demo_public", type=str2bool, default=False,
                        help="Expose the gradio demo publicly.")
    parser.add_argument("--model_name_or_path", type=str, default="facebook/opt-1.3b",
                        help="Identifier for the pretrained model from Hugging Face.")
    parser.add_argument("--prompt_max_length", type=int, default=None,
                        help="Truncation length for the prompt.")
    parser.add_argument("--max_new_tokens", type=int, default=200,
                        help="Maximum number of tokens to generate.")
    parser.add_argument("--generation_seed", type=int, default=123,
                        help="Seed for generation reproducibility.")
    parser.add_argument("--use_sampling", type=str2bool, default=True,
                        help="Use multinomial sampling for generation.")
    parser.add_argument("--sampling_temp", type=float, default=0.7,
                        help="Sampling temperature.")
    parser.add_argument("--n_beams", type=int, default=1,
                        help="Number of beams for beam search (if not sampling).")
    parser.add_argument("--use_gpu", type=str2bool, default=True,
                        help="Run inference on GPU if available.")
    parser.add_argument("--seeding_scheme", type=str, default="simple_1",
                        help="Seeding scheme for watermarking.")
    parser.add_argument("--gamma", type=float, default=0.25,
                        help="Target fraction of tokens for the green list.")
    parser.add_argument("--delta", type=float, default=2.0,
                        help="Bias to add to green list token logits.")
    parser.add_argument("--normalizers", type=str, default="",
                        help="Comma separated normalizer names for detection.")
    parser.add_argument("--ignore_repeated_bigrams", type=str2bool, default=False,
                        help="Use repeated bigram variant in detection.")
    parser.add_argument("--detection_z_threshold", type=float, default=4.0,
                        help="Z-score threshold for detection.")
    parser.add_argument("--select_green_tokens", type=str2bool, default=True,
                        help="Legacy option for selecting green tokens.")
    parser.add_argument("--skip_model_load", type=str2bool, default=False,
                        help="Skip model loading (for debugging).")
    parser.add_argument("--seed_separately", type=str2bool, default=True,
                        help="Seed separately for each generation call.")
    parser.add_argument("--load_fp16", type=str2bool, default=False,
                        help="Load model in FP16 mode.")
    args = parser.parse_args()
    # Convert normalizers into a list (if provided)
    args.normalizers = args.normalizers.split(",") if args.normalizers else []
    return args


#############################
# Preprocessing: Vocabulary & Matching
#############################

def get_vocabulary(tokenizer) -> list[str]:
    """
    Returns a list of tokens (where index corresponds to token ID)
    extracted from the tokenizer's vocabulary.
    """
    vocab_dict = tokenizer.get_vocab()
    vocab_list = [None] * len(vocab_dict)
    for token, idx in vocab_dict.items():
        vocab_list[idx] = token
    return vocab_list


import nltk
from nltk.corpus import wordnet as wn
from functools import lru_cache
from tqdm import tqdm
import itertools
from concurrent.futures import ThreadPoolExecutor, as_completed
import networkx as nx


@lru_cache(maxsize=None)
def get_lemma_set(token: str) -> set:
    """
    Given a token string, first strip any leading BPE marker (e.g. "Ġ")
    and return the set of lowercased lemma names from all its WordNet synsets.
    """
    # Strip off common prefix markers
    word = token.lstrip("Ġ")
    synsets = wn.synsets(word)
    return {lemma.lower() for s in synsets for lemma in s.lemma_names()}


def are_synonyms(token1: str, token2: str) -> bool:
    """
    Determines whether two tokens are synonyms by checking if token1 (after stripping)
    appears in token2's lemma set and vice versa.
    """
    word1 = token1.lstrip("Ġ")
    word2 = token2.lstrip("Ġ")
    lemmas1 = get_lemma_set(word1)
    lemmas2 = get_lemma_set(word2)
    if not lemmas1 or not lemmas2:
        return False
    return (word1.lower() in lemmas2) and (word2.lower() in lemmas1)


def filter_tokens_with_synonyms(vocab_list: list[str]) -> (list[int], list[int]):
    """
    Splits the vocabulary indices into:
      - unique_indices: indices of tokens that have no synonym in the vocabulary.
      - paired_indices: indices of tokens that have at least one synonym.
    Uses a progress bar via tqdm.
    """
    unique_indices = []
    paired_indices = []
    n = len(vocab_list)
    # Precompute lemma sets for each token (stripped of any leading marker)
    lemma_sets = [get_lemma_set(token) for token in vocab_list]
    for i in tqdm(range(n), desc="Filtering vocabulary"):
        token_i = vocab_list[i]
        lemmas_i = lemma_sets[i]
        has_synonym = False
        for j in range(n):
            if i == j:
                continue
            if (token_i.lstrip("Ġ")).lower() in lemma_sets[j] and (vocab_list[j].lstrip("Ġ")).lower() in lemmas_i:
                has_synonym = True
                break
        if has_synonym:
            paired_indices.append(i)
        else:
            unique_indices.append(i)
    return unique_indices, paired_indices


def construct_similarity_matrix(vocab_list: list[str], indices: list[int]) -> list[list[float]]:
    """
    Constructs an m x m similarity matrix for tokens specified by indices.
    Entry [i][j] is 1.0 if the tokens are synonyms, 0 otherwise.
    Uses nested loops with tqdm progress bar.
    """
    m = len(indices)
    C = [[0.0 for _ in range(m)] for _ in range(m)]
    # Precompute lemma sets for tokens in indices
    lemma_dict = {i: get_lemma_set(vocab_list[i].lstrip("Ġ")) for i in indices}
    for a in tqdm(range(m), desc="Constructing similarity matrix (outer loop)"):
        for b in range(a + 1, m):
            token_a = vocab_list[indices[a]].lstrip("Ġ")
            token_b = vocab_list[indices[b]].lstrip("Ġ")
            lemmas_a = lemma_dict[indices[a]]
            lemmas_b = lemma_dict[indices[b]]
            weight = 1.0 if (token_a.lower() in lemmas_b and token_b.lower() in lemmas_a) else 0.0
            C[a][b] = weight
            C[b][a] = weight
    return C


def find_perfect_matching(similarity_matrix: list[list[float]]) -> list[tuple[int, int]]:
    """
    Constructs an undirected graph from the similarity matrix (only edges with weight>0)
    and returns a maximum–weight matching (as a list of index pairs relative to the input list).
    Uses tqdm over the pairs.
    """
    m = len(similarity_matrix)
    G = nx.Graph()
    G.add_nodes_from(range(m))
    pairs = list(itertools.combinations(range(m), 2))
    for i, j in tqdm(pairs, total=len(pairs), desc="Building graph for matching"):
        weight = similarity_matrix[i][j]
        if weight > 0:
            G.add_edge(i, j, weight=weight)
    matching = nx.max_weight_matching(G, maxcardinality=True)
    pairing = [tuple(sorted(pair)) for pair in matching]
    return pairing


#############################
# Revised Watermark Processor Classes
#############################

class WatermarkBase:
    def __init__(
            self,
            vocab: list[int] = None,
            gamma: float = 0.5,
            delta: float = 2.0,
            seeding_scheme: str = "simple_1",
            hash_key: int = 15485863,
            select_green_tokens: bool = True,
            precomputed_pairing: list[tuple[int, int]] = None,
            unique_tokens: list[int] = None,
    ):
        self.vocab = vocab  # list of token IDs (usually 0, ..., n-1)
        self.vocab_size = len(vocab)
        self.gamma = gamma  # target fraction of tokens for the green list
        self.delta = delta  # bias added to green token logits
        self.seeding_scheme = seeding_scheme
        self.rng = None
        self.hash_key = hash_key
        self.select_green_tokens = select_green_tokens
        self.pairing = precomputed_pairing  # perfect matching (pairs) for tokens with synonyms
        self.unique_tokens = unique_tokens  # token IDs that have no synonyms

    def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None:
        """
        Seeds the RNG deterministically using the last token in input_ids.
        For the "simple_1" scheme, the seed is hash_key * (last token id).
        """
        if seeding_scheme is None:
            seeding_scheme = self.seeding_scheme
        if self.rng is None:
            self.rng = torch.Generator(device=input_ids.device)
        if seeding_scheme == "simple_1":
            assert input_ids.shape[-1] >= 1, "Input must have at least one token."
            prev_token = input_ids[-1].item()
            self.rng.manual_seed(self.hash_key * prev_token)
        else:
            raise NotImplementedError(f"Seeding scheme {seeding_scheme} not implemented.")
        return

    def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
        """
        Returns the green list token IDs.
        If precomputed pairing and unique_tokens are provided, then:
          - All unique tokens (set A) are in the green list.
          - For each pair in the perfect matching, one token is chosen by a coin flip.
        Otherwise, falls back to a random permutation method.
        The final list is optionally truncated to a target size (gamma * vocab_size).
        """
        self._seed_rng(input_ids)
        if self.pairing is None or self.unique_tokens is None:
            # Fallback: use random permutation.
            greenlist_size = int(self.vocab_size * self.gamma)
            vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
            if self.select_green_tokens:
                return vocab_permutation[:greenlist_size].tolist()
            else:
                return vocab_permutation[-greenlist_size:].tolist()
        else:
            greenlist_ids = self.unique_tokens.copy()
            for pair in self.pairing:
                coin_flip = (torch.rand(1, generator=self.rng).item() < 0.5)
                chosen = pair[0] if coin_flip else pair[1]
                greenlist_ids.append(chosen)
            #desired_size = int(self.vocab_size * self.gamma)
            #if len(greenlist_ids) > desired_size:
            #    perm = torch.randperm(len(greenlist_ids), generator=self.rng).tolist()
            #    indices = perm[:desired_size]
            #    greenlist_ids = [greenlist_ids[i] for i in indices]
            return greenlist_ids


class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
        green_tokens_mask = torch.zeros_like(scores)
        for b_idx in range(len(greenlist_token_ids)):
            green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1
        return green_tokens_mask.bool()

    def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor,
                               greenlist_bias: float) -> torch.Tensor:
        scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
        return scores

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if self.rng is None:
            self.rng = torch.Generator(device=input_ids.device)
        batched_greenlist_ids = []
        for b_idx in range(input_ids.shape[0]):
            greenlist_ids = self._get_greenlist_ids(input_ids[b_idx])
            batched_greenlist_ids.append(greenlist_ids)
        green_tokens_mask = self._calc_greenlist_mask(scores, batched_greenlist_ids)
        scores = self._bias_greenlist_logits(scores, green_tokens_mask, self.delta)
        return scores


class WatermarkDetector(WatermarkBase):
    def __init__(
            self,
            *args,
            device: torch.device = None,
            tokenizer=None,
            z_threshold: float = 4.0,
            normalizers: list[str] = ["unicode"],
            ignore_repeated_bigrams: bool = True,
            **kwargs,
    ):
        super().__init__(*args, **kwargs)
        assert device, "Device must be provided."
        assert tokenizer, "A tokenizer is required for detection."
        self.tokenizer = tokenizer
        self.device = device
        self.z_threshold = z_threshold
        self.rng = torch.Generator(device=self.device)
        if self.seeding_scheme == "simple_1":
            self.min_prefix_len = 1
        else:
            raise NotImplementedError(f"Seeding scheme {self.seeding_scheme} not implemented.")
        self.normalizers = [normalization_strategy_lookup(norm) for norm in normalizers]
        self.ignore_repeated_bigrams = ignore_repeated_bigrams
        if self.ignore_repeated_bigrams:
            assert self.seeding_scheme == "simple_1", "Repeated bigram variant requires simple_1 seeding."

    def _compute_z_score(self, observed_count, T):
        expected_count = self.gamma
        numer = observed_count - expected_count * T
        denom = sqrt(T * expected_count * (1 - expected_count))
        return numer / denom

    def _compute_p_value(self, z):
        return scipy.stats.norm.sf(z)

    def _score_sequence(
        self,
        input_ids: Tensor,
        return_num_tokens_scored: bool = True,
        return_num_green_tokens: bool = True,
        return_green_fraction: bool = True,
        return_green_token_mask: bool = False,
        return_z_score: bool = True,
        return_p_value: bool = True,
    ):
        if self.ignore_repeated_bigrams:
            # Repeated bigram variant: T = number of unique bigrams.
            bigram_table = {}
            token_bigram_generator = ngrams(input_ids.cpu().tolist(), 2)
            freq = collections.Counter(token_bigram_generator)
            num_tokens_scored = len(freq.keys())
            for bigram in freq.keys():
                prefix = torch.tensor([bigram[0]], device=self.device)
                greenlist_ids = self._get_greenlist_ids(prefix)
                bigram_table[bigram] = bigram[1] in greenlist_ids
            green_token_count = sum(bigram_table.values())
        else:
            # Standard variant: T = total tokens (after min_prefix_len)
            num_tokens_scored = len(input_ids) - self.min_prefix_len
            if num_tokens_scored < 1:
                raise ValueError("Not enough tokens to score.")
            green_token_count = 0
            green_token_mask = []
            for idx in range(self.min_prefix_len, len(input_ids)):
                curr_token = input_ids[idx]
                greenlist_ids = self._get_greenlist_ids(input_ids[:idx])
                if curr_token in greenlist_ids:
                    green_token_count += 1
                    green_token_mask.append(True)
                else:
                    green_token_mask.append(False)
        # Debug prints:
        print(f"Total tokens scored (T): {num_tokens_scored}")
        print(f"Green token count: {green_token_count}")
        print(f"Green fraction: {green_token_count/num_tokens_scored:.2%}")

        score_dict = {}
        if return_num_tokens_scored:
            score_dict["num_tokens_scored"] = num_tokens_scored
        if return_num_green_tokens:
            score_dict["num_green_tokens"] = green_token_count
        if return_green_fraction:
            score_dict["green_fraction"] = green_token_count / num_tokens_scored
        if return_z_score:
            score_dict["z_score"] = self._compute_z_score(green_token_count, num_tokens_scored)
        if return_p_value:
            z = score_dict.get("z_score", self._compute_z_score(green_token_count, num_tokens_scored))
            score_dict["p_value"] = self._compute_p_value(z)
        if return_green_token_mask:
            score_dict["green_token_mask"] = green_token_mask
        return score_dict
    
    def detect(
            self,
            text: str = None,
            tokenized_text: list[int] = None,
            return_prediction: bool = True,
            return_scores: bool = True,
            z_threshold: float = None,
            **kwargs,
    ) -> dict:
        assert (text is not None) ^ (tokenized_text is not None), "Provide either raw or tokenized text."
        if return_prediction:
            kwargs["return_p_value"] = True
        for normalizer in self.normalizers:
            text = normalizer(text)
        if tokenized_text is None:
            tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(
                self.device)
            if tokenized_text[0] == self.tokenizer.bos_token_id:
                tokenized_text = tokenized_text[1:]
        else:
            if self.tokenizer is not None and tokenized_text[0] == self.tokenizer.bos_token_id:
                tokenized_text = tokenized_text[1:]
        output_dict = {}
        score_dict = self._score_sequence(tokenized_text, **kwargs)
        if return_scores:
            output_dict.update(score_dict)
        if return_prediction:
            z_threshold = z_threshold if z_threshold is not None else self.z_threshold
            output_dict["prediction"] = score_dict["z_score"] > z_threshold
            if output_dict["prediction"]:
                output_dict["confidence"] = 1 - score_dict["p_value"]
        return output_dict


#############################
# Demo Code (Generation & Detection)
#############################

def load_model(args):
    args.is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5", "T0"]])
    args.is_decoder_only_model = any(
        [(model_type in args.model_name_or_path) for model_type in ["gpt", "opt", "bloom"]])
    if args.is_seq2seq_model:
        model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
    elif args.is_decoder_only_model:
        if args.load_fp16:
            model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch.float16,
                                                         device_map='auto')
        else:
            model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
    else:
        raise ValueError(f"Unknown model type: {args.model_name_or_path}")
    if args.use_gpu:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        if not args.load_fp16:
            model = model.to(device)
    else:
        device = "cpu"
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    return model, tokenizer, device


def generate(prompt, args, model=None, device=None, tokenizer=None):
    #print(f"Generating with {args}")

    # Instantiate the watermark processor with precomputed pairing/unique tokens.
    # Assume that in main() we precomputed these values (see below).
    watermark_processor = WatermarkLogitsProcessor(
        vocab=list(tokenizer.get_vocab().values()),
        gamma=args.gamma,
        delta=args.delta,
        seeding_scheme=args.seeding_scheme,
        select_green_tokens=args.select_green_tokens,
        precomputed_pairing=args.precomputed_pairing,
        unique_tokens=args.unique_tokens
    )

    gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
    if args.use_sampling:
        gen_kwargs.update(dict(
            do_sample=True,
            top_k=0,
            temperature=args.sampling_temp
        ))
    else:
        gen_kwargs.update(dict(
            num_beams=args.n_beams
        ))

    generate_without_watermark = partial(
        model.generate,
        **gen_kwargs
    )
    generate_with_watermark = partial(
        model.generate,
        logits_processor=LogitsProcessorList([watermark_processor]),
        **gen_kwargs
    )

    if not args.prompt_max_length:
        if hasattr(model.config, "max_position_embedding"):
            args.prompt_max_length = model.config.max_position_embeddings - args.max_new_tokens
        else:
            args.prompt_max_length = 2048 - args.max_new_tokens

    tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True,
                           max_length=args.prompt_max_length).to(device)
    truncation_warning = tokd_input["input_ids"].shape[-1] == args.prompt_max_length
    redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]

    torch.manual_seed(args.generation_seed)
    output_without_watermark = generate_without_watermark(**tokd_input)
    if args.seed_separately:
        torch.manual_seed(args.generation_seed)
    output_with_watermark = generate_with_watermark(**tokd_input)

    if args.is_decoder_only_model:
        output_without_watermark = output_without_watermark[:, tokd_input["input_ids"].shape[-1]:]
        output_with_watermark = output_with_watermark[:, tokd_input["input_ids"].shape[-1]:]

    decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
    decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]

    return (
    redecoded_input, int(truncation_warning), decoded_output_without_watermark, decoded_output_with_watermark, args)


def format_names(s):
    s = s.replace("num_tokens_scored", "Tokens Counted (T)")
    s = s.replace("num_green_tokens", "# Tokens in Greenlist")
    s = s.replace("green_fraction", "Fraction of T in Greenlist")
    s = s.replace("z_score", "z-score")
    s = s.replace("p_value", "p value")
    s = s.replace("prediction", "Prediction")
    s = s.replace("confidence", "Confidence")
    return s


def list_format_scores(score_dict, detection_threshold):
    lst_2d = []
    for k, v in score_dict.items():
        if k == 'green_fraction':
            lst_2d.append([format_names(k), f"{v:.1%}"])
        elif k == 'confidence':
            lst_2d.append([format_names(k), f"{v:.3%}"])
        elif isinstance(v, float):
            lst_2d.append([format_names(k), f"{v:.3g}"])
        elif isinstance(v, bool):
            lst_2d.append([format_names(k), ("Watermarked" if v else "Human/Unwatermarked")])
        else:
            lst_2d.append([format_names(k), f"{v}"])
    if "confidence" in score_dict:
        lst_2d.insert(-2, ["z-score Threshold", f"{detection_threshold}"])
    else:
        lst_2d.insert(-1, ["z-score Threshold", f"{detection_threshold}"])
    return lst_2d


def detect(input_text, args, device=None, tokenizer=None):
    watermark_detector = WatermarkDetector(
        vocab=list(tokenizer.get_vocab().values()),
        gamma=args.gamma,
        seeding_scheme=args.seeding_scheme,
        device=device,
        tokenizer=tokenizer,
        z_threshold=args.detection_z_threshold,
        normalizers=args.normalizers,
        ignore_repeated_bigrams=args.ignore_repeated_bigrams,
        select_green_tokens=args.select_green_tokens
    )
    if len(input_text) - 1 > watermark_detector.min_prefix_len:
        score_dict = watermark_detector.detect(input_text)
        output = list_format_scores(score_dict, watermark_detector.z_threshold)
    else:
        output = [["Error", "string too short to compute metrics"]]
        output += [["", ""] for _ in range(6)]
    return output, args


def run_gradio(args, model=None, device=None, tokenizer=None):
    generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
    detect_partial = partial(detect, device=device, tokenizer=tokenizer)
    with gr.Blocks() as demo:
        gr.Markdown("Gradio demo not shown in command-line mode.")
        demo.launch()

[nltk_data] Downloading package wordnet to /Users/zhzhu1/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [7]:
import sys
sys.argv = [sys.argv[0]]  # Remove any extra arguments from the notebook environment

args = parse_args()
print(args)

# Load model and tokenizer if not skipping
if not args.skip_model_load:
    model, tokenizer, device = load_model(args)
else:
    model, tokenizer, device = None, None, None

# Precompute the vocabulary and pairing.
from transformers import AutoTokenizer
tokenizer_for_vocab = AutoTokenizer.from_pretrained(args.model_name_or_path)
vocab_list = get_vocabulary(tokenizer_for_vocab)
print(f"Vocabulary size: {len(vocab_list)}")
unique_indices, paired_indices = filter_tokens_with_synonyms(vocab_list)
print(f"Unique tokens (set A): {len(unique_indices)}")
print(f"Tokens with synonyms (set B): {len(paired_indices)}")
similarity_matrix = construct_similarity_matrix(vocab_list, paired_indices)
matching = find_perfect_matching(similarity_matrix)
mapped_pairing = [(paired_indices[i], paired_indices[j]) for (i, j) in matching]
#print("Computed perfect matching (pairs) on tokens with synonyms (set B):")
#print(mapped_pairing)

# Store the precomputed pairing and unique set in args so that generate() can use them.
args.precomputed_pairing = mapped_pairing
args.unique_tokens = unique_indices

# For testing, use a sample prompt.
input_text = (
    "The diamondback terrapin or simply terrapin (Malaclemys terrapin) is a "
    "species of turtle native to the brackish coastal tidal marshes of the "
    "Northeastern and southern United States, and in Bermuda.[6] It belongs "
    "to the monotypic genus Malaclemys. It has one of the largest ranges of "
    "all turtles in North America, stretching as far south as the Florida Keys "
    "and as far north as Cape Cod.[7] The name 'terrapin' is derived from the "
    "Algonquian word torope.[8] It applies to Malaclemys terrapin in both "
    "British English and American English. The species is"
)
args.default_prompt = input_text

term_width = 80
print("#" * term_width)
print("Prompt:")
print(input_text)

Namespace(demo_public=False, model_name_or_path='facebook/opt-1.3b', prompt_max_length=None, max_new_tokens=200, generation_seed=123, use_sampling=True, sampling_temp=0.7, n_beams=1, use_gpu=True, seeding_scheme='simple_1', gamma=0.25, delta=2.0, normalizers=[], ignore_repeated_bigrams=False, detection_z_threshold=4.0, select_green_tokens=True, skip_model_load=False, seed_separately=True, load_fp16=False)
Vocabulary size: 50265


Filtering vocabulary: 100%|██████████| 50265/50265 [02:09<00:00, 387.38it/s]


Unique tokens (set A): 28109
Tokens with synonyms (set B): 22156


Constructing similarity matrix (outer loop): 100%|██████████| 22156/22156 [01:06<00:00, 332.41it/s] 
Building graph for matching: 100%|██████████| 245433090/245433090 [00:32<00:00, 7660681.21it/s]


Computed perfect matching (pairs) on tokens with synonyms (set B):
[(11355, 42824), (2340, 43106), (13930, 34255), (7774, 30861), (7567, 45224), (11173, 49483), (13133, 41829), (19120, 40556), (28836, 35901), (351, 20311), (7795, 8691), (5312, 43747), (17132, 29381), (3490, 45739), (3864, 14563), (18169, 27361), (3107, 7237), (10395, 47603), (20029, 44391), (34640, 44188), (37656, 45379), (45106, 45178), (12496, 40711), (15040, 15311), (9297, 17976), (6750, 8783), (284, 44578), (3476, 48928), (1407, 8096), (46756, 46816), (2501, 18197), (15851, 27561), (11601, 33214), (35825, 42353), (28040, 36526), (6574, 33499), (4900, 29333), (3619, 43999), (190, 19620), (2178, 46430), (914, 44804), (11253, 31790), (9208, 33789), (40312, 48233), (19407, 37092), (2402, 40642), (4347, 28887), (39881, 47935), (1275, 46725), (6690, 42983), (15005, 17366), (20870, 22550), (4076, 37657), (6283, 27451), (44676, 46674), (7712, 22641), (7357, 37155), (12506, 43177), (27394, 39546), (17906, 35518), (12185, 21

In [65]:
args.delta = 2
args.detection_z_threshold = 4.0
args.model_name_or_path = "facebook/opt-1.3b"

model, tokenizer, device = load_model(args)

In [75]:
from tqdm import tqdm
import scipy.stats as stats

def count_green_tokens_paired(tokenizer, watermark_processor, text: str) -> (int, int, float):
    """
    Tokenizes the input text and, for each token (after a minimum prefix),
    if the token belongs to the paired set (i.e. has at least one synonym),
    computes the green list from the current prefix (using the watermark processor)
    and counts how many of those tokens are in the green list.
    
    Returns:
      green_count: Number of paired tokens that are green.
      tokens_scored: Total number of paired tokens considered.
      proportion: green_count / tokens_scored.
    """
    # Tokenize text without special tokens.
    input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0]
    total_tokens = len(input_ids)
    # Use min_prefix_len from the processor if available; otherwise, default to 1.
    start_idx = getattr(watermark_processor, "min_prefix_len", 1)
    green_count = 0
    tokens_scored = 0
    # The paired set is the complement of unique tokens.
    if watermark_processor.unique_tokens is None:
        paired_set = set(range(watermark_processor.vocab_size))
    else:
        paired_set = set(range(watermark_processor.vocab_size)) - set(watermark_processor.unique_tokens)
    
    for idx in tqdm(range(start_idx, total_tokens), desc="Counting paired tokens"):
        token = input_ids[idx].item()
        # Only count tokens that belong to the paired set.
        if token not in paired_set:
            continue
        tokens_scored += 1
        prefix = input_ids[:idx]
        greenlist_ids = watermark_processor._get_greenlist_ids(prefix)
        if token in greenlist_ids:
            green_count += 1
    proportion = green_count / tokens_scored if tokens_scored > 0 else 0.0
    return green_count, tokens_scored, proportion


def compute_p_value(green_count: int, tokens_scored: int) -> (float, float):
    """
    Under the null hypothesis for paired tokens, the probability of a token being green is 0.5.
    Computes the z–score and one–tailed p–value:
    
       z = (green_count - 0.5 * tokens_scored) / sqrt(tokens_scored * 0.5 * 0.5)
    """
    if tokens_scored == 0:
        return None, None
    p0 = 0.5
    expected = p0 * tokens_scored
    std = (tokens_scored * p0 * (1 - p0)) ** 0.5
    z_score = (green_count - expected) / std
    p_value = stats.norm.sf(z_score)  # one-tailed test (right tail)
    return z_score, p_value

def compute_perplexity(model, tokenizer, text: str, device: torch.device) -> float:
    """
    Computes the perplexity of a given text using the provided model.
    Assumes a causal language model where providing labels yields the loss.
    """
    model.eval()
    # Tokenize the text with special tokens.
    inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True).to(device)
    with torch.no_grad():
        # When labels are the same as input_ids, the loss is the negative log-likelihood.
        outputs = model(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss  # this is the average loss per token
    perplexity = torch.exp(loss).item()
    return perplexity


In [77]:
from datasets import load_dataset

# Load the "realnewslike" subset of C4 (English)
c4_realnewslike = load_dataset("c4", "realnewslike", split="train", streaming=False, trust_remote_code=True)

# Shuffle the dataset with a fixed seed for reproducibility
shuffled_dataset = c4_realnewslike.shuffle(seed=42)

# Select the first 500 examples from the shuffled dataset
sampled_examples = shuffled_dataset.select(range(500))

# Extract the text from each example (assuming the field is "text")
sampled_texts = [example["text"] for example in sampled_examples]

print(f"Sampled {len(sampled_texts)} news-like texts from C4.")



Sampled 500 news-like texts from C4.


In [82]:
max_words = 200
truncated_texts = []
for text in sampled_texts:
    words = text.split()  # split text into words
    truncated_text = " ".join(words[:max_words])
    truncated_texts.append(truncated_text)

In [88]:

# --- Example usage after text generation ---
# Assume that:
#   - decoded_output_with_watermark is the generated watermarked text.
#   - decoded_output_without_watermark is the generated non-watermarked text.
#   - wm_processor is the WatermarkLogitsProcessor instance used during generation.
#   - tokenizer is the tokenizer used by the model.
#   - args.detection_z_threshold contains the threshold for deciding watermarked text.

input_text = truncated_texts[4] #"The United States has 63 national parks, which are congressionally designated protected areas operated by the National Park Service, an agency of the Department of the Interior.[1] National parks are designated for their natural beauty, unique geological features, diverse ecosystems, and recreational opportunities, typically 'because of some outstanding scenic feature or natural phenomena.'[2] While legislatively all units of the National Park System are considered equal with the same mission, national parks are generally larger and more of a destination, and hunting and extractive activities are prohibited.[3] National monuments, on the other hand, are also frequently protected for their historical or archaeological significance. Eight national parks (including six in Alaska) are paired with a national preserve, areas with different levels of protection that are administered together but considered separate units and whose areas are not included in the figures below. The 433 units of the National Park System can be broadly referred to as national parks, but most have other formal designations.[4]"

# Generate outputs as before
redecoded_input, truncation_warning, decoded_output_without_watermark, decoded_output_with_watermark, args = generate(
    input_text, args, model=model, device=device, tokenizer=tokenizer)

wm_processor = WatermarkLogitsProcessor(
    vocab=list(tokenizer.get_vocab().values()),
    gamma=args.gamma,
    delta=args.delta,
    seeding_scheme=args.seeding_scheme,
    select_green_tokens=args.select_green_tokens,
    precomputed_pairing=args.precomputed_pairing,
    unique_tokens=args.unique_tokens
)

# Compute detection metrics for watermarked text.
green_count_w, tokens_scored_w, proportion_w = count_green_tokens_paired(tokenizer, wm_processor, decoded_output_with_watermark)
z_w, p_w = compute_p_value(green_count_w, tokens_scored_w)
judgement_w = "LLM-generated (watermarked)" if z_w > args.detection_z_threshold else "Human-generated (non-watermarked)"

# Compute detection metrics for non-watermarked text.
green_count_nw, tokens_scored_nw, proportion_nw = count_green_tokens_paired(tokenizer, wm_processor, decoded_output_without_watermark)
z_nw, p_nw = compute_p_value(green_count_nw, tokens_scored_nw)
judgement_nw = "LLM-generated (watermarked)" if z_nw > args.detection_z_threshold else "Human-generated (non-watermarked)"

# Compute perplexity for both watermarked and non-watermarked outputs:
perplexity_nonwm = compute_perplexity(model, tokenizer, decoded_output_without_watermark, device)
perplexity_wm = compute_perplexity(model, tokenizer, decoded_output_with_watermark, device)


# Print the outputs.
print("=== Detection Results ===")
print("Original Prompt:")
print(redecoded_input)
print("\n--- Watermarked Text ---")
print(decoded_output_with_watermark)
print("Detection on Watermarked Text:")
print(f"  Green tokens (paired only): {green_count_w} / {tokens_scored_w} ({proportion_w:.2%})")
print(f"  z–score: {z_w:.2f}, p–value: {p_w:.4f}")
print(f"  Judgement: {judgement_w}")
print(f"Perplexity (watermarked): {perplexity_wm:.2f}")

print("\n--- Non-watermarked Text ---")
print(decoded_output_without_watermark)
print("Detection on Non-watermarked Text:")
print(f"  Green tokens (paired only): {green_count_nw} / {tokens_scored_nw} ({proportion_nw:.2%})")
print(f"  z–score: {z_nw:.2f}, p–value: {p_nw:.4f}")
print(f"  Judgement: {judgement_nw}")
print(f"Perplexity (non-watermarked): {perplexity_nonwm:.2f}")

Counting paired tokens: 100%|██████████| 199/199 [00:01<00:00, 164.66it/s]
Counting paired tokens: 100%|██████████| 199/199 [00:01<00:00, 152.68it/s]


=== Detection Results ===
Original Prompt:
At least four Republican senators are talking privately about opposing the current version of the tax reform bill, which the House voted to pass early Thursday, on the grounds that it would balloon the national deficit too much. Arizona Sen. Jeff Flake and Oklahoma Sen. James Lankford are among the four — enough to stop a bill that can only spare two Republican defections — who have concerns about a tax reform bill that was estimated to hike the deficit by $1.5 trillion over 10 years. The other two senators have not publicly confirmed their concerns. In an interview with TIME on Tuesday, two days before the House voted to pass its version of the tax reform bill, Flake said that he believes the bill is larded with temporary gimmicks that will ultimately add even more than that to the deficit. Their concerns show the tricky needle that Republicans will have to thread to get the tax reform bill through the upper chamber using reconciliation, a pa

# **Other Similarity Metric**

**LLM based**

In [None]:
import itertools
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import openai  # make sure to set openai.api_key appropriately

def query_llm_similarity(prompt: str) -> float:
    """
    Query an LLM (e.g. OpenAI API) to rate the similarity between two words.
    The prompt instructs the LLM to output a number among {0, 0.5, 1}.
    
    Returns:
      A float: 0.0, 0.5, or 1.0.
    """
    try:
        response = openai.Completion.create(
            model="text-davinci-003",
            prompt=prompt,
            max_tokens=10,
            temperature=0.0,  # use deterministic behavior
            n=1,
            stop=None
        )
        answer = response.choices[0].text.strip()
        value = float(answer)  # try to convert the answer to a float
    except Exception as e:
        value = 0.0  # default similarity on error

    # Clamp value to one of 0.0, 0.5, or 1.0
    if value < 0.25:
        return 0.0
    elif value < 0.75:
        return 0.5
    else:
        return 1.0

def construct_similarity_matrix_llm(vocab_list: list[str], indices: list[int]) -> list[list[float]]:
    """
    Constructs an m x m similarity matrix for tokens specified by indices.
    For each unique pair (i, j) (with i < j), it queries an LLM with a prompt to rate the similarity
    between the two tokens. The similarity is clamped to one of {0.0, 0.5, 1.0}.
    
    This function uses a ThreadPoolExecutor to parallelize API calls and displays a progress bar via tqdm.
    """
    m = len(indices)
    # Initialize an m x m matrix with zeros.
    C = [[0.0 for _ in range(m)] for _ in range(m)]
    
    # Create all unique pairs (i, j) for indices with i < j.
    pairs = [(i, j) for i in range(m) for j in range(i + 1, m)]
    
    with ThreadPoolExecutor(max_workers=10) as executor:
        future_to_pair = {}
        for (i, j) in pairs:
            token_a = vocab_list[indices[i]].lstrip("Ġ")
            token_b = vocab_list[indices[j]].lstrip("Ġ")
            prompt = (f"Rate the similarity between the words '{token_a}' and '{token_b}'. "
                      "Answer with one of the numbers: 0, 0.5, 1.")
            future = executor.submit(query_llm_similarity, prompt)
            future_to_pair[future] = (i, j)
        
        for future in tqdm(as_completed(future_to_pair), total=len(future_to_pair), desc="LLM similarity queries"):
            i, j = future_to_pair[future]
            try:
                similarity = future.result()
            except Exception as e:
                similarity = 0.0
            C[i][j] = similarity
            C[j][i] = similarity

    return C

# --- Example usage ---
if __name__ == "__main__":
    from transformers import AutoTokenizer
    # For example, use a small model (e.g. 'distilgpt2') to extract the vocabulary.
    tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
    vocab_dict = tokenizer.get_vocab()  # mapping: token -> id
    # Build a list such that the index corresponds to the token ID.
    vocab_list = [None] * len(vocab_dict)
    for token, idx in vocab_dict.items():
        vocab_list[idx] = token

    # For demonstration, we use all tokens.
    indices = list(range(len(vocab_list)))
    print(f"Vocabulary size: {len(vocab_list)}")
    
    # Construct the similarity matrix using the LLM prompt–based method.
    similarity_matrix = construct_similarity_matrix_llm(vocab_list, indices)
    print("Similarity matrix constructed for a subset of tokens.")

  from .autonotebook import tqdm as notebook_tqdm


Vocabulary size: 50257


**Cos similarity**

In [None]:
import torch
import torch.nn.functional as F

def construct_full_similarity_matrix(embeddings: torch.Tensor) -> torch.Tensor:
    """
    Given a token embedding matrix of shape (n, d) (for n tokens),
    compute the full cosine similarity matrix (n x n) where each entry is
    mapped from [-1,1] to [0,1] using (cos_sim + 1)/2.
    
    Args:
        embeddings: Tensor of shape (n, d) representing token embeddings.
        
    Returns:
        sim_matrix: Tensor of shape (n, n) with cosine similarity values in [0, 1].
    """
    # Normalize each token embedding (L2 norm)
    normalized = F.normalize(embeddings, p=2, dim=1)
    # Compute cosine similarity for all pairs: (n x n) matrix
    cos_sim = torch.mm(normalized, normalized.t())
    # Map cosine similarity from [-1,1] to [0,1]
    sim_matrix = (cos_sim + 1) / 2
    return sim_matrix

# --- Example usage ---
if __name__ == "__main__":
    from transformers import AutoTokenizer, AutoModelForCausalLM

    # Load a model and its tokenizer (choose a small model for CPU)
    model_name = "distilgpt2"  # For example, a small GPT-2 variant.
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    device = "cpu"  # or "cuda" if available
    model.to(device)

    # Get the embedding matrix from the model's input embeddings.
    embedding_matrix = model.get_input_embeddings().weight  # shape: (vocab_size, d)
    vocab_size = embedding_matrix.shape[0]
    print(f"Vocabulary size: {vocab_size}")

    # Compute the full similarity matrix.
    sim_matrix = construct_full_similarity_matrix(embedding_matrix)
    print("Full similarity matrix computed.")
    # Optionally, convert to a Python list of lists:
    sim_matrix_list = sim_matrix.tolist()
    print("Converted similarity matrix to list format.")

Vocabulary size: 50257


# **Parameter selection**

# **Difference Decoding method**

# **Text length**

# **Subtitution Attack**

# **Paraphrase Attack**