In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# from collections.abc import Sequence
from typing import Sequence
import torch
import numpy as np
import time

_SSM_NAME = "JackFram/llama-160m"
_LLM_NAME = 'openlm-research/open_llama_3b_v2'
device = "cuda"

assert torch.cuda.is_available()
tokenizer = AutoTokenizer.from_pretrained(_SSM_NAME)
ssm = AutoModelForCausalLM.from_pretrained(_SSM_NAME).cuda()
llm = AutoModelForCausalLM.from_pretrained(_LLM_NAME).cuda()
N_ITERATIONS = 1000

  from .autonotebook import tqdm as notebook_tqdm
  return self.fget.__get__(instance, owner)()


In [2]:
def _create_token_tree(
    expansion_config: Sequence[int],
    prompt: str,
    tokenizer: AutoTokenizer,
    model: AutoModelForCausalLM,
):
    """Create token tree following Figure 3 in the paper.

    We don't need "real" tokens for our experiments - just
    random integers would work too - but might as well.

    Figure 3 illustrates the <k1, k2, ...> expansion approach they
    use to create token trees. We can use each of the top_k tokens from
    a single model to create the same tree structure.

    Args:
        expansion_config: A sequence of integers representing how much to
            branch at each generation step.
        prompt: Initial prompt.
        tokenizer: HF tokenizer.
        model: HF generative model.
    """
    assert expansion_config
    current_tree = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
    for k in expansion_config:
        output = model.generate(
            current_tree,
            max_new_tokens=1,
            return_dict_in_generate=True,
            output_scores=True,
        )
        # Take the top_k tokens from the 1 generation step we've done
        top_k = torch.topk(output.scores[-1], k=k, dim=-1).indices.reshape(-1, 1)
        current_tree = torch.repeat_interleave(current_tree, k, dim=0)
        # Join the top_k tokens to the current tree
        current_tree = torch.cat((current_tree, top_k), dim=-1)

    return current_tree

In [6]:
def construct_tree_model_inputs(sequences):
    # input_1 = torch.unique(torch.flatten(sequences), sorted=False)
    flat = torch.flatten(sequences).tolist()
    unique = []
    for tok in flat:
        if tok not in unique:
            unique.append(tok)
    # input is list of unique tokens
    input_1 = torch.tensor([unique]).to(device)

    a = input_1.shape[-1]
    mask_1 = np.zeros((a, a))
    positions = [-1] * len(unique)
    
    for seq in sequences:
        branch_progress = []
        for (pos, tok) in enumerate(seq):
            input_1_idx = unique.index(tok)
            positions[input_1_idx] = pos
            branch_progress.append(input_1_idx)
            for idx in branch_progress:
                mask_1[input_1_idx][idx] = 1
    mask_1 = torch.tensor(mask_1, device=device, dtype=torch.int64)
    mask_1 = mask_1.unsqueeze(0).unsqueeze(0).to(device).int()
    position_ids_1 = torch.tensor([positions], device=device, dtype=torch.int64)
    return (input_1, mask_1, position_ids_1)

In [7]:
def time_normal(input, N_iterations, model: AutoModelForCausalLM):
    total_time = 0.0
    for i in range(N_iterations):
        with torch.no_grad():
            start = time.time()
            logits = model(input_ids=input).logits
            end = time.time()
        total_time += end - start
    return (total_time / N_iterations)

def time_tree(input, mask, position_ids, N_iterations, model: AutoModelForCausalLM):
    total_time = 0.0
    for i in range(N_iterations):
        with torch.no_grad():
            start = time.time()
            logits = model.forward(input_ids=input, attention_mask=mask, position_ids=position_ids).logits
            end = time.time()
        total_time += end - start
    return (total_time / N_iterations)

In [8]:
def main():
    token_tree = _create_token_tree(
        expansion_config=(2, 1, 2),
        prompt="The",
        tokenizer=tokenizer,
        model=ssm,
    )
    print(token_tree)

    # construct inputs for tree decoding
    tree_input, tree_mask, tree_position_ids = construct_tree_model_inputs(token_tree)
    print(tree_input, tree_mask, tree_position_ids)
    
    sequential_time = time_normal(token_tree, N_ITERATIONS, llm)
    tree_time = time_tree(tree_input, tree_mask, tree_position_ids, N_ITERATIONS, llm)
    print("Sequential Time: ", sequential_time)
    print("Tree Time: ", tree_time)

if __name__ == "__main__":
    main()

tensor([[    1,   450, 29871, 29896, 29900],
        [    1,   450, 29871, 29896, 29929],
        [    1,   450,   937,  2655,   366],
        [    1,   450,   937,  2655,   306]], device='cuda:0')
tensor([[    1,   450, 29871, 29896, 29900, 29929,   937,  2655,   366,   306]],
       device='cuda:0') tensor([[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 1, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 1, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 1, 1, 0, 0],
          [1, 1, 0, 0, 0, 0, 1, 1, 1, 0],
          [1, 1, 0, 0, 0, 0, 1, 1, 0, 1]]]], device='cuda:0',
       dtype=torch.int32) tensor([[0, 1, 2, 3, 4, 4, 2, 3, 4, 4]], device='cuda:0')
Sequential Time:  0.20251182007789612
Tree Time:  0.07197877502441406
