# Token Tree Analysis with Gemma
This notebook demonstrates how to build and analyze a token decision tree from the Gemma 2 model. The token tree explores multiple generation paths based on probability thresholds, providing insights into the model's token selection process.

## Installation
Install necessary libraries for model inference, data manipulation, and visualization.

In [1]:
!pip install torch transformers pandas accelerate numpy huggingface-hub



## Authentication
Login to Hugging Face to access the Gemma 2 model.

In [2]:
import os
from huggingface_hub import login
from google.colab import userdata

try:
    # Try to retrieve token from Colab secrets
    token = userdata.get('HF_TOKEN')
except ImportError:
    # Fallback to environment variable or manual input
    token = os.environ.get('HF_TOKEN')

if token:
    login(token=token)
else:
    print("HF_TOKEN not found. Please login manually.")
    login()

## Imports
Import required libraries for tensor manipulation, model handling, and data structures.

In [3]:
import json
import torch
import networkx as nx
import matplotlib.pyplot as plt
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM

## Model Loading
Load the Gemma 2 model and tokenizer. Ensure you have access to the model on Hugging Face.

In [4]:
model_id = "google/gemma-2-2b-it"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

## Token Tree Building
Define the `TokenNode` class and `build_token_tree` function to explore multiple token generation paths based on probability thresholds. This creates a tree structure representing the model's decision space.

In [5]:
import torch
import torch.nn.functional as F
from collections import deque
import gc
import sys

sys.setrecursionlimit(10_000)

class TokenNode:
    def __init__(self, text, prob, token_id, depth):
        self.text = text
        self.prob = prob
        self.token_id = token_id
        self.depth = depth
        self.children = []
        self.candidates = []
        self.is_finished = False

    def to_dict(self):
        return {
            "text": self.text,
            "prob": self.prob,
            "token_id": self.token_id,
            "depth": self.depth,
            "is_finished": self.is_finished,
            "candidates": self.candidates,
            "children": [child.to_dict() for child in self.children]
        }

def build_token_tree(
    model, tokenizer, prompt,
    branching_threshold=0.2,
    candidate_threshold=0.1,
    max_depth=5,
    max_completed_paths=3
):
    model.eval()
    device = model.device
    torch.cuda.empty_cache()

    # Initialize Root and Cache
    initial_inputs = tokenizer(prompt, return_tensors="pt").to(device)
    initial_input_ids = initial_inputs.input_ids

    with torch.no_grad():
        outputs = model(initial_input_ids, use_cache=True)
        next_token_logits = outputs.logits[0, -1, :]
        root_past_key_values = outputs.past_key_values
        root_probs = F.softmax(next_token_logits, dim=-1)

    root = TokenNode(prompt, 1.0, None, 0)
    state = {"completed_paths": 0, "nodes_count": 0}

    # Queue stores tuples: (Node, PastKeyValues, NextTokenProbs)
    queue = deque([(root, root_past_key_values, root_probs)])

    print(f"Building tree starting from: '{prompt}'...")

    while queue:
        if state["completed_paths"] >= max_completed_paths:
            break

        curr_node, curr_cache, curr_probs = queue.popleft()

        # Termination Checks for current node
        if curr_node.depth >= max_depth:
            del curr_cache, curr_probs # Free memory immediately
            continue

        # for gemma 107 -> <end_of_turn>
        if curr_node.token_id in [tokenizer.eos_token_id, 107]:
            curr_node.is_finished = True
            state["completed_paths"] += 1
            del curr_cache, curr_probs
            continue

        # Using the probs computed in the PREVIOUS step (stored in queue)
        cand_indices = torch.where(curr_probs > candidate_threshold)[0]
        cand_probs = curr_probs[cand_indices]
        sorted_indices = torch.argsort(cand_probs, descending=True)
        final_cand_indices = cand_indices[sorted_indices]

        for idx in final_cand_indices:
            curr_node.candidates.append({
                "text": tokenizer.decode(idx),
                "prob": curr_probs[idx].item(),
                "token_id": idx.item()
            })

        # determine Branches
        branch_indices = torch.where(curr_probs > branching_threshold)[0]

        # if no branches are available, use the first one
        if len(branch_indices) == 0:
            _, branch_indices = torch.topk(curr_probs, 1)

        # Sort branches by probability
        branch_probs = curr_probs[branch_indices]
        branch_sorted = torch.argsort(branch_probs, descending=True)
        final_branches = branch_indices[branch_sorted]

        for idx in final_branches:
            # Stop adding to queue if we already found enough paths
            if state["completed_paths"] >= max_completed_paths:
                break

            token_id = idx.item()
            token_prob = curr_probs[idx].item()
            token_text = tokenizer.decode(token_id)

            child_node = TokenNode(token_text, token_prob, token_id, curr_node.depth + 1)
            curr_node.children.append(child_node)
            state["nodes_count"] += 1

            # Prepare single token input
            next_input_id = torch.tensor([[token_id]], device=device)
            with torch.no_grad():
                outputs = model(
                    next_input_id,
                    past_key_values=curr_cache,
                    use_cache=True
                )
                child_logits = outputs.logits[0, -1, :]
                child_probs = F.softmax(child_logits, dim=-1)
                child_cache = outputs.past_key_values

            # Add child to queue
            queue.append((child_node, child_cache, child_probs))

        del curr_cache
        del curr_probs
        if len(queue) % 10 == 0:
            gc.collect()
            torch.cuda.empty_cache()

    # Final cleanup
    del queue
    gc.collect()
    torch.cuda.empty_cache()
    print(f"Done. Nodes: {state['nodes_count']}, Completed Paths: {state['completed_paths']}")
    return root

## Tree Serialization
Define a function to save the generated token trees to JSON format for further analysis and visualization.

In [None]:
def save_tree_to_json(root_node, filename, prompt, params):
    data = {
        "metadata": {
            "prompt": prompt,
            "parameters": params
        },
        "tree": root_node.to_dict()
    }
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False)
    print(f"Tree saved to {filename}")

## Experiment: Generate Token Trees
Build and save token trees for various prompts. Each tree explores different branching paths based on token probabilities, allowing us to analyze the model's reasoning process.

In [7]:
prompts = [
    {"prompt": "The capital of India is", "file": "india_capital_tree.json"},
    {"prompt": "Roses are red, violets are", "file": "roses_violets_tree.json"},
    {"prompt": "The most important quality in a leader is", "file": "leader_quality_tree.json"},
    {"prompt": "The secret to happiness is", "file": "happiness_secret_tree.json"},
    {"prompt": "To sort the list of numbers efficiently, the best algorithm to use is", "file": "sort_algorithm_tree.json"},
    {"prompt": "As the alien spaceship landed on the White House lawn, the President decided to", "file": "president_decision_tree.json"},
    {"prompt": "If I'm being completely honest, I think", "file": "honest_opinion_tree.json"},
]

params = {
    "branching_threshold": 0.5,  # Follow paths > 50%
    "candidate_threshold": 0.10, # Record alternatives > 10%
    "max_depth": 50,
    "max_completed_paths": 2
}

for p in prompts:
    root = build_token_tree(model, tokenizer, p["prompt"], **params)
    save_tree_to_json(root, p["file"], p["prompt"], params)

Building tree starting from: 'The capital of India is'...
Done. Nodes: 13, Completed Paths: 1
Tree saved to india_capital_tree.json
Building tree starting from: 'Roses are red, violets are'...
Done. Nodes: 50, Completed Paths: 0
Tree saved to roses_violets_tree.json
Building tree starting from: 'The most important quality in a leader is'...
Done. Nodes: 50, Completed Paths: 0
Tree saved to leader_quality_tree.json
Building tree starting from: 'The secret to happiness is'...
Done. Nodes: 50, Completed Paths: 0
Tree saved to happiness_secret_tree.json
Building tree starting from: 'To sort the list of numbers efficiently, the best algorithm to use is'...
Done. Nodes: 50, Completed Paths: 0
Tree saved to sort_algorithm_tree.json
Building tree starting from: 'As the alien spaceship landed on the White House lawn, the President decided to'...
Done. Nodes: 50, Completed Paths: 0
Tree saved to president_decision_tree.json
Building tree starting from: 'If I'm being completely honest, I think'..