# Experiment Data Collections Notebook

This notebook generates g-values and perplexities so that you can save for later interpretation. This is what we use for our experiments with the Frequentist scoring function

# 1. Setup

In [None]:
# @title Install and import the required Python packages
#
# @markdown Running this cell may require you to restart your session.

! pip install synthid-text[notebook]

from collections.abc import Sequence
import enum
import gc

import datasets
import huggingface_hub
from synthid_text import detector_mean
from synthid_text import logits_processing
from synthid_text import synthid_mixin
from synthid_text import detector_bayesian
import tensorflow as tf
import torch
import tqdm
import transformers

In [None]:
#@title Login to Hugging Face Hub
huggingface_hub.notebook_login()

In [3]:
# @title Choose your model.

# @markdown Edit this cell to set your pre-trained model name, where your parameters will come from, and model class. AutoModel likely will not work
MODEL_NAME = 'meta-llama/Llama-3.1-8B-Instruct'
MODEL_CLASS = transformers.LlamaForCausalLM

In [None]:
# @title Configure your device
#
# @markdown Its important that your model fits in GPU memory. If you use a very small model like GPT-2, you might be able to get away with using CPU. We used A100 for most experiments

DEVICE = (
    torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
)
DEVICE

In [None]:
# @title Example watermarking config
#
# @markdown We use the defualt configuration, with m=30 tournament layers and slidingwindow size H=4, context history of 1024 tokens for context masking

CONFIG = synthid_mixin.DEFAULT_WATERMARKING_CONFIG
CONFIG

In [None]:
# @title Initialize the required constants, tokenizer, and logits processor

# Feel free to mess around with these hyperparams
OUTPUTS_LEN = 1024
TEMPERATURE = 0.5
TOP_K = 40
TOP_P = 0.99

tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

logits_processor = logits_processing.SynthIDLogitsProcessor(
    **CONFIG, top_k=TOP_K, temperature=TEMPERATURE
)

In [7]:
# @title Utility functions to load models, compute perplexity, and process prompts.

def load_model(
    model_name: str,
    expected_device: torch.device,
    enable_watermarking: bool = False,
) -> transformers.PreTrainedModel:
  class SynthIDModelClass(synthid_mixin.SynthIDSparseTopKMixin, MODEL_CLASS):
    pass
  model_cls = SynthIDModelClass if enable_watermarking else MODEL_CLASS
  model = model_cls.from_pretrained(
      model_name,
      device_map='auto',
      torch_dtype=torch.bfloat16,
  )

  if str(model.device) != str(expected_device):
    raise ValueError('Model device not as expected.')

  return model


def _compute_perplexity(
    outputs: torch.LongTensor,
    scores: torch.FloatTensor,
    eos_token_mask: torch.LongTensor,
    watermarked: bool = False,
) -> float:
  """Compute perplexity given the model outputs and the logits."""
  len_offset = len(scores)
  if watermarked:
    nll_scores = scores
  else:
    nll_scores = [
        torch.gather(
            -torch.log(torch.nn.Softmax(dim=1)(sc)),
            1,
            outputs[:, -len_offset + idx, None],
        )
        for idx, sc in enumerate(scores)
    ]
  nll_sum = torch.nan_to_num(
      torch.squeeze(torch.stack(nll_scores, dim=1), dim=2)
      * eos_token_mask.long(),
      posinf=0,
  )
  nll_sum = nll_sum.sum(dim=1)
  nll_mean = nll_sum / eos_token_mask.sum(dim=1)
  return nll_mean.sum(dim=0)


def _process_raw_prompt(prompt: Sequence[str]) -> str:
  """Add chat template to the raw prompt."""
  return tokenizer.apply_chat_template(
      [{'role': 'user', 'content': prompt.decode().strip('"')}],
      tokenize=False,
      add_generation_prompt=True,
  )

In [None]:
#@title Load ELI5 dataset

eli5_prompts = datasets.load_dataset("Pavithree/eli5")

In [None]:
#@title Generate g-values scores on ELI5 test

#@markdown We use 100 samples from ELI5, using the maximum batch size that we have enough memory for. You will likely need to play around with these.
NUM_BATCHES = 20 # @param {"type":"integer"}
BATCH_SIZE = 5 # @param {"type":"integer"}
ENABLE_WATERMARKING = True # @param {"type":"boolean"}

model = load_model(MODEL_NAME, expected_device=DEVICE, enable_watermarking=ENABLE_WATERMARKING)
torch.manual_seed(0)

eli5_g_values = []
eli5_combined_mask = []
eli5_perplexities = []
for batch_id in tqdm.tqdm(range(NUM_BATCHES)):
  prompts = eli5_prompts['test']['title'][
      batch_id * BATCH_SIZE:(batch_id + 1) * BATCH_SIZE]
  prompts = [_process_raw_prompt(prompt.encode()) for prompt in prompts]
  inputs = tokenizer(
      prompts,
      return_tensors='pt',
      padding=True,
  ).to(DEVICE)
  _, inputs_len = inputs['input_ids'].shape

  outputs = model.generate(
      **inputs,
      do_sample=True,
      max_length=inputs_len + OUTPUTS_LEN,
      temperature=TEMPERATURE,
      top_k=TOP_K,
      top_p=TOP_P,
      return_dict_in_generate=True,
      output_scores=True,
      pad_token_id=tokenizer.eos_token_id,
  )

  scores = outputs.scores
  outputs = outputs.sequences

  eos_token_mask = logits_processor.compute_eos_token_mask(
      input_ids=outputs[:, inputs_len:],
      eos_token_id=tokenizer.eos_token_id,
  )

  eli5_perplexities.append(_compute_perplexity(outputs, scores, eos_token_mask, watermarked=ENABLE_WATERMARKING).cpu())

  eos_token_mask = eos_token_mask[:, CONFIG['ngram_len'] - 1 :]

  context_repetition_mask = logits_processor.compute_context_repetition_mask(
      input_ids=outputs[:, inputs_len:],
  )

  combined_mask = context_repetition_mask * eos_token_mask

  g_values = logits_processor.compute_g_values(
      input_ids=outputs[:, inputs_len:],
  )

  eli5_g_values.append(g_values.cpu())
  eli5_combined_mask.append(combined_mask.cpu())

  del inputs, prompts, eos_token_mask, context_repetition_mask, combined_mask, g_values, outputs
gc.collect()
torch.cuda.empty_cache()

def cat(l):
  max_len=max([val.shape[1] for val in l])
  return torch.cat([torch.nn.functional.pad(val, (0, 0, 0, max_len-val.shape[1]) if len(val.shape) == 3 else (0, max_len-val.shape[1]), mode="constant", value=(tokenizer.eos_token_id if len(val.shape) == 3 else False)) for val in l])

padded_eli5_g_values = cat(eli5_g_values)
padded_eli5_combined_mask = cat(eli5_combined_mask)


In [None]:
#@title Save results to files
F_MODEL_NAME = MODEL_NAME.replace("/","_")
torch.save(padded_eli5_g_values, f"eli5_g_values_{F_MODEL_NAME}_t={TEMPERATURE}_{'wm' if ENABLE_WATERMARKING else 'uwm'}.pt")
torch.save(padded_eli5_combined_mask, f"eli5_combined_mask_{F_MODEL_NAME}_t={TEMPERATURE}_{'wm' if ENABLE_WATERMARKING else 'uwm'}.pt")
torch.save(eli5_perplexities, f"eli5_perplexities_{F_MODEL_NAME}_t={TEMPERATURE}_{'wm' if ENABLE_WATERMARKING else 'uwm'}.pt")