# SynthID Text: Watermarking for Generated Text

This notebook demonstrates how to use the [SynthID Text library][synthid-code]
to apply and detect watermarks on generated text. It is divided into three major
sections and intended to be run end-to-end.

1.  **_Setup_**: Importing the SynthID Text library, choosing your model (either
    [Gemma][gemma] or [GPT-2][gpt2]) and device (either CPU or GPU, depending
    on your runtime), defining the watermarking configuration, and initializing
    some helper functions.
1.  **_Applying a watermark_**: Loading your selected model using the
    [Hugging Face Transformers][transformers] library, using that model to
    generate some watermarked text, and comparing the perplexity of the
    watermarked text to that of text generated by the base model.
1.  **_Detecting a watermark_**: Training a detector to recognize text generated
    with a specific watermarking configuration, and then using that detector to
    predict whether a set of examples were generated with that configuration.

[gemma]: https://ai.google.dev/gemma/docs/model_card
[gpt2]: https://huggingface.co/openai-community/gpt2
[synthid-code]: https://github.com/google-deepmind/synthid-text
[synthid-paper]: https://www.nature.com/
[transformers]: https://huggingface.co/docs/transformers/en/index

# 1. Setup

In [None]:
!pip install synthid-text[notebook] datasets huggingface_hub tensorflow tqdm transformers accelerate

In [None]:
from collections.abc import Sequence
import enum
import gc

import datasets
import huggingface_hub
from synthid_text import detector_mean
from synthid_text import logits_processing
from synthid_text import synthid_mixin
from synthid_text import detector_bayesian
import tensorflow as tf
import torch
import tqdm
import transformers

In [None]:
class ModelName(enum.Enum):
  GPT2 = 'gpt2'
  GEMMA_2B = 'google/gemma-2b-it'
  GEMMA_7B = 'google/gemma-7b-it'


model_name = 'google/gemma-2b-it' # @param ['gpt2', 'google/gemma-2b-it', 'google/gemma-7b-it']
MODEL_NAME = ModelName(model_name)

if MODEL_NAME is not ModelName.GPT2:
  huggingface_hub.notebook_login()

In [None]:
DEVICE = (
    torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
)
DEVICE

In [None]:
CONFIG = synthid_mixin.DEFAULT_WATERMARKING_CONFIG
CONFIG

In [None]:
BATCH_SIZE = 8
NUM_BATCHES = 320
OUTPUTS_LEN = 1024
TEMPERATURE = 0.5
TOP_K = 40
TOP_P = 0.99

tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME.value)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

logits_processor = logits_processing.SynthIDLogitsProcessor(
    **CONFIG, top_k=TOP_K, temperature=TEMPERATURE
)

In [None]:
def load_model(
    model_name: ModelName,
    expected_device: torch.device,
    enable_watermarking: bool = False,
) -> transformers.PreTrainedModel:
  if model_name == ModelName.GPT2:
    model_cls = (
        synthid_mixin.SynthIDGPT2LMHeadModel
        if enable_watermarking
        else transformers.GPT2LMHeadModel
    )
    model = model_cls.from_pretrained(model_name.value, device_map='auto')
  else:
    model_cls = (
        synthid_mixin.SynthIDGemmaForCausalLM
        if enable_watermarking
        else transformers.GemmaForCausalLM
    )
    model = model_cls.from_pretrained(
        model_name.value,
        device_map='auto',
        torch_dtype=torch.bfloat16,
    )

  if str(model.device) != str(expected_device):
    raise ValueError('Model device not as expected.')

  return model


def _process_raw_prompt(prompt: Sequence[str]) -> str:
  """Add chat template to the raw prompt."""
  if MODEL_NAME == ModelName.GPT2:
    return prompt.decode().strip('"')
  else:
    return tokenizer.apply_chat_template(
        [{'role': 'user', 'content': prompt.decode().strip('"')}],
        tokenize=False,
        add_generation_prompt=True,
    )

# 2. Applying a watermark

In [None]:
gc.collect()
torch.cuda.empty_cache()

batch_size = 1
example_inputs = [
    'I enjoy walking with my cute dog',
    'I am from New York',
    'The test was not so very hard after all',
    "I don't think they can score twice in so short a time",
]
example_inputs = example_inputs * (int(batch_size / 4) + 1)
example_inputs = example_inputs[:batch_size]

inputs = tokenizer(
    example_inputs,
    return_tensors='pt',
    padding=True,
).to(DEVICE)
print("Loading model")
model = load_model(MODEL_NAME, expected_device=DEVICE, enable_watermarking=True)
torch.manual_seed(0)
print("Model loaded")
outputs = model.generate(
    **inputs,
    do_sample=True,
    temperature=0.7,
    max_length=1024,
    top_k=40,
    pad_token_id=tokenizer.eos_token_id,
)
print("Generation finished")
print('Output:\n' + 100 * '-')
for i, output in enumerate(outputs):
  print(tokenizer.decode(output, skip_special_tokens=True))
  print(100 * '-')

del inputs, outputs, model
gc.collect()
torch.cuda.empty_cache()

# 3. Detecting a watermark

To detect the watermark, we use for the moment one option:
We use the simple **Mean** scoring function. This can be done quickly and requires no training.

In [None]:
NUM_NEGATIVES = 10000
POS_BATCH_SIZE = 32
NUM_POS_BATCHES = 313
NEG_BATCH_SIZE = 32
# Truncate outputs to this length for training.
POS_TRUNCATION_LENGTH = 200
NEG_TRUNCATION_LENGTH = 200
# Pad trucated outputs to this length for equal shape across all batches.
MAX_PADDED_LENGTH = 1000
TEMPERATURE = 1.0

In [None]:
def generate_responses(example_inputs, enable_watermarking):
  inputs = tokenizer(
      example_inputs,
      return_tensors='pt',
      padding=True,
  ).to(DEVICE)

  # Watermarked output preparation for detector training
  gc.collect()
  torch.cuda.empty_cache()

  model = load_model(
      MODEL_NAME,
      expected_device=DEVICE,
      enable_watermarking=enable_watermarking,
  )
  torch.manual_seed(0)
  _, inputs_len = inputs['input_ids'].shape

  outputs = model.generate(
      **inputs,
      do_sample=True,
      max_length=inputs_len + OUTPUTS_LEN,
      temperature=TEMPERATURE,
      top_k=TOP_K,
      top_p=TOP_P,
      pad_token_id=tokenizer.eos_token_id,
  )

  outputs = outputs[:, inputs_len:]

  # eos mask is computed, skip first ngram_len - 1 tokens
  # eos_mask will be of shape [batch_size, output_len]
  eos_token_mask = logits_processor.compute_eos_token_mask(
      input_ids=outputs,
      eos_token_id=tokenizer.eos_token_id,
  )[:, CONFIG['ngram_len'] - 1 :]

  # context repetition mask is computed
  context_repetition_mask = logits_processor.compute_context_repetition_mask(
      input_ids=outputs,
  )
  # context repitition mask shape [batch_size, output_len - (ngram_len - 1)]

  combined_mask = context_repetition_mask * eos_token_mask

  g_values = logits_processor.compute_g_values(
      input_ids=outputs,
  )
  # g values shape [batch_size, output_len - (ngram_len - 1), depth]
  del model, inputs

  return g_values, combined_mask


example_inputs = [
    'I enjoy walking with my cute dog',
    'I am from New York',
    'The test was not so very hard after all',
    "I don't think they can score twice in so short a time",
]

wm_g_values, wm_mask = generate_responses(
    example_inputs, enable_watermarking=True
)
uwm_g_values, uwm_mask = generate_responses(
    example_inputs, enable_watermarking=False
)

# Mean detector

In [None]:
# Watermarked responses tend to have higher Mean scores than unwatermarked
# responses. To classify responses you can set a score threshold, but this will
# depend on the distribution of scores for your use-case and your desired false
# positive / false negative rates.

wm_mean_scores = detector_mean.mean_score(
    wm_g_values.cpu().numpy(), wm_mask.cpu().numpy()
)
uwm_mean_scores = detector_mean.mean_score(
    uwm_g_values.cpu().numpy(), uwm_mask.cpu().numpy()
)

print('Mean scores for watermarked responses: ', wm_mean_scores)
print('Mean scores for unwatermarked responses: ', uwm_mean_scores)

# Wefind that the Weighted Mean scoring function gives better
# classification performance than the Mean scoring function (in particular,
# higher scores for watermarked responses). See the paper for full details.

wm_weighted_mean_scores = detector_mean.weighted_mean_score(
    wm_g_values.cpu().numpy(), wm_mask.cpu().numpy()
)
uwm_weighted_mean_scores = detector_mean.weighted_mean_score(
    uwm_g_values.cpu().numpy(), uwm_mask.cpu().numpy()
)

print(
    'Weighted Mean scores for watermarked responses: ', wm_weighted_mean_scores
)
print(
    'Weighted Mean scores for unwatermarked responses: ',
    uwm_weighted_mean_scores,
)
