In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import hashlib
import scipy.stats
from math import sqrt
from transformers import AutoModelForCausalLM, AutoTokenizer

class WatermarkForCode:
    def __init__(self, model, tokenizer, entropy_threshold, gamma, delta):
        """
        Initialize the SWEET (Selective WatErmarking via Entropy Thresholding) watermarker.

        Args:
        - model: The language model used for text generation
        - tokenizer: The tokenizer corresponding to the model
        - entropy_threshold: Threshold for high-entropy tokens (τ in the algorithm)
        - gamma: Proportion of vocabulary to be considered as "green" tokens (γ in the algorithm)
        - delta: Logit increase for green tokens (δ in the algorithm)
        """
        self.model = model
        self.tokenizer = tokenizer
        self.entropy_threshold = entropy_threshold
        self.gamma = gamma
        self.delta = delta

    def calculate_entropy(self, probs):
        """Calculate the entropy of a probability distribution."""
        return -torch.sum(probs * torch.log2(probs + 1e-10))

    def generate_with_watermark(self, prompt, max_length=200):
        """
        Generate text with SWEET watermarking applied.

        This method implements Algorithm 1 from the SWEET paper.
        """

        input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
        generated_ids = input_ids.clone()
        watermarked_tokens = []
        past = None


        for t in range(max_length):

            with torch.no_grad():
                outputs = self.model(input_ids=input_ids, past_key_values=past)
                logits = outputs.logits[:, -1, :]
                past = outputs.past_key_values


            probs = F.softmax(logits, dim=-1)


            entropy = self.calculate_entropy(probs)


            if entropy > self.entropy_threshold:

                prev_token = generated_ids[0, -1].item()
                hash_value = hashlib.sha256(str(prev_token).encode()).hexdigest()
                np.random.seed(int(hash_value, 16) % (2**32))


                vocab_size = logits.shape[-1]
                green_size = int(self.gamma * vocab_size)
                green_indices = np.random.choice(vocab_size, green_size, replace=False)


                logits[0, green_indices] += self.delta
                watermarked_tokens.append(True)
            else:
                watermarked_tokens.append(False)


            next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)

            # Append to generated sequence
            generated_ids = torch.cat([generated_ids, next_token], dim=-1)
            input_ids = next_token

            if next_token.item() == self.tokenizer.eos_token_id:
                break

        generated_text = self.tokenizer.decode(generated_ids[0])
        generated_tokens = self.tokenizer.convert_ids_to_tokens(generated_ids[0])

        return generated_text, generated_tokens, watermarked_tokens

    def detect_watermark(self, text, n_grams=2):
        """
        Detect the presence of a watermark in the given text.

        This method implements the watermark detection algorithm described in the SWEET paper.
        """
        # Tokenize the input text
        tokens = self.tokenizer.encode(text)
        green_counts = []

        # Iterate through n-grams in the text
        for i in range(len(tokens) - n_grams):
            # Extract the current n-gram
            n_gram = tokens[i:i+n_grams]

            # Compute hash of the n-gram to seed the random number generator
            hash_value = hashlib.sha256(str(n_gram).encode()).hexdigest()
            np.random.seed(int(hash_value, 16) % (2**32))

            # Randomly select green tokens based on gamma
            vocab_size = len(self.tokenizer)
            green_size = int(self.gamma * vocab_size)
            green_set = set(np.random.choice(vocab_size, green_size, replace=False))

            # Check if the next token is in the green set
            next_token = tokens[i + n_grams]
            if next_token in green_set:
                green_counts.append(1)
            else:
                green_counts.append(0)

        # Calculate observed green ratio (observed watermark fraction)
        observed_wl_frac = np.mean(green_counts)

        # Total number of tokens checked
        T = len(green_counts)

        # Compute z-score using the provided formula
        def compute_z_score(observed_wl_frac, T, gamma):
            numer = observed_wl_frac - gamma
            denom = sqrt(gamma * (1 - gamma) / T)
            z = numer / denom
            return z

        z_score = compute_z_score(observed_wl_frac, T, self.gamma)

        # Compute p-value using the survival function of the standard normal distribution
        p_value = scipy.stats.norm.sf(abs(z_score))

        return z_score, p_value, observed_wl_frac, self.gamma

    def interpret_watermark_detection(self, z_score, p_value, observed_ratio, expected_ratio, significance_level=0.05):
        """
        Interpret the results of watermark detection.

        This method provides a human-readable interpretation of the statistical results.
        """
        if observed_ratio > expected_ratio or p_value < significance_level:
            return f"The text is likely watermarked (z-score: {z_score:.4f}, p-value: {p_value:.4f}, observed ratio: {observed_ratio:.4f}, expected ratio: {expected_ratio:.4f})"
        else:
            return f"The text is not conclusively watermarked (z-score: {z_score:.4f}, p-value: {p_value:.4f}, observed ratio: {observed_ratio:.4f}, expected ratio: {expected_ratio:.4f})"

# Example usage
model_name = "Salesforce/codegen-350M-mono"  # Code-specific model
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Initialize SWEET watermarker with parameters tuned for code generation
watermarker = WatermarkForCode(
    model=model,
    tokenizer=tokenizer,
    entropy_threshold=3.0,  # Entropy threshold τ for selective watermarking
    gamma=0.4,  # Proportion γ of vocabulary to be considered as green tokens
    delta=0.1   # Logit increase δ for green tokens
)

# Generate watermarked code
prompt = "def factorial(n):\n    # Implement Factorial function\n    "
generated_text, generated_tokens, watermarked_tokens = watermarker.generate_with_watermark(prompt, max_length=200)

print("Prompt:")
print(prompt)
print("\nGenerated Text:")
print(generated_text)
print("\nWatermarking Statistics:")
print(f"{sum(watermarked_tokens)} out of {len(watermarked_tokens)} tokens watermarked")
print(f"Watermarking rate: {sum(watermarked_tokens)/len(watermarked_tokens):.2%}")

print("\nWatermarked Tokens:")
for token, is_watermarked in zip(generated_tokens[len(tokenizer.encode(prompt)):], watermarked_tokens):
    if is_watermarked:
        print(tokenizer.convert_tokens_to_string([token]))

# Detect and interpret watermark
z_score, p_value, observed_ratio, expected_ratio = watermarker.detect_watermark(generated_text)
print(f"\nWatermark Detection:")
print(watermarker.interpret_watermark_detection(z_score, p_value, observed_ratio, expected_ratio))

Prompt:
def factorial(n):
    # Implement Factorial function
    

Generated Text:
def factorial(n):
    # Implement Factorial function
    # Be careful
    if n<0:
        raise ValueError

  
    
    

def is_palindrome(number):
    is_positive = True
    number = str(number)
        
    ### Initializing the loop variables i,j where i is a negative number
    # We always begin it with 0 at the first iteration
    # i = 0
    i = 0
    j = 0
    
          
    # The condition corresponding to 4 without $ signs
    if len(number)==1: 
        is_one = True
    else:
          
        j=len(number)-1
          
        while j>=0 or is_positive:
            is_number = True
            if j>=0 and is_negative and number[j]=='$':
                number = number[0:j]+number[j+1:]

Watermarking Statistics:
49 out of 200 tokens watermarked
Watermarking rate: 24.50%

Watermarked Tokens:
#
 Be








def
 is
pal
number
number
###
 Initial
izing
 the
 loop
 variables
 where
 i
 is
 a
 neg