# Decoding Simulator for Language Models

This notebook implements a decoding simulator to compare **greedy search**, **beam search**, **top-k sampling**, and **top-p sampling** with **temperature scaling**. It visualizes token paths, computes **entropy** and **diversity** metrics, and analyzes the **diversity vs. coherence** tradeoff, inspired by OpenAI's work (Holtzman et al., 2019).

## Setup
Install dependencies and load the GPT-2 model and tokenizer.

In [None]:
!pip install transformers torch numpy networkx matplotlib ipywidgets
import torch
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from IPython.display import display
import ipywidgets as widgets

# Initialize model and tokenizer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model.eval()

## Helper Functions
Define functions for probability, entropy, diversity, and visualization.

In [None]:
# Compute log probability
def get_log_prob(logits, token_id):
    probabilities = torch.nn.functional.softmax(logits, dim=-1)
    log_probabilities = torch.log(probabilities)
    return log_probabilities[token_id].item()

# Compute entropy
def compute_entropy(logits):
    probabilities = torch.nn.functional.softmax(logits, dim=-1)
    return -torch.sum(probabilities * torch.log(probabilities + 1e-10)).item()

# Compute diversity (unique bigrams)
def compute_diversity(texts):
    bigrams = set()
    for text in texts:
        tokens = tokenizer.tokenize(text)
        for i in range(len(tokens) - 1):
            bigrams.add((tokens[i], tokens[i + 1]))
    return len(bigrams)

# Visualize token paths
def plot_graph(graph, length, title):
    plt.figure(figsize=(8, 6), dpi=300, facecolor='white')
    pos = nx.nx_agraph.graphviz_layout(graph, prog="dot")
    scores = [data['tokenscore'] for _, data in graph.nodes(data=True) if data['token'] is not None]
    vmin, vmax = min(scores), max(scores)
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    cmap = LinearSegmentedColormap.from_list('rg', ["r", "y", "g"], N=256)
    
    nx.draw_networkx_nodes(graph, pos, node_size=2000, node_shape='o', alpha=1, 
                          linewidths=4, node_color=scores, cmap=cmap)
    nx.draw_networkx_edges(graph, pos)
    labels = {node: f"{data['token'].split('_')[0]}\n{data['tokenscore']:.2f}%\nH={data['entropy']:.2f}" 
              for node, data in graph.nodes(data=True) if data['token'] is not None}
    nx.draw_networkx_labels(graph, pos, labels=labels, font_size=8)
    
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    plt.colorbar(sm, orientation='vertical', label='Token probability (%)')
    plt.title(title)
    plt.box(False)
    plt.show()

## Decoding Strategies
Implement greedy search, beam search, top-k sampling, and top-p sampling.

In [None]:
# Greedy Search
def greedy_search(input_ids, node, graph, length=5):
    if length == 0:
        return input_ids
    outputs = model(input_ids)
    logits = outputs.logits[0, -1, :]
    token_id = torch.argmax(logits).unsqueeze(0)
    token_score = get_log_prob(logits, token_id)
    new_input_ids = torch.cat([input_ids, token_id.unsqueeze(0)], dim=-1)
    next_token = tokenizer.decode(token_id, skip_special_tokens=True)
    current_node = list(graph.successors(node))[0]
    graph.nodes[current_node]['tokenscore'] = np.exp(token_score) * 100
    graph.nodes[current_node]['token'] = next_token + f"_{length}"
    graph.nodes[current_node]['entropy'] = compute_entropy(logits)
    return greedy_search(new_input_ids, current_node, graph, length - 1)

# Beam Search
def beam_search(input_ids, length=5, num_beams=3):
    beams = [(input_ids, 0.0, [0])]
    graph = nx.DiGraph()
    graph.add_node(0, token='Start', tokenscore=100, entropy=0.0)
    
    for step in range(length):
        new_beams = []
        for beam_ids, beam_score, node_path in beams:
            outputs = model(beam_ids)
            logits = outputs.logits[0, -1, :]
            top_k = torch.topk(logits, num_beams, dim=-1)
            for i, token_id in enumerate(top_k.indices):
                new_score = beam_score + get_log_prob(logits, token_id)
                new_ids = torch.cat([beam_ids, token_id.unsqueeze(0).unsqueeze(0)], dim=-1)
                new_node = len(graph.nodes)
                next_token = tokenizer.decode(token_id, skip_special_tokens=True)
                graph.add_node(new_node, token=next_token + f"_{length-step}", 
                              tokenscore=np.exp(get_log_prob(logits, token_id)) * 100,
                              entropy=compute_entropy(logits))
                graph.add_edge(node_path[-1], new_node)
                new_beams.append((new_ids, new_score, node_path + [new_node]))
        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:num_beams]
    
    return beams[0][0], graph

# Top-k Sampling
def top_k_sampling(input_ids, node, graph, length=5, k=5, temperature=1.0):
    if length == 0:
        return input_ids
    outputs = model(input_ids)
    logits = outputs.logits[0, -1, :] / temperature
    top_k = torch.topk(logits, k, dim=-1)
    probs = torch.nn.functional.softmax(top_k.values, dim=-1)
    token_idx = torch.multinomial(probs, 1)
    token_id = top_k.indices[token_idx].unsqueeze(0)
    token_score = get_log_prob(logits, token_id)
    new_input_ids = torch.cat([input_ids, token_id.unsqueeze(0)], dim=-1)
    next_token = tokenizer.decode(token_id, skip_special_tokens=True)
    current_node = list(graph.successors(node))[0]
    graph.nodes[current_node]['tokenscore'] = np.exp(token_score) * 100
    graph.nodes[current_node]['token'] = next_token + f"_{length}"
    graph.nodes[current_node]['entropy'] = compute_entropy(logits)
    return top_k_sampling(new_input_ids, current_node, graph, length - 1, k, temperature)

# Top-p Sampling
def top_p_sampling(input_ids, node, graph, length=5, p=0.9, temperature=1.0):
    if length == 0:
        return input_ids
    outputs = model(input_ids)
    logits = outputs.logits[0, -1, :] / temperature
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
    nucleus = cumulative_probs <= p
    nucleus_indices = sorted_indices[nucleus]
    nucleus_probs = torch.nn.functional.softmax(sorted_logits[nucleus], dim=-1)
    token_idx = torch.multinomial(nucleus_probs, 1)
    token_id = nucleus_indices[token_idx].unsqueeze(0)
    token_score = get_log_prob(logits, token_id)
    new_input_ids = torch.cat([input_ids, token_id.unsqueeze(0)], dim=-1)
    next_token = tokenizer.decode(token_id, skip_special_tokens=True)
    current_node = list(graph.successors(node))[0]
    graph.nodes[current_node]['tokenscore'] = np.exp(token_score) * 100
    graph.nodes[current_node]['token'] = next_token + f"_{length}"
    graph.nodes[current_node]['entropy'] = compute_entropy(logits)
    return top_p_sampling(new_input_ids, current_node, graph, length - 1, p, temperature)

## Interactive Simulator
Use widgets to adjust parameters and run the simulator.

In [None]:
# Interactive widgets
prompt_widget = widgets.Text(value='The cat is', description='Prompt:')
length_widget = widgets.IntSlider(value=5, min=3, max=10, description='Length:')
k_widget = widgets.IntSlider(value=5, min=1, max=20, description='Top-k:')
p_widget = widgets.FloatSlider(value=0.9, min=0.1, max=1.0, step=0.1, description='Top-p:')
temp_widget = widgets.FloatSlider(value=1.0, min=0.1, max=2.0, step=0.1, description='Temperature:')
beams_widget = widgets.IntSlider(value=3, min=1, max=5, description='Beams:')
run_button = widgets.Button(description='Run Simulator')

def run_simulator(b):
    prompt = prompt_widget.value
    length = length_widget.value
    k = k_widget.value
    p = p_widget.value
    temperature = temp_widget.value
    num_beams = beams_widget.value
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    outputs = []
    
    # Greedy Search
    graph = nx.balanced_tree(1, length, create_using=nx.DiGraph())
    for node in graph.nodes:
        graph.nodes[node]['tokenscore'] = 100
        graph.nodes[node]['token'] = prompt
        graph.nodes[node]['entropy'] = 0.0
    output_ids = greedy_search(input_ids, 0, graph, length)
    greedy_text = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=True)
    outputs.append(('Greedy', greedy_text, graph))
    
    # Beam Search
    output_ids, graph = beam_search(input_ids, length, num_beams)
    beam_text = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=True)
    outputs.append(('Beam', beam_text, graph))
    
    # Top-k Sampling
    graph = nx.balanced_tree(1, length, create_using=nx.DiGraph())
    for node in graph.nodes:
        graph.nodes[node]['tokenscore'] = 100
        graph.nodes[node]['token'] = prompt
        graph.nodes[node]['entropy'] = 0.0
    output_ids = top_k_sampling(input_ids, 0, graph, length, k, temperature)
    topk_text = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=True)
    outputs.append((f'Top-k (k={k}, T={temperature})', topk_text, graph))
    
    # Top-p Sampling
    graph = nx.balanced_tree(1, length, create_using=nx.DiGraph())
    for node in graph.nodes:
        graph.nodes[node]['tokenscore'] = 100
        graph.nodes[node]['token'] = prompt
        graph.nodes[node]['entropy'] = 0.0
    output_ids = top_p_sampling(input_ids, 0, graph, length, p, temperature)
    topp_text = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=True)
    outputs.append((f'Top-p (p={p}, T={temperature})', topp_text, graph))
    
    # Compute diversity
    texts = [text for _, text, _ in outputs]
    diversity = compute_diversity(texts)
    
    # Display results
    print(f"Prompt: {prompt}")
    for name, text, graph in outputs:
        avg_entropy = np.mean([data['entropy'] for _, data in graph.nodes(data=True) if data['entropy'] > 0])
        print(f"\n{name} Output: {text}")
        print(f"Average Entropy: {avg_entropy:.2f}")
        plot_graph(graph, length, name)
    print(f"\nDiversity (unique bigrams): {diversity}")

run_button.on_click(run_simulator)
display(prompt_widget, length_widget, k_widget, p_widget, temp_widget, beams_widget, run_button)

## Analysis
- **Greedy Search**: High coherence, low diversity, low entropy.
- **Beam Search**: Improved coherence, moderate diversity.
- **Top-k/Top-p Sampling**: High diversity, potential incoherence with high temperature.
- **Tradeoffs**: Sampling methods excel in creative tasks; search methods suit factual tasks.