In [5]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# from collections.abc import Sequence
from typing import Sequence
import torch
import numpy as np
import time

_SSM_NAME = "JackFram/llama-160m"
_LLM_NAME = 'openlm-research/open_llama_3b_v2'
device = "cuda"

assert torch.cuda.is_available()
tokenizer = AutoTokenizer.from_pretrained(_SSM_NAME)
ssm = AutoModelForCausalLM.from_pretrained(_SSM_NAME).cuda()
llm = AutoModelForCausalLM.from_pretrained(_LLM_NAME).cuda()
N_ITERATIONS = 1000

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
def _create_token_tree(
    expansion_config: Sequence[int],
    prompt: str,
    tokenizer: AutoTokenizer,
    model: AutoModelForCausalLM,
):
    """Create token tree following Figure 3 in the paper.

    We don't need "real" tokens for our experiments - just
    random integers would work too - but might as well.

    Figure 3 illustrates the <k1, k2, ...> expansion approach they
    use to create token trees. We can use each of the top_k tokens from
    a single model to create the same tree structure.

    Args:
        expansion_config: A sequence of integers representing how much to
            branch at each generation step.
        prompt: Initial prompt.
        tokenizer: HF tokenizer.
        model: HF generative model.
    """
    assert expansion_config
    current_tree = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
    for k in expansion_config:
        output = model.generate(
            current_tree,
            max_new_tokens=1,
            return_dict_in_generate=True,
            output_scores=True,
        )
        # Take the top_k tokens from the 1 generation step we've done
        top_k = torch.topk(output.scores[-1], k=k, dim=-1).indices.reshape(-1, 1)
        current_tree = torch.repeat_interleave(current_tree, k, dim=0)
        # Join the top_k tokens to the current tree
        current_tree = torch.cat((current_tree, top_k), dim=-1)

    return current_tree

In [7]:
def construct_tree_model_inputs(sequences):
    # input_1 = torch.unique(torch.flatten(sequences), sorted=False)
    flat = torch.flatten(sequences).tolist()
    unique = []
    for tok in flat:
        if tok not in unique:
            unique.append(tok)
    # input is list of unique tokens
    input_1 = torch.tensor([unique]).to(device)

    a = input_1.shape[-1]
    mask_1 = np.zeros((a, a))
    positions = [-1] * len(unique)
    
    for seq in sequences:
        branch_progress = []
        for (pos, tok) in enumerate(seq):
            input_1_idx = unique.index(tok)
            positions[input_1_idx] = pos
            branch_progress.append(input_1_idx)
            for idx in branch_progress:
                mask_1[input_1_idx][idx] = 1
    mask_1 = torch.tensor(mask_1, device=device, dtype=torch.int64)
    mask_1 = mask_1.unsqueeze(0).unsqueeze(0).to(device).int()
    position_ids_1 = torch.tensor([positions], device=device, dtype=torch.int64)
    return (input_1, mask_1, position_ids_1)

In [8]:
def time_normal(input, N_iterations, model: AutoModelForCausalLM):
    total_time = 0.0
    for i in range(N_iterations):
        with torch.no_grad():
            start = time.time()
            logits = model(input_ids=input).logits
            end = time.time()
        total_time += end - start
    return (total_time / N_iterations)

def time_tree(input, mask, position_ids, N_iterations, model: AutoModelForCausalLM):
    total_time = 0.0
    for i in range(N_iterations):
        with torch.no_grad():
            start = time.time()
            logits = model.forward(input_ids=input, attention_mask=mask, position_ids=position_ids).logits
            end = time.time()
        total_time += end - start
    return (total_time / N_iterations)

In [9]:
from torch.profiler import profile, ProfilerActivity, schedule

# Guide: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html

_N_ITERATIONS = 10
_WAIT_STEPS = 1
_WARMUP_STEPS = 1
schedule_params = {
    'wait': _WAIT_STEPS,
    'warmup': _WARMUP_STEPS,
    'active': _N_ITERATIONS - _WAIT_STEPS - _WARMUP_STEPS,
}
profiler_kwargs = {
    'activities': [ProfilerActivity.CPU, ProfilerActivity.CUDA],
    'profile_memory': True,
    'schedule': schedule(**schedule_params),
}

def print_normal_profile_stats(input, model):
    with torch.inference_mode(), profile(**profiler_kwargs) as prof:
        for _ in range(_N_ITERATIONS):
            model(input_ids=input)
            prof.step()
    print(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10))

def print_tree_profile_stats(input, mask, position_ids, model):
    with torch.inference_mode(), profile(**profiler_kwargs) as prof:
        for _ in range(_N_ITERATIONS):
            model(input_ids=input, attention_mask=mask, position_ids=position_ids)
            prof.step()
    print(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10))

In [10]:
def main():
    token_tree = _create_token_tree(
        expansion_config=(2, 1, 2),
        prompt="The",
        tokenizer=tokenizer,
        model=ssm,
    )
    print(token_tree)

    # construct inputs for tree decoding
    tree_input, tree_mask, tree_position_ids = construct_tree_model_inputs(token_tree)
    print(tree_input, tree_mask, tree_position_ids)
    
    sequential_time = time_normal(token_tree, N_ITERATIONS, llm)
    tree_time = time_tree(tree_input, tree_mask, tree_position_ids, N_ITERATIONS, llm)
    print("Sequential Time: ", sequential_time)
    print("Tree Time: ", tree_time)

    print_normal_profile_stats(token_tree, llm)
    print_tree_profile_stats(tree_input, tree_mask, tree_position_ids, llm)

if __name__ == "__main__":
    main()

tensor([[    1,   450, 29871, 29896, 29900],
        [    1,   450, 29871, 29896, 29929],
        [    1,   450,   937,  2655,   366],
        [    1,   450,   937,  2655,   306]], device='cuda:0')
tensor([[    1,   450, 29871, 29896, 29900, 29929,   937,  2655,   366,   306]],
       device='cuda:0') tensor([[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 1, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 1, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 1, 1, 0, 0],
          [1, 1, 0, 0, 0, 0, 1, 1, 1, 0],
          [1, 1, 0, 0, 0, 0, 1, 1, 0, 1]]]], device='cuda:0',
       dtype=torch.int32) tensor([[0, 1, 2, 3, 4, 4, 2, 3, 4, 4]], device='cuda:0')
Sequential Time:  0.029008917570114135
Tree Time:  0.029660537242889404


STAGE:2024-05-13 00:01:29 1994710:1994710 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-05-13 00:01:29 1994710:1994710 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-05-13 00:01:29 1994710:1994710 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                               aten::mm         8.77%      22.648ms        13.99%      36.143ms      24.688us     195.666ms        41.24%     200.700ms     137.090us           0 b           0 b     547.66 Mb     547.66 M

STAGE:2024-05-13 00:01:33 1994710:1994710 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-05-13 00:01:33 1994710:1994710 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-05-13 00:01:33 1994710:1994710 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                               aten::mm         8.17%      23.143ms        13.29%      37.639ms      25.710us     201.013ms        39.95%     201.013ms     137.304us           0 b           0 b     275.06 Mb     275.06 M