# Interactive Lookahead Text Generator

LLMs generate text one single prediction at a time. This makes it hard for users to visualize multiple possible continuations of a token. This limits the user's ability to interact with multiple branching options, especially in creative writing. Our goal for this project was to create an interface that would let the user explore multiple 'lookahead' completions of tokens interactively. 

The standard Hugging Face .generate() API does not support lookahead branching directly. With .generate() API, we cannot intervene after each token to explore multiple possible next tokens. It produces and outputs one sequence at a time. And because of this reason, we needed to custom implement this inference pipeline ourselves.

Lookahead generation is a technique to see explore multiple possible next-token continuations of a prompt. This will allow the user to see and choose from several potential paths instead of just a single prediction. This enables dynamic user involvement by allowing user to steer the direction of the interaction.

Real world goal:

Technical goal:

Approach: Implementing a low level implementation of the 'generate' API to allow us to take full control over the internal process like cache, braching, and computing logits for evaluation.



In [1]:
%pip install streamlit

Collecting streamlit
  Using cached streamlit-1.45.0-py3-none-any.whl.metadata (8.9 kB)
Collecting altair<6,>=4.0 (from streamlit)
  Using cached altair-5.5.0-py3-none-any.whl.metadata (11 kB)
Collecting blinker<2,>=1.5.0 (from streamlit)
  Using cached blinker-1.9.0-py3-none-any.whl.metadata (1.6 kB)
Collecting cachetools<6,>=4.0 (from streamlit)
  Using cached cachetools-5.5.2-py3-none-any.whl.metadata (5.4 kB)
Collecting click<9,>=7.0 (from streamlit)
  Using cached click-8.1.8-py3-none-any.whl.metadata (2.3 kB)
Collecting numpy<3,>=1.23 (from streamlit)
  Using cached numpy-2.2.5-cp313-cp313-win_amd64.whl.metadata (60 kB)
Collecting packaging<25,>=20 (from streamlit)
  Using cached packaging-24.2-py3-none-any.whl.metadata (3.2 kB)
Collecting pandas<3,>=1.4.0 (from streamlit)
  Using cached pandas-2.2.3-cp313-cp313-win_amd64.whl.metadata (19 kB)
Collecting pillow<12,>=7.1.0 (from streamlit)
  Using cached pillow-11.2.1-cp313-cp313-win_amd64.whl.metadata (9.1 kB)
Collecting protobuf<


[notice] A new release of pip is available: 25.0.1 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
import streamlit as st
st.set_page_config(page_title="Interactive Text Generator", layout="centered")



In [4]:
%pip install torch transformers


Collecting torch
  Using cached torch-2.7.0-cp313-cp313-win_amd64.whl.metadata (29 kB)
Collecting transformers
  Using cached transformers-4.51.3-py3-none-any.whl.metadata (38 kB)
Collecting filelock (from torch)
  Using cached filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting sympy>=1.13.3 (from torch)
  Using cached sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Using cached networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting fsspec (from torch)
  Using cached fsspec-2025.3.2-py3-none-any.whl.metadata (11 kB)
Collecting setuptools (from torch)
  Downloading setuptools-80.3.1-py3-none-any.whl.metadata (6.5 kB)
Collecting huggingface-hub<1.0,>=0.30.0 (from transformers)
  Using cached huggingface_hub-0.30.2-py3-none-any.whl.metadata (13 kB)
Collecting pyyaml>=5.1 (from transformers)
  Using cached PyYAML-6.0.2-cp313-cp313-win_amd64.whl.metadata (2.1 kB)
Collecting regex!=2019.12.17 (from transformers)
  Using cached regex-2024.11.


[notice] A new release of pip is available: 25.0.1 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [5]:
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
torch.classes.__path__ = [] # add this line to manually set it to empty.
## Workaround for the issue with torch.classes.__path__ in transformers library
## Reference: https://discuss.streamlit.io/t/message-error-about-torch/90886/6

# --- 1. Setup ---
@st.cache_resource
def load_model():
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_name = "Alina3234/gemma-lookahead"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    model.to(device)

    return model, tokenizer, device

model, tokenizer, device = load_model()

2025-05-06 14:21:18.715 
  command:

    streamlit run d:\Calvin Semester Work\Spring 2025\ML\gemma-lookahead\.venv\Lib\site-packages\ipykernel_launcher.py [ARGUMENTS]


## Lookahead Token Logic

In [None]:
def get_lookahead_sequences(model, tokenizer, hypotheses, n_branch_tokens=5, device='cuda'):
    assert len(hypotheses.shape) == 2 and hypotheses.shape[0] == 1, "Expected input shape (1, seq_len)"
    n_tokens_so_far = hypotheses.shape[1]
    hypotheses = hypotheses.to(device)
    past_key_values = DynamicCache()

    with torch.no_grad():
        outputs = model(hypotheses, output_hidden_states=True, past_key_values=past_key_values)

    branch_tokens = outputs.logits[0, -1].topk(n_branch_tokens).indices.to(device)
    assert branch_tokens.shape == (n_branch_tokens,)

    for i in range(len(past_key_values.key_cache)):
        past_key_values.key_cache[i] = past_key_values.key_cache[i].repeat(n_branch_tokens, 1, 1, 1).to(device)
        past_key_values.value_cache[i] = past_key_values.value_cache[i].repeat(n_branch_tokens, 1, 1, 1).to(device)

    past_key_values.reorder_cache(torch.arange(n_branch_tokens, device=device))

    sequences = branch_tokens.unsqueeze(1)
    position_id = n_tokens_so_far
    loop_output_logits = []

    for step in range(2):
        cache_position_tensor = torch.tensor([position_id], device=device)
        attention_mask = torch.ones((n_branch_tokens, 1), dtype=torch.long, device=device)

        with torch.no_grad():
            current_input = sequences[:, -1:]
            model_outs = model(
                current_input,
                past_key_values=past_key_values,
                output_hidden_states=True,
                use_cache=True,
                cache_position=cache_position_tensor,
                attention_mask=attention_mask
            )

        next_token_logits = model_outs.logits[:, -1]
        next_tokens = next_token_logits.argmax(dim=-1)
        sequences = torch.cat([sequences, next_tokens.unsqueeze(1)], dim=1)
        loop_output_logits.append(model_outs.logits)
        position_id += 1

    return sequences, outputs.logits[0, -1], loop_output_logits

In [None]:
def generate_lookahead_text(model, tokenizer, sequence, n_branch_tokens=5, device='cuda'):
    sequences, _, _ = get_lookahead_sequences(model, tokenizer, sequence, n_branch_tokens, device)
    return tokenizer.batch_decode(sequences, skip_special_tokens=True)

## Generating a lookahead sequence using the whole prompt

In [7]:
def generate_initial_lookahead(prompt, n_branch_tokens=5):
    """Generate initial lookahead with full prompt tokenization"""
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    return generate_lookahead_text(model, tokenizer, input_ids, n_branch_tokens, device), input_ids

## Generating a lookahead sequence based on the new token
Need to ask Alina

In [8]:
def generate_incremental_lookahead(new_token, n_branch_tokens=5):
    # Tokenize just the new token
    input_ids = tokenizer(new_token, add_special_tokens=False, return_tensors="pt").input_ids.to(device)

    full_prompt = st.session_state.prompt
    full_input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(device)

    return generate_lookahead_text(model, tokenizer, full_input_ids, n_branch_tokens, device), full_input_ids

Streamlit UI

In [9]:
st.title("✍️ Interactive Lookahead Text Generator")
st.markdown("This app shows potential continuations as you write. Select a suggestion to continue your text.")

# Initialize session state variables
if "prompt" not in st.session_state:
    st.session_state.prompt = ""

if "suggestions" not in st.session_state:
    st.session_state.suggestions = []

if "last_token_added" not in st.session_state:
    st.session_state.last_token_added = ""

if "input_ids" not in st.session_state:
    st.session_state.input_ids = None

if "regenerate" not in st.session_state:
    st.session_state.regenerate = False

# Function to handle suggestion selection
def select_suggestion(suggestion):
    # Store the original prompt length before adding the suggestion
    original_length = len(st.session_state.prompt)

    # Add the suggestion to the prompt
    st.session_state.prompt += " " + suggestion

    # Store the newly added token for incremental generation
    st.session_state.last_token_added = suggestion

    # Flag that we need to regenerate
    st.session_state.regenerate = True

    # Clear the current suggestions as they're no longer relevant
    st.session_state.suggestions = []

# Prompt input area
prompt_input = st.text_area(
    "Your text:",
    value=st.session_state.prompt,
    height=150,
    key="prompt_area"
)

# Check if the user has manually edited the prompt
if prompt_input != st.session_state.prompt:
    # Update the prompt and clear any cached state
    st.session_state.prompt = prompt_input
    st.session_state.input_ids = None
    st.session_state.suggestions = []
    st.session_state.last_token_added = ""

# Generate button
if st.button("Generate Completions", type="primary"):
    if st.session_state.prompt.strip():
        with st.spinner("Generating suggestions..."):
            try:
                # Generate initial suggestions based on the full prompt
                st.session_state.suggestions, st.session_state.input_ids = generate_initial_lookahead(
                    st.session_state.prompt
                )
                st.session_state.regenerate = False
            except Exception as e:
                st.error(f"Error during generation: {str(e)}")
    else:
        st.warning("Please enter some text to begin.")

# Display suggestions
if st.session_state.suggestions:
    st.markdown("### ✨ Top Branching Completions:")

    # Create columns for better layout (5 suggestions per row)
    cols = st.columns(5)

    for i, suggestion in enumerate(st.session_state.suggestions):
        col_idx = i % 5
        with cols[col_idx]:
            suggestion_text = suggestion.strip()
            if st.button(f"{suggestion_text}", key=f"sugg_{i}"):
                select_suggestion(suggestion_text)
                st.rerun()

# Auto-regenerate after selecting a suggestion
if st.session_state.regenerate and st.session_state.prompt.strip():
    st.session_state.regenerate = False

    with st.spinner("Generating new suggestions..."):
        try:
            # Generate suggestions based on the incremental token
            st.session_state.suggestions, st.session_state.input_ids = generate_incremental_lookahead(
                st.session_state.last_token_added
            )
            st.rerun()
        except Exception as e:
            st.error(f"Error during generation: {str(e)}")

# Show text preview and stats
if st.session_state.prompt:

    # Display character and word count
    char_count = len(st.session_state.prompt)
    word_count = len(st.session_state.prompt.split())
    st.caption(f"Characters: {char_count} | Words: {word_count}")

# Add a clear button
if st.button("Clear All"):
    st.session_state.prompt = ""
    st.session_state.suggestions = []
    st.session_state.input_ids = None
    st.session_state.last_token_added = ""
    st.rerun()

2025-05-06 14:22:31.416 Session state does not function when running a script without `streamlit run`


Evaluation

Cached Prediction

In [None]:
def get_lookahead_sequences(model, tokenizer, hypotheses, n_branch_tokens=5, device='cuda'):

  assert len(hypotheses.shape) == 2 and hypotheses.shape[0] == 1, "Expected input shape (1, seq_len)"
  # stores how long the prompt is
  n_tokens_so_far = hypotheses.shape[1]
  hypotheses = hypotheses.to(device)
  past_key_values = DynamicCache() # hold key/value

  with torch.no_grad():
      outputs = model(hypotheses, output_hidden_states=True, past_key_values=past_key_values)

  # Get top-k tokens from last position
  branch_tokens = outputs.logits[0, -1].topk(n_branch_tokens).indices.to(device)
  branched_output_logits = outputs.logits[0, -1]
  print(tokenizer.decode(branch_tokens))
  print("Branch tokens shape:", branch_tokens.shape)  # Expected: (5,)
  assert branch_tokens.shape == (n_branch_tokens,)

  # Repeat past_key_values for each branch
  for i in range(len(past_key_values.key_cache)):
      past_key_values.key_cache[i] = past_key_values.key_cache[i].repeat(n_branch_tokens, 1, 1, 1).to(device)
      past_key_values.value_cache[i] = past_key_values.value_cache[i].repeat(n_branch_tokens, 1, 1, 1).to(device)

  # Fixes the internal tracking
  past_key_values.reorder_cache(torch.arange(n_branch_tokens, device=device))

  # Start sequences from the branch tokens
  sequences = branch_tokens.unsqueeze(1)
  print("Initial sequences shape:", sequences.shape)  # Expected: (5, 1)
  assert sequences.shape == (n_branch_tokens, 1)

  position_id = n_tokens_so_far
  cached_logits = []

  for step in range(2):  # Generate 2 more tokens
      print(f"\n--- Step {step + 1} ---")
      print("Current sequences shape before generation:", sequences.shape)

      cache_position_tensor = torch.tensor([position_id], device=device)  # Convert to tensor
      # Keep attention mask as is to tell the model to fully attend to each n_branch numbered tokens
      attention_mask = torch.ones((n_branch_tokens,1), dtype=torch.long, device=device)
      print("Before generation:")
      print("past_key_values key shape:", past_key_values.key_cache[0].shape)  # Should start as (5, ..., ..., ...)
      print("attention_mask shape:", attention_mask.shape)                     # Should be (5, 1) (1,1)


      try:
          with torch.no_grad():
              current_input = sequences[:, -1:]
              print("Input to model (last token):", current_input.shape)  # Expected: (5, 1)
              assert current_input.shape == (n_branch_tokens, 1)

              model_outs = model(
                  current_input,
                  past_key_values=past_key_values,
                  output_hidden_states=True,
                  use_cache=True,
                  cache_position=cache_position_tensor, #cache_position
                  attention_mask=attention_mask
              )
              print("model_outs type:", type(model_outs))
              print("model_outs logits shape:", model_outs.logits.shape)
              loop_model_logits = model_outs.logits
              print("model_outs past_key_values shapes:")
              if hasattr(model_outs, "past_key_values"):
                  if isinstance(model_outs.past_key_values, tuple) and len(model_outs.past_key_values) > 0:
                      print("First layer k/v shapes:",
                            model_outs.past_key_values[0][0].shape,
                            model_outs.past_key_values[0][1].shape)
      except Exception as e:
          print("Error during model forward pass:", e)
          raise

      next_token_logits = model_outs.logits[:, -1]
      print(next_token_logits)
      print("Next token logits shape:", next_token_logits.shape)  # Expected: (5, vocab_size)
      assert next_token_logits.shape[0] == n_branch_tokens

      next_tokens = next_token_logits.argmax(dim=-1)
      print("Next tokens shape:", next_tokens.shape)  # Expected: (5,)
      assert next_tokens.shape == (n_branch_tokens,)

      sequences = torch.cat([sequences, next_tokens.unsqueeze(1)], dim=1)
      print("Updated sequences shape:", sequences.shape)  # Should grow (5, 2), then (5, 3)

      cached_logits.append(loop_model_logits)
      position_id += 1

  print(sequences)
  return sequences, branched_output_logits, cached_logits  # Final shape: (5, 3)

Step by Step Implementation

In [None]:
def get_lookahead_sequences(model, tokenizer, hypotheses, n_branch_tokens=5, device='cuda'):
  assert len(hypotheses.shape) == 2 and hypotheses.shape[0] == 1, "Expected input shape (1, seq_len)"

  # Get the initial sequence from the input
  original_sequence = hypotheses[0].tolist()
  hypotheses = hypotheses.to(device)

  # Get the logits for the next token without using cache
  with torch.no_grad():
      outputs = model(hypotheses, output_hidden_states=True)

  # Get top-k tokens from last position
  branch_tokens = outputs.logits[0, -1].topk(n_branch_tokens).indices.to(device)
  branched_token_logit_2 = outputs.logits[0,-1]
  print("Top-k branch tokens:", tokenizer.decode(branch_tokens))
  print("Branch tokens shape:", branch_tokens.shape)  # Expected: (5,)
  assert branch_tokens.shape == (n_branch_tokens,)

  # Create initial sequences for each branch
  all_sequences = []
  for branch_token in branch_tokens:
      # Each sequence starts with the original prompt + the branch token
      sequence = original_sequence + [branch_token.item()]
      all_sequences.append(sequence)

  # Convert to tensor for easier manipulation
  sequences = torch.tensor([all_sequences[i] for i in range(n_branch_tokens)], device=device)
  print("Initial sequences shape:", sequences.shape)  # Expected: (5, seq_len+1)

  no_cache_logits = []
  # Generate additional tokens step by step
  for step in range(2):  # Generate 2 more tokens
      print(f"\n--- Step {step + 1} ---")
      print("Current sequences shape before generation:", sequences.shape)

      next_tokens = []

      # Process each sequence independently
      for seq_idx, sequence in enumerate(sequences):
          # Create input for model (full sequence up to now)
          current_input = sequence.unsqueeze(0)  # Add batch dimension
          print(f"Sequence {seq_idx} input shape:", current_input.shape)

          try:
              with torch.no_grad():
                  # Forward pass without cache or position_ids
                  model_outs = model(
                      current_input,
                      output_hidden_states=True,
                      use_cache=False
                  )

                  # Get prediction for next token
                  next_token_logits = model_outs.logits[0, -1]
                  no_cache_logits.append(next_token_logits)
                  print(next_token_logits)
                  next_token = next_token_logits.argmax(dim=-1)
                  next_tokens.append(next_token)

                  print(f"Sequence {seq_idx} next token:", tokenizer.decode(next_token))

          except Exception as e:
              print(f"Error processing sequence {seq_idx}:", e)
              raise

      # Stack the next tokens
      next_tokens = torch.stack(next_tokens)
      print("Next tokens shape:", next_tokens.shape)  # Expected: (5,)

      # Add new tokens to sequences
      sequences = torch.cat([sequences, next_tokens.unsqueeze(1)], dim=1)
      print("Updated sequences shape:", sequences.shape)

  # Print the final token sequences
  for i, seq in enumerate(sequences):
      print(f"Sequence {i}:", tokenizer.decode(seq))

  return sequences, branched_token_logit_2, no_cache_logits

We compared the output logits for each steps in tokenzation of cached method and no cache method. And it gave us the result where all the logits matched with the corresponding ones.
Using one small example, we could also see that cached method is about 4 times faster in the generation process.
CPU times: user 2.78 s, sys: 25.2 ms, total: 2.81 s Wall time: 1.49 s
CPU times: user 11.7 s, sys: 28.3 ms, total: 11.7 s Wall time: 5.92 s

are_equal = (

len(cached_logits) == len(no_cache_logits) and
all(torch.allclose(a, b, atol=1e-4) for a, b in zip(cached_logits, no_cache_logits))
)

print(are_equal)

## What we learned

We gained a practical understanding of the tokenization process and learned how the shapes of outputs play a crucial role in ensuring correct evaluation.
We discovered that using cached outputs significantly reduces computational load and power consumption by enabling faster generation compared to step-by-step processing.
We also learned that running the model on a GPU can further improve speed, although performance on a CPU was still reasonably good.

## Future Direction


Test out multiple branched prediction for the generation of second and third token.
Test the limit of the cached method: how many more tokens can it predict successfully?
Implement the whole process with GPU to save more energy.

## Supporting Material


This project is based on Professor Ken Arnold's initial implementation of lookahead generation.
https://huggingface.co/spaces/CalvinU/writing-prototypes/blob/main/custom_llm_inference.py#L66
add Codeadd Markdown