### BUT FIRST - Something cool - really showing you how "model inference" works using Hugging Face Models

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
from typing import List, Dict
import math
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM


class TokenPredictor:
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            use_fast=True
        )

        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
            device_map="auto"
        ).eval()

    def predict_tokens(self, prompt: str, max_tokens: int = 100) -> List[Dict]:
        """
        Generate text token by token and track prediction probabilities.
        """
        predictions = []

        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)

        for _ in range(max_tokens):
            with torch.no_grad():
                outputs = self.model(input_ids=input_ids)

            logits = outputs.logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)

            # Top-3 predictions
            top_probs, top_ids = torch.topk(probs, k=3)

            top_probs = top_probs[0].tolist()
            top_ids = top_ids[0].tolist()

            tokens = [self.tokenizer.decode(t) for t in top_ids]

            top_token = tokens[0]
            top_prob = top_probs[0]

            alternatives = [
                (tokens[i], top_probs[i])
                for i in range(1, len(tokens))
            ]

            predictions.append({
                "token": top_token,
                "probability": top_prob,
                "alternatives": alternatives
            })

            # Append generated token
            next_id = torch.tensor([[top_ids[0]]], device=self.device)
            input_ids = torch.cat([input_ids, next_id], dim=-1)

        return predictions


def create_token_graph(model_name: str, predictions: List[Dict]) -> nx.DiGraph:
    """
    Create a directed graph showing token predictions and alternatives.
    """
    G = nx.DiGraph()

    G.add_node("START", token=model_name, prob="START", color="lightgreen", size=4000)

    for i, pred in enumerate(predictions):
        token_id = f"t{i}"
        G.add_node(
            token_id,
            token=pred["token"],
            prob=f"{pred['probability'] * 100:.1f}%",
            color="lightblue",
            size=6000,
        )

        if i == 0:
            G.add_edge("START", token_id)
        else:
            G.add_edge(f"t{i - 1}", token_id)

    last_id = None
    for i, pred in enumerate(predictions):
        parent_token = "START" if i == 0 else f"t{i - 1}"

        for j, (alt_token, alt_prob) in enumerate(pred["alternatives"]):
            alt_id = f"t{i}_alt{j}"
            G.add_node(
                alt_id,
                token=alt_token,
                prob=f"{alt_prob * 100:.1f}%",
                color="lightgray",
                size=6000,
            )
            G.add_edge(parent_token, alt_id)
            last_id = parent_token

    G.add_node("END", token="END", prob="100%", color="red", size=6000)
    if last_id:
        G.add_edge(last_id, "END")

    return G


def visualize_predictions(G: nx.DiGraph, figsize=(14, 80)):
    """
    Visualize the token prediction graph with vertical layout and alternatives.
    """
    plt.figure(figsize=figsize)

    pos = {}
    spacing_y = 5
    spacing_x = 5

    main_nodes = [n for n in G.nodes() if "_alt" not in n]
    for i, node in enumerate(main_nodes):
        pos[node] = (0, -i * spacing_y)

    for node in G.nodes():
        if "_alt" in node:
            main_token = node.split("_")[0]
            alt_num = int(node.split("_alt")[1])
            if main_token in pos:
                x_offset = -spacing_x if alt_num == 0 else spacing_x
                pos[node] = (x_offset, pos[main_token][1] + 0.05)

    node_colors = [G.nodes[node]["color"] for node in G.nodes()]
    node_sizes = [G.nodes[node]["size"] for node in G.nodes()]

    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_sizes)
    nx.draw_networkx_edges(G, pos, edge_color="gray", arrows=True, arrowsize=20, alpha=0.7)

    labels = {
        node: f"{G.nodes[node]['token']}\n{G.nodes[node]['prob']}"
        for node in G.nodes()
    }
    nx.draw_networkx_labels(G, pos, labels, font_size=14)

    plt.title("Token prediction (LLaMA 3.2 â€“ Hugging Face)")
    plt.axis("off")

    margin = 8
    x_vals = [x for x, y in pos.values()]
    y_vals = [y for x, y in pos.values()]
    plt.xlim(min(x_vals) - margin, max(x_vals) + margin)
    plt.ylim(min(y_vals) - margin, max(y_vals) + margin)

    return plt


In [None]:
message = "In one sentence, describe the color orange to someone who has never been able to see"

model_name = "meta-llama/Llama-3.2-1B-Instruct"

predictor = TokenPredictor(model_name)
predictions = predictor.predict_tokens(message, max_tokens=30)

G = create_token_graph(model_name, predictions)
plt = visualize_predictions(G)
plt.show()

In [None]:
import gc
del predictor, predictions, G, plt
gc.collect()
torch.cuda.ipc_collect()
torch.cuda.empty_cache()

In [None]:
print(torch.cuda.memory_summary())