# **DetectGPT**: Identifying AI-generated text
This notebook implements the **DetectGPT** method from Mitchell et al. (2023) [1], which helps determine whether a given text is AI-generated. The approach involves perturbing the text and analyzing its log-probabilities.

In [1]:
import re
import json
import torch
import random
import matplotlib.pyplot as plt
from functools import lru_cache
from transformers import AutoModelForCausalLM, AutoTokenizer, T5ForConditionalGeneration, T5Tokenizer

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

## I- **Model setup**

This part is a simple setup of different transformer based models that will be needed to:
1. produce the AI-generated text - ``generation_model``
2. compute the log-probablities - ``computation_model``
3. perturb the text with the T5 perturbation - ``t5_model``

### 1. *Text generation*

In [None]:
CACHE_DIR = "/tmp/huggingface"
GENERATION_MODEL_NAME = "EleutherAI/gpt-j-6B" 

# Model list (all tested)
# gpt2
# gpt2-large
# EleutherAI/gpt-j-6B
# EleutherAI/gpt-neox-20b

# Load model
generation_model = AutoModelForCausalLM.from_pretrained(GENERATION_MODEL_NAME)

# Load tokenizer 
generation_tokenizer = AutoTokenizer.from_pretrained(GENERATION_MODEL_NAME)

# Set model to evaluation mode
generation_model.eval()

generation_model.to(device)
print(device)

  return torch.load(checkpoint_file, map_location=map_location)


KeyboardInterrupt: 

### 2. *Computation*

In [None]:
CACHE_DIR = "/tmp/huggingface"
COMPUTATION_MODEL_NAME = "openai-community/gpt2-large"

# Load model
computation_model = AutoModelForCausalLM.from_pretrained(COMPUTATION_MODEL_NAME, torch_dtype=torch.float16, cache_dir=CACHE_DIR)
# torch_dtype=torch.bfloat16

# Load tokenizer 
computation_tokenizer = AutoTokenizer.from_pretrained(COMPUTATION_MODEL_NAME, cache_dir=CACHE_DIR)
computation_tokenizer.pad_token = computation_tokenizer.eos_token

# Set model to evaluation mode (ensures stable log prob estimation + disables dropout)
computation_model.eval()

computation_model.to(device)
print(device)

### 3. *Perturbation*

In [None]:
CACHE_DIR = "/tmp/huggingface"
PERTURBATION_MODEL_NAME = "t5-large"

# Load model
t5_model = T5ForConditionalGeneration.from_pretrained(PERTURBATION_MODEL_NAME, torch_dtype=torch.float16, cache_dir=CACHE_DIR)

# Load tokenizer 
t5_tokenizer = T5Tokenizer.from_pretrained(PERTURBATION_MODEL_NAME, cache_dir=CACHE_DIR)

# Set to evaluation mode
t5_model.eval()

t5_model.to(device)
print(device)

## II- **Code setup**

### 1. *T5 perturbation*

In [None]:
def batch_mask_text(texts, mask_ratio=0.15, max_words=370):
    """Mask multiple texts at once."""
    masked_texts = []
    mask_indices_list = []
    
    for text in texts:
        words = text.split()
        
        # Truncate text
        if len(words) > max_words:
            words = words[:max_words]
        
        num_masks = int(len(words) * mask_ratio)
        
        # Randomly select spans to mask
        mask_indices = sorted(random.sample(range(len(words) - 1), num_masks))
        mask_indices_list.append(mask_indices)
        
        for i, idx in enumerate(mask_indices):
            words[idx] = f"<extra_id_{i}>"
            if idx + 1 < len(words):  # Ensure a 2-word span
                words[idx + 1] = ""
        
        masked_texts.append(" ".join(words))
    
    return masked_texts, mask_indices_list

def batch_replace_masks(texts, batch_size=8):
    """Generate T5 model outputs for masked texts in batches."""
    all_outputs = []
    
    # Process in batches
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        n_expected = [text.count("<extra_id_") for text in batch_texts]
        stop_id = t5_tokenizer.encode(f"<extra_id_{max(n_expected)}>")[0]
        
        tokens = t5_tokenizer(batch_texts, return_tensors="pt", padding=True)
        
        # Move input tensors to model's device
        with torch.no_grad():
            outputs = t5_model.generate(
                input_ids=tokens["input_ids"].to(t5_model.device),
                attention_mask=tokens["attention_mask"].to(t5_model.device),
                max_length=150,
                do_sample=True,
                top_p=0.9,
                num_return_sequences=1,
                eos_token_id=stop_id
            )
            
        # Move outputs back to CPU to save GPU memory
        outputs = outputs.detach().cpu()
        batch_decoded = t5_tokenizer.batch_decode(outputs, skip_special_tokens=False)
        all_outputs.extend(batch_decoded)
    
    return all_outputs

def batch_extract_fills(texts):
    """Extract the generated fills from T5's output for multiple texts."""
    extracted_fills = []
    for text in texts:
        text = text.replace("<pad>", "").replace("</s>", "").strip()
        
        # Use regex to extract text inside <extra_id_X> tokens
        fills = re.findall(r"<extra_id_\d+>\s*(.*?)\s*(?=<extra_id_\d+>|$)", text)
        
        # Clean extracted tokens
        extracted_fills.append([fill.strip() for fill in fills])
    
    return extracted_fills

def batch_apply_extracted_fills(masked_texts, extracted_fills):
    """Replace mask tokens in the masked texts with generated fills."""
    filled_texts = []
    
    for masked_text, fills in zip(masked_texts, extracted_fills):
        if not fills:
            filled_texts.append(masked_text)
            continue
        
        filled_text = masked_text
        # Iterate through expected mask positions and replace them
        for i, fill in enumerate(fills):
            filled_text = filled_text.replace(f"<extra_id_{i}>", fill, 1)
        
        filled_texts.append(filled_text)
    
    return filled_texts



In [None]:
def t5_perturbation(text: str) -> str:
    """
    T5 perturbation, batch version

    Args:
        text (str): the input texts to be perturbed

    Returns:
        str: the perturbed text 
    """
    # Step 1: Mask all texts at once
    masked_text,_ = batch_mask_text(text)

    # Step 2: Generate replacements in batches
    raw_fills = batch_replace_masks([masked_text])

    # Step 3: Extract fills
    extracted_fills = batch_extract_fills(raw_fills)

    # Step 4: Apply fills
    perturbed_text = batch_apply_extracted_fills([masked_text], extracted_fills)[0]
    return perturbed_text

### 2. *Main function*

In [None]:
def batch_average_log_prob(texts, batch_size=8):
    """Calculate average log probability for multiple texts in batches."""
    all_log_probs = []
    
    # Process in batches
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        
        # Tokenize input
        inputs = computation_tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True)
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)
        
        with torch.no_grad():
            outputs = computation_model(input_ids, labels=input_ids, attention_mask=attention_mask)
        
        # For batch processing, we need to compute loss per sample
        if hasattr(outputs, "loss") and outputs.loss.dim() == 0:
            # If model returns a single loss value for the batch
            avg_log_prob = -outputs.loss.item()
            all_log_probs.extend([avg_log_prob] * len(batch_texts))
        else:
            # If we need to calculate per-sample loss
            # This is a simplification - you might need to adjust based on your model's output
            logits = outputs.logits
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = input_ids[..., 1:].contiguous()
            
            loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
            loss_per_token = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 
                                       shift_labels.view(-1))
            
            # Reshape back to [batch_size, sequence_length]
            loss_per_token = loss_per_token.view(shift_labels.size())
            
            # Calculate average loss per sample by considering attention mask
            sample_losses = []
            for j in range(loss_per_token.size(0)):
                # Use attention mask to identify real tokens
                mask = attention_mask[j, 1:].bool()  # Shift to align with targets
                if mask.sum() > 0:
                    sample_loss = loss_per_token[j][mask].mean().item()
                    sample_losses.append(-sample_loss)  # Negative loss is log probability
                else:
                    sample_losses.append(0.0)
            
            all_log_probs.extend(sample_losses)
    
    return all_log_probs

In [None]:
# TODO: include correct and and max_length option

# Main optimized processing loop
def optimized_processing(texts, num_samples=50, num_perturbations=25, batch_size=8, max_length=50):
    log_probs_per_text_base = []
    log_probs_per_text_transformed = [[] for _ in range(num_samples)]
    
    # Process original texts in batches
    original_texts = [" ".join(texts[j]["text"].split()[:max_length]) for j in range(num_samples)]
    base_log_probs = batch_average_log_prob(original_texts, batch_size)
    
    # For each iteration, process all texts together in batches
    for iter_idx in range(num_perturbations):
        # Step 1: Mask all texts at once
        all_masked_texts, _ = batch_mask_text(original_texts)
        
        # Step 2: Generate replacements in batches
        all_raw_fills = batch_replace_masks(all_masked_texts, batch_size)
        
        # Step 3: Extract fills
        all_extracted_fills = batch_extract_fills(all_raw_fills)
        
        # Step 4: Apply fills
        all_perturbed_texts = batch_apply_extracted_fills(all_masked_texts, all_extracted_fills)
        
        # all_perturbed_texts = t5_perturbation(original_texts)

        # Step 5: Calculate log probs in batches
        all_log_probs = batch_average_log_prob(all_perturbed_texts, batch_size)
        
        # Organize results by original text
        for j in range(num_samples):
            log_probs_per_text_transformed[j].append(all_log_probs[j])

    
    return base_log_probs, log_probs_per_text_transformed

In [None]:
# TODO: add saving option

In [None]:
def compute_detectgpt_discrepancy(log_probs_per_text_base, log_probs_per_text_transformed):
    """
    Compute the DetectGPT discrepancy metric for each of the n_samples texts
    Calculated for num_perturbations perturbations

    Args:
        log_probs_per_text_base (list): original log probability of each text
        log_probs_per_text_transformed (list): list of size n_samples where each element is a list of the num_perturbations perturbed log probabilities

    Returns:
        discrepancy_scores (list): list of discrepancy values (d) for the n_samples texs
    """
    num_samples = len(log_probs_per_text_base) 
    discrepancy_scores = []

    for i in range(num_samples):
        original_log_prob = log_probs_per_text_base[i]
        perturbed_log_probs = log_probs_per_text_transformed[i] # List of perturbed log probs
        num_perturbations = len(perturbed_log_probs) # Number of perturbations

        # Compute mean log probability of the perturbed texts
        mu = sum(perturbed_log_probs) / num_perturbations  

        # Compute discrepancy
        d = original_log_prob - mu  
        discrepancy_scores.append(d)
    
    return discrepancy_scores

### 3. *Utility functions*

In [None]:
# Memory management utilities
def clear_cuda_cache():
    """Clear CUDA cache to free up memory."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# Add caching for tokenization
@lru_cache(maxsize=1024)
def cached_tokenize(text, is_t5=False):
    """Cache tokenization results to avoid repeated work."""
    if is_t5:
        return t5_tokenizer(text, return_tensors="pt", padding=True)
    else:
        return computation_tokenizer(text, return_tensors="pt", padding=True, truncation=True)

## III- **Data loading**

### 1. *Human text*

In [None]:
file_path = "subtaskB_train.jsonl"
with open(file_path, "r", encoding="utf-8") as file:
    data_human = [json.loads(line) for line in file if json.loads(line).get("model") == "human"]

# Print first 3 records
print(data_human[0])

### 2. *AI-generated text*

In [None]:
file_path = ""
with open(file_path, "r", encoding="utf-8") as file:
    data_ai = [json.loads(line) for line in file if json.loads(line).get("model") == "human"]

# Print first 3 records
print(data_ai[0])

In [None]:
def generate_text(prompt: str, max_length: int) -> str:
    """
    Generate AI text from a given prompt.

    Args:
        prompt (str): Prompt to generate text.
        max_length (int): Maximum length of generated text.

    Returns:
        str: Cleaned generated text.
    """
    inputs = generation_tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        output = generation_model.generate(**inputs, max_length=max_length, do_sample=True, temperature=0.7)
    
    generated_text = generation_tokenizer.decode(output[0], skip_special_tokens=True)
    
    # Remove the prompt if it's at the beginning
    if generated_text.startswith(prompt):
        generated_text = generated_text[len(prompt):].strip()
    
    # Clean up formatting
    cleaned_text = " ".join(generated_text.split()).strip()  # Remove excessive spaces
    cleaned_text = cleaned_text.replace('\\"', '"')  # Fix escaped quotes
    cleaned_text = cleaned_text.replace("\\'", "'")  # Fix escaped single quotes

    # Remove leading and trailing quotes if they exist
    cleaned_text = cleaned_text.strip()  # Remove leading/trailing spaces
    if cleaned_text.startswith('"'):
        cleaned_text = cleaned_text[1:].strip()
    if cleaned_text.endswith('"'):
        cleaned_text = cleaned_text[:-1].strip()

    return cleaned_text


def generate_dataset(N: int, max_length: int) -> list:
    """
    Generates a dataset of N AI-generated texts

    Args:
        N (int): number of AI-generated texts
        max_length (int): maximum length of each generated text

    Returns:
        data_ai (list): datatset - list of N generated texts
    """
    prompt = "Write a random excerpt from an unpublished novel:"
    data_ai = [generate_text(prompt, max_length) for _ in range(N)]
    return data_ai

In [None]:
# TODO: add saving option

In [None]:
N = 5
max_length = 100
data_ai = generate_dataset(N, max_length)

In [None]:
data_ai

## IV- **Exemple usage**

In [None]:
texts = data_human
num_samples = 100
num_perturbations = 25
batch_size = 128

log_probs_base, log_probs_transformed = optimized_processing(texts, num_samples, num_perturbations, batch_size)
# TODO: add saving option
discrepancy_scores = compute_detectgpt_discrepancy(log_probs_base, log_probs_transformed)