# LLM
Setup Basic LLM

In [1]:
!pip install transformers



In [3]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel #GPT2Model
import torch
import numpy as np

ModuleNotFoundError: No module named 'transformers'

In [None]:
# Initialize the tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
# Get the maximum token ID
vocab_size = tokenizer.vocab_size
max_token_id = vocab_size - 1

In [None]:
"""
Generate the next word distribution given an input text using a language model.

    Args:
        input_text (str): The input text for which the next word prediction is generated.
        tokenizer (transformers.PreTrainedTokenizer): The tokenizer used for text encoding and decoding.
        model (transformers.PreTrainedModel): The language model used for prediction.

    Returns:
        list: The probability distribution of the next word based on the input text.

    This function takes an input text and utilizes a pre-trained language model to predict the most likely
    next word in the sequence. It encodes the input text, calculates the probabilities of the next word
    based on the language model's output, and returns the probability distribution of the next word.
"""
def get_next_word_distribution(input_text, tokenizer, model):
  # Encode the text and get output logits
  encoded_input = tokenizer(input_text, return_tensors='pt')
  output = model(**encoded_input)

  logits = output.logits[0]

  # Apply softmax to get probabilities
  all_layers_probabilities = torch.softmax(logits, dim=-1)

  # Get the probabilities from the last layer
  probabilities = all_layers_probabilities[-1]

  return probabilities

In [None]:
def sample_token(probabilities):
  # Sample a token from the probability distribution
  index = torch.multinomial(probabilities, 1).item()
  #print(index, probabilities[index])

  # Decode the token IDs to text
  decoded_text = tokenizer.decode(index)
  return decoded_text

In [None]:
"""
Generate the next word prediction given an input text using a language model.

    Args:
        input_text (str): The input text for which the next word prediction is generated.
        tokenizer (transformers.PreTrainedTokenizer): The tokenizer used for text encoding and decoding.
        model (transformers.PreTrainedModel): The language model used for prediction.

    Returns:
        str: The predicted next word based on the input text.

    This function takes an input text and utilizes a pre-trained language model to predict the most likely
    next word in the sequence. It encodes the input text, calculates the probabilities of the next word
    based on the language model's output, and returns the predicted word as a string.
"""
def get_next_word(input_text, tokenizer, model):
  # Encode the text and get output logits
  encoded_input = tokenizer(input_text, return_tensors='pt')
  output = model(**encoded_input)

  logits = output.logits[0]

  # Apply softmax to get probabilities
  all_layers_probabilities = torch.softmax(logits, dim=-1)

  # Get the probabilities from the last layer
  probabilities = all_layers_probabilities[-1]

  #! this is where we would perturb the distribution

  # Sample a token from the probability distribution
  index = torch.multinomial(probabilities, 1).item()
  #print(index, probabilities[index])

  # Decode the token IDs to text
  decoded_text = tokenizer.decode(index)
  return decoded_text

# Cryptography
Setup and test cryptography tools
Need:
- prf

In [None]:
!pip install pynacl
!pip install libnacl
!pip install pycryptodome
!pip install bitstring

In [None]:
import random
from bitstring import BitArray
import nacl.encoding
import nacl.hash
import nacl.secret
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad, unpad
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
import hashlib
import math

In [None]:
#! Need the output of the prf to be at least the vocab_size
def prf(key, data):
    iv = b'\0' * 16
    cipher = AES.new(key, AES.MODE_CBC, iv)
    prf_output = cipher.encrypt(pad(data, AES.block_size))

    extended_output = b''
    counter = 0
    while len(extended_output) < vocab_size:
        # Concatenate the PRF output with the counter and hash the result
        data_to_hash = prf_output + counter.to_bytes(8, 'big')  # 8 bytes for the counter
        hash_output = hashlib.sha256(data_to_hash).digest()
        extended_output += hash_output
        counter += 1
    return extended_output[:vocab_size]

In [None]:
def int_list_to_bytes(int_list):
    # determine the byte size to handle all token_ids
    byte_length = (max_token_id.bit_length() + 7) // 8  # Calculate byte length needed

    # Convert each integer to a byte sequence of the same length
    bytes_list = [i.to_bytes(byte_length, 'big') for i in int_list]

    # Concatenate all byte sequences into a single byte string
    byte_string = b''.join(bytes_list)

    return byte_string

In [None]:
# c is the number of tokens to use
def PRF(key, salt, n_gram, c):
  encoded_text = tokenizer(n_gram, return_tensors='pt')
  full_gram = encoded_text["input_ids"].tolist()[0]
  c_gram = full_gram[-c:]
  salted_bytes = [salt] + c_gram
  encoded_bytes = int_list_to_bytes(salted_bytes)

  digest = nacl.hash.sha256(encoded_bytes) # 64 bytes
  return truncate_to_vocab_size(prf(key, digest), vocab_size)



Change to counter mode and have the counter be the salt?

Need to look at how to use the output of the prf to perturb the distribution or color tokens

**We want to truncate the prf output to l bits where the vocab is l words. Then you have a 1 or 0 for each token.**

In [None]:
"""
Truncate the output of a PRF to the size of the vocabulary (1 bit per vocab word)
"""
def truncate_to_vocab_size(data, vocab_size):
  # turn bytes into bitstring
  bits = BitArray(data).bin
  bit_array = [int(b) for b in bits]

  # truncate to vocab size
  truncated_bit_array = bit_array[:vocab_size]

  return truncated_bit_array

# truncate_to_vocab_size(prf(digest), vocab_size)

In [None]:
"""
Implements the perturb function from our paper.
p: distribution, r: prf output, delta: amount to adjust logits by
Returns a perturbed distribution p' (p is updated in place to become p' and p is returned)
"""
def perturb(p, r, delta):
  # print('in perturb')
  #! check that this is right and that p can be updated in place. Also check that p' sums to 1
  N = vocab_size

  # Build I: set of indices in [N] for which p_i ∈ [2δ, 1 − 2δ].
  I = set()
  for i, p_i in enumerate(p):
    if (p_i >= 2 * delta and p_i <= 1 - 2 * delta):
      I.add(i)

  # Set w to be the number of indices in [N] for which r_i = 1 and δ′ = δw/(N′ − w).
  #! double check
  w = sum(r)
  #! the formula says N_prime - w. Is this a typo? N_prime is never defined
  delta_prime = (delta * w) / (N - w)

  # print('I: ', I)

  # Adjust probabilities
  for j in I:
    # print('j: ', j)
    # print('r[j]: ', r[j])
    # print('p[j]: ', p[j])
    if (r[j] == 1):
      p[j] += delta
    else:
      p[j] -= delta_prime
    # print('p_prime[j]: ', p[j])
  # the j not in I stay the same and since p was updated in place this has been handled

  return p

#Encode

In [None]:
#! Lior is changing the sampling algorithm to be a random key each time instead of rotating through the keys.
#! This has the benefit of the attacker not being able to predict which key was used for each word that's sampled.
#! It does not break the system because the receiver runs the sum over all words, the attacker might have reordered or changed words
#! (so the receiver does not know which key was used with each word anyway), and in expectation enough words will be sampled with respect to each
#! key so that that key's respective watermark will be present in the overall text.
"""
keys, history, message to be encoded, hardness parameter delta
"""
def encode(keys, h, m, delta, c):
  # get keys that correspond to 1s in message
  watermarking_keys = get_keys_to_use(m, keys)
  # print('watermarking_keys: ', watermarking_keys)
  # get loop limit that is large enough that in expectation, each watermarking key is used enough
  limit = get_limit(len(watermarking_keys))
  # Compute the salt s from the history h
  s = len(h)

  #! for now, concatenate the entire history together as the starting text
  text = ''.join(h) #! what seperator should we use?
  # print('h: ', h)
  # print('text: ', text)

  # watermark for each key
  for j in range(limit):
    i = sample_key(watermarking_keys)
    # print('i: ', i)
    # Apply the language model over previous tokens to get a probability distribution p over the tth token
    p = get_next_word_distribution(text, tokenizer, model)
    # print('p: ', p)
    # compute r
    #! should only feed in the c prior tokens, not all of text
    r = PRF(i, s, text, c)
    print('i, s, text, c: ', i[0], s, text, c)
    print('j, i, r: ', j, i[0], r)
    # print('r: ', r)
    p_prime = perturb(p, r, delta)
    # print('p_prime: ', p_prime)
    # sample next token with p_prime
    token = sample_token(p_prime)
    text += token

  return text

In [None]:
def get_limit(num_watermarks):
  #! temporary constant
  return 20

def get_keys_to_use(m, keys):
  return [key for key, flag in zip(keys, m) if flag == 1]

def sample_key(keys):
  random_index = np.random.randint(len(keys))
  return keys[random_index]

In [None]:
def decode(keys, h, ct, z, c):
  # Compute the salt s from the history h
  s = len(h)
  # initialize counters for each bit in m (seeing if the threshold is crossed for that bit)
  counters = [0 for _ in range(len(keys))]

  # tokenize stegotext
  tokens = tokenizer(ct, return_tensors='pt')['input_ids'][0] #! do we need to start the tokens later to line up with h too?
  text = ''
  idx = 0
  goal = ''.join(h)
  while text != goal:
    text += tokenizer.decode(tokens[idx])
    idx = idx + 1
  tokens = tokens[idx:]
  # text needs to start at h (encode's text starts at h)
  # when making a back and forth system, one side will probably have to call this function with h -1
  # text = ''
  # for each token in stegotext
  for j in range(len(tokens)):
    # print('j: ', j)
    # text += tokenizer.decode(tokens[j])
    # test each key
    for i, key in enumerate(keys):
      # print('i: ', i)
      r = PRF(key, s, text, c)
      print('i, s, text, c: ', i, s, text, c)
      print('j, i, r: ', j, i, r)
      # print('r: ', r)
      current_token_index = tokens[j]
      # print('current_token_index: ', current_token_index)
      # print('r[current_token_index]: ', r[current_token_index])
      if (r[current_token_index] == 1):
        counters[i] += 1
        # print('counters: ', counters)
    text += tokenizer.decode(tokens[j])

  print(counters)

  return counters

In [None]:
# m = [1, 0, 1]
# m = [1, 0]
m = [1]
l = len(m)
# keys = [b'\0' * 32, b'\1' * 32, b'\2' * 32]
# keys = [b'\0' * 32, b'\1' * 32]
keys = [b'\0' * 32]
h = ['Hey! Want to get coffee?', 'Would Friday work for you?']
c = 5

# keys = [get_random_bytes(32), get_random_bytes(32), get_random_bytes(32)]
delta = 0.05
ct = encode(keys, h, m, delta, c)
print('-------------------')
recovered_message = decode(keys, h, ct, None, c)

print(ct)
print(recovered_message)

In [None]:
print(keys)
print(h)
print(m)
print(delta)
print(c)
print(ct)

In [None]:
# T is the number of tokens generated
# s_g is the number of green list tokens
def detect(T, s_g):
  z = (2 * (s_g - T / 2)) / math.sqrt(T)
  print(z)
  return z > 4

In [None]:
detect(40, 40)

In [None]:
get_limit(None)