In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# from collections.abc import Sequence
from typing import Sequence
import torch
import numpy as np
import time

_SSM_NAME = "JackFram/llama-160m"
_LLM_NAME = 'openlm-research/open_llama_3b_v2'
device = "cuda"

assert torch.cuda.is_available()
tokenizer = AutoTokenizer.from_pretrained(_SSM_NAME)
ssm = AutoModelForCausalLM.from_pretrained(_SSM_NAME).cuda()
llm = AutoModelForCausalLM.from_pretrained(_LLM_NAME).cuda()

In [None]:
def _create_token_tree(
    expansion_config: Sequence[int],
    prompt: str,
    tokenizer: AutoTokenizer,
    model: AutoModelForCausalLM,
    has_kv_cache: bool = False,
):
    """Create token tree following Figure 3 in the paper.

    We don't need "real" tokens for our experiments - just
    random integers would work too - but might as well.

    Figure 3 illustrates the <k1, k2, ...> expansion approach they
    use to create token trees. We can use each of the top_k tokens from
    a single model to create the same tree structure.

    Args:
        expansion_config: A sequence of integers representing how much to
            branch at each generation step.
        prompt: Initial prompt.
        tokenizer: HF tokenizer.
        model: HF generative model.
        has_kv_cache: If there are preceding tokens this is generating after, in which
            case we exclude the begin-of-sequence token.
    """
    assert expansion_config
    current_tree = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    if has_kv_cache:
        assert tokenizer.add_bos_token
        current_tree = current_tree[:, 1:]
    for k in expansion_config:
        output = model.generate(
            current_tree,
            max_new_tokens=1,
            return_dict_in_generate=True,
            output_scores=True,
        )
        # Take the top_k tokens from the 1 generation step we've done
        top_k = torch.topk(output.scores[-1], k=k, dim=-1).indices.reshape(-1, 1)
        current_tree = torch.repeat_interleave(current_tree, k, dim=0)
        # Join the top_k tokens to the current tree
        current_tree = torch.cat((current_tree, top_k), dim=-1)

    return current_tree

In [None]:
def _invert_4d_attention_mask(attention_mask: torch.Tensor, kv_cache_num_tokens: int=0) -> torch.Tensor:
    """For 4D masks, new HF requires us to invert the mask so it doesn't modify it at all."""
    # The attention mask must have last 2 dims shape [current seq len, KV cache size + current seq len]
    # So we prepend a tensor of 1s to allow attending to the full KV cache
    assert attention_mask.dim() == 4
    if kv_cache_num_tokens > 0:
        attention_mask = torch.cat(
            (
                torch.ones(
                    attention_mask.shape[0],
                    attention_mask.shape[1],
                    attention_mask.shape[2],
                    kv_cache_num_tokens,
                ).to(device),
                attention_mask,
            ),
            dim=-1,
        )
    # Invert the mask: 0s to -inf and 1s to 0 (0 means attention allowed)
    min_dtype = torch.finfo(torch.float32).min
    attention_mask.masked_fill_(attention_mask == 0, min_dtype)
    attention_mask.masked_fill_(attention_mask == 1, 0.0)
    return attention_mask

def construct_tree_model_inputs(sequences):
    # input_1 = torch.unique(torch.flatten(sequences), sorted=False)
    flat = torch.flatten(sequences).tolist()
    unique = []
    for tok in flat:
        if tok not in unique:
            unique.append(tok)
    # input is list of unique tokens
    input_1 = torch.tensor([unique]).to(device)

    a = input_1.shape[-1]
    mask_1 = np.zeros((a, a), dtype=np.float32)
    positions = [-1] * len(unique)
    
    for seq in sequences:
        branch_progress = []
        for (pos, tok) in enumerate(seq):
            input_1_idx = unique.index(tok)
            positions[input_1_idx] = pos
            branch_progress.append(input_1_idx)
            for idx in branch_progress:
                mask_1[input_1_idx][idx] = 1
    mask_1 = torch.tensor(mask_1, device=device)
    mask_1 = mask_1.unsqueeze(0).unsqueeze(0).to(device)
    position_ids_1 = torch.tensor([positions], device=device, dtype=torch.int64)
    return (input_1, mask_1, position_ids_1)

In [None]:
def _create_dummy_kv_cache(
    kv_cache_num_tokens: int,
    batch_size: int,
    num_attention_heads: int,
    hidden_size: int,
    num_layers: int,
):
    k = torch.rand(
        batch_size,
        num_attention_heads,
        kv_cache_num_tokens,
        hidden_size // num_attention_heads,
    ).to(device)
    v = torch.rand(
        batch_size,
        num_attention_heads,
        kv_cache_num_tokens,
        hidden_size // num_attention_heads,
    ).to(device)
    return tuple((k, v) for _ in range(num_layers))

In [None]:
def time_normal(input_ids, model: AutoModelForCausalLM, kv_cache=None):
    with torch.inference_mode():
        model(input_ids=input_ids, past_key_values=kv_cache, use_cache=kv_cache is not None)

def time_tree(input_ids, mask, position_ids, model: AutoModelForCausalLM, kv_cache=None):
    with torch.inference_mode():
        model(input_ids=input_ids, attention_mask=mask, position_ids=position_ids, past_key_values=kv_cache, use_cache=kv_cache is not None)

In [None]:
from torch.profiler import profile, ProfilerActivity, schedule

# Guide: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html

_N_ITERATIONS = 10
_WAIT_STEPS = 1
_WARMUP_STEPS = 1
schedule_params = {
    'wait': _WAIT_STEPS,
    'warmup': _WARMUP_STEPS,
    'active': _N_ITERATIONS - _WAIT_STEPS - _WARMUP_STEPS,
}
profiler_kwargs = {
    'activities': [ProfilerActivity.CPU, ProfilerActivity.CUDA],
    'profile_memory': True,
    'schedule': schedule(**schedule_params),
    'record_shapes': True
}

def print_normal_profile_stats(input, model):
    with torch.inference_mode(), profile(**profiler_kwargs) as prof:
        for _ in range(_N_ITERATIONS):
            model(input_ids=input)
            prof.step()
    print(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10))

def print_tree_profile_stats(input, mask, position_ids, model):
    with torch.inference_mode(), profile(**profiler_kwargs) as prof:
        for _ in range(_N_ITERATIONS):
            model(input_ids=input, attention_mask=mask, position_ids=position_ids)
            prof.step()
    print(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10))

In [None]:
# Measure max memory allocated
import gc

def reset_memory():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.synchronize()

def end_memory_collection():
    torch.cuda.synchronize()
    max_mem_gb = torch.cuda.max_memory_allocated() / 1024**3
    return max_mem_gb

token_tree = _create_token_tree([2, 2, 2], "The", tokenizer, ssm)
reset_memory()
time_normal(token_tree, llm, kv_cache=None)
seq_max_mem_gb = end_memory_collection()

tree_input, tree_mask, tree_position_ids = construct_tree_model_inputs(token_tree)
tree_mask = _invert_4d_attention_mask(tree_mask)
reset_memory()
time_tree(input_ids=tree_input, mask=tree_mask, position_ids=tree_position_ids, model=llm, kv_cache=None)
tree_max_mem_gb = end_memory_collection()

print(f"Normal: {seq_max_mem_gb:.2f} GB")
print(f"Tree: {tree_max_mem_gb:.2f} GB")

In [None]:
import torch.utils.benchmark as benchmark
import numpy as np

N_ITERATIONS = 32

tree_widths = range(1, 4)
sequential_times = []
tree_times = []

kv_cache_num_tokens = 128

for tree_width in tree_widths:
    expansion_config = (tree_width, 1, 1, 1, 1)

    token_tree = _create_token_tree(
        expansion_config=expansion_config,
        prompt="The",
        tokenizer=tokenizer,
        model=ssm,
    )
    print(token_tree)

    batch_size=np.prod(expansion_config)
    kv_cache_sequential = _create_dummy_kv_cache(
        kv_cache_num_tokens=kv_cache_num_tokens,
        batch_size=batch_size,
        num_attention_heads=llm.config.num_attention_heads,
        hidden_size=llm.config.hidden_size,
        num_layers=llm.config.num_hidden_layers
    )
    
    sequential_timer = benchmark.Timer(
        stmt="time_normal(input_ids, model, kv_cache)",
        setup="from __main__ import time_normal",
        num_threads=torch.get_num_threads(),
        globals={
            'input_ids': token_tree,
            'model': llm,
            'kv_cache': kv_cache_sequential
        },
        label="Sequential"
    )
    sequential_measurement = sequential_timer.timeit(N_ITERATIONS)
    sequential_times.append(sequential_measurement.times[-1])
    
    # construct inputs for tree decoding
    kv_cache_tree = _create_dummy_kv_cache(
        kv_cache_num_tokens=kv_cache_num_tokens,
        batch_size=1,
        num_attention_heads=llm.config.num_attention_heads,
        hidden_size=llm.config.hidden_size,
        num_layers=llm.config.num_hidden_layers
    )
    tree_input, tree_mask, tree_position_ids = construct_tree_model_inputs(token_tree)
    print(tree_input, tree_mask, tree_position_ids)
    # Required for 4D mask support in new HF
    tree_mask = _invert_4d_attention_mask(tree_mask, kv_cache_num_tokens)

    tree_timer = benchmark.Timer(
        stmt="time_tree(input_ids, mask, position_ids, model, kv_cache)",
        setup="from __main__ import time_tree",
        num_threads=torch.get_num_threads(),
        globals={
            'input_ids': tree_input,
            'mask': tree_mask,
            'position_ids': tree_position_ids,
            'model': llm,
            'kv_cache': kv_cache_tree
        },
        label="Tree"
    )
    tree_measurement = tree_timer.timeit(N_ITERATIONS)
    tree_times.append(tree_measurement.times[-1])
    
    # print_normal_profile_stats(token_tree, llm)
    # print_tree_profile_stats(tree_input, tree_mask, tree_position_ids, llm)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

width = 0.35
plt.figure(figsize=(10, 6))

x_data = tree_widths
y_sequential = np.array(sequential_times) * 1000 # scale to ms
plt.bar(x_data, y_sequential, label="Sequential", width=width)  # Plot the first list as the y-axis values
y_tree = np.array(tree_times) * 1000 # scale to ms
plt.bar([pos + width for pos in x_data], y_tree, label="Tree", width=width)  # Plot the second list as the y-axis values

plt.xlabel("Tree Width")
plt.ylabel("Latency (ms)")
plt.legend()

plt.show()
plt.savefig("tree_vs_sequential.png")

In [10]:
import metrics
from transformers import AutoConfig

mistral_7b_config = AutoConfig.from_pretrained("mistralai/Mistral-7B-v0.1")

metrics.identify_compute_memory_bound(
    gpu=metrics.T4,
    token_batch=torch.ones(1, np.sum(np.cumprod(expansion_config))),
    dtype=torch.float32,
    num_layers=mistral_7b_config.num_hidden_layers,
    d_model=mistral_7b_config.hidden_size,
    n_head=mistral_7b_config.num_attention_heads,
    vocab_size=mistral_7b_config.vocab_size,
    kv_cache_token_count=0
)

Memory-bound: arithmetic intensity 7.448872654272583 < 27.0


