# Query

In [None]:
I have converted a tinyllama model to onnx format like this. Your tasks are the following - 
1. Check the code below and point out if there are inaccuracies. I want a static cache and support for dynamic inputs and then move forward to converting the model to onnx
2. Once you have validated the onnx conversion ,I want to get the model's outputs. I have shared my code below to get outputs from the onnx model, but the outputs aren't normal, they are garbled. I am hence sharing my current inference code and a reference code that I want you to follow/improvise upon such that I can now try the token by token generation. 

<<  Code to wrap the model with static cache etc>>
class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def forward(self, input_ids, attention_mask, position_ids):
        # Initialize static cache
        past_key_values = StaticCache(
            config=self.model.config,
            batch_size=batch_size,
            max_cache_len= sequence_length, # 32,  # Match the sequence length
            device=input_ids.device,
            dtype=self.model.dtype
        )
        
        # Generate cache position
        cache_position = torch.arange(sequence_length, device=input_ids.device)
        
        # Forward pass with static cache
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            cache_position=cache_position,
            use_cache=True,
            return_dict=False
        )
        
        return outputs[0]  # Return only logits for simplicity


wrapped_model = ModelWrapper(model).to(device)

# Input tensors
dummy_input_ids = torch.ones(batch_size, sequence_length, dtype=torch.long, device=device)
dummy_attention_mask = torch.ones(batch_size, sequence_length, dtype=torch.long, device=device)
dummy_position_ids = torch.arange(sequence_length, device=device).unsqueeze(0).repeat(batch_size, 1)

# Define dynamic axes for variable sequence lengths
dynamic_axes = {
    'input_ids': {0: 'batch_size', 1: 'sequence_length'},
    'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
    'position_ids': {0: 'batch_size', 1: 'sequence_length'},
}

<<  Code to convert the model to onnx >>
torch.onnx.export(
    wrapped_model,
    (
        dummy_input_ids,
        dummy_attention_mask,
        dummy_position_ids,
    ),
    '/kaggle/working/tiny_llama.onnx',
    input_names=['input_ids', 'attention_mask', 'position_ids'],
    output_names=['logits'],
    dynamic_axes=dynamic_axes,
    opset_version=18, 
    do_constant_folding=False,
    verbose=False 
)

<<My inference code>>
ort_session = ort.InferenceSession("/kaggle/working/tiny_llama.onnx", providers=["CPUExecutionProvider"])

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

def prepare_inputs(text_input, tokenizer, target_length=256, device='cuda'):
    """
    Prepare inputs by padding/truncating to target length
    """
    # Tokenize the input
    encoded = tokenizer.encode(text_input, add_special_tokens=True, return_tensors="pt")
    input_ids = encoded[0]  # Remove batch dimension
    
    # Get the actual sequence length
    seq_len = len(input_ids)
    
    if seq_len > target_length:
        # Truncate if longer than target_length
        input_ids = input_ids[:target_length]
        seq_len = target_length
    
    # Pad input_ids
    padding_length = target_length - seq_len
    padding_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
    input_ids = torch.cat([
        input_ids,
        torch.full((padding_length,), padding_id, dtype=torch.long)
    ])
    
    # Create attention mask (1 for real tokens, 0 for padding)
    attention_mask = torch.cat([
        torch.ones(seq_len, dtype=torch.long),  # Explicitly set dtype to long
        torch.zeros(padding_length, dtype=torch.long)  # Explicitly set dtype to long
    ])
    
    # Create position ids
    # position_ids = torch.arange(target_length)
    position_ids = torch.arange(target_length, dtype=torch.long)  # Explicitly set dtype to long

    
    # Add batch dimension and move to device
    input_ids = input_ids.unsqueeze(0).to(device)
    attention_mask = attention_mask.unsqueeze(0).to(device)
    position_ids = position_ids.unsqueeze(0).to(device)
    
    return input_ids, attention_mask, position_ids


def make_both_inferences(text_input: str):

    input_ids, attention_mask, position_ids = prepare_inputs(
        text_input, 
        tokenizer, 
        target_length=256,
        device=device
    )

    wrapped_model_outputs = to_numpy(wrapped_model(input_ids, attention_mask, position_ids))

    ort_inputs = { 
        ort_session.get_inputs()[0].name: to_numpy(input_ids),
        ort_session.get_inputs()[1].name: to_numpy(attention_mask),
        ort_session.get_inputs()[2].name: to_numpy(position_ids)
    }
    ort_outs = ort_session.run(None, ort_inputs)

    return wrapped_model_outputs, ort_outs


<<  Reference Code that I saw to perform token by token conversion, you have to think along these lines to give me my inference code >>
# use staticCache from huggingface
# ref https://github.com/huggingface/transformers/issues/30670

def decode_one_tokens(model, cur_token, input_pos, attention_mask, cache_position, past_key_values):
    logits = model(
        cur_token,
        attention_mask=attention_mask,
        position_ids=input_pos,
        cache_position=cache_position,
        past_key_values=past_key_values,
        return_dict=False,
        use_cache=True
    )[0]
    new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
    return new_token


def generate(
    prompts: List[str], 
    model: LlamaForCausalLM, 
    tokenizer: LlamaTokenizer, 
    num_tokens_to_generate: int = 40
) ->  List[str]:
    
    global decode_one_tokens
    inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
    batch_size, seq_length = inputs["input_ids"].shape

    attention_mask = inputs["attention_mask"]
    
    with torch.no_grad():
        past_key_values = StaticCache(
            config=model.config, 
            max_batch_size=batch_size, 
            max_cache_len=4096, device=torch_device, dtype=model.dtype
        )
        cache_position = torch.arange(seq_length, device=torch_device)
        generated_ids = torch.zeros(
            batch_size, seq_length + num_tokens_to_generate + 1, dtype=torch.int, device=torch_device
        )
        generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)
        with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
            logits = model(
                **inputs, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
            )[0]
        next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
        generated_ids[:, seq_length] = next_token[:, 0]

        # Not using torch.compile to simplify debugging
        # decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
        cache_position = torch.tensor([seq_length + 1], device=torch_device)
        attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) # update and pass every step
        for _ in range(1, num_tokens_to_generate):
            with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
                next_token = decode_one_tokens(model, next_token.clone(), None, attention_mask, cache_position, past_key_values)
                generated_ids[:, cache_position] = next_token.int()
            cache_position += 1
            position_ids = position_ids[:, -1:] + 1
            attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)

    return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

# Resp

In [None]:
Let's address your tasks step-by-step:

1. Checking the ONNX Conversion Code
Issues in the Model Wrapper:
Static Cache Initialization: The StaticCache class is not defined in your provided code. Ensure it's correctly imported or defined.
Cache Position: The cache_position should be initialized correctly for each step during inference. In the provided wrapper, it seems like cache_position is fixed, which won't work for dynamic generation.
Batch Size and Sequence Length: These should be dynamic, but in the wrapper, they are hardcoded. You need to handle them dynamically.
Revised Model Wrapper:

import torch
from transformers import PreTrainedModel

class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def forward(self, input_ids, attention_mask, position_ids, past_key_values=None, cache_position=None):
        # Forward pass with dynamic cache
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            cache_position=cache_position,
            use_cache=True,
            return_dict=False
        )
        
        return outputs[0], outputs[1]  # Return logits and updated past_key_values

# Assuming StaticCache is defined somewhere
wrapped_model = ModelWrapper(model).to(device)

# Input tensors
dummy_input_ids = torch.ones(batch_size, sequence_length, dtype=torch.long, device=device)
dummy_attention_mask = torch.ones(batch_size, sequence_length, dtype=torch.long, device=device)
dummy_position_ids = torch.arange(sequence_length, device=device).unsqueeze(0).repeat(batch_size, 1)
dummy_past_key_values = None  # Placeholder for past key values
dummy_cache_position = None  # Placeholder for cache position

# Define dynamic axes for variable sequence lengths
dynamic_axes = {
    'input_ids': {0: 'batch_size', 1: 'sequence_length'},
    'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
    'position_ids': {0: 'batch_size', 1: 'sequence_length'},
    'logits': {0: 'batch_size', 1: 'sequence_length'}
}

# Export to ONNX
torch.onnx.export(
    wrapped_model,
    (
        dummy_input_ids,
        dummy_attention_mask,
        dummy_position_ids,
        dummy_past_key_values,
        dummy_cache_position
    ),
    '/kaggle/working/tiny_llama.onnx',
    input_names=['input_ids', 'attention_mask', 'position_ids', 'past_key_values', 'cache_position'],
    output_names=['logits', 'past_key_values'],
    dynamic_axes=dynamic_axes,
    opset_version=18, 
    do_constant_folding=False,
    verbose=False 
)
2. Inference Code for Token-by-Token Generation
Issues in Your Inference Code:
Cache Handling: You need to handle the cache properly for each token generation step.
Dynamic Inputs: Ensure that inputs are handled dynamically for each token.
Revised Inference Code:

import onnxruntime as ort
import torch

ort_session = ort.InferenceSession("/kaggle/working/tiny_llama.onnx", providers=["CPUExecutionProvider"])

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

def prepare_inputs(text_input, tokenizer, target_length=256, device='cuda'):
    """
    Prepare inputs by padding/truncating to target length
    """
    # Tokenize the input
    encoded = tokenizer.encode(text_input, add_special_tokens=True, return_tensors="pt")
    input_ids = encoded[0]  # Remove batch dimension
    
    # Get the actual sequence length
    seq_len = len(input_ids)
    
    if seq_len > target_length:
        # Truncate if longer than target_length
        input_ids = input_ids[:target_length]
        seq_len = target_length
    
    # Pad input_ids
    padding_length = target_length - seq_len
    padding_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
    input_ids = torch.cat([
        input_ids,
        torch.full((padding_length,), padding_id, dtype=torch.long)
    ])
    
    # Create attention mask (1 for real tokens, 0 for padding)
    attention_mask = torch.cat([
        torch.ones(seq_len, dtype=torch.long),  # Explicitly set dtype to long
        torch.zeros(padding_length, dtype=torch.long)  # Explicitly set dtype to long
    ])
    
    # Create position ids
    position_ids = torch.arange(target_length, dtype=torch.long)  # Explicitly set dtype to long

    # Add batch dimension and move to device
    input_ids = input_ids.unsqueeze(0).to(device)
    attention_mask = attention_mask.unsqueeze(0).to(device)
    position_ids = position_ids.unsqueeze(0).to(device)
    
    return input_ids, attention_mask, position_ids

def generate_tokens(prompt: str, tokenizer, num_tokens_to_generate: int = 40, device='cuda'):
    input_ids, attention_mask, position_ids = prepare_inputs(prompt, tokenizer, target_length=256, device=device)
    past_key_values = None
    cache_position = torch.tensor([len(input_ids[0])], device=device)

    for _ in range(num_tokens_to_generate):
        ort_inputs = { 
            ort_session.get_inputs()[0].name: to_numpy(input_ids),
            ort_session.get_inputs()[1].name: to_numpy(attention_mask),
            ort_session.get_inputs()[2].name: to_numpy(position_ids),
            ort_session.get_inputs()[3].name: to_numpy(past_key_values) if past_key_values is not None else None,
            ort_session.get_inputs()[4].name: to_numpy(cache_position)
        }
        ort_outs = ort_session.run(None, ort_inputs)

        logits = torch.from_numpy(ort_outs[0]).to(device)
        past_key_values = ort_outs[1]

        next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
        input_ids = torch.cat([input_ids, next_token], dim=-1)
        attention_mask = torch.cat([attention_mask, torch.ones((1, 1), dtype=torch.long, device=device)], dim=-1)
        position_ids = torch.cat([position_ids, torch.tensor([[cache_position.item()]], device=device)], dim=-1)
        cache_position += 1

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

# Example usage
prompt = "Once upon a time"
generated_text = generate_tokens(prompt, tokenizer, num_tokens_to_generate=40, device='cuda')
print(generated_text)
This revised code should help you generate text token by token using the ONNX model. Make sure to adjust paths and parameters according to your specific setup.

# Resp 2

In [None]:
Certainly! Let's clarify the concepts and ensure that the implementation aligns with your goal of using a static cache for faster inference.

Understanding Static Cache vs. Dynamic Cache
Dynamic Cache:
Purpose: Automatically manages the cache for each forward pass. Suitable for single-pass inference where the cache is not reused across multiple calls.
Usage: Typically used when you don't need to persist the cache between different inference calls.
Static Cache:
Purpose: Persists the cache across multiple forward passes. Useful for generating sequences token by token, where the cache needs to be reused.
Usage: Requires manual management of the cache, ensuring that the cache is updated and reused correctly.
Using Static Cache in Hugging Face Transformers
The StaticCache class you mentioned is part of Hugging Face's transformers library and is designed to manage the cache statically. However, when exporting to ONNX, you need to ensure that the cache is handled correctly in the ONNX model.

Revised Implementation
Let's revise the implementation to ensure that the static cache is used correctly both in the PyTorch model and during ONNX inference.

1. Model Wrapper with Static Cache
First, ensure that the StaticCache class is correctly defined and used. Here’s how you can modify the ModelWrapper to use StaticCache.


import torch
from transformers import PreTrainedModel, LlamaForCausalLM, LlamaTokenizer
from transformers.generation.utils import StaticCache

class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.static_cache = None
        
    def initialize_cache(self, batch_size, max_cache_len, device, dtype):
        self.static_cache = StaticCache(
            config=self.model.config,
            max_batch_size=batch_size,
            max_cache_len=max_cache_len,
            device=device,
            dtype=dtype
        )
        
    def forward(self, input_ids, attention_mask, position_ids, cache_position=None):
        if self.static_cache is None:
            raise ValueError("Static cache has not been initialized.")
        
        # Forward pass with static cache
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=self.static_cache.past_key_values,
            cache_position=cache_position,
            use_cache=True,
            return_dict=False
        )
        
        # Update the static cache with the new past key values
        self.static_cache.update(outputs[1])
        
        return outputs[0]  # Return only logits for simplicity

# Initialize the model and tokenizer
model = LlamaForCausalLM.from_pretrained('path_to_your_model').to(device)
tokenizer = LlamaTokenizer.from_pretrained('path_to_your_model')

# Initialize the model wrapper
wrapped_model = ModelWrapper(model).to(device)

# Initialize the static cache
batch_size = 1  # Adjust as needed
max_cache_len = 256  # Adjust as needed
wrapped_model.initialize_cache(batch_size, max_cache_len, device, model.dtype)

# Dummy inputs for ONNX export
dummy_input_ids = torch.ones(batch_size, 1, dtype=torch.long, device=device)
dummy_attention_mask = torch.ones(batch_size, 1, dtype=torch.long, device=device)
dummy_position_ids = torch.tensor([[0]], dtype=torch.long, device=device)
dummy_cache_position = torch.tensor([0], dtype=torch.long, device=device)

# Define dynamic axes for variable sequence lengths
dynamic_axes = {
    'input_ids': {0: 'batch_size', 1: 'sequence_length'},
    'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
    'position_ids': {0: 'batch_size', 1: 'sequence_length'},
    'logits': {0: 'batch_size', 1: 'sequence_length'}
}

# Export to ONNX
torch.onnx.export(
    wrapped_model,
    (
        dummy_input_ids,
        dummy_attention_mask,
        dummy_position_ids,
        dummy_cache_position
    ),
    '/kaggle/working/tiny_llama.onnx',
    input_names=['input_ids', 'attention_mask', 'position_ids', 'cache_position'],
    output_names=['logits'],
    dynamic_axes=dynamic_axes,
    opset_version=18, 
    do_constant_folding=False,
    verbose=False 
)
2. Inference Code with Static Cache
Now, let's modify the inference code to handle the static cache correctly.


import onnxruntime as ort
import torch

ort_session = ort.InferenceSession("/kaggle/working/tiny_llama.onnx", providers=["CPUExecutionProvider"])

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

def prepare_initial_inputs(text_input, tokenizer, device='cuda'):
    """
    Prepare initial inputs by tokenizing the prompt
    """
    # Tokenize the input
    encoded = tokenizer.encode(text_input, add_special_tokens=True, return_tensors="pt")
    input_ids = encoded[0]  # Remove batch dimension
    
    # Create attention mask (1 for real tokens, 0 for padding)
    attention_mask = torch.ones_like(input_ids, dtype=torch.long)
    
    # Create position ids
    position_ids = torch.arange(len(input_ids), dtype=torch.long).unsqueeze(0)
    
    # Add batch dimension and move to device
    input_ids = input_ids.unsqueeze(0).to(device)
    attention_mask = attention_mask.unsqueeze(0).to(device)
    position_ids = position_ids.to(device)
    
    return input_ids, attention_mask, position_ids

def generate_tokens(prompt: str, tokenizer, num_tokens_to_generate: int = 40, device='cuda'):
    input_ids, attention_mask, position_ids = prepare_initial_inputs(prompt, tokenizer, device=device)
    cache_position = torch.tensor([len(input_ids[0])], dtype=torch.long, device=device)
    
    for _ in range(num_tokens_to_generate):
        ort_inputs = { 
            ort_session.get_inputs()[0].name: to_numpy(input_ids),
            ort_session.get_inputs()[1].name: to_numpy(attention_mask),
            ort_session.get_inputs()[2].name: to_numpy(position_ids),
            ort_session.get_inputs()[3].name: to_numpy(cache_position)
        }
        ort_outs = ort_session.run(None, ort_inputs)

        logits = torch.from_numpy(ort_outs[0]).to(device)
        
        next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
        input_ids = torch.cat([input_ids, next_token], dim=-1)
        attention_mask = torch.cat([attention_mask, torch.ones((1, 1), dtype=torch.long, device=device)], dim=-1)
        position_ids = torch.cat([position_ids, torch.tensor([[cache_position.item()]], device=device)], dim=-1)
        cache_position += 1

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

# Example usage
# prompt = "Once upon a time"
# generated_text = generate_tokens(prompt, tokenizer, num_tokens_to_generate=40, device='cuda')
# print(generated_text)
# Key Points

# Static Cache Initialization:
# The StaticCache is initialized once and reused across multiple inference steps.
# This ensures that the cache is managed efficiently and persists between token generations.
# Dynamic Inputs:
# The input_ids, attention_mask, and position_ids are dynamically updated for each token generation step.
# The cache_position is incremented to reflect the current position in the sequence.
# ONNX Export:
# The ONNX model is exported with the necessary inputs (input_ids, attention_mask, position_ids, cache_position) and outputs (logits).
# The dynamic axes ensure that the model can handle variable sequence lengths.
# By following this approach, you should be able to achieve efficient token-by-token generation using a static cache in both PyTorch and ONNX