In [None]:
from typing import Any, Union

import plotly.graph_objects as go
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


In [None]:
def load_tokenizer(model_id: str, trust_remote_code: bool = False) -> AutoTokenizer:
    """
    Load a tokenizer from the Hugging Face model.

    Args:
        model_id (str): The name of the model to load the tokenizer for.
        trust_remote_code (bool, optional): Whether to trust remote code. Defaults to False.

    Returns:
        AutoTokenizer: The loaded tokenizer.

    Example:
        >>> tokenizer = load_tokenizer("meta-llama/Llama-3.1-8B-Instruct")
        >>> print(type(tokenizer))
        <class 'transformers.models.auto.tokenization_auto.AutoTokenizer'>
    """
    # Load the tokenizer from the Hugging Face model
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code)
    return tokenizer


def load_model(
    model_name: str,
    dtype: Union[str, torch.dtype] = "bfloat16",
    device_map: Union[str, None] = "auto",
    low_cpu_mem_usage: bool = True,
    attn_implementation: str = "eager",
    trust_remote_code: bool = False
) -> AutoModelForCausalLM:
    """
    Load a causal language model from the Hugging Face model.

    Args:
        model_name (str): The name of the model to load.
        dtype (Union[str, torch.dtype], optional): The data type to load the model with. Defaults to "bfloat16".
        device_map (Union[str, None], optional): The device map to use for loading the model. Defaults to "auto".
        low_cpu_mem_usage (bool, optional): Whether to use low CPU memory usage. Defaults to True.
        attn_implementation (str, optional): The attention implementation to use. Defaults to "eager".
        trust_remote_code (bool, optional): Whether to trust remote code. Defaults to False.

    Returns:
        AutoModelForCausalLM: The loaded model.

    Example:
        >>> model = load_model("meta-llama/Llama-3.1-8B-Instruct")
        >>> print(type(model))
        <class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'>
    """
    # Convert dtype string to torch dtype if necessary
    if isinstance(dtype, str) and dtype != "auto":
        dtype = getattr(torch, dtype)  # Convert dtype string to torch dtype

    # Load the model from the Hugging Face model
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=dtype,
        device_map=device_map,
        low_cpu_mem_usage=low_cpu_mem_usage,
        attn_implementation=attn_implementation,
        trust_remote_code=trust_remote_code
    ).eval()  # Set the model to evaluation mode

    return model


In [None]:
# Visualize attention maps of multiple decoder blocks using Plotly's slider
def visualize_attention_with_slider(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    input_ids: torch.Tensor,  # (batch_size, seq_len)
    head_idx: Union[int, str, None] = "mean",
    zmin: float = 0.0,
    zmax: float = 1.0,
    print_prompt: bool = False
) -> None:
    """
    Visualize attention maps of multiple decoder blocks using Plotly's animation (frames + slider) feature.
    This function is designed for decoder-only models (e.g., GPT, LLaMA).

    Args:
        model (AutoModelForCausalLM): Pre-loaded model.
        tokenizer (AutoTokenizer): Pre-loaded tokenizer.
        input_ids (torch.Tensor): Input IDs of shape (batch_size, seq_len). Assumes batch_size=1.
        head_idx (Union[int, str, None]):
            - int: Visualize attention of the specified head.
            - "mean"/None: Visualize the mean attention of all heads.
        zmin (float): Minimum value for the color scale.
        zmax (float): Maximum value for the color scale.
        print_prompt (bool): If True, prints the input prompt.

    Raises:
        ValueError: If batch_size is not 1 or if head_idx is invalid.

    Example:
        >>> model = AutoModelForCausalLM.from_pretrained('gpt2')
        >>> tokenizer = AutoTokenizer.from_pretrained('gpt2')
        >>> input_ids = tokenizer.encode("Hello, world!", return_tensors='pt')
        >>> visualize_attention_slider(model, tokenizer, input_ids)
    """
    # Ensure batch_size is 1
    if input_ids.shape[0] != 1:
        raise ValueError("This implementation only supports batch_size=1.")

    if print_prompt:
        print(f"Input prompt: \n{tokenizer.batch_decode(input_ids)[0]}")

    # Create labels for the axes (format: "i:token")
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
    indexed_tokens = [f"{i}:{tok}" for i, tok in enumerate(tokens)]

    # Perform inference and get attentions from all layers
    with torch.no_grad():
        outputs = model(input_ids=input_ids, output_attentions=True, use_cache=False)
    all_attentions = outputs.attentions  # list: [layer_0, layer_1, ..., layer_(N-1)]
    num_layers = len(all_attentions)

    # Create frames for each block
    frames = []
    for block_idx in range(num_layers):
        # Extract attention block (shape: (n_heads, seq_len, seq_len))
        # all_attentions[block_idx] (shape: (batch_size=1, n_heads, seq_len, seq_len))
        attn_block = all_attentions[block_idx][0]

        # Select attention based on head_idx
        if head_idx is None or head_idx == "mean":
            attn_selected = attn_block.mean(dim=0)
            title_suffix = " (mean of all heads)"
        else:
            if not (0 <= head_idx < attn_block.shape[0]):
                raise ValueError(f"head_idx={head_idx} is not a valid head index.")
            attn_selected = attn_block[head_idx]
            title_suffix = f" (head {head_idx})"

        attn_np = attn_selected.float().cpu().numpy()

        # Create frame
        frames.append(
            go.Frame(
                data=[
                    go.Heatmap(
                        z=attn_np,
                        x=indexed_tokens,
                        y=indexed_tokens,
                        colorscale="YlGnBu",
                        zmin=zmin,
                        zmax=zmax,
                        zsmooth=False,
                        colorbar=dict(
                            title=dict(
                                text="Attention<br>Score<br>&nbsp;",
                                side="top",
                            )
                        )
                    )
                ],
                name=f"Block_{block_idx}",
                layout=go.Layout(
                    title=f"Decoder Block {block_idx} Attention {title_suffix}"
                ),
            )
        )

    # Initialize figure with the first frame
    fig = go.Figure(
        data=frames[0].data,
        layout=frames[0].layout,
        frames=frames
    )

    # Create slider steps
    steps = []
    for i, frame in enumerate(frames):
        step = dict(
            method="animate",
            args=[
                [frame.name],  # Frame to play
                {"mode": "immediate", "frame": {"duration": 0, "redraw": True}, "transition": {"duration": 0}},
            ],
            label=str(i),
        )
        steps.append(step)

    # Slider configuration
    sliders = [
        dict(
            active=0,
            currentvalue={"prefix": "Decoder Block: "},
            steps=steps,
            x=0.05,            # X position of the slider (0=left, 1=right)
            y=-0.30,           # Y position of the slider (negative to move below the plot)
            xanchor="left",
            yanchor="top",
            pad={"t": 50},     # Padding around the slider
        )
    ]

    # Play/Pause buttons configuration
    updatemenus = [
        dict(
            type="buttons",
            showactive=False,
            x=1.20,
            y=1.15,
            xanchor="right",
            yanchor="top",
            direction="left",
            buttons=[
                dict(
                    label="Play",
                    method="animate",
                    args=[
                        None,
                        {"frame": {"duration": 500, "redraw": True}, "fromcurrent": True, "mode": "immediate"},
                    ],
                ),
                dict(
                    label="Pause",
                    method="animate",
                    args=[
                        [None],
                        {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"},
                    ],
                ),
            ],
        )
    ]

    # Finalize layout
    fig.update_layout(
        sliders=sliders,
        updatemenus=updatemenus,
        xaxis=dict(title="Key Tokens"),
        yaxis=dict(title="Query Tokens", autorange="reversed"),  # Reverse y-axis
        width=800,
        height=800,
        margin=dict(t=100, b=100),  # Increase top and bottom margins
    )

    fig.show()


In [None]:
# Specify the model path and device
model_name = "meta-llama/Llama-3.1-8B-Instruct"
device = "cuda:0"

# Load the tokenizer and model
tokenizer = load_tokenizer(model_name)
model = load_model(model_name, device_map=device)


In [None]:
# The user prompt and tentative model response
prompt: str = "I want to go to Tokyo Tower. What station should I get off at?"
tentative_response: str = "Please get off at Akabanebashi Station on the Oedo Line."


# Prepare message list for chat template
messages: list[dict[str, str]] = [
    {"role": "user", "content": prompt},
    {"role": "assistant", "content": tentative_response}
]

# Apply chat template (model-specific preprocessing)
prompt_with_chat_template: str = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=False
)

# Process inputs (tokenize and prepare input_ids for model input)
inputs: dict[str, torch.Tensor] = tokenizer(
    text=prompt_with_chat_template, return_tensors="pt", add_special_tokens=False
).to(device)
input_ids: torch.Tensor = inputs["input_ids"].to(device)


In [None]:
# Specify how many heads of the attention map should be visualized
# head_idx = 0     # Visualize the attention of the first head
head_idx = "mean"  # Visualize the mean attention of all heads

# Visualize the attention map with a slider
visualize_attention_with_slider(model, tokenizer, input_ids, head_idx, zmax=0.1, print_prompt=True)


<img src="../images/attention_map_with_decoder_slider.png" alt="attention_map_with_decoder_slider" width=80%>

In [None]:
# Visualize attention maps of multiple decoder blocks using Plotly's slider (with font-size slider)
def visualize_attention_with_slider_and_fontsize(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    input_ids: torch.Tensor,  # (batch_size, seq_len)
    head_idx: Union[int, str, None] = "mean",
    zmin: float = 0.0,
    zmax: float = 1.0,
    print_prompt: bool = False
) -> None:
    """
    Visualizes attention maps of multiple decoder blocks with a slider to switch between blocks.
    Additionally, includes a slider to dynamically change the font size of axis labels.
    Assumes a decoder-only model (e.g., GPT, LLaMA).

    Args:
        model (AutoModelForCausalLM): Pre-loaded model.
        tokenizer (AutoTokenizer): Pre-loaded tokenizer.
        input_ids (torch.Tensor): Input IDs of shape (batch_size, seq_len). Assumes batch_size=1.
        head_idx (Union[int, str, None]):
            - int: Visualize attention of the specified head.
            - "mean"/None: Visualize mean attention of all heads.
        zmin (float): Minimum value for the color scale.
        zmax (float): Maximum value for the color scale.
        print_prompt (bool): If True, prints the input prompt before visualization.

    Raises:
        ValueError: If batch_size is not 1 or if head_idx is invalid.

    Example:
        >>> model = AutoModelForCausalLM.from_pretrained('gpt2')
        >>> tokenizer = AutoTokenizer.from_pretrained('gpt2')
        >>> input_ids = tokenizer.encode("Hello, world!", return_tensors='pt')
        >>> visualize_attention_slider_with_fontsize(model, tokenizer, input_ids)
    """

    # --------------------------------------------------
    # Preprocessing: Assumes batch_size=1
    # --------------------------------------------------
    if input_ids.shape[0] != 1:
        raise ValueError("Current implementation only supports batch_size=1.")

    if print_prompt:
        print(f"Input Prompt: \n{tokenizer.decode(input_ids[0])}")

    # Create axis labels in "i:token" format to avoid overlapping categories
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
    indexed_tokens = [f"{i}:{tok}" for i, tok in enumerate(tokens)]

    # --------------------------------------------------
    # Model inference to get attentions from all layers
    # --------------------------------------------------
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            output_attentions=True,
            use_cache=False  # Recommended to set False to fully visualize causal mask
        )
    all_attentions = outputs.attentions  # list: [layer_0, layer_1, ..., layer_(N-1)]
    num_layers = len(all_attentions)

    # --------------------------------------------------
    # Create attention maps as frames (for animation)
    # --------------------------------------------------
    frames = []
    for block_idx in range(num_layers):
        # Extract attention block (shape: (n_heads, seq_len, seq_len))
        # all_attentions[block_idx] (shape: (batch_size=1, n_heads, seq_len, seq_len))
        attn_block = all_attentions[block_idx][0]

        # Select attention based on head_idx
        if head_idx is None or head_idx == "mean":
            attn_selected = attn_block.mean(dim=0)
            title_suffix = " (mean of all heads)"
        else:
            if not (0 <= head_idx < attn_block.shape[0]):
                raise ValueError(f"head_idx={head_idx} is not a valid head number.")
            attn_selected = attn_block[head_idx]
            title_suffix = f" (head {head_idx})"

        attn_np = attn_selected.float().cpu().numpy()

        # Create frame
        frames.append(
            go.Frame(
                data=[
                    go.Heatmap(
                        z=attn_np,
                        x=indexed_tokens,
                        y=indexed_tokens,
                        colorscale="YlGnBu",
                        zmin=zmin,
                        zmax=zmax,
                        zsmooth=False,
                        colorbar=dict(
                            title=dict(
                                text="Attention<br>Score<br>&nbsp;",
                                side="top",
                            )
                        )
                    )
                ],
                name=f"Block_{block_idx}",
                layout=go.Layout(
                    title=f"Decoder Block {block_idx} Attention {title_suffix}"
                ),
            )
        )

    # --------------------------------------------------
    # Set initial state (Block_0) to Figure
    # --------------------------------------------------
    fig = go.Figure(
        data=frames[0].data,
        layout=frames[0].layout,
        frames=frames
    )

    # --------------------------------------------------
    # 1) Slider for switching blocks (animation)
    # --------------------------------------------------
    steps = []
    for i, frame in enumerate(frames):
        step = dict(
            method="animate",
            args=[
                [frame.name],  # Frame name to play
                {
                    "mode": "immediate",
                    "frame": {"duration": 0, "redraw": True},
                    "transition": {"duration": 0},
                },
            ],
            label=str(i),
        )
        steps.append(step)

    block_slider = dict(
        active=0,
        currentvalue={"prefix": "Decoder Block: "},
        steps=steps,
        x=0.05,
        y=-0.35,  # Position below the figure (adjust as needed)
        xanchor="left",
        yanchor="top",
        pad={"t": 50},
    )

    # --------------------------------------------------
    # 2) Slider for changing font size (layout update)
    # --------------------------------------------------
    # Example: 4, 5, 6, 7, 8, 9, 10, 15 (8 steps)
    font_sizes = [4, 5, 6, 7, 8, 9, 10, 15]
    font_steps = []
    for size in font_sizes:
        step = dict(
            method="relayout",
            args=[{"xaxis.tickfont.size": size, "yaxis.tickfont.size": size}],
            label=str(size),
        )
        font_steps.append(step)

    font_slider = dict(
        active=6,  # Default selection (index in the list)
        currentvalue={"prefix": "Font Size: "},
        steps=font_steps,
        x=0.05,
        y=-0.55,  # Position below block_slider
        xanchor="left",
        yanchor="top",
        pad={"t": 50},
    )

    # --------------------------------------------------
    # Set sliders
    # --------------------------------------------------
    fig.update_layout(sliders=[block_slider, font_slider])

    # --------------------------------------------------
    # Play/Pause buttons (updatemenus) positioned at the top right
    # --------------------------------------------------
    updatemenus = [
        dict(
            type="buttons",
            showactive=False,
            x=1.10,
            y=1.15,
            xanchor="right",
            yanchor="top",
            direction="left",
            buttons=[
                dict(
                    label="Play",
                    method="animate",
                    args=[
                        None,
                        {
                            "frame": {"duration": 500, "redraw": True},
                            "fromcurrent": True,
                            "mode": "immediate",
                        },
                    ],
                ),
                dict(
                    label="Pause",
                    method="animate",
                    args=[
                        [None],
                        {
                            "frame": {"duration": 0, "redraw": True},
                            "mode": "immediate",
                        },
                    ],
                ),
            ],
        )
    ]

    # --------------------------------------------------
    # Finalize layout
    # --------------------------------------------------
    fig.update_layout(
        updatemenus=updatemenus,
        xaxis=dict(title="Key Tokens", tickfont=dict(size=10)),  # Default font size
        yaxis=dict(title="Query Tokens", autorange="reversed", tickfont=dict(size=10)),
        width=940,
        height=1200,
        margin=dict(t=120, b=120),
    )

    # Display
    fig.show()


In [None]:
# Specify how many heads of the attention map should be visualized
# head_idx = 0     # Visualize the attention of the first head
head_idx = "mean"  # Visualize the mean attention of all heads

# Visualize the attention map with a decoder block slider and font size slider
visualize_attention_with_slider_and_fontsize(model, tokenizer, input_ids, head_idx, zmax=0.1, print_prompt=True)


<img src="../images/attention_map_with_decoder_and_fontsize_slider.png" alt="attention_map_with_decoder_and_fontsize_slider" width=80%>