## Interactive Lookahead Text Generator

LLMs output generative text one single token at a time. This makes it hard for users to visualize multiple possible continuations of a token. A token being a word or a sub-word. This limits the user's ability to interact with multiple branching options, especially in creative writing. Our goal for this project was to create an interface that would let the user explore multiple 'lookahead' completions of tokens interactively. 

The standard Hugging Face .generate() API does not support lookahead branching directly. It is a function that abstracts the low-level implementation of token-by-token generation. With the .generate() API, we cannot intervene after each token to explore multiple possible next tokens. It produces and outputs one sequence at a time. And because of this reason, we needed to custom implement this inference pipeline ourselves for the lookahead generation.

Lookahead generation is a technique to explore multiple possible next-token continuations of a prompt. This allows the user to see and choose from several potential paths instead of just a single prediction. This enables dynamic user involvement by allowing user to steer the direction of the piece they are writing.

Technical Approach: Coding a low level custom implementation of the Hugging Face .generate() API to allow us to take full manual control over the internal process like caching, and branching.

Technical Goal: We already had starter code for the lookahead generation logic from Professor Arnold's existing work. However, his lookahead sequence was limited to 2 next tokens for each branch. One of our main technical goals was to expand upon his existing backend logic code to support more than 2 next tokens and to evaluate and validate its accuracy. 

Real-World Goal: To make text generation from an LLM more collaborative, interactive and exploratory for the user.



In [18]:
%pip install torch transformers --quiet

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.0.1 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [5]:
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache

  from .autonotebook import tqdm as notebook_tqdm


### Comparing Cached vs. Non-Cached Forward Passes Functions

Since we are manually implementing the generation functionality, we also need to correctly handle the caching of past key-value pairs. This is important because the accuracy and efficiency of our custom API depend on proper caching.

To verify that our caching logic is correct, we compare the model's output logits with and without caching. Specifically, we define two functions—one that uses caching and one that doesn't. If both functions produce the same logits for the same inputs, we can conclude that the caching has been implemented correctly.

Below, you'll find both functions. They perform the same forward pass but differ in how they manage cache.

#### Cached Forward Pass

In [27]:
def get_lookahead_sequences_with_cache(model, tokenizer, hypotheses, n_branch_tokens=5, device='cuda'):

  assert len(hypotheses.shape) == 2 and hypotheses.shape[0] == 1, "Expected input shape (1, seq_len)"
  # stores how long the prompt is
  n_tokens_so_far = hypotheses.shape[1]
  hypotheses = hypotheses.to(device)
  past_key_values = DynamicCache() # hold key/value

  with torch.no_grad():
      outputs = model(hypotheses, output_hidden_states=True, past_key_values=past_key_values)

  # Get top-k tokens from last position
  branch_tokens = outputs.logits[0, -1].topk(n_branch_tokens).indices.to(device)
  branched_output_logits = outputs.logits[0, -1]
#   print(tokenizer.decode(branch_tokens))
#   print("Branch tokens shape:", branch_tokens.shape)  # Expected: (5,)
  assert branch_tokens.shape == (n_branch_tokens,)

  # Repeat past_key_values for each branch
  for i in range(len(past_key_values.key_cache)):
      past_key_values.key_cache[i] = past_key_values.key_cache[i].repeat(n_branch_tokens, 1, 1, 1).to(device)
      past_key_values.value_cache[i] = past_key_values.value_cache[i].repeat(n_branch_tokens, 1, 1, 1).to(device)

  # Fixes the internal tracking
  past_key_values.reorder_cache(torch.arange(n_branch_tokens, device=device))

  # Start sequences from the branch tokens
  sequences = branch_tokens.unsqueeze(1)
#   print("Initial sequences shape:", sequences.shape)  # Expected: (5, 1)
  assert sequences.shape == (n_branch_tokens, 1)

  position_id = n_tokens_so_far
  cached_logits = []

  for step in range(2):  # Generate 2 more tokens
    #   print(f"\n--- Step {step + 1} ---")
    #   print("Current sequences shape before generation:", sequences.shape)

      cache_position_tensor = torch.tensor([position_id], device=device)  # Convert to tensor
      # Keep attention mask as is to tell the model to fully attend to each n_branch numbered tokens
      attention_mask = torch.ones((n_branch_tokens,1), dtype=torch.long, device=device)
    #   print("Before generation:")
    #   print("past_key_values key shape:", past_key_values.key_cache[0].shape)  # Should start as (5, ..., ..., ...)
    #   print("attention_mask shape:", attention_mask.shape)                     # Should be (5, 1) (1,1)


      try:
          with torch.no_grad():
              current_input = sequences[:, -1:]
            #   print("Input to model (last token):", current_input.shape)  # Expected: (5, 1)
              assert current_input.shape == (n_branch_tokens, 1)

              model_outs = model(
                  current_input,
                  past_key_values=past_key_values,
                  output_hidden_states=True,
                  use_cache=True,
                  cache_position=cache_position_tensor, #cache_position
                  attention_mask=attention_mask
              )
            #   print("model_outs type:", type(model_outs))
            #   print("model_outs logits shape:", model_outs.logits.shape)
              loop_model_logits = model_outs.logits
            #   print("model_outs past_key_values shapes:")
            #   if hasattr(model_outs, "past_key_values"):
            #       if isinstance(model_outs.past_key_values, tuple) and len(model_outs.past_key_values) > 0:
            #           print("First layer k/v shapes:",
            #                 model_outs.past_key_values[0][0].shape,
            #                 model_outs.past_key_values[0][1].shape)
      except Exception as e:
        #   print("Error during model forward pass:", e)
          raise

      next_token_logits = model_outs.logits[:, -1]
    #   print(next_token_logits)
    #   print("Next token logits shape:", next_token_logits.shape)  # Expected: (5, vocab_size)
      assert next_token_logits.shape[0] == n_branch_tokens

      next_tokens = next_token_logits.argmax(dim=-1)
    #   print("Next tokens shape:", next_tokens.shape)  # Expected: (5,)
      assert next_tokens.shape == (n_branch_tokens,)

      sequences = torch.cat([sequences, next_tokens.unsqueeze(1)], dim=1)
    #   print("Updated sequences shape:", sequences.shape)  # Should grow (5, 2), then (5, 3)

      cached_logits.append(loop_model_logits)
      position_id += 1

#   print(sequences)
  return sequences, branched_output_logits, cached_logits  # Final shape: (5, 3)

#### Non-Cached Forward Pass

In [28]:
def get_lookahead_sequences_without_cache(model, tokenizer, hypotheses, n_branch_tokens=5, device='cuda'):
  assert len(hypotheses.shape) == 2 and hypotheses.shape[0] == 1, "Expected input shape (1, seq_len)"

  # Get the initial sequence from the input
  original_sequence = hypotheses[0].tolist()
  hypotheses = hypotheses.to(device)

  # Get the logits for the next token without using cache
  with torch.no_grad():
      outputs = model(hypotheses, output_hidden_states=True)

  # Get top-k tokens from last position
  branch_tokens = outputs.logits[0, -1].topk(n_branch_tokens).indices.to(device)
  branched_token_logit_2 = outputs.logits[0,-1]
#   print("Top-k branch tokens:", tokenizer.decode(branch_tokens))
#   print("Branch tokens shape:", branch_tokens.shape)  # Expected: (5,)
  assert branch_tokens.shape == (n_branch_tokens,)

  # Create initial sequences for each branch
  all_sequences = []
  for branch_token in branch_tokens:
      # Each sequence starts with the original prompt + the branch token
      sequence = original_sequence + [branch_token.item()]
      all_sequences.append(sequence)

  # Convert to tensor for easier manipulation
  sequences = torch.tensor([all_sequences[i] for i in range(n_branch_tokens)], device=device)
#   print("Initial sequences shape:", sequences.shape)  # Expected: (5, seq_len+1)

  no_cache_logits = []
  # Generate additional tokens step by step
  for step in range(2):  # Generate 2 more tokens
    #   print(f"\n--- Step {step + 1} ---")
    #   print("Current sequences shape before generation:", sequences.shape)

      next_tokens = []

      # Process each sequence independently
      for seq_idx, sequence in enumerate(sequences):
          # Create input for model (full sequence up to now)
          current_input = sequence.unsqueeze(0)  # Add batch dimension
        #   print(f"Sequence {seq_idx} input shape:", current_input.shape)

          try:
              with torch.no_grad():
                  # Forward pass without cache or position_ids
                  model_outs = model(
                      current_input,
                      output_hidden_states=True,
                      use_cache=False
                  )

                  # Get prediction for next token
                  next_token_logits = model_outs.logits[0, -1]
                  no_cache_logits.append(next_token_logits)
                #   print(next_token_logits)
                  next_token = next_token_logits.argmax(dim=-1)
                  next_tokens.append(next_token)

                #   print(f"Sequence {seq_idx} next token:", tokenizer.decode(next_token))

          except Exception as e:
            #   print(f"Error processing sequence {seq_idx}:", e)
              raise

      # Stack the next tokens
      next_tokens = torch.stack(next_tokens)
    #   print("Next tokens shape:", next_tokens.shape)  # Expected: (5,)

      # Add new tokens to sequences
      sequences = torch.cat([sequences, next_tokens.unsqueeze(1)], dim=1)
    #   print("Updated sequences shape:", sequences.shape)

  # Print the final token sequences
#   for i, seq in enumerate(sequences):
#       print(f"Sequence {i}:", tokenizer.decode(seq))

  return sequences, branched_token_logit_2, no_cache_logits

#### Evaluation 

In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "Alina3234/gemma-lookahead"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

input_text = "After careful"

input_ids = tokenizer(input_text, return_tensors='pt').input_ids
results, branched_logits, loop_logits  = get_lookahead_sequences_with_cache(model, tokenizer, input_ids, device=device)
results, branched_token_logit_2, all_logits = get_lookahead_sequences_without_cache(model, tokenizer, input_ids, device=device)

In [30]:
loop_logits_list = []
for group in loop_logits:
    # group has shape (5, 1, N), so we squeeze the middle dimension
    squeezed = group.squeeze(1)  # shape becomes (5, N)
    # then split into list of tensors
    loop_logits_list.extend(list(squeezed))

are_equal = (
    len(loop_logits_list) == len(all_logits) and
    all(torch.allclose(a, b, atol=1e-4) for a, b in zip(loop_logits_list, all_logits))
)
print(are_equal)

True


We compared the output logits for each steps in tokenzation of cached method and no cache method. And it gave us the result where all the logits matched with the corresponding ones.
Using one small example, we could also see that cached method is about 4 times faster in the generation process.
CPU times: user 2.78 s, sys: 25.2 ms, total: 2.81 s Wall time: 1.49 s
CPU times: user 11.7 s, sys: 28.3 ms, total: 11.7 s Wall time: 5.92 s

are_equal = (

len(cached_logits) == len(no_cache_logits) and
all(torch.allclose(a, b, atol=1e-4) for a, b in zip(cached_logits, no_cache_logits))
)

print(are_equal)

## What we learned

We gained a practical understanding of the tokenization process and learned how the shapes of outputs play a crucial role in ensuring correct evaluation.
We discovered that using cached outputs significantly reduces computational load and power consumption by enabling faster generation compared to step-by-step processing.
We also learned that running the model on a GPU can further improve speed, although performance on a CPU was still reasonably good.

## Future Direction


Test out multiple branched prediction for the generation of second and third token.
Test the limit of the cached method: how many more tokens can it predict successfully?
Implement the whole process with GPU to save more energy.

## Supporting Material


This project is based on Professor Ken Arnold's initial implementation of lookahead generation.
https://huggingface.co/spaces/CalvinU/writing-prototypes/blob/main/custom_llm_inference.py#L66
add Codeadd Markdown