This notebook demonstrates a simple implementation of the MultiDecode algorithm.  In this notebook the attention masks and position_ids are generated manually and then passed to mdgen for iterative token generation.  A few simple examples are demonstrated.
- beam search
- parallel questions
- writing in the margins

In [8]:
import os
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import copy
import time
from transformers import AutoTokenizer, AutoModelForCausalLM,AutoConfig
import torch
import transformers
import torch.nn.functional as F
from functools import partial
import pandas as pd
import matplotlib.pyplot as plt

In [9]:
# Log in to Hugging Face

from huggingface_hub import login
try:
    from google.colab import userdata
    hf_token=userdata.get('huggingface')
except:
    import os
    import dotenv
    dotenv.load_dotenv("../.env")
    hf_token=os.getenv('HUGGINGFACE')


login(token=hf_token)

In [10]:
def mdgen(model, input_ids,positions=None,mask=None,gen_len=10,n_branch=2,greedy=False,branch_locations=None,past_key_values=None):
    """
    Implements the parallel generation of tokens using the multidecode technique.

    This function generates tokens in parallel by branching at specified positions in the input sequence.
    It uses a model's forward pass to compute logits and generate tokens iteratively, either greedily or
    through sampling. The generated tokens are accumulated and returned along with other relevant data.

    Args:
        model (torch.nn.Module): The model used for token generation. Must support `forward` with caching.
        input_ids (torch.Tensor): Input token IDs of shape (batch_size, ctx_len).
        positions (torch.Tensor, optional): Position encodings for the input tokens. If None, defaults to
            sequential positions [0, 1, ..., ctx_len-1]. Shape must match `input_ids`.
        mask (torch.Tensor): Attention mask of shape (batch_size, ctx_len, ctx_len). Controls which tokens
            the model attends to during prefill.
        gen_len (int, optional): Number of tokens to generate. Defaults to 10.
        n_branch (int, optional): Number of parallel branches for token generation. Defaults to 2.
        greedy (bool, optional): If True, selects the most probable token at each step. If False, samples
            tokens based on probabilities. Defaults to False.
        branch_locations (list, optional): List of positions where branches start. If None, defaults to
            the end of the input context.

    Returns:
        dict: A dictionary containing:
            - 'branch_ids' (torch.Tensor): Generated token IDs for each branch, reshaped to (n_branch, batch_size).
            - 'mask' (torch.Tensor): Final attention mask after generation.
            - 'output_ids' (torch.Tensor): All generated token IDs concatenated sequentially.
            - 'input_ids' (torch.Tensor): Original input token IDs.
            - 'positions' (torch.Tensor): Position encodings for all tokens, including generated ones.

    Raises:
        AssertionError: If `positions` shape does not match `input_ids` shape.

    Example:
        output = mdgen(
            model=my_model,
            input_ids=torch.tensor([[1, 2, 3]]),
            mask=torch.ones(1, 3, 3),
            gen_len=5,
            n_branch=2,
            greedy=True
        )
        print(output['branch_ids'])
    """

    past_len=0 if past_key_values is None else past_key_values.get_seq_length()

    if positions is not None:
        assert input_ids.shape == positions.shape,"positions.shape must match input_ids.shape"
    #assert mask.shape[2]==input_ids.shape[1],"length of attn mask must match input length"


    batch_size,ctx_len=input_ids.shape

    # every cycle we add n_branch more tokens, so the end of the 4D attention mask is a diagonal. Create here and reuse below
    gen_mask = torch.where(torch.eye(n_branch) == 1, torch.tensor(0.0), torch.tensor(float('-inf'))).unsqueeze(0).unsqueeze(0).to(model.device)

    # position information of the initial context input_ids. If None assume 0..ctx_len
    if positions is None:
        positions=torch.arange(ctx_len,dtype=torch.int).unsqueeze(0)
    positions=positions.to(model.device)
    position_history=copy.copy(positions)


    # if branch location is not specified, assume all branches start at the end of the context
    if branch_locations is None:
        branch_locations=[ctx_len-1]*n_branch

    assert all(bl>past_len for bl in branch_locations),"Branches must start with new input_ids, not from past_key_values."

    # the position encoding of the first generated token is just after the branch location position encoding
    tmp=[int(positions[0,x]) for x in branch_locations]
    gen_positions=torch.tensor(tmp).unsqueeze(0).to(model.device)


    # we will accumulate the generated tokens into output_ids
    output_ids=torch.empty((batch_size,0),dtype=torch.int).to(model.device)

    # move remaining tensors to model.device
    mask=mask.to(model.device)
    input_ids=input_ids.to(model.device)
    initial_length=input_ids.shape[1]
    pkv=past_key_values

    with torch.no_grad():
        # first step is to prefill the context and generate pkv
        output=model.forward(input_ids=input_ids[:,past_len:],position_ids=positions[:,past_len:] ,attention_mask=mask, use_cache=True,past_key_values=pkv)
        pkv = output.past_key_values

        # get logits from the locations where the branches fork
        branch_locations=torch.tensor(branch_locations,dtype=torch.int)

        # branch locations are relative to full input sequence,
        # so we subtrack the pkv length
        logits=output['logits'][:,branch_locations-past_len,:]
        mask = mask[:,:,branch_locations-past_len,:]

        for i in range(gen_len):
            # select tokens, greedy or not
            next_token_probs = F.softmax(logits / 0.7, dim=-1)
            if greedy:
                tokens = torch.argmax(next_token_probs,dim=-1)
            else:
                samples = torch.multinomial(next_token_probs.view(-1,next_token_probs.shape[-1]), num_samples=1, replacement=True).view(batch_size,n_branch)
                tokens = samples.squeeze(-1)

            # save the generated tokens
            output_ids=torch.cat([output_ids,tokens],dim=-1)
            mask=torch.cat([mask,gen_mask],dim=-1)

            # Generate n_branch new tokens.
            output=model.forward(input_ids=tokens,position_ids=gen_positions ,attention_mask=mask, past_key_values=pkv, use_cache=True)
            logits=output['logits']
            pkv = output['past_key_values']

            # increment the position information for the next token
            gen_positions+=1

            position_history=torch.cat([position_history,gen_positions],dim=-1)

    # restruture the results to have n_branch sequences
    branch_ids=output_ids.view(-1,n_branch,1).permute(2,1,0).squeeze(-1)
    full_ids=torch.cat([input_ids,output_ids],dim=-1)
    return {'branch_ids':branch_ids,'mask':mask,'output_ids':output_ids,'input_ids':full_ids,
            'n_branch':n_branch,'initial_length':initial_length,'positions':position_history,'past_key_values':pkv}




In [11]:
#Helpful utilities
def print_branches(branch_ids):
    branch_ids=branch_ids.cpu()
    for sidx,branches in enumerate(branch_ids):
        for bidx,branch_ids in enumerate(branches):
            ids=branch_ids
            print(f"{sidx}.{bidx}: {''.join(tokenizer.batch_decode(ids, skip_special_tokens=True))}")

def print_mask(mask):
    for i in range(mask.shape[2]):
        for j in range(mask.shape[3]):
            print('*' if mask[0,0,i,j]==0. else '.',end="")
        print()

def print_full(output):
    full_ids=torch.cat([output['input_ids'],output['output_ids']],dim=-1)
    # print(f"{full_ids=}")
    # print(''.join(tokenizer.batch_decode(full_ids, skip_special_tokens=True)))
    mask=output['mask']
    for b in range(mask.shape[2]):
        branch_full_ids=[]
        for p in range(mask.shape[3]):
            if mask[0,0,b,p] == 0.0:
                branch_full_ids.append(int(full_ids[0,p]))
        print(f"{b}:{''.join(tokenizer.batch_decode(branch_full_ids, skip_special_tokens=True))}")


def print_args(input_ids=None,positions=None,mask=None,branch_locations=None):
    print()
    print("Arguments:")
    print(f"{input_ids.shape=}")
    if positions is not None:
        print(f"{positions=}")
    if branch_locations is not None:
        print(f"{branch_locations=}")
    print()
    if mask is not None:
        print_mask(mask)
    print()

def print_results(output):
    print()
    print("Results")
    print("raw")
    print(f"{''.join(tokenizer.batch_decode(output['output_ids'], skip_special_tokens=True))}")
    print()
    print("Reformated")
    print_full(output)
    print()
    print(f"positions {output['positions']}")
    print()
    print_mask(output['mask'])
    print()

def strmask(*args):
    n_branch=len(args)
    seq_len=len(args[0])
    ret=torch.full((n_branch,seq_len),fill_value=float('-inf'))

    for b,arg in enumerate(args):
        for i,v in enumerate(arg):
            if v=='1' or v=='*':
                ret[b,i]=0
    return ret.unsqueeze(0).unsqueeze(0)

def lut_attn(n):
    ''''

    Returns a lower triangle array with dimensions and values suitable for an attention mask
    dimension: [1,1,n,n]
    values: 0 in lower triangle and diagonal
            -inf in upper triangle

    '''
    return torch.where(torch.tril(torch.ones(n,n)) == 1, torch.tensor(0.0), torch.tensor(float('-inf'))).unsqueeze(0).unsqueeze(0)

In [12]:
#Initialize model
model_name="meta-llama/Llama-3.2-1B"

tokenizer = AutoTokenizer.from_pretrained(model_name,padding_side='left')
tokenizer.pad_token_id=tokenizer.eos_token_id

model = AutoModelForCausalLM.from_pretrained(model_name)#,attn_implementation="flex_attention")
model = model.to("cuda" if torch.cuda.is_available() else "cpu") # Use GPU if available

In [13]:
# simple beam generation
input_ids=tokenizer("Once upon a time", return_tensors="pt", padding=True, truncation=True)['input_ids'].to(model.device)

mask=lut_attn(input_ids.shape[1])

print_args(input_ids,mask=mask)

output=mdgen(model, input_ids,positions=None,mask=mask,n_branch=5,greedy=False)

print_results(output)



Arguments:
input_ids.shape=torch.Size([1, 5])

*....
**...
***..
****.
*****


Results
raw
, in,, there there the there an was were German was idea a twoic a was little companion kingdoms man born boy animals there called in who who was John the lived were a. heart in more king He of the than who had a heart

Reformated
0:Once upon a time, there were two companion animals who were more than
1:Once upon a time in the Germanic kingdoms there was a king who
2:Once upon a time, there was a man called John. He had
3:Once upon a time, an idea was born in the heart of a
4:Once upon a time there was a little boy who lived in the heart

positions tensor([[ 0,  1,  2,  3,  4,  5,  5,  5,  5,  5,  6,  6,  6,  6,  6,  7,  7,  7,
          7,  7,  8,  8,  8,  8,  8,  9,  9,  9,  9,  9, 10, 10, 10, 10, 10, 11,
         11, 11, 11, 11, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 14, 14, 14, 14,
         14]], device='cuda:0')

******....*....*....*....*....*....*....*....*....*....
*****.*....*....*....

In [14]:
# Multi-question
context_ids=tokenizer("The house is red. The grass is green. The bike is purple. ", return_tensors="pt", padding=True, truncation=True)['input_ids'].to(model.device)
question1_ids=tokenizer("What color is the bike?", return_tensors="pt", padding=True, truncation=True)['input_ids'].to(model.device)
question2_ids=tokenizer("What color is the grass?", return_tensors="pt", padding=True, truncation=True)['input_ids'].to(model.device)
context_len=context_ids.shape[1]
question1_len=question1_ids.shape[1]
question2_len=question2_ids.shape[1]

input_ids=torch.cat([context_ids,question1_ids,question2_ids],dim=-1)

mask=lut_attn(input_ids.shape[1])
# mask out the first question from the view of the second question
mask[:,:,context_len+question1_len:,context_len:context_len+question1_len]=float('-inf')

positions=torch.cat([torch.arange(context_len+question1_len),torch.arange(context_len,context_len+question2_len)]).unsqueeze(0)
branch_locations=[context_len+question1_len-1,context_len+question1_len+question2_len-1]

print_args(input_ids,positions,mask)

output=mdgen(model, input_ids,positions=positions,mask=mask,branch_locations=branch_locations,greedy=True)

print_results(output)



Arguments:
input_ids.shape=torch.Size([1, 31])
positions=tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 17, 18, 19, 20, 21, 22, 23]])

*..............................
**.............................
***............................
****...........................
*****..........................
******.........................
*******........................
********.......................
*********......................
**********.....................
***********....................
************...................
*************..................
**************.................
***************................
****************...............
*****************..............
******************.............
*******************............
********************...........
*********************..........
**********************.........
***********************........
************************.......
*****************.......*....

In [15]:
# Writing in the margins
context1_ids=tokenizer("The house was red. ", return_tensors="pt", padding=True, truncation=True)['input_ids'].to(model.device)
context2_ids=tokenizer("The grass was green. The bike was blue. ", return_tensors="pt", padding=True, truncation=True)['input_ids'].to(model.device)
question1_ids=tokenizer("What color is the bike?", return_tensors="pt", padding=True, truncation=True)['input_ids'].to(model.device)
question2_ids=tokenizer("What color is the bike?", return_tensors="pt", padding=True, truncation=True)['input_ids'].to(model.device)
context1_len=context1_ids.shape[1]
context2_len=context2_ids.shape[1]
question1_len=question1_ids.shape[1]
question2_len=question2_ids.shape[1]
print(f"{context1_len=} {context2_len=} {question1_len=} {question2_len=}")

input_ids=torch.cat([context1_ids,context2_ids,question1_ids,question2_ids],dim=-1)

# mask out the second context from the first question, and the first question from the view of the second question
mask=lut_attn(input_ids.shape[1])
mask[:,:,context1_len+context2_len:context1_len+context2_len+question1_len,context1_len:context1_len+context2_len]=float('-inf')
mask[:,:,context1_len+context2_len+question1_len:,context1_len+context2_len:context1_len+context2_len+question1_len]=float('-inf')


# make question 1 follow context1 and question 2 follow context2
positions=torch.cat([
    torch.arange(context1_len+context2_len),
    torch.arange(context1_len,context1_len+question1_len),
    torch.arange(context1_len+context2_len,context1_len+context2_len+question2_len)]).unsqueeze(0)

print(f"{positions=}")
print(f"{input_ids.shape=} {positions.shape=}")
branch_locations=[context1_len+context2_len+question1_len-1,context1_len+context2_len+question1_len+question2_len-1]

print_args(input_ids,positions,mask)

output=mdgen(model, input_ids,positions=positions,mask=mask,branch_locations=branch_locations,greedy=True)

print_results(output)


context1_len=7 context2_len=12 question1_len=7 question2_len=7
positions=tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18,  7,  8,  9, 10, 11, 12, 13, 19, 20, 21, 22, 23, 24, 25]])
input_ids.shape=torch.Size([1, 33]) positions.shape=torch.Size([1, 33])

Arguments:
input_ids.shape=torch.Size([1, 33])
positions=tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18,  7,  8,  9, 10, 11, 12, 13, 19, 20, 21, 22, 23, 24, 25]])

*................................
**...............................
***..............................
****.............................
*****............................
******...........................
*******..........................
********.........................
*********........................
**********.......................
***********......................
************.....................
*************....................
**************...................
***************.

In [16]:
# Teds example part 1
input_ids=tokenizer("5 - 3 =", return_tensors="pt", padding=True, truncation=True)['input_ids'].to(model.device)

mask=lut_attn(input_ids.shape[1])

print_args(input_ids,mask=mask)

output=mdgen(model, input_ids,mask=mask,n_branch=1,gen_len=2,greedy=True)

print_results(output)
print_full(output)



Arguments:
input_ids.shape=torch.Size([1, 6])

*.....
**....
***...
****..
*****.
******


Results
raw
 2

Reformated
0:5 - 3 = 2

positions tensor([[0, 1, 2, 3, 4, 5, 6, 7]], device='cuda:0')

********

0:5 - 3 = 2


In [17]:
# Teds example part 2
original_ids=tokenizer("5 - 3 =", return_tensors="pt", padding=True, truncation=True)['input_ids'].to(model.device)
print(f"{original_ids.cpu()=}")
input_ids=tokenizer("5 = - 3", return_tensors="pt", padding=True, truncation=True)['input_ids'].to(model.device)
print(f"{input_ids.cpu()=}")


mask=strmask("*.....","**....","******","**.*..","**.**.","**.***")

order=torch.tensor([0, 1, 3, 4, 5, 2])

positions=torch.tensor([[0, 1, 5, 2,3,4]],dtype=torch.int)
branch_locations=[2]

print_args(input_ids,mask=mask,positions=positions)

output=mdgen(model, input_ids,positions=positions,mask=mask,branch_locations=branch_locations,n_branch=1,gen_len=2,greedy=True)

print_results(output)
print_full(output)

original_ids.cpu()=tensor([[128000,     20,    482,    220,     18,    284]])
input_ids.cpu()=tensor([[128000,     20,    284,    482,    220,     18]])

Arguments:
input_ids.shape=torch.Size([1, 6])
positions=tensor([[0, 1, 5, 2, 3, 4]], dtype=torch.int32)

*.....
**....
******
**.*..
**.**.
**.***


Results
raw
 2

Reformated
0:5 = - 3 2

positions tensor([[0, 1, 5, 2, 3, 4, 6, 7]], device='cuda:0')

********

0:5 = - 3 2


In [18]:
def select_branch(output,selected_branch):
    """
    Selects a specific branch from the output of the `mdgen` function.

    This function extracts the input IDs, position IDs, attention mask, and past key values
    corresponding to the specified branch index from the output of the `mdgen` function.

    Args:
        output (dict): The output dictionary from the `mdgen` function. It should contain:
            - 'branch_ids' (torch.Tensor): Generated token IDs for each branch.
            - 'positions' (torch.Tensor): Position encodings for each branch.
            - 'mask' (torch.Tensor): Attention mask for each branch.
            - 'past_key_values' (optional): Cached key-value pairs for efficient decoding.
        branch_index (int): The index of the branch to select.

    Returns:
        tuple: A tuple containing:
            - input2_ids (torch.Tensor): The token IDs for the selected branch.
            - position2_ids (torch.Tensor): The position encodings for the selected branch.
            - mask2 (torch.Tensor): The attention mask for the selected branch.
            - pkv (optional): The past key-value pairs for the selected branch, if available.

    Raises:
        IndexError: If the specified branch index is out of range.

    Example:
        input2_ids, position2_ids, mask2, pkv = select_branch(output, branch_index=1)
    """
    pkv=output['past_key_values']
    o_positions=output['positions']
    o_mask=output['mask']
    o_input_ids=output['input_ids']
    o_initial_len=output['initial_length']
    o_input_ids_len=o_input_ids.shape[1]

    input_indexes=torch.cat([torch.arange(o_initial_len),torch.arange(o_initial_len+selected_branch,o_input_ids_len+1,n_branch,dtype=torch.int)],dim=-1)

    input2_ids=o_input_ids[:,input_indexes]
    position2_ids=o_positions[:,input_indexes]
    selected_len=(o_input_ids.shape[1]- o_initial_len)//n_branch
    mask2=o_mask[:,:,selected_branch,input_indexes[:o_initial_len]].repeat([1,1,selected_len,1])
    mask2=torch.cat([mask2,lut_attn(selected_len).to(model.device)],dim=-1)

    pkv.crop(o_initial_len)

    print(f"{input2_ids.shape=} {position2_ids.shape=} {mask2.shape=} {pkv.get_seq_length()=}")
    return input2_ids,position2_ids, mask2, pkv

In [19]:
# Example for using the select branch function

input_ids=tokenizer("Once upon a time", return_tensors="pt", padding=True, truncation=True)['input_ids'].to(model.device)
mask=lut_attn(input_ids.shape[1])

print_args(input_ids,mask=mask)

input_len=input_ids.shape[1]
n_branch=3
output=mdgen(model, input_ids,positions=None,mask=mask,n_branch=n_branch,greedy=False)
print_results(output)
selected_branch=1



input2_ids,position2_ids,mask2, pkv=select_branch(output,selected_branch)

print_args(input2_ids,positions=position2_ids,mask=mask2)

output=mdgen(model, input2_ids,positions=position2_ids,mask=mask2,n_branch=n_branch,greedy=False,past_key_values=pkv)

print_results(output)




Arguments:
input_ids.shape=torch.Size([1, 5])

*....
**...
***..
****.
*****


Results
raw
,,, there there there was was lived a a a boy small happy who island, lived, peaceful in a and a small very small country productive

Reformated
0:Once upon a time, there was a boy who lived in a small
1:Once upon a time, there was a small island, a small country
2:Once upon a time, there lived a happy, peaceful and very productive

positions tensor([[ 0,  1,  2,  3,  4,  5,  5,  5,  6,  6,  6,  7,  7,  7,  8,  8,  8,  9,
          9,  9, 10, 10, 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14]],
       device='cuda:0')

******..*..*..*..*..*..*..*..*..*..
*****.*..*..*..*..*..*..*..*..*..*.
*****..*..*..*..*..*..*..*..*..*..*

input2_ids.shape=torch.Size([1, 15]) position2_ids.shape=torch.Size([1, 15]) mask2.shape=torch.Size([1, 1, 10, 15]) pkv.get_seq_length()=5

Arguments:
input_ids.shape=torch.Size([1, 15])
positions=tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14]],
 