In [1]:
import transformers
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
import numpy as np
import json

## 2.3 How perplexed can you get?

In [2]:
model_id = "meta-llama/Llama-3.1-8B"

pipe = pipeline("text-generation", model=model_id, 
                eos_token_id=128001, pad_token_id=128001, device_map="auto")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Device set to use cuda:0


In [3]:
def calculate_perplexity_for_generated(full_text, prompt, model, tokenizer):
    """
    Calculate per-token and global perplexity for the generated part of text.
    
    Args:
        full_text (str): The complete text (prompt + generated)
        prompt (str): The original prompt
        model: The language model
        tokenizer: The tokenizer
    
    Returns:
        dict: Contains per-token perplexities, global perplexity, and tokens for generated part
    """
    # Tokenize both full text and prompt
    full_inputs = tokenizer(full_text, return_tensors="pt")
    prompt_inputs = tokenizer(prompt, return_tensors="pt")
    
    full_input_ids = full_inputs.input_ids.to(model.device)
    prompt_length = prompt_inputs.input_ids.shape[1]
    
    # Get model outputs for the full sequence
    with torch.no_grad():
        outputs = model(full_input_ids)
        logits = outputs.logits
    
    # Calculate log probabilities
    log_probs = F.log_softmax(logits, dim=-1)
    
    # Extract the generated part (everything after the prompt)
    generated_token_log_probs = []
    generated_tokens = []
    
    # Start from prompt_length to get only generated tokens
    for i in range(prompt_length, full_input_ids.shape[1]):
        token_id = full_input_ids[0, i].item()
        token_log_prob = log_probs[0, i-1, token_id].item()  # i-1 because logits are shifted
        generated_token_log_probs.append(token_log_prob)
        generated_tokens.append(tokenizer.decode([token_id]))
    
    if len(generated_token_log_probs) == 0:
        return None  # No generated tokens
    
    # Convert to numpy array
    generated_token_log_probs = np.array(generated_token_log_probs)
    
    # Calculate per-token perplexity for generated tokens
    per_token_perplexity = np.exp(-generated_token_log_probs)
    
    # Calculate global perplexity for generated tokens
    global_perplexity = np.exp(-np.mean(generated_token_log_probs))
    
    # Extract just the generated text
    generated_text = full_text[len(prompt):]
    
    return {
        'generated_tokens': generated_tokens,
        'generated_text': generated_text,
        'per_token_log_probs': generated_token_log_probs,
        'per_token_perplexity': per_token_perplexity,
        'global_perplexity': global_perplexity,
        'prompt': prompt,
        'full_text': full_text
    }

# Access the model and tokenizer from the pipeline
model = pipe.model
tokenizer = pipe.tokenizer

In [4]:
pipe("The capital of France is")

[{'generated_text': 'The capital of France is one of the most popular tourist destinations in the world, and it’s not hard to see why.'}]

In [5]:
low_prompt1 = "The capital of France is"
low_prompt2 = "Thank you for your email. I am currently out of the office and will"
low_prompt3 = "Error 404:"

high_prompt1 = "In my dream last night, the color blue tasted like" 
high_prompt2 = "How are"
high_prompt3 = "The last person on Earth"

In [6]:
# Generate text and calculate perplexity for generated sequences
prompts = {
    'Low Perplexity': [low_prompt1, low_prompt2, low_prompt3],
    'High Perplexity': [high_prompt1, high_prompt2, high_prompt3]
}

results = {}

print("=== GENERATED TEXT PERPLEXITY ANALYSIS ===\n")

for category, prompt_list in prompts.items():
    print(f"{category.upper()} PROMPTS:")
    results[category] = []
    
    for i, prompt in enumerate(prompt_list, 1):
        # Generate text
        generation = pipe(prompt, max_new_tokens=64, do_sample=False, num_beams=1)
        full_text = generation[0]['generated_text']
        
        # Calculate perplexity for the generated part
        result = calculate_perplexity_for_generated(full_text, prompt, model, tokenizer)
        
        if result is not None:
            results[category].append(result)
            
            print(f"\n{i}. Prompt: '{prompt}'")
            print(f"   Generated: '{result['generated_text']}'")
            print(f"   Global Perplexity: {result['global_perplexity']:.4f}")
        else:
            print(f"\n{i}. Prompt: '{prompt}' - No tokens generated")
    
    print("\n" + "="*60)

=== GENERATED TEXT PERPLEXITY ANALYSIS ===

LOW PERPLEXITY PROMPTS:





1. Prompt: 'The capital of France is'
   Generated: ' a city of many faces. It is a city of history, culture, and art. It is a city of fashion, food, and wine. It is a city of romance, passion, and adventure. It is a city that has something to offer everyone.
Paris is a city that is steeped in history.'
   Global Perplexity: 1.9427

2. Prompt: 'Thank you for your email. I am currently out of the office and will'
   Generated: ' be back on 10th April 2019. If your email is urgent, please contact my office on 020 7219 7020.'
   Global Perplexity: 2.7863

3. Prompt: 'Error 404:'
   Generated: ' Page not found. Sorry, but the page you are looking for does not exist. Please check the URL for proper spelling and capitalization. If you're having trouble locating a destination on Penn State Live, please try visiting the sitemap.'
   Global Perplexity: 1.6961

HIGH PERPLEXITY PROMPTS:

1. Prompt: 'In my dream last night, the color blue tasted like'
   Generated: ' a sweet, creamy, vanilla ice 

In [7]:
formatted_results = {}
for category, res_list in results.items():
    formatted_results[category] = []
    for res in res_list:
        formatted_results[category].append({
            'prompt': res['prompt'],
            'generated_text': res['generated_text'],
            'global_perplexity': res['global_perplexity'],
            'per_token_perplexity': res['per_token_perplexity'].tolist(),
            'generated_tokens': res['generated_tokens'],
            'avg_per_token_perplexity': float(np.mean(res['per_token_perplexity']))
        })

In [8]:
with open("perplexity_results.json", "w") as f:                               
    json.dump(formatted_results, f, indent=2)

In [9]:
for category, res_list in formatted_results.items():
    print(f"\n{category.upper()} PROMPTS DETAILED RESULTS:")
    for i, res in enumerate(res_list, 1):
        print(f"\n{i}. Prompt: '{res['prompt']}'")
        print(f"   Generated: '{res['generated_text']}'")
        print(f"   Global Perplexity: {res['global_perplexity']:.4f}")
        print(f"   Average Per-Token Perplexity: {res['avg_per_token_perplexity']:.4f}")
        # print(f"   Per-Token Perplexities: {res['per_token_perplexity']}")
        # print(f"   Generated Tokens: {res['generated_tokens']}")
    print("\n" + "="*60)


LOW PERPLEXITY PROMPTS DETAILED RESULTS:

1. Prompt: 'The capital of France is'
   Generated: ' a city of many faces. It is a city of history, culture, and art. It is a city of fashion, food, and wine. It is a city of romance, passion, and adventure. It is a city that has something to offer everyone.
Paris is a city that is steeped in history.'
   Global Perplexity: 1.9427
   Average Per-Token Perplexity: 2.4700

2. Prompt: 'Thank you for your email. I am currently out of the office and will'
   Generated: ' be back on 10th April 2019. If your email is urgent, please contact my office on 020 7219 7020.'
   Global Perplexity: 2.7863
   Average Per-Token Perplexity: 6.6761

3. Prompt: 'Error 404:'
   Generated: ' Page not found. Sorry, but the page you are looking for does not exist. Please check the URL for proper spelling and capitalization. If you're having trouble locating a destination on Penn State Live, please try visiting the sitemap.'
   Global Perplexity: 1.6961
   Average Per

In [10]:
formatted_results['Low Perplexity'][1]

{'prompt': 'Thank you for your email. I am currently out of the office and will',
 'generated_text': ' be back on 10th April 2019. If your email is urgent, please contact my office on 020 7219 7020.',
 'global_perplexity': 2.786281929260581,
 'per_token_perplexity': [3.849644694489727,
  2.8508591804911467,
  2.195663156426126,
  3.4060082199050283,
  27.041567040558693,
  3.1996512838248377,
  9.418717264809843,
  2.8625995797040673,
  1.1216551092306364,
  1.7614440859161942,
  2.0104832294895387,
  2.3807405945617703,
  2.0220501727307085,
  3.2358653201368193,
  1.2105000000687076,
  1.5542683678611078,
  1.8838506266306723,
  1.0841472151303213,
  1.5837806209268046,
  9.04368226357122,
  5.073446339505259,
  1.5771913219565126,
  1.179865126698548,
  2.048208231059669,
  1.692588484969701,
  2.502174797210619,
  1.039388896054288,
  1.0044697524242978,
  99.9017308888125,
  4.959717963651696,
  2.2619619633674506],
 'generated_tokens': [' be',
  ' back',
  ' on',
  ' ',
  '10',
 

## 2.4 Beam Search Puzzle

In [11]:
model_name = "Qwen/Qwen3-1.7B"

# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [16]:
# prepare the model input
# prompt = "Print the string ‘break’ and string 'fast' without white space! DON'T GENERATE ANYTHING ELSE"
prompt = "give me the string of: 'hunanexpress'. DON'T RETURN ANYTHING ELSE."
messages = [
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    enable_thinking=False 
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

# conduct text completion
generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=512,
    num_beams=10,
    num_return_sequences=10,
    early_stopping=False
)

for out in generated_ids:
    output_ids = out[len(model_inputs.input_ids[0]):].tolist()
    print(output_ids)
    try:
        index = len(output_ids) - output_ids[::-1].index(151668)
    except ValueError:
        index = 0
    print(tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n"))

[71, 64634, 13788, 151645, 151643, 151643, 151643, 151643, 151643]
hunanexpress
[71, 64634, 13788, 151668, 271, 71, 64634, 13788, 151645]
hunanexpress
[71, 359, 309, 327, 1726, 288, 151645, 151643, 151643]
hunamexprees
[71, 359, 309, 327, 1726, 325, 151645, 151643, 151643]
hunamexprese
[71, 359, 309, 327, 1726, 778, 151645, 151643, 151643]
hunamexpress
[71, 359, 309, 327, 1726, 325, 417, 151645, 151643]
hunamexpresept
[71, 359, 276, 13788, 151645, 151643, 151643, 151643, 151643]
hunanexpress
[71, 359, 309, 327, 1726, 325, 267, 151645, 151643]
hunamexpresest
[71, 359, 309, 327, 1873, 151645, 151643, 151643, 151643]
hunamexpress
[71, 359, 309, 327, 649, 778, 151645, 151643, 151643]
hunamexprss


In [19]:
decoded_token = tokenizer.convert_ids_to_tokens(276)
decoded_token

'an'

In [None]:
decoded_token = tokenizer.convert_ids_to_tokens(359)
decoded_token

'un'

In [20]:
decoded_token = tokenizer.convert_ids_to_tokens(64634)
decoded_token

'unan'

In [14]:
decoded_token = tokenizer.convert_ids_to_tokens(8180)
decoded_token

'Ġeat'

In [15]:
decoded_token = tokenizer.convert_ids_to_tokens(624)
decoded_token

'.Ċ'