In [None]:
!pip install -U datasets
!pip install transformers_stream_generator

In [8]:
from datasets import load_dataset
dataset = load_dataset("gsm8k", "main", split="train[:10]", cache_dir="./hf_cache")

# for example in dataset:
#     print("Q:", example["question"], "A:", example["answer"])


In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
# model_name = "Qwen/Qwen2.5-7B-Instruct-1M"
# model_name = "Qwen/Qwen-1_8B-Chat"
model_name = "gpt2"


tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
model.eval()


In [4]:
import torch
import torch.nn.functional as F
import numpy as np

class TopKSamplerWithEntropy:
    def __init__(self, model, tokenizer, k=10, max_length=50):
        self.model = model
        self.tokenizer = tokenizer
        self.k = k
        self.max_length = max_length

    def compute_entropy(self, probs):
        return -torch.sum(probs * probs.log(), dim=-1).item() # H(p)=−∑p(x)logp(x), entropy for the output probability distribution:


    def compute_layerwise_entropy(self, hidden_states):
        """
        Compute entropy for each layer's hidden state at the last token position.
        Normalize and apply softmax across hidden dimensions.
        Here we are computing entropy of that vector as if it were a probablilty distribution.
        But hidden states are not probability distributions. But it can still give a useful signal about how focused the layer’s activation is.

        Output: a list of entropies — one per layer per step.
        """
        entropies = []
        for layer in hidden_states:
            last_token_vec = layer[0, -1, :]  # [hidden_dim]
            probs = F.softmax(last_token_vec, dim=-1)
            entropy = -torch.sum(probs * probs.log()).item()
            entropies.append(entropy)
        return entropies  # one value per layer

    def sample(self, prompt):
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
        output_ids = input_ids.clone()

        output_entropies = []
        layerwise_entropies = []

        for _ in range(self.max_length):
            with torch.no_grad():
                outputs = self.model(output_ids, output_hidden_states=True)
                logits = outputs.logits[:, -1, :]
                probs = F.softmax(logits, dim=-1)

                # Output entropy
                entropy = self.compute_entropy(probs)
                output_entropies.append(entropy)

                # Layerwise entropy
                hidden_states = outputs.hidden_states  # tuple of [layer_i] each of shape [1, seq_len, hidden_dim]
                layer_entropies = self.compute_layerwise_entropy(hidden_states)
                layerwise_entropies.append(layer_entropies)

                # Top-k sampling
                topk_probs, topk_indices = torch.topk(probs, self.k, dim=-1)
                topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)
                next_token = torch.multinomial(topk_probs, num_samples=1)
                next_token_id = topk_indices.gather(-1, next_token)
                output_ids = torch.cat([output_ids, next_token_id], dim=-1)

        generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
        return generated_text, output_entropies, layerwise_entropies


In [None]:
sampler = TopKSamplerWithEntropy(model, tokenizer, k=10, max_length=30)

for i, example in enumerate(dataset):
    print(f"\nExample {i + 1}")
    prompt = example["question"]
    print("Prompt:", prompt)

    generated, output_entropies, layerwise_entropies = sampler.sample(prompt)
    print("Generated:", generated)
    print("Output entropy:", output_entropies)
    print("Layerwise entropy (per step):")
    for step_idx, entropies in enumerate(layerwise_entropies):
        print(f"Step {step_idx + 1}: {entropies}")


In [6]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def visualize_entropies(prompt, generated_text, output_entropies, layerwise_entropies, save_path=None):
    num_layers = len(layerwise_entropies)
    num_tokens = len(output_entropies)

    fig, axs = plt.subplots(2, 1, figsize=(14, 10), gridspec_kw={'height_ratios': [1, 2]})
    fig.suptitle("Top-k Sampling and Entropy Visualization", fontsize=16)

    # FOR Output entropy over generation steps
    axs[0].plot(range(1, num_tokens + 1), output_entropies, marker='o', label="Output Entropy")
    axs[0].set_title("Token-wise Output Entropy")
    axs[0].set_xlabel("Generation Step")
    axs[0].set_ylabel("Entropy")
    axs[0].grid(True)
    axs[0].legend()

    # FOR  Layer-wise entropy heatmap
    entropy_matrix = np.array(layerwise_entropies)  # shape: [num_layers, num_tokens]
    sns.heatmap(entropy_matrix, ax=axs[1], cmap="viridis", xticklabels=True, yticklabels=True)
    axs[1].set_title("Layerwise Entropy Heatmap")
    axs[1].set_xlabel("Generation Step")
    axs[1].set_ylabel("Layer")

    plt.tight_layout(rect=[0, 0, 1, 0.96])



In [None]:
# TESTING with examples from the dataset
example = dataset[0]
prompt = example["question"]

generated, output_entropies, layerwise_entropies = sampler.sample(prompt)

print("Prompt:", prompt)
print("Generated:", generated)
print("Output Entropies:", output_entropies)

visualize_entropies(prompt, generated, output_entropies, layerwise_entropies)
