# SynthID Text: Watermarking for Generated Text

This notebook demonstrates how to use the [SynthID Text library][synthid-code]
to apply and detect watermarks on generated text. It is divided into three major
sections and intended to be run end-to-end.

1.  **_Setup_**: Importing the SynthID Text library, choosing your model (either
    [Gemma][gemma] or [GPT-2][gpt2]) and device (either CPU or GPU, depending
    on your runtime), defining the watermarking configuration, and initializing
    some helper functions.
1.  **_Applying a watermark_**: Loading your selected model using the
    [Hugging Face Transformers][transformers] library, using that model to
    generate some watermarked text, and comparing the perplexity of the
    watermarked text to that of text generated by the base model.
1.  **_Detecting a watermark_**: Training a detector to recognize text generated
    with a specific watermarking configuration, and then using that detector to
    predict whether a set of examples were generated with that configuration.

As the reference implementation for the
[SynthID Text paper in _Nature_][synthid-paper], this library and notebook are
intended for research review and reproduction only. They should not be used in
production systems. For a production-grade implementation, check out the
official SynthID logits processor in [Hugging Face Transformers][transformers].

[gemma]: https://ai.google.dev/gemma/docs/model_card
[gpt2]: https://huggingface.co/openai-community/gpt2
[synthid-code]: https://github.com/google-deepmind/synthid-text
[synthid-paper]: https://www.nature.com/
[transformers]: https://huggingface.co/docs/transformers/en/index

# 1. Setup

In [None]:
!pip install -q -r ../requirements.txt

In [1]:
import sys
import os

# Ajoutez le chemin de votre répertoire src
sys.path.append(os.path.abspath(".."))

In [4]:
from collections.abc import Sequence
import enum
import gc

import datasets
import huggingface_hub

from src.synthid_text import detector_mean
from src.synthid_text import logits_processing
from src.synthid_text import synthid_mixin
from src.synthid_text import detector_bayesian
import tensorflow as tf
import torch
import tqdm
import transformers
import accelerate

In [5]:
class ModelName(enum.Enum):
  GPT2 = 'gpt2'
  GEMMA_2B = 'google/gemma-2b-it'
  GEMMA_7B = 'google/gemma-7b-it'
  OLMO = 'allenai/OLMo-1B-0724-hf'
  LLAMA = 'meta-llama/Llama-3.2-1B'
  TINY_LLAMA = 'TinyLlama/TinyLlama_v1.1'

model_name = 'meta-llama/Llama-3.2-1B'
MODEL_NAME = ModelName(model_name)

if MODEL_NAME is not ModelName.GPT2:
  huggingface_hub.notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [6]:
DEVICE = (
    torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
)
DEVICE

device(type='cpu')

In [7]:
CONFIG = synthid_mixin.DEFAULT_WATERMARKING_CONFIG
CONFIG

immutabledict({'ngram_len': 5, 'keys': [654, 400, 836, 123, 340, 443, 597, 160, 57, 29, 590, 639, 13, 715, 468, 990, 966, 226, 324, 585, 118, 504, 421, 521, 129, 669, 732, 225, 90, 960], 'sampling_table_size': 65536, 'sampling_table_seed': 0, 'context_history_size': 1024, 'device': device(type='cpu')})

In [8]:
BATCH_SIZE = 8
NUM_BATCHES = 320
OUTPUTS_LEN = 1024
TEMPERATURE = 0.5
TOP_K = 40
TOP_P = 0.99

tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME.value)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

logits_processor = logits_processing.SynthIDLogitsProcessor(
    **CONFIG, top_k=TOP_K, temperature=TEMPERATURE
)

loading file tokenizer.model from cache at C:\Users\jujus\.cache\huggingface\hub\models--google--gemma-2b-it\snapshots\96988410cbdaeb8d5093d1ebdc5a8fb563e02bad\tokenizer.model
loading file tokenizer.json from cache at C:\Users\jujus\.cache\huggingface\hub\models--google--gemma-2b-it\snapshots\96988410cbdaeb8d5093d1ebdc5a8fb563e02bad\tokenizer.json
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at C:\Users\jujus\.cache\huggingface\hub\models--google--gemma-2b-it\snapshots\96988410cbdaeb8d5093d1ebdc5a8fb563e02bad\special_tokens_map.json
loading file tokenizer_config.json from cache at C:\Users\jujus\.cache\huggingface\hub\models--google--gemma-2b-it\snapshots\96988410cbdaeb8d5093d1ebdc5a8fb563e02bad\tokenizer_config.json


In [9]:
def load_model(
    model_name: ModelName,
    expected_device: torch.device,
    enable_watermarking: bool = False,
) -> transformers.PreTrainedModel:
  if model_name == ModelName.GPT2:
    model_cls = (
        synthid_mixin.SynthIDGPT2LMHeadModel
        if enable_watermarking
        else transformers.GPT2LMHeadModel
    )
    model = model_cls.from_pretrained(model_name.value, device_map='auto')
  elif model_name == ModelName.GEMMA_2B or model_name == ModelName.GEMMA_7B:
    model_cls = (
        synthid_mixin.SynthIDGemmaForCausalLM
        if enable_watermarking
        else transformers.GemmaForCausalLM
    )
    
    model = model_cls.from_pretrained(
        model_name.value,
        device_map='auto',
        torch_dtype=torch.bfloat16,
    )
  else:
      model_cls = (
          synthid_mixin.SynthIDAutoModelForCausalLM
          if enable_watermarking
          else transformers.AutoModelForCausalLM
      )
      model = model_cls.from_pretrained(
          model_name.value,
          device_map='auto',
          torch_dtype=torch.bfloat16
      )    


  if str(model.device) != str(expected_device):
    raise ValueError('Model device not as expected.')

  return model


def _process_raw_prompt(prompt: Sequence[str]) -> str:
  """Add chat template to the raw prompt."""
  if MODEL_NAME == ModelName.GPT2:
    return prompt.decode().strip('"')
  else:
    return tokenizer.apply_chat_template(
        [{'role': 'user', 'content': prompt.decode().strip('"')}],
        tokenize=False,
        add_generation_prompt=True,
    )

# 2. Applying a watermark

In [10]:
gc.collect()
torch.cuda.empty_cache()

batch_size = 1
example_inputs = [
    'I enjoy walking with my cute dog',
    'I am from New York',
    'The test was not so very hard after all',
    "I don't think they can score twice in so short a time",
]
example_inputs = example_inputs * (int(batch_size / 4) + 1)
example_inputs = example_inputs[:batch_size]

inputs = tokenizer(
    example_inputs,
    return_tensors='pt',
    padding=True,
).to(DEVICE)
print("Loading model")
model = load_model(MODEL_NAME, expected_device=DEVICE, enable_watermarking=True)
torch.manual_seed(0)
print("Model loaded")
outputs = model.generate(
    **inputs,
    do_sample=True,
    temperature=0.7,
    max_length=1024,
    top_k=40,
    pad_token_id=tokenizer.eos_token_id,
)
print("Generation finished")
print('Output:\n' + 100 * '-')
for i, output in enumerate(outputs):
  print(tokenizer.decode(output, skip_special_tokens=True))
  print(100 * '-')

del inputs, outputs, model
gc.collect()
torch.cuda.empty_cache()

Loading model


loading configuration file config.json from cache at C:\Users\jujus\.cache\huggingface\hub\models--google--gemma-2b-it\snapshots\96988410cbdaeb8d5093d1ebdc5a8fb563e02bad\config.json
Model config GemmaConfig {
  "architectures": [
    "GemmaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 2,
  "eos_token_id": 1,
  "head_dim": 256,
  "hidden_act": "gelu",
  "hidden_activation": null,
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 16384,
  "max_position_embeddings": 8192,
  "model_type": "gemma",
  "num_attention_heads": 8,
  "num_hidden_layers": 18,
  "num_key_value_heads": 1,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.43.3",
  "use_cache": true,
  "vocab_size": 256000
}

loading weights file model.safetensors from cache at C:\Users\jujus\.cache\huggingface\hub\models--google--gemma-2b-it\snapshots\96988410cbda

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

All model checkpoint weights were used when initializing SynthIDGemmaForCausalLM.

All the weights of SynthIDGemmaForCausalLM were initialized from the model checkpoint at google/gemma-2b-it.
If your task is similar to the task the model of the checkpoint was trained on, you can already use SynthIDGemmaForCausalLM for predictions without further training.
loading configuration file generation_config.json from cache at C:\Users\jujus\.cache\huggingface\hub\models--google--gemma-2b-it\snapshots\96988410cbdaeb8d5093d1ebdc5a8fb563e02bad\generation_config.json
Generate config GenerationConfig {
  "bos_token_id": 2,
  "eos_token_id": 1,
  "pad_token_id": 0
}

Some parameters are on the meta device because they were offloaded to the cpu and disk.


Model loaded


  batched_outputs = func(*batched_inputs, **kwargs)


KeyboardInterrupt: 

In [None]:
gc.collect()
torch.cuda.empty_cache()

batch_size = 1
example_inputs = [
    'I enjoy walking with my cute dog',
    'I am from New York',
    'The test was not so very hard after all',
    "I don't think they can score twice in so short a time",
]
example_inputs = example_inputs * (int(batch_size / 4) + 1)
example_inputs = example_inputs[:batch_size]

inputs = tokenizer(
    example_inputs,
    return_tensors='pt',
    padding=True,
).to(DEVICE)
print("Loading model")
model = load_model(MODEL_NAME, expected_device=DEVICE, enable_watermarking=False)
torch.manual_seed(0)
print("Model loaded")
outputs = model.generate(
    **inputs,
    do_sample=True,
    temperature=0.7,
    max_length=1024,
    top_k=40,
    pad_token_id=tokenizer.eos_token_id,
)
print("Generation finished")
print('Output:\n' + 100 * '-')
for i, output in enumerate(outputs):
  print(tokenizer.decode(output, skip_special_tokens=True))
  print(100 * '-')

del inputs, outputs, model
gc.collect()
torch.cuda.empty_cache()

# 3. Detecting a watermark

To detect the watermark, you have two options:
1.   Use the simple **Mean** scoring function. This can be done quickly and requires no training.
2.   Use the more powerful **Bayesian** scoring function. This requires training and takes more time.

For full explanation of these scoring functions, see the paper and its Supplementary Materials.


In [None]:
NUM_NEGATIVES = 10000
POS_BATCH_SIZE = 32
NUM_POS_BATCHES = 313
NEG_BATCH_SIZE = 32
# Truncate outputs to this length for training.
POS_TRUNCATION_LENGTH = 200
NEG_TRUNCATION_LENGTH = 200
# Pad trucated outputs to this length for equal shape across all batches.
MAX_PADDED_LENGTH = 1000
TEMPERATURE = 1.0

In [None]:
def generate_responses(example_inputs, enable_watermarking):
  inputs = tokenizer(
      example_inputs,
      return_tensors='pt',
      padding=True,
  ).to(DEVICE)

  gc.collect()
  torch.cuda.empty_cache()

  model = load_model(
      MODEL_NAME,
      expected_device=DEVICE,
      enable_watermarking=enable_watermarking,
  )
  torch.manual_seed(0)
  _, inputs_len = inputs['input_ids'].shape

  outputs = model.generate(
      **inputs,
      do_sample=True,
      max_length=inputs_len + OUTPUTS_LEN,
      temperature=TEMPERATURE,
      top_k=TOP_K,
      top_p=TOP_P,
      pad_token_id=tokenizer.eos_token_id,
  )

  outputs = outputs[:, inputs_len:]

  # eos mask is computed, skip first ngram_len - 1 tokens
  # eos_mask will be of shape [batch_size, output_len]
  eos_token_mask = logits_processor.compute_eos_token_mask(
      input_ids=outputs,
      eos_token_id=tokenizer.eos_token_id,
  )[:, CONFIG['ngram_len'] - 1 :]

  # context repetition mask is computed
  context_repetition_mask = logits_processor.compute_context_repetition_mask(
      input_ids=outputs,
  )
  # context repitition mask shape [batch_size, output_len - (ngram_len - 1)]

  combined_mask = context_repetition_mask * eos_token_mask

  g_values = logits_processor.compute_g_values(
      input_ids=outputs,
  )
  # g values shape [batch_size, output_len - (ngram_len - 1), depth]

  return g_values, combined_mask


example_inputs = [
    'I enjoy walking with my cute dog',
    'I am from New York',
    'The test was not so very hard after all',
    "I don't think they can score twice in so short a time",
]

wm_g_values, wm_mask = generate_responses(
    example_inputs, enable_watermarking=True
)
uwm_g_values, uwm_mask = generate_responses(
    example_inputs, enable_watermarking=False
)

## Option 1: Mean detector

In [None]:
# Watermarked responses tend to have higher Mean scores than unwatermarked
# responses. To classify responses you can set a score threshold, but this will
# depend on the distribution of scores for your use-case and your desired false
# positive / false negative rates.

wm_mean_scores = detector_mean.mean_score(
    wm_g_values.cpu().numpy(), wm_mask.cpu().numpy()
)
uwm_mean_scores = detector_mean.mean_score(
    uwm_g_values.cpu().numpy(), uwm_mask.cpu().numpy()
)

print('Mean scores for watermarked responses: ', wm_mean_scores)
print('Mean scores for unwatermarked responses: ', uwm_mean_scores)

# You may find that the Weighted Mean scoring function gives better
# classification performance than the Mean scoring function (in particular,
# higher scores for watermarked responses). See the paper for full details.

wm_weighted_mean_scores = detector_mean.weighted_mean_score(
    wm_g_values.cpu().numpy(), wm_mask.cpu().numpy()
)
uwm_weighted_mean_scores = detector_mean.weighted_mean_score(
    uwm_g_values.cpu().numpy(), uwm_mask.cpu().numpy()
)

print(
    'Weighted Mean scores for watermarked responses: ', wm_weighted_mean_scores
)
print(
    'Weighted Mean scores for unwatermarked responses: ',
    uwm_weighted_mean_scores,
)