In [107]:
# coding=utf-8
# Copyright 2023 Authors of "A Watermark for Large Language Models"
# available at https://arxiv.org/abs/2301.10226
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations
import argparse
from functools import partial

import numpy as np
from sklearn.metrics import roc_auc_score
import scipy.stats as stats
from datasets import load_dataset
import collections
from math import sqrt
import scipy.stats
import torch
from torch import Tensor
import nltk
import ssl
from nltk.util import ngrams
from transformers import LogitsProcessor
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    AutoModelForCausalLM,
    LogitsProcessorList,
)
from nltk.corpus import wordnet as wn
from functools import lru_cache
from tqdm import tqdm
import itertools
import networkx as nx

try:
    _create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
    pass
else:
    ssl._create_default_https_context = _create_unverified_https_context

nltk.download('wordnet')

#########################
# Helper Functions
#########################

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def parse_args():
    parser = argparse.ArgumentParser(
        description="A demo for watermarking using a revised green/red partition based on WordNet synonyms."
    )
    parser.add_argument("--demo_public", type=str2bool, default=False,
                        help="Expose the gradio demo publicly.")
    parser.add_argument("--model_name_or_path", type=str, default="facebook/opt-1.3b",
                        help="Identifier for the pretrained model from Hugging Face.")
    parser.add_argument("--prompt_max_length", type=int, default=None,
                        help="Truncation length for the prompt.")
    parser.add_argument("--max_new_tokens", type=int, default=200,
                        help="Maximum number of tokens to generate.")
    parser.add_argument("--generation_seed", type=int, default=123,
                        help="Seed for generation reproducibility.")
    parser.add_argument("--use_sampling", type=str2bool, default=True,
                        help="Use multinomial sampling for generation.")
    parser.add_argument("--sampling_temp", type=float, default=0.7,
                        help="Sampling temperature.")
    parser.add_argument("--n_beams", type=int, default=1,
                        help="Number of beams for beam search (if not sampling).")
    parser.add_argument("--use_gpu", type=str2bool, default=True,
                        help="Run inference on GPU if available.")
    parser.add_argument("--seeding_scheme", type=str, default="simple_1",
                        help="Seeding scheme for watermarking.")
    parser.add_argument("--gamma", type=float, default=0.25,
                        help="Target fraction of tokens for the green list.")
    parser.add_argument("--delta", type=float, default=2.0,
                        help="Bias to add to green list token logits.")
    parser.add_argument("--normalizers", type=str, default="",
                        help="Comma separated normalizer names for detection.")
    parser.add_argument("--ignore_repeated_bigrams", type=str2bool, default=False,
                        help="Use repeated bigram variant in detection.")
    parser.add_argument("--detection_z_threshold", type=float, default=4.0,
                        help="Z-score threshold for detection.")
    parser.add_argument("--select_green_tokens", type=str2bool, default=True,
                        help="Legacy option for selecting green tokens.")
    parser.add_argument("--skip_model_load", type=str2bool, default=False,
                        help="Skip model loading (for debugging).")
    parser.add_argument("--seed_separately", type=str2bool, default=True,
                        help="Seed separately for each generation call.")
    parser.add_argument("--load_fp16", type=str2bool, default=False,
                        help="Load model in FP16 mode.")
    args = parser.parse_args()
    # Convert normalizers into a list (if provided)
    args.normalizers = args.normalizers.split(",") if args.normalizers else []
    return args


#############################
# Preprocessing: Vocabulary & Matching
#############################

def get_vocabulary(tokenizer) -> list[str]:
    """
    Returns a list of tokens (where index corresponds to token ID)
    extracted from the tokenizer's vocabulary.
    """
    vocab_dict = tokenizer.get_vocab()
    vocab_list = [None] * len(vocab_dict)
    for token, idx in vocab_dict.items():
        vocab_list[idx] = token
    return vocab_list


@lru_cache(maxsize=None)
def get_lemma_set(token: str) -> set:
    """
    Given a token string, first strip any leading BPE marker (e.g. "Ġ")
    and return the set of lowercased lemma names from all its WordNet synsets.
    """
    # Strip off common prefix markers
    word = token.lstrip("Ġ")
    synsets = wn.synsets(word)
    return {lemma.lower() for s in synsets for lemma in s.lemma_names()}


def are_synonyms(token1: str, token2: str) -> bool:
    """
    Determines whether two tokens are synonyms by checking if token1 (after stripping)
    appears in token2's lemma set and vice versa.
    """
    word1 = token1.lstrip("Ġ")
    word2 = token2.lstrip("Ġ")
    lemmas1 = get_lemma_set(word1)
    lemmas2 = get_lemma_set(word2)
    if not lemmas1 or not lemmas2:
        return False
    return (word1.lower() in lemmas2) and (word2.lower() in lemmas1)


def filter_tokens_with_synonyms(vocab_list: list[str]) -> (list[int], list[int]):
    """
    Splits the vocabulary indices into:
      - unique_indices: indices of tokens that have no synonym in the vocabulary.
      - paired_indices: indices of tokens that have at least one synonym.
    Uses a progress bar via tqdm.
    """
    unique_indices = []
    paired_indices = []
    n = len(vocab_list)
    # Precompute lemma sets for each token (stripped of any leading marker)
    lemma_sets = [get_lemma_set(token) for token in vocab_list]
    for i in tqdm(range(n), desc="Filtering vocabulary"):
        token_i = vocab_list[i]
        lemmas_i = lemma_sets[i]
        has_synonym = False
        for j in range(n):
            if i == j:
                continue
            if (token_i.lstrip("Ġ")).lower() in lemma_sets[j] and (vocab_list[j].lstrip("Ġ")).lower() in lemmas_i:
                has_synonym = True
                break
        if has_synonym:
            paired_indices.append(i)
        else:
            unique_indices.append(i)
    return unique_indices, paired_indices


def construct_similarity_matrix(vocab_list: list[str], indices: list[int]) -> list[list[float]]:
    """
    Constructs an m x m similarity matrix for tokens specified by indices.
    Entry [i][j] is 1.0 if the tokens are synonyms, 0 otherwise.
    Uses nested loops with tqdm progress bar.
    """
    m = len(indices)
    C = [[0.0 for _ in range(m)] for _ in range(m)]
    # Precompute lemma sets for tokens in indices
    lemma_dict = {i: get_lemma_set(vocab_list[i].lstrip("Ġ")) for i in indices}
    for a in tqdm(range(m), desc="Constructing similarity matrix (outer loop)"):
        for b in range(a + 1, m):
            token_a = vocab_list[indices[a]].lstrip("Ġ")
            token_b = vocab_list[indices[b]].lstrip("Ġ")
            lemmas_a = lemma_dict[indices[a]]
            lemmas_b = lemma_dict[indices[b]]
            weight = 1.0 if (token_a.lower() in lemmas_b and token_b.lower() in lemmas_a) else 0.0
            C[a][b] = weight
            C[b][a] = weight
    return C


def find_perfect_matching(similarity_matrix: list[list[float]]) -> list[tuple[int, int]]:
    """
    Constructs an undirected graph from the similarity matrix (only edges with weight>0)
    and returns a maximum–weight matching (as a list of index pairs relative to the input list).
    Uses tqdm over the pairs.
    """
    m = len(similarity_matrix)
    G = nx.Graph()
    G.add_nodes_from(range(m))
    pairs = list(itertools.combinations(range(m), 2))
    for i, j in tqdm(pairs, total=len(pairs), desc="Building graph for matching"):
        weight = similarity_matrix[i][j]
        if weight > 0:
            G.add_edge(i, j, weight=weight)
    matching = nx.max_weight_matching(G, maxcardinality=True)
    pairing = [tuple(sorted(pair)) for pair in matching]
    return pairing


#############################
# Revised Watermark Processor Classes
#############################

class WatermarkBase:
    def __init__(
            self,
            vocab: list[int] = None,
            gamma: float = 0.25,
            delta: float = 2.0,
            seeding_scheme: str = "simple_1",
            hash_key: int = 15485863,
            select_green_tokens: bool = True,
            precomputed_pairing: list[tuple[int, int]] = None,
            unique_tokens: list[int] = None,
    ):
        self.vocab = vocab  # list of token IDs (usually 0, ..., n-1)
        self.vocab_size = len(vocab)
        self.gamma = gamma  # target fraction of tokens for the green list
        self.delta = delta  # bias added to green token logits
        self.seeding_scheme = seeding_scheme
        self.rng = None
        self.hash_key = hash_key
        self.select_green_tokens = select_green_tokens
        self.pairing = precomputed_pairing  # perfect matching (pairs) for tokens with synonyms
        self.unique_tokens = unique_tokens  # token IDs that have no synonyms

    #debug: replace input.device to device in _seed_rng, _get_greenlist_ids
    def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None:
        """
        Seeds the RNG deterministically using the last token in input_ids.
        For the "simple_1" scheme, the seed is hash_key * (last token id).
        """
        if seeding_scheme is None:
            seeding_scheme = self.seeding_scheme
        if self.rng is None:
            self.rng = torch.Generator(device=device)
        if seeding_scheme == "simple_1":
            assert input_ids.shape[-1] >= 1, "Input must have at least one token."
            prev_token = input_ids[-1].item()
            self.rng.manual_seed(self.hash_key * prev_token)
            #print(self.hash_key * prev_token)
        else:
            raise NotImplementedError(f"Seeding scheme {seeding_scheme} not implemented.")
        return

    def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
        """
        Returns the green list token IDs.
        If precomputed pairing and unique_tokens are provided, then:
          - All unique tokens (set A) are in the green list.
          - For each pair in the perfect matching, one token is chosen by a coin flip.
        Otherwise, falls back to a random permutation method.
        The final list is optionally truncated to a target size (gamma * vocab_size).
        """
        self._seed_rng(input_ids)
        
        if self.pairing is None or self.unique_tokens is None:
            # Fallback: use random permutation.
            greenlist_size = int(self.vocab_size * self.gamma)
            vocab_permutation = torch.randperm(self.vocab_size, device=device, generator=self.rng)
            if self.select_green_tokens:
                return vocab_permutation[:greenlist_size].tolist()
            else:
                return vocab_permutation[-greenlist_size:].tolist()
        else:
            greenlist_ids = self.unique_tokens.copy()
            for pair in self.pairing:
                coin_flip = (torch.rand(1, generator=self.rng, device=device).item() < 0.5)
                chosen = pair[0] if coin_flip else pair[1]
                greenlist_ids.append(chosen)
            # desired_size = int(self.vocab_size * self.gamma)
            # if len(greenlist_ids) > desired_size:
            #    perm = torch.randperm(len(greenlist_ids), generator=self.rng).tolist()
            #    indices = perm[:desired_size]
            #    greenlist_ids = [greenlist_ids[i] for i in indices]
            return greenlist_ids


class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
        green_tokens_mask = torch.zeros_like(scores)
        for b_idx in range(len(greenlist_token_ids)):
            green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1
        return green_tokens_mask.bool()

    def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor,
                               greenlist_bias: float) -> torch.Tensor:
        scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
        return scores

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if self.rng is None:
            self.rng = torch.Generator(device=input_ids.device)
        batched_greenlist_ids = []
        for b_idx in range(input_ids.shape[0]):
            greenlist_ids = self._get_greenlist_ids(input_ids[b_idx])
            batched_greenlist_ids.append(greenlist_ids)
        green_tokens_mask = self._calc_greenlist_mask(scores, batched_greenlist_ids)
        scores = self._bias_greenlist_logits(scores, green_tokens_mask, self.delta)
        #print(torch.max(scores))
        return scores


class WatermarkDetector(WatermarkBase):
    def __init__(
            self,
            *args,
            device: torch.device = None,
            tokenizer=None,
            z_threshold: float = 4.0,
            normalizers: list[str] = ["unicode"],
            ignore_repeated_bigrams: bool = True,
            **kwargs,
    ):
        super().__init__(*args, **kwargs)
        assert device, "Device must be provided."
        assert tokenizer, "A tokenizer is required for detection."
        self.tokenizer = tokenizer
        self.device = device
        self.z_threshold = z_threshold
        self.rng = torch.Generator(device=self.device)
        if self.seeding_scheme == "simple_1":
            self.min_prefix_len = 1
        else:
            raise NotImplementedError(f"Seeding scheme {self.seeding_scheme} not implemented.")
        self.normalizers = [normalization_strategy_lookup(norm) for norm in normalizers]
        self.ignore_repeated_bigrams = ignore_repeated_bigrams
        if self.ignore_repeated_bigrams:
            assert self.seeding_scheme == "simple_1", "Repeated bigram variant requires simple_1 seeding."

    def _compute_z_score(self, observed_count, T):
        expected_count = self.gamma
        numer = observed_count - expected_count * T
        denom = sqrt(T * expected_count * (1 - expected_count))
        return numer / denom

    def _compute_p_value(self, z):
        return scipy.stats.norm.sf(z)

    def _score_sequence(
            self,
            input_ids: Tensor,
            return_num_tokens_scored: bool = True,
            return_num_green_tokens: bool = True,
            return_green_fraction: bool = True,
            return_green_token_mask: bool = False,
            return_z_score: bool = True,
            return_p_value: bool = True,
    ):
        if self.ignore_repeated_bigrams:
            # Repeated bigram variant: T = number of unique bigrams.
            bigram_table = {}
            token_bigram_generator = ngrams(input_ids.cpu().tolist(), 2)
            freq = collections.Counter(token_bigram_generator)
            num_tokens_scored = len(freq.keys())
            for bigram in freq.keys():
                prefix = torch.tensor([bigram[0]], device=self.device)
                greenlist_ids = self._get_greenlist_ids(prefix)
                bigram_table[bigram] = bigram[1] in greenlist_ids
            green_token_count = sum(bigram_table.values())
        else:
            # Standard variant: T = total tokens (after min_prefix_len)
            num_tokens_scored = len(input_ids) - self.min_prefix_len
            if num_tokens_scored < 1:
                raise ValueError("Not enough tokens to score.")
            green_token_count = 0
            green_token_mask = []
            for idx in range(self.min_prefix_len, len(input_ids)):
                curr_token = input_ids[idx]
                greenlist_ids = self._get_greenlist_ids(input_ids[:idx])
                if curr_token in greenlist_ids:
                    green_token_count += 1
                    green_token_mask.append(True)
                else:
                    green_token_mask.append(False)
        # Debug prints:
        print(f"Total tokens scored (T): {num_tokens_scored}")
        print(f"Green token count: {green_token_count}")
        print(f"Green fraction: {green_token_count / num_tokens_scored:.2%}")

        score_dict = {}
        if return_num_tokens_scored:
            score_dict["num_tokens_scored"] = num_tokens_scored
        if return_num_green_tokens:
            score_dict["num_green_tokens"] = green_token_count
        if return_green_fraction:
            score_dict["green_fraction"] = green_token_count / num_tokens_scored
        if return_z_score:
            score_dict["z_score"] = self._compute_z_score(green_token_count, num_tokens_scored)
        if return_p_value:
            z = score_dict.get("z_score", self._compute_z_score(green_token_count, num_tokens_scored))
            score_dict["p_value"] = self._compute_p_value(z)
        if return_green_token_mask:
            score_dict["green_token_mask"] = green_token_mask
        return score_dict

    def detect(
            self,
            text: str = None,
            tokenized_text: list[int] = None,
            return_prediction: bool = True,
            return_scores: bool = True,
            z_threshold: float = None,
            **kwargs,
    ) -> dict:
        assert (text is not None) ^ (tokenized_text is not None), "Provide either raw or tokenized text."
        if return_prediction:
            kwargs["return_p_value"] = True
        for normalizer in self.normalizers:
            text = normalizer(text)
        if tokenized_text is None:
            tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(
                self.device)
            if tokenized_text[0] == self.tokenizer.bos_token_id:
                tokenized_text = tokenized_text[1:]
        else:
            if self.tokenizer is not None and tokenized_text[0] == self.tokenizer.bos_token_id:
                tokenized_text = tokenized_text[1:]
        output_dict = {}
        score_dict = self._score_sequence(tokenized_text, **kwargs)
        if return_scores:
            output_dict.update(score_dict)
        if return_prediction:
            z_threshold = z_threshold if z_threshold is not None else self.z_threshold
            output_dict["prediction"] = score_dict["z_score"] > z_threshold
            if output_dict["prediction"]:
                output_dict["confidence"] = 1 - score_dict["p_value"]
        return output_dict


#############################
# Demo Code (Generation & Detection)
#############################

def load_model(args):
    # Set model type attributes on args.
    args.is_seq2seq_model = any(model_type in args.model_name_or_path for model_type in ["t5", "T0"])
    args.is_decoder_only_model = any(model_type in args.model_name_or_path for model_type in ["gpt", "opt", "bloom"])

    if args.is_seq2seq_model:
        from transformers import AutoModelForSeq2SeqLM
        model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
    elif args.is_decoder_only_model:
        if args.load_fp16:
            model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch.float16,
                                                         device_map='auto')
        else:
            model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
    else:
        raise ValueError(f"Unknown model type: {args.model_name_or_path}")

    if args.use_gpu:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        if not args.load_fp16:
            model = model.to(device)
    else:
        device = "cpu"
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    return model, tokenizer, device

def generate(prompt, args, model=None, device=None, tokenizer=None):
    # print(f"Generating with {args}")
    # Set model type attributes on args.
    args.is_seq2seq_model = any(model_type in args.model_name_or_path for model_type in ["t5", "T0"])
    args.is_decoder_only_model = any(model_type in args.model_name_or_path for model_type in ["gpt", "opt", "bloom"])

    # Instantiate the watermark processor with precomputed pairing/unique tokens.
    # Assume that in main() we precomputed these values (see below).
    watermark_processor = WatermarkLogitsProcessor(
        vocab=list(tokenizer.get_vocab().values()),
        gamma=args.gamma,
        delta=args.delta,
        seeding_scheme=args.seeding_scheme,
        select_green_tokens=args.select_green_tokens,
        precomputed_pairing=args.precomputed_pairing,
        unique_tokens=args.unique_tokens
    )

    gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
    if args.use_sampling:
        gen_kwargs.update(dict(
            do_sample=True,
            top_k=0,
            temperature=args.sampling_temp
        ))
    else:
        gen_kwargs.update(dict(
            num_beams=args.n_beams
        ))

    generate_without_watermark = partial(
        model.generate,
        **gen_kwargs
    )
    generate_with_watermark = partial(
        model.generate,
        logits_processor=LogitsProcessorList([watermark_processor]),
        **gen_kwargs
    )

    if not args.prompt_max_length:
        if hasattr(model.config, "max_position_embedding"):
            args.prompt_max_length = model.config.max_position_embeddings - args.max_new_tokens
        else:
            args.prompt_max_length = 2048 - args.max_new_tokens

    tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True,
                           max_length=args.prompt_max_length).to(device)
    truncation_warning = tokd_input["input_ids"].shape[-1] == args.prompt_max_length
    redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]

    torch.manual_seed(args.generation_seed)
    output_without_watermark = generate_without_watermark(**tokd_input)
    if args.seed_separately:
        torch.manual_seed(args.generation_seed)
    output_with_watermark = generate_with_watermark(**tokd_input)

    if args.is_decoder_only_model:
        output_without_watermark = output_without_watermark[:, tokd_input["input_ids"].shape[-1]:]
        output_with_watermark = output_with_watermark[:, tokd_input["input_ids"].shape[-1]:]

    decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
    decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]

    return (
        redecoded_input, int(truncation_warning), decoded_output_without_watermark, decoded_output_with_watermark, args)


def format_names(s):
    s = s.replace("num_tokens_scored", "Tokens Counted (T)")
    s = s.replace("num_green_tokens", "# Tokens in Greenlist")
    s = s.replace("green_fraction", "Fraction of T in Greenlist")
    s = s.replace("z_score", "z-score")
    s = s.replace("p_value", "p value")
    s = s.replace("prediction", "Prediction")
    s = s.replace("confidence", "Confidence")
    return s


def list_format_scores(score_dict, detection_threshold):
    lst_2d = []
    for k, v in score_dict.items():
        if k == 'green_fraction':
            lst_2d.append([format_names(k), f"{v:.1%}"])
        elif k == 'confidence':
            lst_2d.append([format_names(k), f"{v:.3%}"])
        elif isinstance(v, float):
            lst_2d.append([format_names(k), f"{v:.3g}"])
        elif isinstance(v, bool):
            lst_2d.append([format_names(k), ("Watermarked" if v else "Human/Unwatermarked")])
        else:
            lst_2d.append([format_names(k), f"{v}"])
    if "confidence" in score_dict:
        lst_2d.insert(-2, ["z-score Threshold", f"{detection_threshold}"])
    else:
        lst_2d.insert(-1, ["z-score Threshold", f"{detection_threshold}"])
    return lst_2d


def detect(input_text, args, device=None, tokenizer=None):
    watermark_detector = WatermarkDetector(
        vocab=list(tokenizer.get_vocab().values()),
        gamma=args.gamma,
        seeding_scheme=args.seeding_scheme,
        device=device,
        tokenizer=tokenizer,
        z_threshold=args.detection_z_threshold,
        normalizers=args.normalizers,
        ignore_repeated_bigrams=args.ignore_repeated_bigrams,
        select_green_tokens=args.select_green_tokens
    )
    if len(input_text) - 1 > watermark_detector.min_prefix_len:
        score_dict = watermark_detector.detect(input_text)
        output = list_format_scores(score_dict, watermark_detector.z_threshold)
    else:
        output = [["Error", "string too short to compute metrics"]]
        output += [["", ""] for _ in range(6)]
    return output, args


def run_gradio(args, model=None, device=None, tokenizer=None):
    generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
    detect_partial = partial(detect, device=device, tokenizer=tokenizer)
    with gr.Blocks() as demo:
        gr.Markdown("Gradio demo not shown in command-line mode.")
        demo.launch()


def compute_p_value(green_count: int, tokens_scored: int) -> (float, float):
    """
    Given the number of paired tokens that appear in the green list (green_count)
    out of tokens_scored (only paired tokens are scored), compute a z–score and p–value.
    Under the null hypothesis, each paired token is green with probability 0.5.
    """
    import math
    expected = 0.5 * tokens_scored
    std = math.sqrt(0.25 * tokens_scored)
    z = (green_count - expected) / std if std > 0 else 0.0
    p = scipy.stats.norm.sf(z)
    return z, p


def compute_perplexity(model, tokenizer, text: str, device):
    """
    Computes perplexity of the text using the provided model and tokenizer.
    Here we use the language-modeling loss computed by the model.
    """
    input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=True)["input_ids"].to(device)
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
        loss = outputs.loss
    perplexity = torch.exp(loss).item()
    return perplexity

# === Cos similarity and greedy matching ===

import torch
import torch.nn.functional as F

def construct_similarity_matrix_cos(vocab_list: list[str], indices: list[int], embedding_matrix: torch.Tensor) -> torch.Tensor:
    """
    Constructs an m x m similarity matrix for tokens specified by indices,
    using cosine similarity between token embeddings.

    Args:
      vocab_list: List of all tokens (strings).
      indices: List of token indices corresponding to the non-unique tokens (set B).
      embedding_matrix: A torch.Tensor of shape (vocab_size, hidden_dim) representing the token embeddings.

    Returns:
      A torch.Tensor of shape (m, m) where each entry [i][j] is the cosine similarity between
      the embeddings of vocab_list[indices[i]] and vocab_list[indices[j]], with the diagonal entries set to 0.
    """
    # Extract embeddings for the selected tokens (non-unique tokens).
    selected_embeddings = embedding_matrix[indices]  # shape: (m, hidden_dim)

    # Normalize the embeddings along the feature dimension.
    norm_embeddings = F.normalize(selected_embeddings, p=2, dim=1)

    # Compute the cosine similarity matrix as the dot product between normalized embeddings.
    sim_matrix = torch.mm(norm_embeddings, norm_embeddings.t())

    # Set the diagonal entries to 0 (to ignore self-similarity).
    sim_matrix.fill_diagonal_(0)

    return sim_matrix

import math
import random

def find_perfect_matching_greedy_random(similarity_matrix: list[list[float]]) -> list[tuple[int, int]]:
    """
    Constructs a greedy matching from the similarity matrix using random sampling.
    In each iteration, a random token i is selected from the unmatched set.
    Then, approximately ceil(log2(n)) tokens (where n is the current number of unmatched tokens)
    are randomly sampled from the remaining tokens, and the token j with the highest similarity
    (i.e. highest value in similarity_matrix[i][j]) is selected as a match.

    The function returns a list of tuples (i, j) (with i < j) representing the matched token indices.
    Note: The similarity matrix should have 0 on its diagonal.
    """
    m = len(similarity_matrix)
    unmatched = list(range(m))
    matching = []

    pbar = tqdm(total=len(unmatched)//2, desc="Greedy random matching")
    while len(unmatched) > 1:
        n = len(unmatched)
        # Set sample size to ceil(log2(n)); ensure at least one candidate.
        sample_size = math.ceil(math.log(n, 2)) if n > 1 else 1
        # Randomly choose one token i from the unmatched set.
        i = random.choice(unmatched)
        # Build a list of candidates (all unmatched tokens except i).
        remaining = [x for x in unmatched if x != i]
        # Adjust sample_size if there are fewer candidates than sample_size.
        sample_size = min(sample_size, len(remaining))
        # Randomly sample sample_size candidates.
        candidates = random.sample(remaining, sample_size)
        # Find the candidate j with maximum similarity with i.
        best_j = candidates[0]
        best_weight = similarity_matrix[i][best_j]
        for j in candidates:
            w = similarity_matrix[i][j]
            if w > best_weight:
                best_weight = w
                best_j = j
        # Add the pair (min(i, best_j), max(i, best_j)) for consistency.
        matching.append((min(i, best_j), max(i, best_j)))
        # Remove both tokens from the unmatched set.
        unmatched.remove(i)
        if best_j in unmatched:
            unmatched.remove(best_j)
        pbar.update(1)
    pbar.close()
    return matching


import numpy as np
from sklearn.metrics import roc_auc_score
from tqdm import tqdm


# --- Helper: Check if both outputs meet length condition (≥195 tokens) ---
def valid_length(wm_text, nw_text, tokenizer, min_tokens=195):
    len_wm = len(tokenizer(wm_text)["input_ids"])
    len_nw = len(tokenizer(nw_text)["input_ids"])
    return (len_wm >= min_tokens) and (len_nw >= min_tokens)

# ---- Main Evaluation Function ----

import numpy as np
from tqdm import tqdm
from sklearn.metrics import roc_curve, auc  # Instead of roc_auc_score

def evaluate_watermarking(truncated_texts, model, tokenizer, args, args_cos=None):
    """
    For each prompt in truncated_texts, generate completions using three methods:
      1. Default watermark method (our method)
      2. KGW method (via generate_kgw and detect_kgw)
      3. Cosine–similarity based method (via generate with args_cos and detect_cos)

    For each method, compute:
      - Number of tokens generated
      - Detection metrics (only on paired tokens): green token count, tokens scored, green fraction,
        z–score and p–value, judgement (watermarked vs. non-watermarked), perplexity.

    Only prompts where all outputs (for all three methods) have length ≥195 tokens are considered.
    We then compute aggregated metrics (average perplexity) and an ROC–AUC by computing
    the false-positive and true-positive rates (via roc_curve) and passing them to auc().

    Returns:
      results: a list of per–prompt result dictionaries.
      aggregated: a dictionary with aggregated metrics for each method.
    """

    results = []

    # Accumulators for default method:
    wm_z_default_acc, nw_z_default_acc, labels_default = [], [], []
    ppl_wm_default_acc, ppl_nw_default_acc = [], []
    green_counts_w_default_acc, tokens_scored_w_default_acc, props_w_default_acc = [], [], []

    # For KGW method:
    wm_z_kgw_acc, nw_z_kgw_acc, labels_kgw = [], [], []
    ppl_wm_kgw_acc, ppl_nw_kgw_acc = [], []
    green_counts_w_kgw_acc, tokens_scored_w_kgw_acc, props_w_kgw_acc = [], [], []

    # For Cosine-based method:
    wm_z_cos_acc, nw_z_cos_acc, labels_cos = [], [], []
    ppl_wm_cos_acc, ppl_nw_cos_acc = [], []
    green_counts_w_cos_acc, tokens_scored_w_cos_acc, props_w_cos_acc = [], [], []

    for prompt in tqdm(truncated_texts, desc="Evaluating prompts"):
        # --- Default Watermark Generation ---
        redecoded_input, truncation_warning, decoded_nw, decoded_wm, _ = generate(
            prompt, args, model=model, device=device, tokenizer=tokenizer
        )
        #print('generate finish')
        
        wm_processor_default = WatermarkLogitsProcessor(
            vocab=list(tokenizer.get_vocab().values()),
            gamma=args.gamma,
            delta=args.delta,
            seeding_scheme=args.seeding_scheme,
            select_green_tokens=args.select_green_tokens,
            precomputed_pairing=args.precomputed_pairing,
            unique_tokens=args.unique_tokens
        )

        # Count green tokens (only among paired tokens) for watermarked text:
        green_count_w, tokens_scored_w, prop_w = count_green_tokens_paired(
            tokenizer, wm_processor_default, decoded_wm
        )
        z_default, p_default = compute_p_value(green_count_w, tokens_scored_w)

        # Non-watermarked text:
        green_count_nw, tokens_scored_nw, prop_nw = count_green_tokens_paired(
            tokenizer, wm_processor_default, decoded_nw
        )
        z_nw_default, p_nw_default = compute_p_value(green_count_nw, tokens_scored_nw)

        judgement_default = (
            "LLM-generated (watermarked)" if z_default > args.detection_z_threshold else "Human-generated (non-watermarked)"
        )
        perplexity_wm = compute_perplexity(model, tokenizer, decoded_wm, device)
        perplexity_nw = compute_perplexity(model, tokenizer, decoded_nw, device)
        tokens_generated_default = len(tokenizer(decoded_wm)["input_ids"])

        #print('detect finish')

        # --- KGW Method ---
        redecoded_input_kgw, truncation_warning_kgw, decoded_nw_kgw, decoded_wm_kgw, _ = generate_kgw(
            prompt, args, model=model, device=device, tokenizer=tokenizer
        )
        #print('generate finish')
        detect_result_w_kgw = detect_kgw(decoded_wm_kgw, args, device=device, tokenizer=tokenizer)[1]
        if len(decoded_wm_kgw) < 195:
            continue
        z_kgw = detect_result_w_kgw["z_score"]
        green_count_w_kgw = detect_result_w_kgw["num_green_tokens"]
        tokens_scored_w_kgw_local = detect_result_w_kgw["num_tokens_scored"]
        prop_w_kgw = detect_result_w_kgw["green_fraction"]

        detect_result_nw_kgw = detect_kgw(decoded_nw_kgw, args, device=device, tokenizer=tokenizer)[1]
        print(decoded_nw_kgw)
        if len(decoded_nw_kgw) < 195:
            continue
        z_nw_kgw = detect_result_nw_kgw["z_score"]
        green_count_nw_kgw = detect_result_nw_kgw["num_green_tokens"]

        judgement_kgw = (
            "LLM-generated (watermarked)" if z_kgw > args.detection_z_threshold else "Human-generated (non-watermarked)"
        )
        perplexity_wm_kgw = compute_perplexity(model, tokenizer, decoded_wm_kgw, device)
        perplexity_nw_kgw = compute_perplexity(model, tokenizer, decoded_nw_kgw, device)
        tokens_generated_kgw = len(tokenizer(decoded_wm_kgw)["input_ids"])

        #print('detect finish')
        # --- Cosine-based Method ---
        redecoded_input_cos, truncation_warning_cos, decoded_nw_cos, decoded_wm_cos, _ = generate(
            prompt, args_cos, model=model, device=device, tokenizer=tokenizer
        )
        wm_processor_cos = WatermarkLogitsProcessor(
            vocab=list(tokenizer.get_vocab().values()),
            gamma=args_cos.gamma,
            delta=args_cos.delta,
            seeding_scheme=args_cos.seeding_scheme,
            select_green_tokens=args_cos.select_green_tokens,
            precomputed_pairing=args_cos.precomputed_pairing,
            unique_tokens=args_cos.unique_tokens
        )
        green_count_w_cos, tokens_scored_w_cos, prop_w_cos = count_green_tokens_paired(
            tokenizer, wm_processor_cos, decoded_wm_cos
        )
        z_cos, p_cos = compute_p_value(green_count_w_cos, tokens_scored_w_cos)

        green_count_nw_cos, tokens_scored_nw_cos, prop_nw_cos = count_green_tokens_paired(
            tokenizer, wm_processor_cos, decoded_nw_cos
        )
        z_nw_cos, p_nw_cos = compute_p_value(green_count_nw_cos, tokens_scored_nw_cos)
        judgement_cos = (
            "LLM-generated (watermarked)" if z_cos > args_cos.detection_z_threshold else "Human-generated (non-watermarked)"
        )
        perplexity_wm_cos = compute_perplexity(model, tokenizer, decoded_wm_cos, device)
        perplexity_nw_cos = compute_perplexity(model, tokenizer, decoded_nw_cos, device)
        tokens_generated_cos = len(tokenizer(decoded_wm_cos)["input_ids"])

        # Store per-prompt results
        result = {
            "prompt": redecoded_input,
            "default": {
                "decoded_wm": decoded_wm,
                "decoded_nw": decoded_nw,
                "green_count_w": green_count_w,
                "green_count_nw": green_count_nw,
                "tokens_scored_w": tokens_scored_w,
                "tokens_scored_nw": tokens_scored_nw,
                "prop_w": prop_w,
                "z_w": z_default,
                "z_nw": z_nw_default,
                "p_w": p_default,
                "judgement": judgement_default,
                "ppl_wm": perplexity_wm,
                "ppl_nw": perplexity_nw,
                "tokens_generated": tokens_generated_default
            },
            "kgw": {
                "decoded_wm": decoded_wm_kgw,
                "decoded_nw": decoded_nw_kgw,
                "green_count_w": green_count_w_kgw,
                "tokens_scored_w": tokens_scored_w_kgw_local,
                "prop_w": prop_w_kgw,
                "z_w": z_kgw,
                "p_w": detect_result_w_kgw.get("p_value", None),
                "judgement": judgement_kgw,
                "ppl_wm": perplexity_wm_kgw,
                "ppl_nw": perplexity_nw_kgw,
                "tokens_generated": tokens_generated_kgw
            },
            "cos": {
                "decoded_wm": decoded_wm_cos,
                "decoded_nw": decoded_nw_cos,
                "green_count_w": green_count_w_cos,
                "tokens_scored_w": tokens_scored_w_cos,
                "prop_w": prop_w_cos,
                "z_w": z_cos,
                "p_w": p_cos,
                "judgement": judgement_cos,
                "ppl_wm": perplexity_wm_cos,
                "ppl_nw": perplexity_nw_cos,
                "tokens_generated": tokens_generated_cos
            }
        }
        results.append(result)
        # Optional prompt-level printing:
        print(result)

        # Only consider prompts where all three methods produce outputs of length ≥ 195 tokens.
        if (valid_length(decoded_wm, decoded_nw, tokenizer) and
            valid_length(decoded_wm_kgw, decoded_nw_kgw, tokenizer) and
            valid_length(decoded_wm_cos, decoded_nw_cos, tokenizer)):

            # Accumulate Default method metrics.
            wm_z_default_acc.append(z_default)
            nw_z_default_acc.append(z_nw_default)
            labels_default = [1] * len(wm_z_default_acc) + [0] * len(nw_z_default_acc)
            ppl_wm_default_acc.append(perplexity_wm)
            ppl_nw_default_acc.append(perplexity_nw)
            green_counts_w_default_acc.append(green_count_w)
            tokens_scored_w_default_acc.append(tokens_scored_w)

            # KGW method accumulators.
            wm_z_kgw_acc.append(z_kgw)
            nw_z_kgw_acc.append(z_nw_kgw)
            labels_kgw = [1] * len(wm_z_kgw_acc) + [0] * len(nw_z_kgw_acc)
            ppl_wm_kgw_acc.append(perplexity_wm_kgw)
            ppl_nw_kgw_acc.append(perplexity_nw_kgw)
            green_counts_w_kgw_acc.append(green_count_w_kgw)
            tokens_scored_w_kgw_acc.append(tokens_scored_w_kgw_local)

            # Cosine-based method accumulators.
            wm_z_cos_acc.append(z_cos)
            nw_z_cos_acc.append(z_nw_cos)
            labels_cos = [1] * len(wm_z_cos_acc) + [0] * len(nw_z_cos_acc)
            ppl_wm_cos_acc.append(perplexity_wm_cos)
            ppl_nw_cos_acc.append(perplexity_nw_cos)
            green_counts_w_cos_acc.append(green_count_w_cos)
            tokens_scored_w_cos_acc.append(tokens_scored_w_cos)

            num_valid = len(ppl_wm_default_acc)
            curr_avg_ppl_wm_default = np.mean(ppl_wm_default_acc)
            curr_avg_ppl_nw_default = np.mean(ppl_nw_default_acc)
            curr_avg_ppl_wm_kgw = np.mean(ppl_wm_kgw_acc)
            curr_avg_ppl_nw_kgw = np.mean(ppl_nw_kgw_acc)
            curr_avg_ppl_wm_cos = np.mean(ppl_wm_cos_acc)
            curr_avg_ppl_nw_cos = np.mean(ppl_nw_cos_acc)
            curr_avg_z_wm_default = np.mean(wm_z_default_acc)
            curr_avg_z_nw_default = np.mean(nw_z_default_acc)
            curr_avg_z_wm_kgw = np.mean(wm_z_kgw_acc)
            curr_avg_z_nw_kgw = np.mean(nw_z_kgw_acc)
            curr_avg_z_wm_cos = np.mean(wm_z_cos_acc)
            curr_avg_z_nw_cos = np.mean(nw_z_cos_acc)

            # -- Use roc_curve() and auc() here instead of roc_auc_score() --
            try:
                # For default method:
                all_scores_default = wm_z_default_acc + nw_z_default_acc
                fpr_def, tpr_def, _ = roc_curve(labels_default, all_scores_default)
                curr_auc_default = auc(fpr_def, tpr_def)
                #print(all_scores_default)
                #print(wm_z_default_acc)
                #print(nw_z_default_acc)
            except Exception:
                curr_auc_default = float('nan')

            try:
                # For KGW method:
                all_scores_kgw = wm_z_kgw_acc + nw_z_kgw_acc
                fpr_kgw, tpr_kgw, _ = roc_curve(labels_kgw, all_scores_kgw)
                curr_auc_kgw = auc(fpr_kgw, tpr_kgw)
                #print(all_scores_kgw)
                #print(wm_z_kgw_acc)
                #print(nw_z_kgw_acc)
            except Exception:
                curr_auc_kgw = float('nan')

            try:
                # For Cosine-based method:
                all_scores_cos = wm_z_cos_acc + nw_z_cos_acc
                fpr_cos, tpr_cos, _ = roc_curve(labels_cos, all_scores_cos)
                curr_auc_cos = auc(fpr_cos, tpr_cos)
                #print(all_scores_cos)
                #print(wm_z_cos_acc)
                #print(nw_z_cos_acc)
            except Exception:
                curr_auc_cos = float('nan')

            print(f"\nAfter {num_valid} valid prompts:")
            print(" Default Method:    avg ppl (wm) = {0:.2f}, avg ppl (nw) = {1:.2f}, AUC = {2:.3f}, avg z (wm) = {3:.2f}, avg z (nw) = {4:.2f}".format(
                curr_avg_ppl_wm_default, curr_avg_ppl_nw_default, curr_auc_default, curr_avg_z_wm_default, curr_avg_z_nw_default))
            print(" KGW Method:        avg ppl (wm) = {0:.2f}, avg ppl (nw) = {1:.2f}, AUC = {2:.3f}, avg z (wm) = {3:.2f}, avg z (nw) = {4:.2f}".format(
                curr_avg_ppl_wm_kgw, curr_avg_ppl_nw_kgw, curr_auc_kgw, curr_avg_z_wm_kgw, curr_avg_z_nw_kgw))
            print(" Cosine-based Method: avg ppl (wm) = {0:.2f}, avg ppl (nw) = {1:.2f}, AUC = {2:.3f}, avg z (wm) = {3:.2f}, avg z (nw) = {4:.2f}\n".format(
                curr_avg_ppl_wm_cos, curr_avg_ppl_nw_cos, curr_auc_cos, curr_avg_z_wm_cos, curr_avg_z_nw_cos))

    # Final aggregated metrics
    aggregated = {}

    # Summaries for each method:
    def finalize_metrics(labels, wm_z_list, nw_z_list, ppl_wm_list, ppl_nw_list, method_name):
        out_dict = {}
        if len(ppl_wm_list) > 0:
            out_dict["avg_ppl_wm"] = np.mean(ppl_wm_list)
            out_dict["avg_ppl_nw"] = np.mean(ppl_nw_list)
            out_dict["avg_z_wm"] = np.mean(wm_z_list)
            out_dict["avg_z_nw"] = np.mean(nw_z_list)
            out_dict["num_valid"] = len(ppl_wm_list)
        else:
            out_dict["avg_ppl_wm"] = None
            out_dict["avg_ppl_nw"] = None
            out_dict["avg_z_wm"] = None
            out_dict["avg_z_nw"] = None
            out_dict["num_valid"] = 0

        if wm_z_list and nw_z_list:
            all_scores = wm_z_list + nw_z_list
            try:
                fpr, tpr, _ = roc_curve(labels, all_scores)
                out_dict["auc"] = auc(fpr, tpr)
            except Exception:
                out_dict["auc"] = None
        else:
            out_dict["auc"] = None

        aggregated[method_name] = out_dict

    # Default
    finalize_metrics(labels_default,
                     wm_z_default_acc, nw_z_default_acc,
                     ppl_wm_default_acc, ppl_nw_default_acc,
                     "default")

    # KGW
    finalize_metrics(labels_kgw,
                     wm_z_kgw_acc, nw_z_kgw_acc,
                     ppl_wm_kgw_acc, ppl_nw_kgw_acc,
                     "kgw")

    # Cosine
    finalize_metrics(labels_cos,
                     wm_z_cos_acc, nw_z_cos_acc,
                     ppl_wm_cos_acc, ppl_nw_cos_acc,
                     "cos")

    return results, aggregated


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [108]:

# === Helper functions for detection metrics ===

def count_green_tokens_paired(tokenizer, watermark_processor, text: str) -> (int, int, float):
    """
    Tokenizes the input text and for each token (after a minimum prefix),
    computes the green list based on the current prefix (using watermark_processor).
    Only paired tokens (i.e. tokens that have a synonym pair) are counted (unique tokens are excluded).

    Returns:
      green_count: number of paired tokens that appear in the green list.
      tokens_scored: number of tokens scored (only tokens that are paired).
      proportion: green_count / tokens_scored.
    """
    # Tokenize without special tokens
    input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0]
    total_tokens = len(input_ids)
    # Use a minimum prefix length; if not defined in the processor, default to 1.
    start_idx = getattr(watermark_processor, "min_prefix_len", 1)
    green_count = 0
    tokens_scored = 0

    # For each token position starting at start_idx, consider only if the token is a paired token.
    for idx in range(start_idx, total_tokens):
        token = input_ids[idx].item()
        greenlist_ids = watermark_processor._get_greenlist_ids(input_ids[:idx])
        #print(token)
        # Only count if token is NOT in the unique set (i.e. token is paired)
        if watermark_processor.unique_tokens is not None and token not in watermark_processor.unique_tokens:
            tokens_scored += 1
            #print(token in greenlist_ids)
            if token in greenlist_ids:
                green_count += 1
                
    proportion = green_count / tokens_scored if tokens_scored > 0 else 0.0
    return green_count, tokens_scored, proportion

In [None]:


# --- Example usage ---
if __name__ == "__main__":
    import sys
    sys.argv = [sys.argv[0]]
    args = parse_args()
    model, tokenizer, device = load_model(args)

    # Assume truncated_texts is a list of 500 prompt strings (each truncated to less than 200 words).
    # For example, you may load them from a file or sample from a dataset.
    # Here we assume truncated_texts is already defined.

    # --- Precompute Vocabulary and Perfect Matching via Dictionary ---
    tokenizer_for_vocab = AutoTokenizer.from_pretrained(args.model_name_or_path)
    vocab_list = get_vocabulary(tokenizer_for_vocab)
    print(f"Vocabulary size: {len(vocab_list)}")
    unique_indices, paired_indices = filter_tokens_with_synonyms(vocab_list)
    print(f"Unique tokens (set A): {len(unique_indices)}")
    print(f"Tokens with synonyms (set B): {len(paired_indices)}")
    similarity_matrix = construct_similarity_matrix(vocab_list, paired_indices)
    matching = find_perfect_matching(similarity_matrix)
    mapped_pairing = [(paired_indices[i], paired_indices[j]) for (i, j) in matching]
    args.precomputed_pairing = mapped_pairing
    args.unique_tokens = unique_indices

    # --- Precompute Vocabulary and Perfect Matching via Cosine Similarity ---
    args_cos = parse_args()
    embedding_matrix = model.get_input_embeddings().weight  # shape: (vocab_size, hidden_dim)


In [12]:
    similarity_matrix_cos = construct_similarity_matrix_cos(vocab_list, paired_indices,embedding_matrix)
    matching_cos = find_perfect_matching_greedy_random(similarity_matrix_cos)
    mapped_pairing_cos = [(paired_indices[i], paired_indices[j]) for (i, j) in matching_cos]
    args_cos.precomputed_pairing = mapped_pairing_cos
    args_cos.unique_tokens = unique_indices

Greedy random matching: 100%|██████████| 11078/11078 [00:09<00:00, 1185.56it/s]


In [None]:

    # --- Load the "realnewslike" subset of C4 (English) and Shuffle the dataset with a fixed seed for reproducibility ---
    c4_realnewslike = load_dataset("c4", "realnewslike", split="train", streaming=False, trust_remote_code=True)
    shuffled_dataset = c4_realnewslike.shuffle(seed=45)
    sampled_examples = shuffled_dataset.select(range(300))
    sampled_texts = [example["text"] for example in sampled_examples]
    print(f"Sampled {len(sampled_texts)} news-like texts from C4.")
    max_words = 150
    truncated_texts = []
    for text in sampled_texts:
        words = text.split()  # split text into words
        truncated_text = " ".join(words[:max_words])
        truncated_texts.append(truncated_text)


**Delta = 5**

In [111]:
tune_delta = 5
args.delta = tune_delta
args_cos.delta = tune_delta

In [None]:

    results, aggregated = evaluate_watermarking(truncated_texts, model, tokenizer, args, args_cos)

    # Print per-prompt results for the first 5 prompts.
    for r in results[:5]:
        print("=== Prompt ===")
        print(r["prompt"])
        print("--- Default Method ---")
        print("Watermarked Text:")
        print(r["default"]["decoded_wm"])
        print("Detection (Default):")
        print(f"  Green tokens (paired): {r['default']['green_count_w']} / {r['default']['tokens_scored_w']} ({r['default']['prop_w']:.2%})")
        print(f"  z–score: {r['default']['z_w']:.2f}, p–value: {r['default']['p_w']:.4f}")
        print(f"  Judgement: {r['default']['judgement']}")
        print(f"  Perplexity: {r['default']['ppl_wm']:.2f}")
        print("--- KGW Method ---")
        print("Watermarked Text:")
        print(r["kgw"]["decoded_wm"])
        print("Detection (KGW):")
        print(f"  Green tokens (paired): {r['kgw']['green_count_w']} / {r['kgw']['tokens_scored_w']} ({r['kgw']['prop_w']:.2%})")
        print(f"  z–score: {r['kgw']['z_w']:.2f}, p–value: {r['kgw']['p_w']:.4f}")
        print(f"  Judgement: {r['kgw']['judgement']}")
        print(f"  Perplexity: {r['kgw']['ppl_wm']:.2f}")
        print("--- Cosine-based Method ---")
        print("Watermarked Text:")
        print(r["cos"]["decoded_wm"])
        print("Detection (Cos):")
        print(f"  Green tokens (paired): {r['cos']['green_count_w']} / {r['cos']['tokens_scored_w']} ({r['cos']['prop_w']:.2%})")
        print(f"  z–score: {r['cos']['z_w']:.2f}, p–value: {r['cos']['p_w']:.4f}")
        print(f"  Judgement: {r['cos']['judgement']}")
        print(f"  Perplexity: {r['cos']['ppl_wm']:.2f}")
        print("\n")

    # Print aggregated metrics.
    print("=== Aggregated Metrics ===")
    print("Default Method:")
    print("  Average Perplexity (Watermarked):", aggregated["default"]["avg_ppl_wm"])
    print("  Average Perplexity (Non-watermarked):", aggregated["default"]["avg_ppl_nw"])
    print("  Average z-score (Watermarked):", aggregated["default"]["avg_z_wm"])
    print("  Average z-score (Non-watermarked):", aggregated["default"]["avg_z_nw"])
    print("  AUC:", aggregated["default"]["auc"])
    print("  Valid prompts:", aggregated["default"]["num_valid"])

    print("\nKGW Method:")
    print("  Average Perplexity (Watermarked):", aggregated["kgw"]["avg_ppl_wm"])
    print("  Average Perplexity (Non-watermarked):", aggregated["kgw"]["avg_ppl_nw"])
    print("  Average z-score (Watermarked):", aggregated["kgw"]["avg_z_wm"])
    print("  Average z-score (Non-watermarked):", aggregated["kgw"]["avg_z_nw"])
    print("  AUC:", aggregated["kgw"]["auc"])
    print("  Valid prompts:", aggregated["kgw"]["num_valid"])

    print("\nCosine-based Method:")
    print("  Average Perplexity (Watermarked):", aggregated["cos"]["avg_ppl_wm"])
    print("  Average Perplexity (Non-watermarked):", aggregated["cos"]["avg_ppl_nw"])
    print("  Average z-score (Watermarked):", aggregated["cos"]["avg_z_wm"])
    print("  Average z-score (Non-watermarked):", aggregated["cos"]["avg_z_nw"])
    print("  AUC:", aggregated["cos"]["auc"])
    print("  Valid prompts:", aggregated["cos"]["num_valid"])

Evaluating prompts:   0%|          | 0/300 [00:00<?, ?it/s]



The most basic rule in investing is to invest only what you can afford to lose, but you can't afford to lose a lot. Researching companies is a great way to understand them and make your investing decisions.

The most basic rule in investing is to invest only what you can afford to lose, but you can't afford to lose a lot. Researching companies is a great way to understand them and make your investing decisions.

Picking the winners and losers in this increasingly volatile market is a challenge. The market goes up, then it goes down, then it goes up again, and then it goes down. How can you tell the difference? Here are a few simple tips.

The most basic rule in investing is to invest only what you can afford to lose, but you can't afford to lose a lot. Researching companies is a great way to understand them and make your investing decisions.

This is the second in a series of articles that


Evaluating prompts:   0%|          | 1/300 [02:36<13:00:49, 156.69s/it]

{'prompt': 'Choose any of the INOV videos above to watch, by clicking the associated image or headline. These results are drawn from the library of videos produced here at Market News Video, that have been tagged by an editor with the inov symbol. The date of each video is listed underneath the headline. Beneath the listing of inov videos is a current stock quote for inov and performance chart. At the bottom of the page, you will find related articles mentioning inov. From all of us here at Market News Video, we hope you will enjoy these inov videos and articles.', 'default': {'decoded_wm': '\n\nThe most recent inov topic was created on February, 27, 2013.\n\nThe most recent inov stock quote and inov stock performance was created on February, 27, 2013.\n\nMarket News Video\n\nMarket News Video is the source for millions of videos, covering various market events, including news, market analysis, technical insights, and other topics. Inov videos are created and updated by the inov Team, 

**Delta = 10**

In [None]:
tune_delta = 10
args.delta = tune_delta
args_cos.delta = tune_delta

In [None]:

    results, aggregated = evaluate_watermarking(truncated_texts, model, tokenizer, args, args_cos)

    # Print per-prompt results for the first 5 prompts.
    for r in results[:5]:
        print("=== Prompt ===")
        print(r["prompt"])
        print("--- Default Method ---")
        print("Watermarked Text:")
        print(r["default"]["decoded_wm"])
        print("Detection (Default):")
        print(f"  Green tokens (paired): {r['default']['green_count_w']} / {r['default']['tokens_scored_w']} ({r['default']['prop_w']:.2%})")
        print(f"  z–score: {r['default']['z_w']:.2f}, p–value: {r['default']['p_w']:.4f}")
        print(f"  Judgement: {r['default']['judgement']}")
        print(f"  Perplexity: {r['default']['ppl_wm']:.2f}")
        print("--- KGW Method ---")
        print("Watermarked Text:")
        print(r["kgw"]["decoded_wm"])
        print("Detection (KGW):")
        print(f"  Green tokens (paired): {r['kgw']['green_count_w']} / {r['kgw']['tokens_scored_w']} ({r['kgw']['prop_w']:.2%})")
        print(f"  z–score: {r['kgw']['z_w']:.2f}, p–value: {r['kgw']['p_w']:.4f}")
        print(f"  Judgement: {r['kgw']['judgement']}")
        print(f"  Perplexity: {r['kgw']['ppl_wm']:.2f}")
        print("--- Cosine-based Method ---")
        print("Watermarked Text:")
        print(r["cos"]["decoded_wm"])
        print("Detection (Cos):")
        print(f"  Green tokens (paired): {r['cos']['green_count_w']} / {r['cos']['tokens_scored_w']} ({r['cos']['prop_w']:.2%})")
        print(f"  z–score: {r['cos']['z_w']:.2f}, p–value: {r['cos']['p_w']:.4f}")
        print(f"  Judgement: {r['cos']['judgement']}")
        print(f"  Perplexity: {r['cos']['ppl_wm']:.2f}")
        print("\n")

    # Print aggregated metrics.
    print("=== Aggregated Metrics ===")
    print("Default Method:")
    print("  Average Perplexity (Watermarked):", aggregated["default"]["avg_ppl_wm"])
    print("  Average Perplexity (Non-watermarked):", aggregated["default"]["avg_ppl_nw"])
    print("  Average z-score (Watermarked):", aggregated["default"]["avg_z_wm"])
    print("  Average z-score (Non-watermarked):", aggregated["default"]["avg_z_nw"])
    print("  AUC:", aggregated["default"]["auc"])
    print("  Valid prompts:", aggregated["default"]["num_valid"])

    print("\nKGW Method:")
    print("  Average Perplexity (Watermarked):", aggregated["kgw"]["avg_ppl_wm"])
    print("  Average Perplexity (Non-watermarked):", aggregated["kgw"]["avg_ppl_nw"])
    print("  Average z-score (Watermarked):", aggregated["kgw"]["avg_z_wm"])
    print("  Average z-score (Non-watermarked):", aggregated["kgw"]["avg_z_nw"])
    print("  AUC:", aggregated["kgw"]["auc"])
    print("  Valid prompts:", aggregated["kgw"]["num_valid"])

    print("\nCosine-based Method:")
    print("  Average Perplexity (Watermarked):", aggregated["cos"]["avg_ppl_wm"])
    print("  Average Perplexity (Non-watermarked):", aggregated["cos"]["avg_ppl_nw"])
    print("  Average z-score (Watermarked):", aggregated["cos"]["avg_z_wm"])
    print("  Average z-score (Non-watermarked):", aggregated["cos"]["avg_z_nw"])
    print("  AUC:", aggregated["cos"]["auc"])
    print("  Valid prompts:", aggregated["cos"]["num_valid"])

In [6]:
# coding=utf-8
# Copyright 2023 Authors of "A Watermark for Large Language Models"
# available at https://arxiv.org/abs/2301.10226
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import argparse
from argparse import Namespace
from pprint import pprint
from functools import partial

import numpy 

import torch

from transformers import (AutoTokenizer,
                          AutoModelForSeq2SeqLM,
                          AutoModelForCausalLM,
                          LogitsProcessorList)



def str2bool(v):
    """Util function for user friendly boolean flag args"""
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def parse_args():
    """Command line argument specification"""

    parser = argparse.ArgumentParser(
        description="A minimum working example of applying the watermark to any LLM that supports the huggingface 🤗 `generate` API")

    #parser.add_argument(
    #    "--run_gradio",
    #    type=str2bool,
    #    default=False,
    #    help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.",
    #)
    parser.add_argument(
        "--demo_public",
        type=str2bool,
        default=False,
        help="Whether to expose the gradio demo to the internet.",
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default="facebook-opt1.3b",#"gpt2-medium",#
        help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--prompt_max_length",
        type=int,
        default=None,
        help="Truncation length for prompt, overrides model config's max length field.",
    )
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=200,
        help="Maximmum number of new tokens to generate.",
    )
    parser.add_argument(
        "--generation_seed",
        type=int,
        default=123,
        help="Seed for setting the torch global rng prior to generation.",
    )
    parser.add_argument(
        "--use_sampling",
        type=str2bool,
        default=True,
        help="Whether to generate using multinomial sampling.",
    )
    parser.add_argument(
        "--sampling_temp",
        type=float,
        default=0.7,
        help="Sampling temperature to use when generating using multinomial sampling.",
    )
    parser.add_argument(
        "--n_beams",
        type=int,
        default=1,
        help="Number of beams to use for beam search. 1 is normal greedy decoding",
    )
    parser.add_argument(
        "--use_gpu",
        type=str2bool,
        default=True,
        help="Whether to run inference and watermark hashing/seeding/permutation on gpu.",
    )
    parser.add_argument(
        "--seeding_scheme",
        type=str,
        default="simple_1",
        help="Seeding scheme to use to generate the greenlists at each generation and verification step.",
    )
    parser.add_argument(
        "--gamma",
        type=float,
        default=0.25,
        help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.",
    )
    parser.add_argument(
        "--delta",
        type=float,
        default=2.0,
        help="The amount/bias to add to each of the greenlist token logits before each token sampling step.",
    )
    parser.add_argument(
        "--normalizers",
        type=str,
        default="",
        help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.",
    )
    parser.add_argument(
        "--ignore_repeated_bigrams",
        type=str2bool,
        default=False,
        help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.",
    )
    parser.add_argument(
        "--detection_z_threshold",
        type=float,
        default=4.0,
        help="The test statistic threshold for the detection hypothesis test.",
    )
    parser.add_argument(
        "--select_green_tokens",
        type=str2bool,
        default=True,
        help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.",
    )
    parser.add_argument(
        "--skip_model_load",
        type=str2bool,
        default=False,
        help="Skip the model loading to debug the interface.",
    )
    parser.add_argument(
        "--seed_separately",
        type=str2bool,
        default=True,
        help="Whether to call the torch seed function before both the unwatermarked and watermarked generate calls.",
    )
    parser.add_argument(
        "--load_fp16",
        type=str2bool,
        default=False,
        help="Whether to run model in float16 precsion.",
    )
    args = parser.parse_args()
    return args


def load_model(args):
    """Load and return the model and tokenizer"""

    args.is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5", "T0"]])
    args.is_decoder_only_model = any(
        [(model_type in args.model_name_or_path) for model_type in ["gpt", "opt", "bloom"]])
    if args.is_seq2seq_model:
        model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
    elif args.is_decoder_only_model:
        if args.load_fp16:
            model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch.float16,
                                                         device_map='auto')
        else:
            model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
    else:
        raise ValueError(f"Unknown model type: {args.model_name_or_path}")

    if args.use_gpu:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        if args.load_fp16:
            pass
        else:
            model = model.to(device)
    else:
        device = "cpu"
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

    return model, tokenizer, device


def generate_kgw(prompt, args, model=None, device=None, tokenizer=None):
    """Instatiate the WatermarkLogitsProcessor according to the watermark parameters
       and generate watermarked text by passing it to the generate method of the model
       as a logits processor. """

    #print(f"Generating with {args}")

    watermark_processor = WatermarkLogitsProcessor_kgw(vocab=list(tokenizer.get_vocab().values()),
                                                   gamma=args.gamma,
                                                   delta=args.delta,
                                                   seeding_scheme=args.seeding_scheme,
                                                   select_green_tokens=args.select_green_tokens)

    gen_kwargs = dict(max_new_tokens=args.max_new_tokens)

    if args.use_sampling:
        gen_kwargs.update(dict(
            do_sample=True,
            top_k=0,
            temperature=args.sampling_temp
        ))
    else:
        gen_kwargs.update(dict(
            num_beams=args.n_beams
        ))

    generate_without_watermark = partial(
        model.generate,
        **gen_kwargs
    )
    generate_with_watermark = partial(
        model.generate,
        logits_processor=LogitsProcessorList([watermark_processor]),
        **gen_kwargs
    )
    if args.prompt_max_length:
        pass
    elif hasattr(model.config, "max_position_embedding"):
        args.prompt_max_length = model.config.max_position_embeddings - args.max_new_tokens
    else:
        args.prompt_max_length = 2048 - args.max_new_tokens

    tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True,
                           max_length=args.prompt_max_length).to(device)
    truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
    redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]

    torch.manual_seed(args.generation_seed)
    output_without_watermark = generate_without_watermark(**tokd_input)

    # optional to seed before second generation, but will not be the same again generally, unless delta==0.0, no-op watermark
    if args.seed_separately:
        torch.manual_seed(args.generation_seed)
    output_with_watermark = generate_with_watermark(**tokd_input)

    if args.is_decoder_only_model:
        # need to isolate the newly generated tokens
        output_without_watermark = output_without_watermark[:, tokd_input["input_ids"].shape[-1]:]
        output_with_watermark = output_with_watermark[:, tokd_input["input_ids"].shape[-1]:]

    decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
    decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]

    return (redecoded_input,
            int(truncation_warning),
            decoded_output_without_watermark,
            decoded_output_with_watermark,
            args)
    # decoded_output_with_watermark)


def format_names(s):
    """Format names for the gradio demo interface"""
    s = s.replace("num_tokens_scored", "Tokens Counted (T)")
    s = s.replace("num_green_tokens", "# Tokens in Greenlist")
    s = s.replace("green_fraction", "Fraction of T in Greenlist")
    s = s.replace("z_score", "z-score")
    s = s.replace("p_value", "p value")
    s = s.replace("prediction", "Prediction")
    s = s.replace("confidence", "Confidence")
    return s


def list_format_scores(score_dict, detection_threshold):
    """Format the detection metrics into a gradio dataframe input format"""
    lst_2d = []
    # lst_2d.append(["z-score threshold", f"{detection_threshold}"])
    for k, v in score_dict.items():
        if k == 'green_fraction':
            lst_2d.append([format_names(k), f"{v:.1%}"])
        elif k == 'confidence':
            lst_2d.append([format_names(k), f"{v:.3%}"])
        elif isinstance(v, float):
            lst_2d.append([format_names(k), f"{v:.3g}"])
        elif isinstance(v, bool):
            lst_2d.append([format_names(k), ("Watermarked" if v else "Human/Unwatermarked")])
        else:
            lst_2d.append([format_names(k), f"{v}"])
    if "confidence" in score_dict:
        lst_2d.insert(-2, ["z-score Threshold", f"{detection_threshold}"])
    else:
        lst_2d.insert(-1, ["z-score Threshold", f"{detection_threshold}"])
    return lst_2d


def detect_kgw(input_text, args, device=None, tokenizer=None):
    """Instantiate the WatermarkDetection object and call detect on
        the input text returning the scores and outcome of the test"""
    watermark_detector = WatermarkDetector_kgw(vocab=list(tokenizer.get_vocab().values()),
                                           gamma=args.gamma,
                                           seeding_scheme=args.seeding_scheme,
                                           device=device,
                                           tokenizer=tokenizer,
                                           z_threshold=args.detection_z_threshold,
                                           normalizers=args.normalizers,
                                           ignore_repeated_bigrams=args.ignore_repeated_bigrams,
                                           select_green_tokens=args.select_green_tokens)
    if len(input_text) - 1 > watermark_detector.min_prefix_len:
        score_dict = watermark_detector.detect(input_text)
        # output = str_format_scores(score_dict, watermark_detector.z_threshold)
        output = list_format_scores(score_dict, watermark_detector.z_threshold)
    else:
        score_dict = None
        # output = (f"Error: string not long enough to compute watermark presence.")
        output = [["Error", "string too short to compute metrics"]]
        output += [["", ""] for _ in range(6)]
    return output, score_dict, args


def run_gradio(args, model=None, device=None, tokenizer=None):
    """Define and launch the gradio demo interface"""
    generate_partial = partial(generate_kgw, model=model, device=device, tokenizer=tokenizer)
    detect_partial = partial(detect_kgw, device=device, tokenizer=tokenizer)

    with gr.Blocks() as demo:
        # Top section, greeting and instructions
        with gr.Row():
            with gr.Column(scale=9):
                gr.Markdown(
                    """
                    ## 💧 [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) 🔍
                    """
                )
            with gr.Column(scale=1):
                gr.Markdown(
                    """
                    [![](https://badgen.net/badge/icon/GitHub?icon=github&label)](https://github.com/jwkirchenbauer/lm-watermarking)
                    """
                )
            # with gr.Column(scale=2):
            #     pass
            # ![visitor badge](https://visitor-badge.glitch.me/badge?page_id=tomg-group-umd_lm-watermarking) # buggy

        with gr.Accordion("Understanding the output metrics", open=False):
            gr.Markdown(
                """
                - `z-score threshold` : The cuttoff for the hypothesis test
                - `Tokens Counted (T)` : The number of tokens in the output that were counted by the detection algorithm. 
                    The first token is ommitted in the simple, single token seeding scheme since there is no way to generate
                    a greenlist for it as it has no prefix token(s). Under the "Ignore Bigram Repeats" detection algorithm, 
                    described in the bottom panel, this can be much less than the total number of tokens generated if there is a lot of repetition.
                - `# Tokens in Greenlist` : The number of tokens that were observed to fall in their respective greenlist
                - `Fraction of T in Greenlist` : The `# Tokens in Greenlist` / `T`. This is expected to be approximately `gamma` for human/unwatermarked text.
                - `z-score` : The test statistic for the detection hypothesis test. If larger than the `z-score threshold` 
                    we "reject the null hypothesis" that the text is human/unwatermarked, and conclude it is watermarked
                - `p value` : The likelihood of observing the computed `z-score` under the null hypothesis. This is the likelihood of 
                    observing the `Fraction of T in Greenlist` given that the text was generated without knowledge of the watermark procedure/greenlists.
                    If this is extremely _small_ we are confident that this many green tokens was not chosen by random chance.
                -  `prediction` : The outcome of the hypothesis test - whether the observed `z-score` was higher than the `z-score threshold`
                - `confidence` : If we reject the null hypothesis, and the `prediction` is "Watermarked", then we report 1-`p value` to represent 
                    the confidence of the detection based on the unlikeliness of this `z-score` observation.
                """
            )

        with gr.Accordion("A note on model capability", open=True):
            gr.Markdown(
                """
                This demo uses open-source language models that fit on a single GPU. These models are less powerful than proprietary commercial tools like ChatGPT, Claude, or Bard. 

                Importantly, we use a language model that is designed to "complete" your prompt, and not a model this is fine-tuned to follow instructions. 
                For best results, prompt the model with a few sentences that form the beginning of a paragraph, and then allow it to "continue" your paragraph. 
                Some examples include the opening paragraph of a wikipedia article, or the first few sentences of a story. 
                Longer prompts that end mid-sentence will result in more fluent generations.
                """
            )
        gr.Markdown(f"Language model: {args.model_name_or_path} {'(float16 mode)' if args.load_fp16 else ''}")

        # Construct state for parameters, define updates and toggles
        default_prompt = args.__dict__.pop("default_prompt")
        session_args = gr.State(value=args)

        with gr.Tab("Generate and Detect"):

            with gr.Row():
                prompt = gr.Textbox(label=f"Prompt", interactive=True, lines=10, max_lines=10, value=default_prompt)
            with gr.Row():
                generate_btn = gr.Button("Generate")
            with gr.Row():
                with gr.Column(scale=2):
                    output_without_watermark = gr.Textbox(label="Output Without Watermark", interactive=False, lines=14,
                                                          max_lines=14)
                with gr.Column(scale=1):
                    # without_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
                    without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,
                                                                      row_count=7, col_count=2)
            with gr.Row():
                with gr.Column(scale=2):
                    output_with_watermark = gr.Textbox(label="Output With Watermark", interactive=False, lines=14,
                                                       max_lines=14)
                with gr.Column(scale=1):
                    # with_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
                    with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,
                                                                   row_count=7, col_count=2)

            redecoded_input = gr.Textbox(visible=False)
            truncation_warning = gr.Number(visible=False)

            def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
                if truncation_warning:
                    return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
                else:
                    return orig_prompt, args

        with gr.Tab("Detector Only"):
            with gr.Row():
                with gr.Column(scale=2):
                    detection_input = gr.Textbox(label="Text to Analyze", interactive=True, lines=14, max_lines=14)
                with gr.Column(scale=1):
                    # detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
                    detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False, row_count=7,
                                                    col_count=2)
            with gr.Row():
                detect_btn = gr.Button("Detect")

        # Parameter selection group
        with gr.Accordion("Advanced Settings", open=False):
            with gr.Row():
                with gr.Column(scale=1):
                    gr.Markdown(f"#### Generation Parameters")
                    with gr.Row():
                        decoding = gr.Radio(label="Decoding Method", choices=["multinomial", "greedy"],
                                            value=("multinomial" if args.use_sampling else "greedy"))
                    with gr.Row():
                        sampling_temp = gr.Slider(label="Sampling Temperature", minimum=0.1, maximum=1.0, step=0.1,
                                                  value=args.sampling_temp, visible=True)
                    with gr.Row():
                        generation_seed = gr.Number(label="Generation Seed", value=args.generation_seed,
                                                    interactive=True)
                    with gr.Row():
                        n_beams = gr.Dropdown(label="Number of Beams", choices=list(range(1, 11, 1)),
                                              value=args.n_beams, visible=(not args.use_sampling))
                    with gr.Row():
                        max_new_tokens = gr.Slider(label="Max Generated Tokens", minimum=10, maximum=1000, step=10,
                                                   value=args.max_new_tokens)

                with gr.Column(scale=1):
                    gr.Markdown(f"#### Watermark Parameters")
                    with gr.Row():
                        gamma = gr.Slider(label="gamma", minimum=0.1, maximum=0.9, step=0.05, value=args.gamma)
                    with gr.Row():
                        delta = gr.Slider(label="delta", minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
                    gr.Markdown(f"#### Detector Parameters")
                    with gr.Row():
                        detection_z_threshold = gr.Slider(label="z-score threshold", minimum=0.0, maximum=10.0,
                                                          step=0.1, value=args.detection_z_threshold)
                    with gr.Row():
                        ignore_repeated_bigrams = gr.Checkbox(label="Ignore Bigram Repeats")
                    with gr.Row():
                        normalizers = gr.CheckboxGroup(label="Normalizations",
                                                       choices=["unicode", "homoglyphs", "truecase"],
                                                       value=args.normalizers)
            # with gr.Accordion("Actual submitted parameters:",open=False):
            with gr.Row():
                gr.Markdown(
                    f"_Note: sliders don't always update perfectly. Clicking on the bar or using the number window to the right can help. Window below shows the current settings._")
            with gr.Row():
                current_parameters = gr.Textbox(label="Current Parameters", value=args)
            with gr.Accordion("Legacy Settings", open=False):
                with gr.Row():
                    with gr.Column(scale=1):
                        seed_separately = gr.Checkbox(label="Seed both generations separately",
                                                      value=args.seed_separately)
                    with gr.Column(scale=1):
                        select_green_tokens = gr.Checkbox(label="Select 'greenlist' from partition",
                                                          value=args.select_green_tokens)

        with gr.Accordion("Understanding the settings", open=False):
            gr.Markdown(
                """
                #### Generation Parameters:
    
                - Decoding Method : We can generate tokens from the model using either multinomial sampling or we can generate using greedy decoding.
                - Sampling Temperature : If using multinomial sampling we can set the temperature of the sampling distribution. 
                                    0.0 is equivalent to greedy decoding, and 1.0 is the maximum amount of variability/entropy in the next token distribution.
                                    0.7 strikes a nice balance between faithfulness to the model's estimate of top candidates while adding variety. Does not apply for greedy decoding.
                - Generation Seed : The integer to pass to the torch random number generator before running generation. Makes the multinomial sampling strategy
                                    outputs reproducible. Does not apply for greedy decoding.
                - Number of Beams : When using greedy decoding, we can also set the number of beams to > 1 to enable beam search. 
                                    This is not implemented/excluded from paper for multinomial sampling but may be added in future.
                - Max Generated Tokens : The `max_new_tokens` parameter passed to the generation method to stop the output at a certain number of new tokens. 
                                        Note that the model is free to generate fewer tokens depending on the prompt. 
                                        Implicitly this sets the maximum number of prompt tokens possible as the model's maximum input length minus `max_new_tokens`,
                                        and inputs will be truncated accordingly.
    
                #### Watermark Parameters:
    
                - gamma : The fraction of the vocabulary to be partitioned into the greenlist at each generation step. 
                         Smaller gamma values create a stronger watermark by enabling the watermarked model to achieve 
                         a greater differentiation from human/unwatermarked text because it is preferentially sampling 
                         from a smaller green set making those tokens less likely to occur by chance.
                - delta : The amount of positive bias to add to the logits of every token in the greenlist 
                            at each generation step before sampling/choosing the next token. Higher delta values 
                            mean that the greenlist tokens are more heavily preferred by the watermarked model
                            and as the bias becomes very large the watermark transitions from "soft" to "hard". 
                            For a hard watermark, nearly all tokens are green, but this can have a detrimental effect on
                            generation quality, especially when there is not a lot of flexibility in the distribution.
    
                #### Detector Parameters:
    
                - z-score threshold : the z-score cuttoff for the hypothesis test. Higher thresholds (such as 4.0) make
                                    _false positives_ (predicting that human/unwatermarked text is watermarked) very unlikely
                                    as a genuine human text with a significant number of tokens will almost never achieve 
                                    that high of a z-score. Lower thresholds will capture more _true positives_ as some watermarked
                                    texts will contain less green tokens and achive a lower z-score, but still pass the lower bar and 
                                    be flagged as "watermarked". However, a lowere threshold will increase the chance that human text 
                                    that contains a slightly higher than average number of green tokens is erroneously flagged. 
                                    4.0-5.0 offers extremely low false positive rates while still accurately catching most watermarked text.
                - Ignore Bigram Repeats : This alternate detection algorithm only considers the unique bigrams in the text during detection, 
                                        computing the greenlists based on the first in each pair and checking whether the second falls within the list.
                                        This means that `T` is now the unique number of bigrams in the text, which becomes less than the total
                                        number of tokens generated if the text contains a lot of repetition. See the paper for a more detailed discussion.
                - Normalizations : we implement a few basic normaliations to defend against various adversarial perturbations of the
                                    text analyzed during detection. Currently we support converting all chracters to unicode, 
                                    replacing homoglyphs with a canonical form, and standardizing the capitalization. 
                                    See the paper for a detailed discussion of input normalization. 
                """
            )

        gr.HTML("""
                <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. 
                    Follow the github link at the top and host the demo on your own GPU hardware to test out larger models.
                <br/>
                <a href="https://huggingface.co/spaces/tomg-group-umd/lm-watermarking?duplicate=true">
                <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
                <p/>
                """)

        # Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag
        generate_btn.click(fn=generate_partial, inputs=[prompt, session_args],
                           outputs=[redecoded_input, truncation_warning, output_without_watermark,
                                    output_with_watermark, session_args])
        # Show truncated version of prompt if truncation occurred
        redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input, truncation_warning, prompt, session_args],
                               outputs=[prompt, session_args])
        # Call detection when the outputs (of the generate function) are updated
        output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark, session_args],
                                        outputs=[without_watermark_detection_result, session_args])
        output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark, session_args],
                                     outputs=[with_watermark_detection_result, session_args])
        # Register main detection tab click
        detect_btn.click(fn=detect_partial, inputs=[detection_input, session_args],
                         outputs=[detection_result, session_args])

        # State management logic
        # update callbacks that change the state dict
        def update_sampling_temp(session_state, value):
            session_state.sampling_temp = float(value); return session_state

        def update_generation_seed(session_state, value):
            session_state.generation_seed = int(value); return session_state

        def update_gamma(session_state, value):
            session_state.gamma = float(value); return session_state

        def update_delta(session_state, value):
            session_state.delta = float(value); return session_state

        def update_detection_z_threshold(session_state, value):
            session_state.detection_z_threshold = float(value); return session_state

        def update_decoding(session_state, value):
            if value == "multinomial":
                session_state.use_sampling = True
            elif value == "greedy":
                session_state.use_sampling = False
            return session_state

        def toggle_sampling_vis(value):
            if value == "multinomial":
                return gr.update(visible=True)
            elif value == "greedy":
                return gr.update(visible=False)

        def toggle_sampling_vis_inv(value):
            if value == "multinomial":
                return gr.update(visible=False)
            elif value == "greedy":
                return gr.update(visible=True)

        def update_n_beams(session_state, value):
            session_state.n_beams = value; return session_state

        def update_max_new_tokens(session_state, value):
            session_state.max_new_tokens = int(value); return session_state

        def update_ignore_repeated_bigrams(session_state, value):
            session_state.ignore_repeated_bigrams = value; return session_state

        def update_normalizers(session_state, value):
            session_state.normalizers = value; return session_state

        def update_seed_separately(session_state, value):
            session_state.seed_separately = value; return session_state

        def update_select_green_tokens(session_state, value):
            session_state.select_green_tokens = value; return session_state

        # registering callbacks for toggling the visibilty of certain parameters
        decoding.change(toggle_sampling_vis, inputs=[decoding], outputs=[sampling_temp])
        decoding.change(toggle_sampling_vis, inputs=[decoding], outputs=[generation_seed])
        decoding.change(toggle_sampling_vis_inv, inputs=[decoding], outputs=[n_beams])
        # registering all state update callbacks
        decoding.change(update_decoding, inputs=[session_args, decoding], outputs=[session_args])
        sampling_temp.change(update_sampling_temp, inputs=[session_args, sampling_temp], outputs=[session_args])
        generation_seed.change(update_generation_seed, inputs=[session_args, generation_seed], outputs=[session_args])
        n_beams.change(update_n_beams, inputs=[session_args, n_beams], outputs=[session_args])
        max_new_tokens.change(update_max_new_tokens, inputs=[session_args, max_new_tokens], outputs=[session_args])
        gamma.change(update_gamma, inputs=[session_args, gamma], outputs=[session_args])
        delta.change(update_delta, inputs=[session_args, delta], outputs=[session_args])
        detection_z_threshold.change(update_detection_z_threshold, inputs=[session_args, detection_z_threshold],
                                     outputs=[session_args])
        ignore_repeated_bigrams.change(update_ignore_repeated_bigrams, inputs=[session_args, ignore_repeated_bigrams],
                                       outputs=[session_args])
        normalizers.change(update_normalizers, inputs=[session_args, normalizers], outputs=[session_args])
        seed_separately.change(update_seed_separately, inputs=[session_args, seed_separately], outputs=[session_args])
        select_green_tokens.change(update_select_green_tokens, inputs=[session_args, select_green_tokens],
                                   outputs=[session_args])
        # register additional callback on button clicks that updates the shown parameters window
        generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
        detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
        # When the parameters change, display the update and fire detection, since some detection params dont change the model output.
        gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
        gamma.change(fn=detect_partial, inputs=[output_without_watermark, session_args],
                     outputs=[without_watermark_detection_result, session_args])
        gamma.change(fn=detect_partial, inputs=[output_with_watermark, session_args],
                     outputs=[with_watermark_detection_result, session_args])
        gamma.change(fn=detect_partial, inputs=[detection_input, session_args],
                     outputs=[detection_result, session_args])
        detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
        detection_z_threshold.change(fn=detect_partial, inputs=[output_without_watermark, session_args],
                                     outputs=[without_watermark_detection_result, session_args])
        detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark, session_args],
                                     outputs=[with_watermark_detection_result, session_args])
        detection_z_threshold.change(fn=detect_partial, inputs=[detection_input, session_args],
                                     outputs=[detection_result, session_args])
        ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
        ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_without_watermark, session_args],
                                       outputs=[without_watermark_detection_result, session_args])
        ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_with_watermark, session_args],
                                       outputs=[with_watermark_detection_result, session_args])
        ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input, session_args],
                                       outputs=[detection_result, session_args])
        normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
        normalizers.change(fn=detect_partial, inputs=[output_without_watermark, session_args],
                           outputs=[without_watermark_detection_result, session_args])
        normalizers.change(fn=detect_partial, inputs=[output_with_watermark, session_args],
                           outputs=[with_watermark_detection_result, session_args])
        normalizers.change(fn=detect_partial, inputs=[detection_input, session_args],
                           outputs=[detection_result, session_args])
        select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
        select_green_tokens.change(fn=detect_partial, inputs=[output_without_watermark, session_args],
                                   outputs=[without_watermark_detection_result, session_args])
        select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark, session_args],
                                   outputs=[with_watermark_detection_result, session_args])
        select_green_tokens.change(fn=detect_partial, inputs=[detection_input, session_args],
                                   outputs=[detection_result, session_args])

    demo.queue(concurrency_count=3)

    if args.demo_public:
        demo.launch(share=True)  # exposes app to the internet via randomly generated link
    else:
        demo.launch()


def main(args):
    """Run a command line version of the generation and detection operations
        and optionally launch and serve the gradio demo"""
    # Initial arg processing and log
    args.normalizers = (args.normalizers.split(",") if args.normalizers else [])
    print(args)

    if not args.skip_model_load:
        model, tokenizer, device = load_model(args)
    else:
        model, tokenizer, device = None, None, None

    # Generate and detect, report to stdout
    if not args.skip_model_load:
        input_text = "The United States has 63 national parks, which are congressionally designated protected areas operated by the National Park Service, an agency of the Department of the Interior.[1] National parks are designated for their natural beauty, unique geological features, diverse ecosystems, and recreational opportunities, typically 'because of some outstanding scenic feature or natural phenomena.'[2] While legislatively all units of the National Park System are considered equal with the same mission, national parks are generally larger and more of a destination, and hunting and extractive activities are prohibited.[3] National monuments, on the other hand, are also frequently protected for their historical or archaeological significance. Eight national parks (including six in Alaska) are paired with a national preserve, areas with different levels of protection that are administered together but considered separate units and whose areas are not included in the figures below. The 433 units of the National Park System can be broadly referred to as national parks, but most have other formal designations.[4]"

        model, tokenizer, device = load_model(args)

        args.default_prompt = input_text

        term_width = 80
        print("#" * term_width)
        print("Prompt:")
        print(input_text)

        _, _, decoded_output_without_watermark, decoded_output_with_watermark, _ = generate_kgw(input_text,
                                                                                            args,
                                                                                            model=model,
                                                                                            device=device,
                                                                                            tokenizer=tokenizer)
        without_watermark_detection_result = detect_kgw(decoded_output_without_watermark,
                                                    args,
                                                    device=device,
                                                    tokenizer=tokenizer)
        with_watermark_detection_result = detect_kgw(decoded_output_with_watermark,
                                                 args,
                                                 device=device,
                                                 tokenizer=tokenizer)

        def compute_perplexity(model, tokenizer, text: str, device: torch.device) -> float:
            """
            Computes the perplexity of a given text using the provided model.
            Assumes a causal language model where providing labels yields the loss.
            """
            model.eval()
            # Tokenize the text with special tokens.
            inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True).to(device)
            with torch.no_grad():
                # When labels are the same as input_ids, the loss is the negative log-likelihood.
                outputs = model(**inputs, labels=inputs["input_ids"])
                loss = outputs.loss  # this is the average loss per token
            perplexity = torch.exp(loss).item()
            return perplexity

        # Compute perplexity for both watermarked and non-watermarked outputs:
        perplexity_nonwm = compute_perplexity(model, tokenizer, decoded_output_without_watermark, device)
        perplexity_wm = compute_perplexity(model, tokenizer, decoded_output_with_watermark, device)

        print("#" * term_width)
        print("Output without watermark:")
        print(decoded_output_without_watermark)
        print("-" * term_width)
        print(f"Detection result @ {args.detection_z_threshold}:")
        pprint(without_watermark_detection_result)
        print("-" * term_width)
        print(f"Perplexity (non-watermarked): {perplexity_nonwm:.2f}")

        print("#" * term_width)
        print("Output with watermark:")
        print(decoded_output_with_watermark)
        print("-" * term_width)
        print(f"Detection result @ {args.detection_z_threshold}:")
        print(with_watermark_detection_result[1][0])
        pprint(with_watermark_detection_result)
        print("-" * term_width)
        print(f"Perplexity (watermarked): {perplexity_wm:.2f}")

    # Launch the app to generate and detect interactively (implements the hf space demo)
    #if args.run_gradio:
    #    run_gradio(args, model=model, tokenizer=tokenizer, device=device)

    return

def normalization_strategy_lookup(strategy_name: str) -> object:
    if strategy_name == "unicode":
        return UnicodeSanitizer()
    elif strategy_name == "homoglyphs":
        return HomoglyphCanonizer()
    elif strategy_name == "truecase":
        return TrueCaser()

In [7]:
# coding=utf-8
# Copyright 2023 Authors of "A Watermark for Large Language Models"
# available at https://arxiv.org/abs/2301.10226
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import argparse
from argparse import Namespace
from pprint import pprint
from functools import partial

import numpy  # for gradio hot reload

import torch

from transformers import (AutoTokenizer,
                          AutoModelForSeq2SeqLM,
                          AutoModelForCausalLM,
                          LogitsProcessorList)


def str2bool(v):
    """Util function for user friendly boolean flag args"""
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def parse_args():
    """Command line argument specification"""

    parser = argparse.ArgumentParser(
        description="A minimum working example of applying the watermark to any LLM that supports the huggingface 🤗 `generate` API")

    #parser.add_argument(
    #    "--run_gradio",
    #    type=str2bool,
    #    default=False,
    #    help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.",
    #)
    parser.add_argument(
        "--demo_public",
        type=str2bool,
        default=False,
        help="Whether to expose the gradio demo to the internet.",
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default="facebook/opt-1.3b",#"gpt2-medium",#
        help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--prompt_max_length",
        type=int,
        default=None,
        help="Truncation length for prompt, overrides model config's max length field.",
    )
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=200,
        help="Maximmum number of new tokens to generate.",
    )
    parser.add_argument(
        "--generation_seed",
        type=int,
        default=123,
        help="Seed for setting the torch global rng prior to generation.",
    )
    parser.add_argument(
        "--use_sampling",
        type=str2bool,
        default=True,
        help="Whether to generate using multinomial sampling.",
    )
    parser.add_argument(
        "--sampling_temp",
        type=float,
        default=0.7,
        help="Sampling temperature to use when generating using multinomial sampling.",
    )
    parser.add_argument(
        "--n_beams",
        type=int,
        default=1,
        help="Number of beams to use for beam search. 1 is normal greedy decoding",
    )
    parser.add_argument(
        "--use_gpu",
        type=str2bool,
        default=True,
        help="Whether to run inference and watermark hashing/seeding/permutation on gpu.",
    )
    parser.add_argument(
        "--seeding_scheme",
        type=str,
        default="simple_1",
        help="Seeding scheme to use to generate the greenlists at each generation and verification step.",
    )
    parser.add_argument(
        "--gamma",
        type=float,
        default=0.25,
        help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.",
    )
    parser.add_argument(
        "--delta",
        type=float,
        default=2.0,
        help="The amount/bias to add to each of the greenlist token logits before each token sampling step.",
    )
    parser.add_argument(
        "--normalizers",
        type=str,
        default="",
        help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.",
    )
    parser.add_argument(
        "--ignore_repeated_bigrams",
        type=str2bool,
        default=False,
        help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.",
    )
    parser.add_argument(
        "--detection_z_threshold",
        type=float,
        default=4.0,
        help="The test statistic threshold for the detection hypothesis test.",
    )
    parser.add_argument(
        "--select_green_tokens",
        type=str2bool,
        default=True,
        help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.",
    )
    parser.add_argument(
        "--skip_model_load",
        type=str2bool,
        default=False,
        help="Skip the model loading to debug the interface.",
    )
    parser.add_argument(
        "--seed_separately",
        type=str2bool,
        default=True,
        help="Whether to call the torch seed function before both the unwatermarked and watermarked generate calls.",
    )
    parser.add_argument(
        "--load_fp16",
        type=str2bool,
        default=False,
        help="Whether to run model in float16 precsion.",
    )
    args = parser.parse_args()
    return args


def load_model(args):
    """Load and return the model and tokenizer"""

    args.is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5", "T0"]])
    args.is_decoder_only_model = any(
        [(model_type in args.model_name_or_path) for model_type in ["gpt", "opt", "bloom"]])
    if args.is_seq2seq_model:
        model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
    elif args.is_decoder_only_model:
        if args.load_fp16:
            model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch.float16,
                                                         device_map='auto')
        else:
            model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
    else:
        raise ValueError(f"Unknown model type: {args.model_name_or_path}")

    if args.use_gpu:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        if args.load_fp16:
            pass
        else:
            model = model.to(device)
    else:
        device = "cpu"
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

    return model, tokenizer, device


def generate_kgw(prompt, args, model=None, device=None, tokenizer=None):
    """Instatiate the WatermarkLogitsProcessor according to the watermark parameters
       and generate watermarked text by passing it to the generate method of the model
       as a logits processor. """

    #print(f"Generating with {args}")

    watermark_processor = WatermarkLogitsProcessor_kgw(vocab=list(tokenizer.get_vocab().values()),
                                                   gamma=args.gamma,
                                                   delta=args.delta,
                                                   seeding_scheme=args.seeding_scheme,
                                                   select_green_tokens=args.select_green_tokens)

    gen_kwargs = dict(max_new_tokens=args.max_new_tokens)

    if args.use_sampling:
        gen_kwargs.update(dict(
            do_sample=True,
            top_k=0,
            temperature=args.sampling_temp
        ))
    else:
        gen_kwargs.update(dict(
            num_beams=args.n_beams
        ))

    generate_without_watermark = partial(
        model.generate,
        **gen_kwargs
    )
    generate_with_watermark = partial(
        model.generate,
        logits_processor=LogitsProcessorList([watermark_processor]),
        **gen_kwargs
    )
    if args.prompt_max_length:
        pass
    elif hasattr(model.config, "max_position_embedding"):
        args.prompt_max_length = model.config.max_position_embeddings - args.max_new_tokens
    else:
        args.prompt_max_length = 2048 - args.max_new_tokens

    tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True,
                           max_length=args.prompt_max_length).to(device)
    truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
    redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]

    torch.manual_seed(args.generation_seed)
    output_without_watermark = generate_without_watermark(**tokd_input)

    # optional to seed before second generation, but will not be the same again generally, unless delta==0.0, no-op watermark
    if args.seed_separately:
        torch.manual_seed(args.generation_seed)
    output_with_watermark = generate_with_watermark(**tokd_input)

    if args.is_decoder_only_model:
        # need to isolate the newly generated tokens
        output_without_watermark = output_without_watermark[:, tokd_input["input_ids"].shape[-1]:]
        output_with_watermark = output_with_watermark[:, tokd_input["input_ids"].shape[-1]:]

    decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
    decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]

    return (redecoded_input,
            int(truncation_warning),
            decoded_output_without_watermark,
            decoded_output_with_watermark,
            args)
    # decoded_output_with_watermark)


def format_names(s):
    """Format names for the gradio demo interface"""
    s = s.replace("num_tokens_scored", "Tokens Counted (T)")
    s = s.replace("num_green_tokens", "# Tokens in Greenlist")
    s = s.replace("green_fraction", "Fraction of T in Greenlist")
    s = s.replace("z_score", "z-score")
    s = s.replace("p_value", "p value")
    s = s.replace("prediction", "Prediction")
    s = s.replace("confidence", "Confidence")
    return s


def list_format_scores(score_dict, detection_threshold):
    """Format the detection metrics into a gradio dataframe input format"""
    lst_2d = []
    # lst_2d.append(["z-score threshold", f"{detection_threshold}"])
    for k, v in score_dict.items():
        if k == 'green_fraction':
            lst_2d.append([format_names(k), f"{v:.1%}"])
        elif k == 'confidence':
            lst_2d.append([format_names(k), f"{v:.3%}"])
        elif isinstance(v, float):
            lst_2d.append([format_names(k), f"{v:.3g}"])
        elif isinstance(v, bool):
            lst_2d.append([format_names(k), ("Watermarked" if v else "Human/Unwatermarked")])
        else:
            lst_2d.append([format_names(k), f"{v}"])
    if "confidence" in score_dict:
        lst_2d.insert(-2, ["z-score Threshold", f"{detection_threshold}"])
    else:
        lst_2d.insert(-1, ["z-score Threshold", f"{detection_threshold}"])
    return lst_2d


def detect_kgw(input_text, args, device=None, tokenizer=None):
    """Instantiate the WatermarkDetection object and call detect on
        the input text returning the scores and outcome of the test"""
    watermark_detector = WatermarkDetector_kgw(vocab=list(tokenizer.get_vocab().values()),
                                           gamma=args.gamma,
                                           seeding_scheme=args.seeding_scheme,
                                           device=device,
                                           tokenizer=tokenizer,
                                           z_threshold=args.detection_z_threshold,
                                           normalizers=args.normalizers,
                                           ignore_repeated_bigrams=args.ignore_repeated_bigrams,
                                           select_green_tokens=args.select_green_tokens)
    if len(input_text) - 1 > watermark_detector.min_prefix_len:
        score_dict = watermark_detector.detect(input_text)
        # output = str_format_scores(score_dict, watermark_detector.z_threshold)
        output = list_format_scores(score_dict, watermark_detector.z_threshold)
    else:
        score_dict = None
        # output = (f"Error: string not long enough to compute watermark presence.")
        output = [["Error", "string too short to compute metrics"]]
        output += [["", ""] for _ in range(6)]
    return output, score_dict, args


def run_gradio(args, model=None, device=None, tokenizer=None):
    """Define and launch the gradio demo interface"""
    generate_partial = partial(generate_kgw, model=model, device=device, tokenizer=tokenizer)
    detect_partial = partial(detect_kgw, device=device, tokenizer=tokenizer)

    with gr.Blocks() as demo:
        # Top section, greeting and instructions
        with gr.Row():
            with gr.Column(scale=9):
                gr.Markdown(
                    """
                    ## 💧 [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) 🔍
                    """
                )
            with gr.Column(scale=1):
                gr.Markdown(
                    """
                    [![](https://badgen.net/badge/icon/GitHub?icon=github&label)](https://github.com/jwkirchenbauer/lm-watermarking)
                    """
                )
            # with gr.Column(scale=2):
            #     pass
            # ![visitor badge](https://visitor-badge.glitch.me/badge?page_id=tomg-group-umd_lm-watermarking) # buggy

        with gr.Accordion("Understanding the output metrics", open=False):
            gr.Markdown(
                """
                - `z-score threshold` : The cuttoff for the hypothesis test
                - `Tokens Counted (T)` : The number of tokens in the output that were counted by the detection algorithm. 
                    The first token is ommitted in the simple, single token seeding scheme since there is no way to generate
                    a greenlist for it as it has no prefix token(s). Under the "Ignore Bigram Repeats" detection algorithm, 
                    described in the bottom panel, this can be much less than the total number of tokens generated if there is a lot of repetition.
                - `# Tokens in Greenlist` : The number of tokens that were observed to fall in their respective greenlist
                - `Fraction of T in Greenlist` : The `# Tokens in Greenlist` / `T`. This is expected to be approximately `gamma` for human/unwatermarked text.
                - `z-score` : The test statistic for the detection hypothesis test. If larger than the `z-score threshold` 
                    we "reject the null hypothesis" that the text is human/unwatermarked, and conclude it is watermarked
                - `p value` : The likelihood of observing the computed `z-score` under the null hypothesis. This is the likelihood of 
                    observing the `Fraction of T in Greenlist` given that the text was generated without knowledge of the watermark procedure/greenlists.
                    If this is extremely _small_ we are confident that this many green tokens was not chosen by random chance.
                -  `prediction` : The outcome of the hypothesis test - whether the observed `z-score` was higher than the `z-score threshold`
                - `confidence` : If we reject the null hypothesis, and the `prediction` is "Watermarked", then we report 1-`p value` to represent 
                    the confidence of the detection based on the unlikeliness of this `z-score` observation.
                """
            )

        with gr.Accordion("A note on model capability", open=True):
            gr.Markdown(
                """
                This demo uses open-source language models that fit on a single GPU. These models are less powerful than proprietary commercial tools like ChatGPT, Claude, or Bard. 

                Importantly, we use a language model that is designed to "complete" your prompt, and not a model this is fine-tuned to follow instructions. 
                For best results, prompt the model with a few sentences that form the beginning of a paragraph, and then allow it to "continue" your paragraph. 
                Some examples include the opening paragraph of a wikipedia article, or the first few sentences of a story. 
                Longer prompts that end mid-sentence will result in more fluent generations.
                """
            )
        gr.Markdown(f"Language model: {args.model_name_or_path} {'(float16 mode)' if args.load_fp16 else ''}")

        # Construct state for parameters, define updates and toggles
        default_prompt = args.__dict__.pop("default_prompt")
        session_args = gr.State(value=args)

        with gr.Tab("Generate and Detect"):

            with gr.Row():
                prompt = gr.Textbox(label=f"Prompt", interactive=True, lines=10, max_lines=10, value=default_prompt)
            with gr.Row():
                generate_btn = gr.Button("Generate")
            with gr.Row():
                with gr.Column(scale=2):
                    output_without_watermark = gr.Textbox(label="Output Without Watermark", interactive=False, lines=14,
                                                          max_lines=14)
                with gr.Column(scale=1):
                    # without_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
                    without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,
                                                                      row_count=7, col_count=2)
            with gr.Row():
                with gr.Column(scale=2):
                    output_with_watermark = gr.Textbox(label="Output With Watermark", interactive=False, lines=14,
                                                       max_lines=14)
                with gr.Column(scale=1):
                    # with_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
                    with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,
                                                                   row_count=7, col_count=2)

            redecoded_input = gr.Textbox(visible=False)
            truncation_warning = gr.Number(visible=False)

            def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
                if truncation_warning:
                    return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
                else:
                    return orig_prompt, args

        with gr.Tab("Detector Only"):
            with gr.Row():
                with gr.Column(scale=2):
                    detection_input = gr.Textbox(label="Text to Analyze", interactive=True, lines=14, max_lines=14)
                with gr.Column(scale=1):
                    # detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
                    detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False, row_count=7,
                                                    col_count=2)
            with gr.Row():
                detect_btn = gr.Button("Detect")

        # Parameter selection group
        with gr.Accordion("Advanced Settings", open=False):
            with gr.Row():
                with gr.Column(scale=1):
                    gr.Markdown(f"#### Generation Parameters")
                    with gr.Row():
                        decoding = gr.Radio(label="Decoding Method", choices=["multinomial", "greedy"],
                                            value=("multinomial" if args.use_sampling else "greedy"))
                    with gr.Row():
                        sampling_temp = gr.Slider(label="Sampling Temperature", minimum=0.1, maximum=1.0, step=0.1,
                                                  value=args.sampling_temp, visible=True)
                    with gr.Row():
                        generation_seed = gr.Number(label="Generation Seed", value=args.generation_seed,
                                                    interactive=True)
                    with gr.Row():
                        n_beams = gr.Dropdown(label="Number of Beams", choices=list(range(1, 11, 1)),
                                              value=args.n_beams, visible=(not args.use_sampling))
                    with gr.Row():
                        max_new_tokens = gr.Slider(label="Max Generated Tokens", minimum=10, maximum=1000, step=10,
                                                   value=args.max_new_tokens)

                with gr.Column(scale=1):
                    gr.Markdown(f"#### Watermark Parameters")
                    with gr.Row():
                        gamma = gr.Slider(label="gamma", minimum=0.1, maximum=0.9, step=0.05, value=args.gamma)
                    with gr.Row():
                        delta = gr.Slider(label="delta", minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
                    gr.Markdown(f"#### Detector Parameters")
                    with gr.Row():
                        detection_z_threshold = gr.Slider(label="z-score threshold", minimum=0.0, maximum=10.0,
                                                          step=0.1, value=args.detection_z_threshold)
                    with gr.Row():
                        ignore_repeated_bigrams = gr.Checkbox(label="Ignore Bigram Repeats")
                    with gr.Row():
                        normalizers = gr.CheckboxGroup(label="Normalizations",
                                                       choices=["unicode", "homoglyphs", "truecase"],
                                                       value=args.normalizers)
            # with gr.Accordion("Actual submitted parameters:",open=False):
            with gr.Row():
                gr.Markdown(
                    f"_Note: sliders don't always update perfectly. Clicking on the bar or using the number window to the right can help. Window below shows the current settings._")
            with gr.Row():
                current_parameters = gr.Textbox(label="Current Parameters", value=args)
            with gr.Accordion("Legacy Settings", open=False):
                with gr.Row():
                    with gr.Column(scale=1):
                        seed_separately = gr.Checkbox(label="Seed both generations separately",
                                                      value=args.seed_separately)
                    with gr.Column(scale=1):
                        select_green_tokens = gr.Checkbox(label="Select 'greenlist' from partition",
                                                          value=args.select_green_tokens)

        with gr.Accordion("Understanding the settings", open=False):
            gr.Markdown(
                """
                #### Generation Parameters:
    
                - Decoding Method : We can generate tokens from the model using either multinomial sampling or we can generate using greedy decoding.
                - Sampling Temperature : If using multinomial sampling we can set the temperature of the sampling distribution. 
                                    0.0 is equivalent to greedy decoding, and 1.0 is the maximum amount of variability/entropy in the next token distribution.
                                    0.7 strikes a nice balance between faithfulness to the model's estimate of top candidates while adding variety. Does not apply for greedy decoding.
                - Generation Seed : The integer to pass to the torch random number generator before running generation. Makes the multinomial sampling strategy
                                    outputs reproducible. Does not apply for greedy decoding.
                - Number of Beams : When using greedy decoding, we can also set the number of beams to > 1 to enable beam search. 
                                    This is not implemented/excluded from paper for multinomial sampling but may be added in future.
                - Max Generated Tokens : The `max_new_tokens` parameter passed to the generation method to stop the output at a certain number of new tokens. 
                                        Note that the model is free to generate fewer tokens depending on the prompt. 
                                        Implicitly this sets the maximum number of prompt tokens possible as the model's maximum input length minus `max_new_tokens`,
                                        and inputs will be truncated accordingly.
    
                #### Watermark Parameters:
    
                - gamma : The fraction of the vocabulary to be partitioned into the greenlist at each generation step. 
                         Smaller gamma values create a stronger watermark by enabling the watermarked model to achieve 
                         a greater differentiation from human/unwatermarked text because it is preferentially sampling 
                         from a smaller green set making those tokens less likely to occur by chance.
                - delta : The amount of positive bias to add to the logits of every token in the greenlist 
                            at each generation step before sampling/choosing the next token. Higher delta values 
                            mean that the greenlist tokens are more heavily preferred by the watermarked model
                            and as the bias becomes very large the watermark transitions from "soft" to "hard". 
                            For a hard watermark, nearly all tokens are green, but this can have a detrimental effect on
                            generation quality, especially when there is not a lot of flexibility in the distribution.
    
                #### Detector Parameters:
    
                - z-score threshold : the z-score cuttoff for the hypothesis test. Higher thresholds (such as 4.0) make
                                    _false positives_ (predicting that human/unwatermarked text is watermarked) very unlikely
                                    as a genuine human text with a significant number of tokens will almost never achieve 
                                    that high of a z-score. Lower thresholds will capture more _true positives_ as some watermarked
                                    texts will contain less green tokens and achive a lower z-score, but still pass the lower bar and 
                                    be flagged as "watermarked". However, a lowere threshold will increase the chance that human text 
                                    that contains a slightly higher than average number of green tokens is erroneously flagged. 
                                    4.0-5.0 offers extremely low false positive rates while still accurately catching most watermarked text.
                - Ignore Bigram Repeats : This alternate detection algorithm only considers the unique bigrams in the text during detection, 
                                        computing the greenlists based on the first in each pair and checking whether the second falls within the list.
                                        This means that `T` is now the unique number of bigrams in the text, which becomes less than the total
                                        number of tokens generated if the text contains a lot of repetition. See the paper for a more detailed discussion.
                - Normalizations : we implement a few basic normaliations to defend against various adversarial perturbations of the
                                    text analyzed during detection. Currently we support converting all chracters to unicode, 
                                    replacing homoglyphs with a canonical form, and standardizing the capitalization. 
                                    See the paper for a detailed discussion of input normalization. 
                """
            )

        gr.HTML("""
                <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. 
                    Follow the github link at the top and host the demo on your own GPU hardware to test out larger models.
                <br/>
                <a href="https://huggingface.co/spaces/tomg-group-umd/lm-watermarking?duplicate=true">
                <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
                <p/>
                """)

        # Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag
        generate_btn.click(fn=generate_partial, inputs=[prompt, session_args],
                           outputs=[redecoded_input, truncation_warning, output_without_watermark,
                                    output_with_watermark, session_args])
        # Show truncated version of prompt if truncation occurred
        redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input, truncation_warning, prompt, session_args],
                               outputs=[prompt, session_args])
        # Call detection when the outputs (of the generate function) are updated
        output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark, session_args],
                                        outputs=[without_watermark_detection_result, session_args])
        output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark, session_args],
                                     outputs=[with_watermark_detection_result, session_args])
        # Register main detection tab click
        detect_btn.click(fn=detect_partial, inputs=[detection_input, session_args],
                         outputs=[detection_result, session_args])

        # State management logic
        # update callbacks that change the state dict
        def update_sampling_temp(session_state, value):
            session_state.sampling_temp = float(value); return session_state

        def update_generation_seed(session_state, value):
            session_state.generation_seed = int(value); return session_state

        def update_gamma(session_state, value):
            session_state.gamma = float(value); return session_state

        def update_delta(session_state, value):
            session_state.delta = float(value); return session_state

        def update_detection_z_threshold(session_state, value):
            session_state.detection_z_threshold = float(value); return session_state

        def update_decoding(session_state, value):
            if value == "multinomial":
                session_state.use_sampling = True
            elif value == "greedy":
                session_state.use_sampling = False
            return session_state

        def toggle_sampling_vis(value):
            if value == "multinomial":
                return gr.update(visible=True)
            elif value == "greedy":
                return gr.update(visible=False)

        def toggle_sampling_vis_inv(value):
            if value == "multinomial":
                return gr.update(visible=False)
            elif value == "greedy":
                return gr.update(visible=True)

        def update_n_beams(session_state, value):
            session_state.n_beams = value; return session_state

        def update_max_new_tokens(session_state, value):
            session_state.max_new_tokens = int(value); return session_state

        def update_ignore_repeated_bigrams(session_state, value):
            session_state.ignore_repeated_bigrams = value; return session_state

        def update_normalizers(session_state, value):
            session_state.normalizers = value; return session_state

        def update_seed_separately(session_state, value):
            session_state.seed_separately = value; return session_state

        def update_select_green_tokens(session_state, value):
            session_state.select_green_tokens = value; return session_state

        # registering callbacks for toggling the visibilty of certain parameters
        decoding.change(toggle_sampling_vis, inputs=[decoding], outputs=[sampling_temp])
        decoding.change(toggle_sampling_vis, inputs=[decoding], outputs=[generation_seed])
        decoding.change(toggle_sampling_vis_inv, inputs=[decoding], outputs=[n_beams])
        # registering all state update callbacks
        decoding.change(update_decoding, inputs=[session_args, decoding], outputs=[session_args])
        sampling_temp.change(update_sampling_temp, inputs=[session_args, sampling_temp], outputs=[session_args])
        generation_seed.change(update_generation_seed, inputs=[session_args, generation_seed], outputs=[session_args])
        n_beams.change(update_n_beams, inputs=[session_args, n_beams], outputs=[session_args])
        max_new_tokens.change(update_max_new_tokens, inputs=[session_args, max_new_tokens], outputs=[session_args])
        gamma.change(update_gamma, inputs=[session_args, gamma], outputs=[session_args])
        delta.change(update_delta, inputs=[session_args, delta], outputs=[session_args])
        detection_z_threshold.change(update_detection_z_threshold, inputs=[session_args, detection_z_threshold],
                                     outputs=[session_args])
        ignore_repeated_bigrams.change(update_ignore_repeated_bigrams, inputs=[session_args, ignore_repeated_bigrams],
                                       outputs=[session_args])
        normalizers.change(update_normalizers, inputs=[session_args, normalizers], outputs=[session_args])
        seed_separately.change(update_seed_separately, inputs=[session_args, seed_separately], outputs=[session_args])
        select_green_tokens.change(update_select_green_tokens, inputs=[session_args, select_green_tokens],
                                   outputs=[session_args])
        # register additional callback on button clicks that updates the shown parameters window
        generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
        detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
        # When the parameters change, display the update and fire detection, since some detection params dont change the model output.
        gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
        gamma.change(fn=detect_partial, inputs=[output_without_watermark, session_args],
                     outputs=[without_watermark_detection_result, session_args])
        gamma.change(fn=detect_partial, inputs=[output_with_watermark, session_args],
                     outputs=[with_watermark_detection_result, session_args])
        gamma.change(fn=detect_partial, inputs=[detection_input, session_args],
                     outputs=[detection_result, session_args])
        detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
        detection_z_threshold.change(fn=detect_partial, inputs=[output_without_watermark, session_args],
                                     outputs=[without_watermark_detection_result, session_args])
        detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark, session_args],
                                     outputs=[with_watermark_detection_result, session_args])
        detection_z_threshold.change(fn=detect_partial, inputs=[detection_input, session_args],
                                     outputs=[detection_result, session_args])
        ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
        ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_without_watermark, session_args],
                                       outputs=[without_watermark_detection_result, session_args])
        ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_with_watermark, session_args],
                                       outputs=[with_watermark_detection_result, session_args])
        ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input, session_args],
                                       outputs=[detection_result, session_args])
        normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
        normalizers.change(fn=detect_partial, inputs=[output_without_watermark, session_args],
                           outputs=[without_watermark_detection_result, session_args])
        normalizers.change(fn=detect_partial, inputs=[output_with_watermark, session_args],
                           outputs=[with_watermark_detection_result, session_args])
        normalizers.change(fn=detect_partial, inputs=[detection_input, session_args],
                           outputs=[detection_result, session_args])
        select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
        select_green_tokens.change(fn=detect_partial, inputs=[output_without_watermark, session_args],
                                   outputs=[without_watermark_detection_result, session_args])
        select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark, session_args],
                                   outputs=[with_watermark_detection_result, session_args])
        select_green_tokens.change(fn=detect_partial, inputs=[detection_input, session_args],
                                   outputs=[detection_result, session_args])

    demo.queue(concurrency_count=3)

    if args.demo_public:
        demo.launch(share=True)  # exposes app to the internet via randomly generated link
    else:
        demo.launch()


def main(args):
    """Run a command line version of the generation and detection operations
        and optionally launch and serve the gradio demo"""
    # Initial arg processing and log
    args.normalizers = (args.normalizers.split(",") if args.normalizers else [])
    print(args)

    if not args.skip_model_load:
        model, tokenizer, device = load_model(args)
    else:
        model, tokenizer, device = None, None, None

    # Generate and detect, report to stdout
    if not args.skip_model_load:
        input_text = "The United States has 63 national parks, which are congressionally designated protected areas operated by the National Park Service, an agency of the Department of the Interior.[1] National parks are designated for their natural beauty, unique geological features, diverse ecosystems, and recreational opportunities, typically 'because of some outstanding scenic feature or natural phenomena.'[2] While legislatively all units of the National Park System are considered equal with the same mission, national parks are generally larger and more of a destination, and hunting and extractive activities are prohibited.[3] National monuments, on the other hand, are also frequently protected for their historical or archaeological significance. Eight national parks (including six in Alaska) are paired with a national preserve, areas with different levels of protection that are administered together but considered separate units and whose areas are not included in the figures below. The 433 units of the National Park System can be broadly referred to as national parks, but most have other formal designations.[4]"

        model, tokenizer, device = load_model(args)

        args.default_prompt = input_text

        term_width = 80
        print("#" * term_width)
        print("Prompt:")
        print(input_text)

        _, _, decoded_output_without_watermark, decoded_output_with_watermark, _ = generate_kgw(input_text,
                                                                                            args,
                                                                                            model=model,
                                                                                            device=device,
                                                                                            tokenizer=tokenizer)
        without_watermark_detection_result = detect_kgw(decoded_output_without_watermark,
                                                    args,
                                                    device=device,
                                                    tokenizer=tokenizer)
        with_watermark_detection_result = detect_kgw(decoded_output_with_watermark,
                                                 args,
                                                 device=device,
                                                 tokenizer=tokenizer)

        def compute_perplexity(model, tokenizer, text: str, device: torch.device) -> float:
            """
            Computes the perplexity of a given text using the provided model.
            Assumes a causal language model where providing labels yields the loss.
            """
            model.eval()
            # Tokenize the text with special tokens.
            inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True).to(device)
            with torch.no_grad():
                # When labels are the same as input_ids, the loss is the negative log-likelihood.
                outputs = model(**inputs, labels=inputs["input_ids"])
                loss = outputs.loss  # this is the average loss per token
            perplexity = torch.exp(loss).item()
            return perplexity

        # Compute perplexity for both watermarked and non-watermarked outputs:
        perplexity_nonwm = compute_perplexity(model, tokenizer, decoded_output_without_watermark, device)
        perplexity_wm = compute_perplexity(model, tokenizer, decoded_output_with_watermark, device)

        print("#" * term_width)
        print("Output without watermark:")
        print(decoded_output_without_watermark)
        print("-" * term_width)
        print(f"Detection result @ {args.detection_z_threshold}:")
        pprint(without_watermark_detection_result)
        print("-" * term_width)
        print(f"Perplexity (non-watermarked): {perplexity_nonwm:.2f}")

        print("#" * term_width)
        print("Output with watermark:")
        print(decoded_output_with_watermark)
        print("-" * term_width)
        print(f"Detection result @ {args.detection_z_threshold}:")
        print(with_watermark_detection_result[1][0])
        pprint(with_watermark_detection_result)
        print("-" * term_width)
        print(f"Perplexity (watermarked): {perplexity_wm:.2f}")

    # Launch the app to generate and detect interactively (implements the hf space demo)
    #if args.run_gradio:
    #    run_gradio(args, model=model, tokenizer=tokenizer, device=device)

    return

In [113]:
# coding=utf-8
# Copyright 2023 Authors of "A Watermark for Large Language Models"
# available at https://arxiv.org/abs/2301.10226
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations
import collections
from math import sqrt
import scipy.stats
import torch
from torch import Tensor
from transformers import LogitsProcessor
import nltk
import ssl
from nltk.util import ngrams

###############################################
# Revised Watermark Classes
###############################################

class WatermarkBase_kgw:
    def __init__(
            self,
            vocab: list[int] = None,
            gamma: float = 0.5,
            delta: float = 2.0,
            seeding_scheme: str = "simple_1",
            hash_key: int = 15485863,
            select_green_tokens: bool = True,
            precomputed_pairing: list[tuple[int, int]] = None,
            unique_tokens: list[int] = None,
    ):
        self.vocab = vocab  # list of token IDs (usually 0,...,n-1)
        self.vocab_size = len(vocab)
        self.gamma = gamma  # fraction of tokens to designate as green (target size)
        self.delta = delta  # bias to add to green tokens' logits
        self.seeding_scheme = seeding_scheme
        self.rng = None
        self.hash_key = hash_key
        self.select_green_tokens = select_green_tokens
        self.pairing = precomputed_pairing  # perfect matching on tokens with synonyms (indices in vocab)
        self.unique_tokens = unique_tokens  # list of token IDs that have no synonyms

    def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None:
        """
        Seeds the RNG deterministically using the last token in input_ids.
        For the "simple_1" scheme, seed = hash_key * (last token id).
        """
        if seeding_scheme is None:
            seeding_scheme = self.seeding_scheme
        # Ensure the RNG is initialized.
        if self.rng is None:
            self.rng = torch.Generator(device=input_ids.device)
        if seeding_scheme == "simple_1":
            assert input_ids.shape[-1] >= 1, "Input must have at least one token."
            prev_token = input_ids[-1].item()
            self.rng.manual_seed(self.hash_key * prev_token)
            #print(self.hash_key * prev_token)
        else:
            raise NotImplementedError(f"Seeding scheme {seeding_scheme} not implemented.")
        return

    def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
        """
        Returns a list of token IDs that form the green list.
        If a precomputed pairing exists, then:
          - All tokens from the unique set (those with no synonyms) are automatically in the green list.
          - For each pair in the perfect matching (on tokens with synonyms), a fair coin flip (using the seeded RNG)
            selects one token from the pair.
        Otherwise, falls back to a random permutation method.
        Optionally, the list may be truncated to a target size (gamma * vocab_size).
        """
        self._seed_rng(input_ids)
        # Fallback: use random permutation.
        greenlist_size = int(self.vocab_size * self.gamma)
        vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
        if self.select_green_tokens:
            return vocab_permutation[:greenlist_size].tolist()
        else:
            return vocab_permutation[-greenlist_size:].tolist()
        #if self.pairing is None or self.unique_tokens is None:
            # Fallback: use random permutation.
        #    greenlist_size = int(self.vocab_size * self.gamma)
        #    vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
        #    if self.select_green_tokens:
        #        return vocab_permutation[:greenlist_size].tolist()
        #    else:
        #        return vocab_permutation[-greenlist_size:].tolist()
        #else:
            # Start with all unique tokens.
        #    greenlist_ids = self.unique_tokens.copy()
            # For tokens that have synonyms (precomputed pairing), randomly assign one from each pair.
        #    for pair in self.pairing:
        #        coin_flip = (torch.rand(1, generator=self.rng, device=input_ids.device) < 0.5).item()
        #        chosen = pair[0] if coin_flip == 1 else pair[1]
        #        greenlist_ids.append(chosen)
            # Optionally, enforce a maximum size (gamma * vocab_size)
        #    desired_size = int(self.vocab_size * self.gamma)
        #    if len(greenlist_ids) > desired_size:
        #        indices = torch.randperm(len(greenlist_ids), generator=self.rng, device=input_ids.device)[:desired_size].tolist()
        #        greenlist_ids = [greenlist_ids[i] for i in indices]
        #    return greenlist_ids


class WatermarkLogitsProcessor_kgw(WatermarkBase_kgw, LogitsProcessor):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
        green_tokens_mask = torch.zeros_like(scores)
        for b_idx in range(len(greenlist_token_ids)):
            green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1
        return green_tokens_mask.bool()

    def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor,
                               greenlist_bias: float) -> torch.Tensor:
        scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
        return scores

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        #print(input_ids.device)
        if self.rng is None:
            self.rng = torch.Generator(device=input_ids.device)
        batched_greenlist_ids = [None for _ in range(input_ids.shape[0])]
        for b_idx in range(input_ids.shape[0]):
            greenlist_ids = self._get_greenlist_ids(input_ids[b_idx])
            batched_greenlist_ids[b_idx] = greenlist_ids
        green_tokens_mask = self._calc_greenlist_mask(scores, batched_greenlist_ids)
        scores = self._bias_greenlist_logits(scores, green_tokens_mask, self.delta)
        #print(torch.max(scores))
        return scores


class WatermarkDetector_kgw(WatermarkBase_kgw):
    def __init__(
            self,
            *args,
            device: torch.device = None,
            tokenizer: Tokenizer = None,
            z_threshold: float = 4.0,
            normalizers: list[str] = ["unicode"],
            ignore_repeated_bigrams: bool = True,
            **kwargs,
    ):
        super().__init__(*args, **kwargs)
        assert device, "Device must be provided."
        assert tokenizer, "A tokenizer is required for detection."
        self.tokenizer = tokenizer
        self.device = device
        self.z_threshold = z_threshold
        self.rng = torch.Generator(device=self.device)
        if self.seeding_scheme == "simple_1":
            self.min_prefix_len = 1
        else:
            raise NotImplementedError(f"Seeding scheme {self.seeding_scheme} not implemented.")
        self.normalizers = [normalization_strategy_lookup(norm) for norm in normalizers]
        self.ignore_repeated_bigrams = ignore_repeated_bigrams
        if self.ignore_repeated_bigrams:
            assert self.seeding_scheme == "simple_1", "Repeated bigram variant requires simple_1 seeding."

    def _compute_z_score(self, observed_count, T):
        expected_count = self.gamma
        numer = observed_count - expected_count * T
        denom = sqrt(T * expected_count * (1 - expected_count))
        return numer / denom

    def _compute_p_value(self, z):
        return scipy.stats.norm.sf(z)

    def _score_sequence(
            self,
            input_ids: Tensor,
            return_num_tokens_scored: bool = True,
            return_num_green_tokens: bool = True,
            return_green_fraction: bool = True,
            return_green_token_mask: bool = False,
            return_z_score: bool = True,
            return_p_value: bool = True,
    ):
        if self.ignore_repeated_bigrams:
            bigram_table = {}
            token_bigram_generator = ngrams(input_ids.cpu().tolist(), 2)
            freq = collections.Counter(token_bigram_generator)
            num_tokens_scored = len(freq.keys())
            for bigram in freq.keys():
                prefix = torch.tensor([bigram[0]], device=self.device)
                greenlist_ids = self._get_greenlist_ids(prefix)
                bigram_table[bigram] = True if bigram[1] in greenlist_ids else False
            green_token_count = sum(bigram_table.values())
        else:
            #print(1)
            num_tokens_scored = len(input_ids) - self.min_prefix_len
            if num_tokens_scored < 1:
                score_dict = None
                return score_dict #raise ValueError("Not enough tokens to score.")
            green_token_count = 0
            green_token_mask = []
            for idx in range(self.min_prefix_len, len(input_ids)):
                curr_token = input_ids[idx]
                greenlist_ids = self._get_greenlist_ids(input_ids[:idx])
                if curr_token in greenlist_ids:
                    green_token_count += 1
                    green_token_mask.append(True)
                    #print(True)
                else:
                    green_token_mask.append(False)
                    #print(False)
        score_dict = {}
        if return_num_tokens_scored:
            score_dict["num_tokens_scored"] = num_tokens_scored
        if return_num_green_tokens:
            score_dict["num_green_tokens"] = green_token_count
        if return_green_fraction:
            score_dict["green_fraction"] = green_token_count / num_tokens_scored
        if return_z_score:
            score_dict["z_score"] = self._compute_z_score(green_token_count, num_tokens_scored)
        if return_p_value:
            z = score_dict.get("z_score", self._compute_z_score(green_token_count, num_tokens_scored))
            score_dict["p_value"] = self._compute_p_value(z)
        if return_green_token_mask:
            score_dict["green_token_mask"] = green_token_mask
        return score_dict

    def detect(
            self,
            text: str = None,
            tokenized_text: list[int] = None,
            return_prediction: bool = True,
            return_scores: bool = True,
            z_threshold: float = None,
            **kwargs,
    ) -> dict:
        assert (text is not None) ^ (tokenized_text is not None), "Provide either raw or tokenized text."
        if return_prediction:
            kwargs["return_p_value"] = True
        for normalizer in self.normalizers:
            text = normalizer(text)
        if tokenized_text is None:
            tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(
                self.device)
            if tokenized_text[0] == self.tokenizer.bos_token_id:
                tokenized_text = tokenized_text[1:]
        else:
            if self.tokenizer is not None and tokenized_text[0] == self.tokenizer.bos_token_id:
                tokenized_text = tokenized_text[1:]
        output_dict = {}
        score_dict = self._score_sequence(tokenized_text, **kwargs)
        if return_scores:
            output_dict.update(score_dict)
        if return_prediction:
            z_threshold = z_threshold if z_threshold is not None else self.z_threshold
            output_dict["prediction"] = score_dict["z_score"] > z_threshold
            if output_dict["prediction"]:
                output_dict["confidence"] = 1 - score_dict["p_value"]
        return output_dict

'''
###############################################
# Outline of the New Partitioning Method
###############################################
# 1. Obtain the vocabulary from the model's tokenizer using get_vocabulary(tokenizer).
# 2. Use filter_tokens_with_synonyms(vocab_list) to split the vocabulary indices into:
#       - unique_indices: tokens with no synonyms (set A)
#       - paired_indices: tokens with at least one synonym (set B)
# 3. Construct the similarity matrix for tokens in set B using construct_similarity_matrix.
# 4. Compute a perfect matching on set B using find_perfect_matching.
# 5. In _get_greenlist_ids, use the precomputed pairing (mapped back to original token IDs)
#    and return all tokens from set A plus one token per pair (chosen at random by a coin flip).
###############################################
# Full Code Example for the Revised Watermark Partition
###############################################

if __name__ == "__main__":
    # For testing purposes, use a small model like 'distilgpt2' which can run on CPU.
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b")
    vocab_list = get_vocabulary(tokenizer)
    print(f"Vocabulary size: {len(vocab_list)}")

    # Filter tokens: separate those with no synonyms (set A) and with synonyms (set B)
    unique_indices, paired_indices = filter_tokens_with_synonyms(vocab_list)
    print(f"Number of unique tokens (no synonyms, set A): {len(unique_indices)}")
    print(f"Number of tokens with synonyms (set B): {len(paired_indices)}")

    # Construct similarity matrix for tokens in set B.
    similarity_matrix = construct_similarity_matrix(vocab_list, paired_indices)
    # Find a perfect matching on the tokens in set B.
    matching = find_perfect_matching(similarity_matrix)
    # Map matching indices (relative to paired_indices) back to the original vocabulary indices.
    mapped_pairing = [(paired_indices[i], paired_indices[j]) for (i, j) in matching]
    print("Computed perfect matching (pairs) on tokens with synonyms (set B):")
    print(mapped_pairing)

    # Initialize WatermarkLogitsProcessor with precomputed pairing and unique tokens.
    wm_processor = WatermarkLogitsProcessor(
        vocab=list(range(len(vocab_list))),
        gamma=0.25,
        delta=2.0,
        seeding_scheme="simple_1",
        select_green_tokens=True,
        precomputed_pairing=mapped_pairing,
        unique_tokens=unique_indices
    )

    # Test _get_greenlist_ids with a sample prompt.
    sample_prompt = "This is a good day."
    input_ids = torch.tensor(tokenizer.encode(sample_prompt))
    greenlist_ids = wm_processor._get_greenlist_ids(input_ids)
    print("Greenlist token IDs for the prompt:")
    print(greenlist_ids)
    green_tokens = [vocab_list[tok] for tok in greenlist_ids]
    print("Greenlist tokens (strings):")
    print(green_tokens)'''


'''# coding=utf-8
# Copyright 2023 Authors of "A Watermark for Large Language Models"
# available at https://arxiv.org/abs/2301.10226
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations
import collections
from math import sqrt

import scipy.stats

import torch
from torch import Tensor
from tokenizers import Tokenizer
from transformers import LogitsProcessor

from nltk.util import ngrams

from normalizers import normalization_strategy_lookup


class WatermarkBase:
    def __init__(
        self,
        vocab: list[int] = None,
        gamma: float = 0.5,
        delta: float = 2.0,
        seeding_scheme: str = "simple_1",  # mostly unused/always default
        hash_key: int = 15485863,  # just a large prime number to create a rng seed with sufficient bit width
        select_green_tokens: bool = True,
    ):

        # watermarking parameters
        self.vocab = vocab
        self.vocab_size = len(vocab)
        self.gamma = gamma
        self.delta = delta
        self.seeding_scheme = seeding_scheme
        self.rng = None
        self.hash_key = hash_key
        self.select_green_tokens = select_green_tokens

    def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None:
        # can optionally override the seeding scheme,
        # but uses the instance attr by default
        if seeding_scheme is None:
            seeding_scheme = self.seeding_scheme

        if seeding_scheme == "simple_1":
            assert input_ids.shape[-1] >= 1, f"seeding_scheme={seeding_scheme} requires at least a 1 token prefix sequence to seed rng"
            prev_token = input_ids[-1].item()
            self.rng.manual_seed(self.hash_key * prev_token)
        else:
            raise NotImplementedError(f"Unexpected seeding_scheme: {seeding_scheme}")
        return

    def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
        # seed the rng using the previous tokens/prefix
        # according to the seeding_scheme
        self._seed_rng(input_ids)

        greenlist_size = int(self.vocab_size * self.gamma)
        vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
        if self.select_green_tokens:  # directly
            greenlist_ids = vocab_permutation[:greenlist_size]  # new
        else:  # select green via red
            greenlist_ids = vocab_permutation[(self.vocab_size - greenlist_size) :]  # legacy behavior
        return greenlist_ids


class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
        # TODO lets see if we can lose this loop
        green_tokens_mask = torch.zeros_like(scores)
        for b_idx in range(len(greenlist_token_ids)):
            green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1
        final_mask = green_tokens_mask.bool()
        return final_mask

    def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor:
        scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
        return scores

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:

        # this is lazy to allow us to colocate on the watermarked model's device
        if self.rng is None:
            self.rng = torch.Generator(device=input_ids.device)

        # NOTE, it would be nice to get rid of this batch loop, but currently,
        # the seed and partition operations are not tensor/vectorized, thus
        # each sequence in the batch needs to be treated separately.
        batched_greenlist_ids = [None for _ in range(input_ids.shape[0])]

        for b_idx in range(input_ids.shape[0]):
            greenlist_ids = self._get_greenlist_ids(input_ids[b_idx])
            batched_greenlist_ids[b_idx] = greenlist_ids

        green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=batched_greenlist_ids)

        scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta)
        return scores


class WatermarkDetector(WatermarkBase):
    def __init__(
        self,
        *args,
        device: torch.device = None,
        tokenizer: Tokenizer = None,
        z_threshold: float = 4.0,
        normalizers: list[str] = ["unicode"],  # or also: ["unicode", "homoglyphs", "truecase"]
        ignore_repeated_bigrams: bool = True,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        # also configure the metrics returned/preprocessing options
        assert device, "Must pass device"
        assert tokenizer, "Need an instance of the generating tokenizer to perform detection"

        self.tokenizer = tokenizer
        self.device = device
        self.z_threshold = z_threshold
        self.rng = torch.Generator(device=self.device)

        if self.seeding_scheme == "simple_1":
            self.min_prefix_len = 1
        else:
            raise NotImplementedError(f"Unexpected seeding_scheme: {self.seeding_scheme}")

        self.normalizers = []
        for normalization_strategy in normalizers:
            self.normalizers.append(normalization_strategy_lookup(normalization_strategy))

        self.ignore_repeated_bigrams = ignore_repeated_bigrams
        if self.ignore_repeated_bigrams:
            assert self.seeding_scheme == "simple_1", "No repeated bigram credit variant assumes the single token seeding scheme."

    def _compute_z_score(self, observed_count, T):
        # count refers to number of green tokens, T is total number of tokens
        expected_count = self.gamma
        numer = observed_count - expected_count * T
        denom = sqrt(T * expected_count * (1 - expected_count))
        z = numer / denom
        return z

    def _compute_p_value(self, z):
        p_value = scipy.stats.norm.sf(z)
        return p_value

    def _score_sequence(
        self,
        input_ids: Tensor,
        return_num_tokens_scored: bool = True,
        return_num_green_tokens: bool = True,
        return_green_fraction: bool = True,
        return_green_token_mask: bool = False,
        return_z_score: bool = True,
        return_p_value: bool = True,
    ):
        if self.ignore_repeated_bigrams:
            # Method that only counts a green/red hit once per unique bigram.
            # New num total tokens scored (T) becomes the number unique bigrams.
            # We iterate over all unqiue token bigrams in the input, computing the greenlist
            # induced by the first token in each, and then checking whether the second
            # token falls in that greenlist.
            assert return_green_token_mask is False, "Can't return the green/red mask when ignoring repeats."
            bigram_table = {}
            token_bigram_generator = ngrams(input_ids.cpu().tolist(), 2)
            freq = collections.Counter(token_bigram_generator)
            num_tokens_scored = len(freq.keys())
            for idx, bigram in enumerate(freq.keys()):
                prefix = torch.tensor([bigram[0]], device=self.device)  # expects a 1-d prefix tensor on the randperm device
                greenlist_ids = self._get_greenlist_ids(prefix)
                bigram_table[bigram] = True if bigram[1] in greenlist_ids else False
            green_token_count = sum(bigram_table.values())
        else:
            num_tokens_scored = len(input_ids) - self.min_prefix_len
            if num_tokens_scored < 1:
                raise ValueError(
                    (
                        f"Must have at least {1} token to score after "
                        f"the first min_prefix_len={self.min_prefix_len} tokens required by the seeding scheme."
                    )
                )
            # Standard method.
            # Since we generally need at least 1 token (for the simplest scheme)
            # we start the iteration over the token sequence with a minimum
            # num tokens as the first prefix for the seeding scheme,
            # and at each step, compute the greenlist induced by the
            # current prefix and check if the current token falls in the greenlist.
            green_token_count, green_token_mask = 0, []
            for idx in range(self.min_prefix_len, len(input_ids)):
                curr_token = input_ids[idx]
                greenlist_ids = self._get_greenlist_ids(input_ids[:idx])
                if curr_token in greenlist_ids:
                    green_token_count += 1
                    green_token_mask.append(True)
                else:
                    green_token_mask.append(False)

        score_dict = dict()
        if return_num_tokens_scored:
            score_dict.update(dict(num_tokens_scored=num_tokens_scored))
        if return_num_green_tokens:
            score_dict.update(dict(num_green_tokens=green_token_count))
        if return_green_fraction:
            score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored)))
        if return_z_score:
            score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored)))
        if return_p_value:
            z_score = score_dict.get("z_score")
            if z_score is None:
                z_score = self._compute_z_score(green_token_count, num_tokens_scored)
            score_dict.update(dict(p_value=self._compute_p_value(z_score)))
        if return_green_token_mask:
            score_dict.update(dict(green_token_mask=green_token_mask))

        return score_dict

    def detect(
        self,
        text: str = None,
        tokenized_text: list[int] = None,
        return_prediction: bool = True,
        return_scores: bool = True,
        z_threshold: float = None,
        **kwargs,
    ) -> dict:

        assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string"
        if return_prediction:
            kwargs["return_p_value"] = True  # to return the "confidence":=1-p of positive detections

        # run optional normalizers on text
        for normalizer in self.normalizers:
            text = normalizer(text)
        if len(self.normalizers) > 0:
            print(f"Text after normalization:\n\n{text}\n")

        if tokenized_text is None:
            assert self.tokenizer is not None, (
                "Watermark detection on raw string ",
                "requires an instance of the tokenizer ",
                "that was used at generation time.",
            )
            tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device)
            if tokenized_text[0] == self.tokenizer.bos_token_id:
                tokenized_text = tokenized_text[1:]
        else:
            # try to remove the bos_tok at beginning if it's there
            if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id):
                tokenized_text = tokenized_text[1:]

        # call score method
        output_dict = {}
        score_dict = self._score_sequence(tokenized_text, **kwargs)
        if return_scores:
            output_dict.update(score_dict)
        # if passed return_prediction then perform the hypothesis test and return the outcome
        if return_prediction:
            z_threshold = z_threshold if z_threshold else self.z_threshold
            assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test"
            output_dict["prediction"] = score_dict["z_score"] > z_threshold
            if output_dict["prediction"]:
                output_dict["confidence"] = 1 - score_dict["p_value"]

        return output_dict'''


'# coding=utf-8\n# Copyright 2023 Authors of "A Watermark for Large Language Models"\n# available at https://arxiv.org/abs/2301.10226\n#\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an "AS IS" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\nimport collections\nfrom math import sqrt\n\nimport scipy.stats\n\nimport torch\nfrom torch import Tensor\nfrom tokenizers import Tokenizer\nfrom transformers import LogitsProcessor\n\nfrom nltk.util import ngrams\n\nfrom normalizers import normalization_strategy_looku

In [1]:
pip install numpy

Collecting numpy
  Downloading numpy-2.2.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.4/16.4 MB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: numpy
Successfully installed numpy-2.2.4
Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install transformers

Collecting transformers
  Downloading transformers-4.50.0-py3-none-any.whl (10.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m0m
Collecting regex!=2019.12.17
  Downloading regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (781 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m781.7/781.7 KB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting safetensors>=0.4.3
  Downloading safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (471 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m471.6/471.6 KB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting tokenizers<0.22,>=0.21
  Downloading tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m2.5 MB/s[0m eta [36m0:00:

In [3]:
pip install datasets

Collecting datasets
  Downloading datasets-3.4.1-py3-none-any.whl (487 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m487.4/487.4 KB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting pandas
  Downloading pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.1/13.1 MB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting dill<0.3.9,>=0.3.0
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 KB[0m [31m20.6 MB/s[0m eta [36m0:00:00[0m
Collecting fsspec[http]<=2024.12.0,>=2023.1.0
  Downloading fsspec-2024.12.0-py3-none-any.whl (183 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 KB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow>=15.0.0
  Downloading pyarrow-19.0.1-cp310-cp310-manylinux_2_28_x86_64.whl (42.1 MB)
[2K

In [4]:
pip install scikit-learn

Collecting scikit-learn
  Downloading scikit_learn-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.5/13.5 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting threadpoolctl>=3.1.0
  Downloading threadpoolctl-3.6.0-py3-none-any.whl (18 kB)
Collecting joblib>=1.2.0
  Downloading joblib-1.4.2-py3-none-any.whl (301 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m301.8/301.8 KB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting scipy>=1.6.0
  Downloading scipy-1.15.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m37.6/37.6 MB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: threadpoolctl, scipy, joblib, scikit-learn
Successfully installed joblib-1.4.2 scikit-learn-1.6.1 scipy-1.15.2 threadpoolctl-3.6.0
Note: y

In [5]:
pip install torch

Collecting torch
  Downloading torch-2.6.0-cp310-cp310-manylinux1_x86_64.whl (766.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m766.7/766.7 MB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting nvidia-cuda-runtime-cu12==12.4.127
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 KB[0m [31m139.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting nvidia-nccl-cu12==2.21.5
  Downloading nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl (188.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m188.7/188.7 MB[0m [31m20.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting triton==3.2.0
  Downloading triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (253.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m253.1/253.1 MB[0m [31m19.5 MB/s[0m eta [36m0:

In [6]:
pip install nltk

Collecting nltk
  Downloading nltk-3.9.1-py3-none-any.whl (1.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting click
  Downloading click-8.1.8-py3-none-any.whl (98 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m98.2/98.2 KB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: click, nltk
Successfully installed click-8.1.8 nltk-3.9.1
Note: you may need to restart the kernel to use updated packages.
