# Token Tree Analysis with Gemma
This notebook demonstrates how to build and analyze a token decision tree from the Gemma 2 model. The token tree explores multiple generation paths based on probability thresholds, providing insights into the model's token selection process.

## Installation
Install necessary libraries for model inference, data manipulation, and visualization.

In [1]:
!pip install torch transformers pandas accelerate numpy huggingface-hub



## Authentication
Login to Hugging Face to access the Gemma 2 model.

In [2]:
import os
import sys
from huggingface_hub import login

if 'google.colab' in sys.modules or os.environ.get('COLAB_RESEARCH_RUNTIME') == 'true':
    from google.colab import userdata

    try:
        # Try to retrieve token from Colab secrets
        token = userdata.get('HF_TOKEN')
    except:
        # Fallback to environment variable or manual input
        token = os.environ.get('HF_TOKEN')
else:
    token = os.environ.get('HF_TOKEN')

if token:
    login(token=token)
else:
    print("HF_TOKEN not found. Please login manually.")
    login()

## Imports
Import required libraries for tensor manipulation, model handling, and data structures.

In [3]:
import json
import torch
import networkx as nx
import matplotlib.pyplot as plt
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM

## Model Loading
Load the Gemma 2 model and tokenizer. Ensure you have access to the model on Hugging Face.

In [4]:
model_id = "google/gemma-2-2b-it"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

## Token Tree Building
Define the `TokenNode` class and `build_token_tree` function to explore multiple token generation paths based on probability thresholds. This creates a tree structure representing the model's decision space.

In [5]:
import gc
import sys
from collections import deque
from transformers import DynamicCache

sys.setrecursionlimit(20_000)

class TokenNode:
    def __init__(self, text, prob, token_id, depth, cumulative_prob=1.0):
        self.text = text
        self.prob = prob
        self.cumulative_prob = cumulative_prob
        self.token_id = token_id
        self.depth = depth
        self.children = []
        self.candidates = []
        self.is_finished = False
        self.full_text = ""

    def to_dict(self):
        return {
            "text": self.text,
            "full_text": self.full_text,
            "prob": round(self.prob, 6),
            "cumulative_prob": round(self.cumulative_prob, 8),
            "token_id": self.token_id,
            "depth": self.depth,
            "is_finished": self.is_finished,
            "candidates": self.candidates,
            "children": [child.to_dict() for child in self.children]
        }


def clone_kv_cache(past_key_values):
    """Deep clone KV cache to prevent corruption across branches."""
    if past_key_values is None:
        return None

    # Handle both DynamicCache and tuple formats
    if isinstance(past_key_values, DynamicCache):
        # Use the to_legacy_cache() method to get tuples, then clone
        legacy_cache = past_key_values.to_legacy_cache()
        cloned_cache = DynamicCache()
        for layer_idx, (key, value) in enumerate(legacy_cache):
            cloned_cache.update(key.clone(), value.clone(), layer_idx)
        return cloned_cache
    elif isinstance(past_key_values, tuple):
        # Legacy tuple format - convert to DynamicCache
        cloned_cache = DynamicCache()
        for layer_idx, (key, value) in enumerate(past_key_values):
            cloned_cache.update(key.clone(), value.clone(), layer_idx)
        return cloned_cache
    else:
        # Unknown format, try generic approach
        return tuple(
            tuple(kv.clone() for kv in layer_cache)
            for layer_cache in past_key_values
        )

def build_token_tree(
    model,
    tokenizer,
    prompt,
    min_branch_threshold=0.1,
    max_branch_threshold=0.5,
    soft_queue_limit=30,
    candidate_threshold=0.01,
    max_depth=1000,
    max_completed_paths=2,
    verbose=True
):
    """
    Build token tree with dynamic branching for long-form generation.

    The algorithm adapts branching based on queue size:
    - Few active paths â†’ explore more alternatives (lower threshold)
    - Many active paths â†’ be more selective (higher threshold)
    """
    model.eval()
    device = model.device

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Initial forward pass
    initial_inputs = tokenizer(prompt, return_tensors="pt").to(device)
    initial_input_ids = initial_inputs.input_ids

    if verbose:
        print(f"\n{'='*80}")
        print(f"ðŸŒ³ Building tree for: '{prompt}'")
        print(f"ðŸ“Š Params: branch [{min_branch_threshold:.2f}-{max_branch_threshold:.2f}], "
              f"queue_limit={soft_queue_limit}, max_paths={max_completed_paths}")
        print(f"{'='*80}\n")

    with torch.no_grad():
        outputs = model(initial_input_ids, use_cache=True)
        next_token_logits = outputs.logits[0, -1, :]
        root_past_key_values = outputs.past_key_values

        # Convert to DynamicCache if it's a tuple (for compatibility)
        if isinstance(root_past_key_values, tuple):
            converted_cache = DynamicCache()
            for layer_idx, (key, value) in enumerate(root_past_key_values):
                converted_cache.update(key, value, layer_idx)
            root_past_key_values = converted_cache

        root_probs = F.softmax(next_token_logits, dim=-1)

    root = TokenNode(prompt, 1.0, None, 0, cumulative_prob=1.0)
    root.full_text = prompt

    stats = {
        "total_nodes": 1,
        "completed_paths": 0,
        "max_depth_reached": 0,
        "branches_created": 0
    }

    # Queue: (node, kv_cache, next_token_probs, parent_full_text)
    queue = deque([(root, root_past_key_values, root_probs, prompt)])

    while queue and stats["completed_paths"] < max_completed_paths:
        # Calculate dynamic threshold based on queue saturation
        curr_q_len = len(queue)
        saturation_ratio = min(1.0, curr_q_len / soft_queue_limit)
        current_threshold = min_branch_threshold + saturation_ratio * (
            max_branch_threshold - min_branch_threshold
        )

        curr_node, curr_cache, curr_probs, parent_text = queue.popleft()

        # Check termination conditions
        if curr_node.depth >= max_depth:
            curr_node.is_finished = True
            stats["max_depth_reached"] = max(stats["max_depth_reached"], curr_node.depth)
            del curr_cache, curr_probs
            if verbose:
                print(f"âœ“ Path {stats['completed_paths']}: Reached max_depth ({max_depth})")
            continue

        # Check for EOS tokens
        eos_tokens = []
        if tokenizer.eos_token_id is not None:
            eos_tokens.append(tokenizer.eos_token_id)

        # Gemma-specific
        if 107 not in eos_tokens:
            eos_tokens.append(107) # 107 is Gemma specific, <end_of_turn>

        if curr_node.token_id in eos_tokens:
            curr_node.is_finished = True
            stats["completed_paths"] += 1
            stats["max_depth_reached"] = max(stats["max_depth_reached"], curr_node.depth)
            del curr_cache, curr_probs
            if verbose:
                print(f"âœ“ Path {stats['completed_paths']}: EOS at depth {curr_node.depth}")
                print(f"   Text: {curr_node.full_text[:100]}...")
            continue

        # Store top candidates for visualization
        cand_indices = torch.where(curr_probs > candidate_threshold)[0]
        if len(cand_indices) > 0:
            cand_probs = curr_probs[cand_indices]
            sorted_idx = torch.argsort(cand_probs, descending=True)
            top_cand_indices = cand_indices[sorted_idx[:10]]

            for idx in top_cand_indices:
                token_id = idx.item()
                prob = curr_probs[idx].item()
                try:
                    text = tokenizer.decode([token_id])
                except:
                    text = f"[tok_{token_id}]"

                curr_node.candidates.append({
                    "text": text,
                    "prob": round(prob, 4),
                    "token_id": token_id
                })

        # Determine branches using dynamic threshold
        branch_indices = torch.where(curr_probs > current_threshold)[0]

        # Safety: always take at least 1 (greedy)
        if len(branch_indices) == 0:
            _, branch_indices = torch.topk(curr_probs, 1)

        # Sort by probability
        branch_probs = curr_probs[branch_indices]
        sorted_idx = torch.argsort(branch_probs, descending=True)
        final_branches = branch_indices[sorted_idx]

        # Progress logging
        if verbose and curr_node.depth % 50 == 0 and curr_node.depth > 0:
            print(f"Depth {curr_node.depth}: {len(final_branches)} branches, "
                  f"threshold={current_threshold:.3f}, queue={curr_q_len}, "
                  f"completed={stats['completed_paths']}/{max_completed_paths}")

        # Expand branches
        for idx in final_branches:
            if stats["completed_paths"] >= max_completed_paths:
                break

            token_id = idx.item()
            token_prob = curr_probs[idx].item()

            try:
                token_text = tokenizer.decode([token_id])
            except:
                token_text = f"[tok_{token_id}]"

            cumulative_prob = curr_node.cumulative_prob * token_prob

            child_node = TokenNode(
                text=token_text,
                prob=token_prob,
                token_id=token_id,
                depth=curr_node.depth + 1,
                cumulative_prob=cumulative_prob
            )
            child_node.full_text = parent_text + token_text

            curr_node.children.append(child_node)
            stats["total_nodes"] += 1
            stats["branches_created"] += 1

            # Clone cache (CRITICAL for independent branches)
            child_cache = clone_kv_cache(curr_cache)

            # Forward pass
            next_input_id = torch.tensor([[token_id]], device=device)

            with torch.no_grad():
                outputs = model(
                    next_input_id,
                    past_key_values=child_cache,
                    use_cache=True
                )
                child_logits = outputs.logits[0, -1, :]
                child_probs = F.softmax(child_logits, dim=-1)
                child_cache = outputs.past_key_values

            queue.append((child_node, child_cache, child_probs, child_node.full_text))

        # Cleanup
        del curr_cache, curr_probs

        if stats["total_nodes"] % 100 == 0:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    # Final cleanup
    del queue
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    if verbose:
        print(f"\n{'='*80}")
        print(f"âœ… Tree Complete!")
        print(f"   Total nodes: {stats['total_nodes']}")
        print(f"   Branches created: {stats['branches_created']}")
        print(f"   Completed paths: {stats['completed_paths']}")
        print(f"   Max depth reached: {stats['max_depth_reached']}")
        print(f"{'='*80}\n")

    return root

## Tree Serialization
Define helper functions to save the generated token trees to JSON format for further analysis and visualization and printing tree summary.

In [6]:
from pathlib import Path

def save_tree_to_json(root_node, filename, prompt, params, output_dir="trees"):
    """Save tree structure to JSON file."""
    Path(output_dir).mkdir(exist_ok=True)
    filepath = Path(output_dir) / filename

    data = {
        "metadata": {
            "prompt": prompt,
            "parameters": params
        },
        "tree": root_node.to_dict()
    }
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False)

    print(f"ðŸ’¾ Saved to: {filepath}")

In [7]:
def get_all_complete_paths(node, current_path="", paths=None):
    """Extract all completed generation paths."""
    if paths is None:
        paths = []

    if node.depth > 0:
        current_path += node.text

    if node.is_finished:
        paths.append({
            "text": current_path,
            "cumulative_prob": node.cumulative_prob,
            "depth": node.depth
        })

    for child in node.children:
        get_all_complete_paths(child, current_path, paths)

    return paths

In [8]:
def print_tree_summary(root):
    """Print summary of generated paths."""
    paths = get_all_complete_paths(root)

    print(f"\n{'='*80}")
    print(f"ðŸ“Š GENERATION SUMMARY")
    print(f"{'='*80}")
    print(f"Total completed paths: {len(paths)}\n")

    paths.sort(key=lambda x: x['cumulative_prob'], reverse=True)

    for i, path in enumerate(paths, 1):
        prob_pct = path['cumulative_prob'] * 100
        print(f"\n{i}. Path (Prob: {prob_pct:.4f}%, Depth: {path['depth']} tokens)")
        print(f"{'â”€'*80}")
        # Show first 500 chars
        display_text = path['text'][:500]
        if len(path['text']) > 500:
            display_text += "..."
        print(display_text)

## Experiment: Generate Token Trees
Build and save token trees for various prompts. Each tree explores different branching paths based on token probabilities, allowing us to analyze the model's reasoning process.

In [10]:
prompts = [
    {"prompt": "The secret to happiness is", "file": "happiness_secret_tree.json"},
    {"prompt": "The capital of India is", "file": "india_capital_tree.json"},
    {"prompt": "Roses are red, violets are", "file": "roses_violets_tree.json"},
    {"prompt": "The most important quality in a leader is", "file": "leader_quality_tree.json"},
    {"prompt": "To sort the list of numbers efficiently, the best algorithm to use is", "file": "sort_algorithm_tree.json"},
    {"prompt": "As the alien spaceship landed on the White House lawn, the President decided to", "file": "president_decision_tree.json"},
    {"prompt": "If I'm being completely honest, I think", "file": "honest_opinion_tree.json"},
    {"prompt": "She looked at the scales and saw", "file": "scales_reading_tree.json"},
]

params = {
    "min_branch_threshold": 0.1,    # Explore alternatives > 10%
    "max_branch_threshold": 0.5,    # Cap at 50% when queue is full
    "soft_queue_limit": 30,         # Target ~30 active branches
    "candidate_threshold": 0.01,    # Record alternatives > 1%
    "max_depth": 1000,              # Allow long sequences
    "max_completed_paths": 5        # Generate 5 complete paths
}

# Build trees
for p in prompts:
    root = build_token_tree(model, tokenizer, p["prompt"], **params)
    save_tree_to_json(root, p["file"], p["prompt"], params)
    print_tree_summary(root)
    print("\n")


ðŸŒ³ Building tree for: 'The secret to happiness is'
ðŸ“Š Params: branch [0.10-0.50], queue_limit=30, max_paths=5

Depth 50: 1 branches, threshold=0.500, queue=30, completed=0/5
Depth 50: 1 branches, threshold=0.500, queue=30, completed=0/5
Depth 50: 1 branches, threshold=0.500, queue=30, completed=0/5
Depth 50: 1 branches, threshold=0.500, queue=30, completed=0/5
Depth 50: 1 branches, threshold=0.500, queue=30, completed=0/5
Depth 50: 1 branches, threshold=0.500, queue=30, completed=0/5
Depth 50: 1 branches, threshold=0.500, queue=30, completed=0/5
Depth 50: 1 branches, threshold=0.500, queue=30, completed=0/5
Depth 50: 1 branches, threshold=0.500, queue=30, completed=0/5
Depth 50: 1 branches, threshold=0.500, queue=30, completed=0/5
Depth 50: 1 branches, threshold=0.500, queue=30, completed=0/5
Depth 50: 1 branches, threshold=0.500, queue=30, completed=0/5
Depth 50: 1 branches, threshold=0.500, queue=30, completed=0/5
Depth 50: 1 branches, threshold=0.500, queue=30, completed=0/5
De