# **DetectGPT**: Identifying AI-generated text
This notebook implements the DetectGPT algorithm from Mitchell et al. (2023) [1], which helps determine whether a given text is AI-generated. The approach involves perturbing the text and analyzing its log probabilities.

In [83]:
import re
import torch
import random
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer, T5ForConditionalGeneration, T5Tokenizer

## I- Code setup

### 1. Model loading

In [84]:
device = "cuda" if torch.cuda.is_available() else "cpu"

Load the model for text generation and probabilities/perplexity computation 

In [85]:
MODEL_NAME = "gpt2-large" 

# Model list (all tested)
# gpt2
# gpt2-large
# EleutherAI/gpt-j-6B
# EleutherAI/gpt-neox-20b

# Load model
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

# Load tokenizer 
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Set to evaluation mode
model.eval()

model.to(device)
print(device)

cpu


Load the model for perturbation implementation

In [86]:
PERTURBATION_MODEL_NAME = "t5-large"

# Load model
t5_model = T5ForConditionalGeneration.from_pretrained(PERTURBATION_MODEL_NAME,torch_dtype=torch.float16,device_map="auto")

# Load tokenizer 
t5_tokenizer = T5Tokenizer.from_pretrained("t5-large")

# Set to evaluation mode
t5_model.eval()

model.to(device)
print(device)

cpu


### 2. Useful functions

In [65]:
def generate_text(prompt: str, max_length: int) -> str:
    """
    Output AI-generated text using the chosen model

    Args:
        prompt (str): prompt to generate text
        max_length (int): the maximum length (~ number of words) of the generated text

    Returns:
        str: generated text
    """
    inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        output = model.generate(**inputs, max_length=max_length, do_sample=True, temperature=0.7)
    return tokenizer.decode(output[0], skip_special_tokens=True)

In [66]:
def compute_log_prob(text: str) -> float:
    """
    Compute the log prob of a given text under the chosen model

    Args:
        text (str): input text for which to compute the log prob

    Returns:
        float: input text log prob
    """
    tokens = tokenizer(text, return_tensors="pt")
    input_ids = tokens["input_ids"]
    n_tokens = input_ids.shape[1]
    
    with torch.no_grad():
        outputs = model(**tokens, labels=input_ids)
        # negative of the NLL per token = log prob
        log_prob = -outputs.loss.item() * n_tokens # total NLL
    return log_prob

In [67]:
def compute_perplexity(text: str) -> float:
    """
    Compute the perplexity score of a given text using the chosen model

    Args:
        text (str): input text for which to compute perplexity

    Returns:
        float: text perplexity score 
    """
    tokens = tokenizer(text, return_tensors="pt")
    input_ids = tokens["input_ids"]

    with torch.no_grad():
        outputs = model(**tokens, labels=input_ids)
        log_prob = outputs.loss # NLL per token
        
    perplexity = torch.exp(log_prob) if log_prob < 100 else float("inf") # overflow possible

    return perplexity

### 3. Perturbation
We define a perturbation function that slightly modifies the text. This allows us to analyze how variations in text influence log probabilities. Different perturbation methods can be used, such as:
- `word_swap_perturbation`: basic word swap function
- `t5_perturbation`: a more sophisticated perturbation method using a transformer

#### Word swap perturbation

In [69]:
def word_swap_perturbation(text: str) -> str:
    """
    Randomly swaps two adjacent words to create a perturbed version of the text

    Args:
        text (str): the input text to be perturbed

    Returns:
        str: the perturbed text 
    """
    words = text.split()
    if len(words) > 3:
        i = random.randint(0, len(words)-2)
        words[i], words[i + 1] = words[i + 1], words[i]
    return " ".join(words)

#### T5 perturbation

In [70]:
def mask_text(text, mask_ratio=0.15, max_words=370):
    words = text.split()

    # Truncate text
    if len(words) > max_words:
        words = words[:max_words]

    num_masks = int(len(words) * mask_ratio)

    # Randomly select spans to mask
    mask_indices = sorted(random.sample(range(len(words) - 1), num_masks))
    for i, idx in enumerate(mask_indices):
        words[idx] = f"<extra_id_{i}>"
        if idx + 1 < len(words): # Ensure a 2-word span
            words[idx + 1] = ""

    return " ".join(words)

def replace_masks(texts):
    """Generate T5 model outputs for masked texts."""
    n_expected = [text.count("<extra_id_") for text in texts]
    stop_id = t5_tokenizer.encode(f"<extra_id_{max(n_expected)}>")[0]

    tokens = t5_tokenizer(texts, return_tensors="pt", padding=True)

    # Move input tensors to model's device just before passing to model
    with torch.no_grad():
        outputs = t5_model.generate(
            input_ids=tokens["input_ids"].to(t5_model.device),  
            attention_mask=tokens["attention_mask"].to(t5_model.device),  
            max_length=150,
            do_sample=True,
            top_p=0.9,
            num_return_sequences=1,
            eos_token_id=stop_id
        )
        
    outputs = outputs.detach().cpu() # Move tensors to CPU and detach

    return t5_tokenizer.batch_decode(outputs, skip_special_tokens=False)

def extract_fills(texts):
    """Extract the generated fills from T5's output."""
    extracted_fills = []
    for text in texts:
        text = text.replace("<pad>", "").replace("</s>", "").strip()
        
        # Use regex to extract text inside <extra_id_X> tokens
        fills = re.findall(r"<extra_id_\d+>\s*(.*?)\s*(?=<extra_id_\d+>|$)", text)

        # Clean extracted tokens
        extracted_fills.append([fill.strip() for fill in fills])

    return extracted_fills

def apply_extracted_fills(masked_texts, extracted_fills):
    """Replace mask tokens in the masked texts with generated fills."""
    filled_texts = []
    
    for masked_text, fills in zip(masked_texts, extracted_fills):
        if not fills:
            filled_texts.append(masked_text)
            continue

        # Iterate through expected mask positions and replace them
        for i, fill in enumerate(fills):
            masked_text = masked_text.replace(f"<extra_id_{i}>", fill, 1)

        filled_texts.append(masked_text)

    return filled_texts

In [71]:
def t5_perturbation(text: str) -> str:
    """
    T5 perturbation

    Args:
        text (str): the input text to be perturbed

    Returns:
        str: the perturbed text 
    """
    masked_text = mask_text(text)
    raw_fills = replace_masks([masked_text])
    extracted_fills = extract_fills(raw_fills)
    perturbed_text = apply_extracted_fills([masked_text], extracted_fills)[0]
    return perturbed_text

### 4. Paper algo
The DetectGPT algorithm works by computing:
- The mean log probability of perturbed texts
- The difference between the original text's log probability and the mean perturbed probability
The final score indicates whether the text is likely AI-generated.

In [None]:
def detectgpt_score(text: str, num_perturbations: int, perturbation_function) -> bool:
    """
    Implement DetectGPT algorithm 1

    Args:
        text (str): input text to be analyzed
        num_perturbations (int): number of perturbed versions of the text to generate
        perturbation_function (function): function to generate perturbed versions of the text (word swap or T5)

    Returns:
        bool: true if the text is likely model-generated
    """
    original_log_prob = compute_log_prob(text) # log prob of the original text

    # Generate perturbed texts + compute their log prob
    perturbed_texts = [perturbation_function(text) for _ in range(num_perturbations)]
    perturbed_log_probs = [compute_log_prob(pt) for pt in perturbed_texts]

    mu = sum(perturbed_log_probs) / num_perturbations # mean log probability of the perturbed texts

    d = original_log_prob - mu # estimate perturbation discrepancy d

    variance = sum((log_prob - mu) ** 2 for log_prob in perturbed_log_probs) / (num_perturbations - 1) # variance of the log probabilities
    sigma = variance ** 0.5 # standard deviation

    score = d / sigma if sigma > 0 else 0 

    return score

## II-Experiments

### 1. Simple use

In [73]:
# Example usage

max_length = 60
prompt = "In a faraway galaxy, where no humans exist"

ai_text = generate_text(prompt,max_length)

# Human text from CNN
human_text = "But Bhaduri found it increasingly hard to secure work after more women began partaking in jatra productions in the 1960s and 1970s. By the time he met Kishore, who was running a theater publication at the time, the actor was in his 60s and only performing a handful of times a year for the equivalent of $1 a night."

print("AI text:", ai_text)
print("human text:", human_text)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


AI text: In a faraway galaxy, where no humans exist, the world is in a terrible state. The space station, which was originally built to serve the planet, has broken down and the crew is trapped in the ship. The only hope for survival is a young girl known as "Sue" who
human text: But Bhaduri found it increasingly hard to secure work after more women began partaking in jatra productions in the 1960s and 1970s. By the time he met Kishore, who was running a theater publication at the time, the actor was in his 60s and only performing a handful of times a year for the equivalent of $1 a night.


In [None]:
num_perturbations = 100
perturbation = t5_perturbation

ai_text_score = detectgpt_score(ai_text,num_perturbations,perturbation)
human_text_score = detectgpt_score(human_text,num_perturbations,perturbation)

print("AI text DetectGPT score:", ai_text_score)
print("Human text DetectGPT score:", human_text_score)

AI text DetectGPT score: 5.055410139882069
Human text DetectGPT score: 3.9946030108610695


### 2. Perplexity score

In [77]:
# Compute perplexity for AI-generated text and human text
perturbation = t5_perturbation

max_length = 60
prompt = "In a faraway galaxy, where no humans exist"
ai_text = generate_text(prompt, max_length)

human_text = "But Bhaduri found it increasingly hard to secure work after more women began partaking in jatra productions in the 1960s and 1970s. By the time he met Kishore, who was running a theater publication at the time, the actor was in his 60s and only performing a handful of times a year for the equivalent of $1 a night."

# Compute perplexity for AI-generated text
perplexity_ai = compute_perplexity(ai_text)
perplexity_ai_perturbed = compute_perplexity(perturbation(ai_text))

# Compute perplexity for human-written text
perplexity_human = compute_perplexity(human_text)
perplexity_human_perturbed = compute_perplexity(perturbation(human_text))

print(f"Perplexity AI: {perplexity_ai}")
print(f"Perplexity AI (perturbed): {perplexity_ai_perturbed}")
print(f"Perplexity human: {perplexity_human}")
print(f"Perplexity human (perturbed): {perplexity_human_perturbed}")

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Perplexity AI: 11.486611366271973
Perplexity AI (perturbed): 54.88127517700195
Perplexity human: 22.94808578491211
Perplexity human (perturbed): 42.53069305419922


In [None]:
# Scatter plot analysis: perplexity before and after perturbation for a wide variety of sample

perturbation = t5_perturbation

num_samples = 100
max_length = 60
prompts = ["In a faraway galaxy, where no humans exist" for _ in range(num_samples)]

# Generate AI-generated texts based on the prompts
ai_texts = [generate_text(prompt, max_length) for prompt in prompts]

# Compute perplexity before and after perturbation
perplexities_before = [compute_perplexity(text) for text in ai_texts]
perplexities_after = [compute_perplexity(perturbation(text)) for text in ai_texts]

In [None]:
# Create scatter plot
plt.scatter(perplexities_before, perplexities_after,marker='x',c='r')

plt.xlabel("Perplexity before perturbation")
plt.ylabel("Perplexity after perturbation")
plt.plot([0,30],[0,30],'b--')
plt.title(f"Perplexity before vs after perturbation for {num_samples} AI-generated texts")
plt.grid(True)
plt.show()

## References
[1] E. Mitchell, C. Lin, A. Bosselut, and C. D. Manning, "DetectGPT: Zero-Shot Machine-Generated Text Detection using Probability Curvature" *arXiv preprint*, 2023. Available at: [arXiv:2301.11305](https://arxiv.org/abs/2301.11305)