# CONVERT TINYLLAMA MODEL TO ONNX WITH TOKEN-WISE DECODING
Kaggle Notebook [Link](https://www.kaggle.com/code/sahil112/tensorrt-v3/notebook?scriptVersionId=207733517). Please check Version 2.

In [54]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# for dirname, _, filenames in os.walk('/kaggle/input/'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [55]:
# Installed to a local directory and now I keep adding the notebook to this work

!pip install tensorrt onnxruntime # --quiet
# !pip install --upgrade tensorrt onnxruntime --target=/kaggle/working/mysitepackages

  pid, fd = os.forkpty()




In [56]:
# imports
import torch
import transformers
# from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.onnx import OnnxConfig

from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import onnx

In [57]:
import onnxruntime as ort
import tensorrt as trt

# Covert to onnx

In [58]:
# Important to do this step in Kaggle only while creating :/
! touch tiny_llama.onnx

In [59]:
# Load model and tokenizer
device='cpu' # 'cuda'

model_path = "/kaggle/input/m/mambagetout/tinyllama/pytorch/default/1/"
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token # added because of generate thing

model = AutoModelForCausalLM.from_pretrained(model_path)
model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): 

In [60]:
from transformers import StaticCache

In [61]:
# Create dummy inputs
batch_size = 1
sequence_length = 256  # Reduced for testing

# Get model configuration
num_layers = model.config.num_hidden_layers
num_attention_heads = model.config.num_attention_heads
hidden_size = model.config.hidden_size
head_dim = hidden_size // num_attention_heads

In [62]:
class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def forward(self, input_ids, attention_mask, position_ids):
        # Initialize static cache
        past_key_values = StaticCache(
            config=self.model.config,
            batch_size=batch_size,
            max_cache_len= sequence_length, # 32,  # Match the sequence length
            device=input_ids.device,
            dtype=self.model.dtype
        )
        
        # Generate cache position
        cache_position = torch.arange(sequence_length, device=input_ids.device)
        
        # Forward pass with static cache
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            cache_position=cache_position, # Can keep None also? TRYING FOR NOW
            use_cache=True,
            return_dict=False
        )
        
        return outputs[0]  # Return only logits for simplicity

# Wrap the model, keeping on gpu is essestial else 
# error two devices found cuda and cuda:0 error
# wrapped_model = ModelWrapper(model).cuda()
wrapped_model = ModelWrapper(model).to(device)

# Input tensors
dummy_input_ids = torch.ones(batch_size, sequence_length, dtype=torch.long, device=device)
dummy_attention_mask = torch.ones(batch_size, sequence_length, dtype=torch.long, device=device)
dummy_position_ids = torch.arange(sequence_length, device=device).unsqueeze(0).repeat(batch_size, 1)

# Define dynamic axes for variable sequence lengths
dynamic_axes = {
    'input_ids': {0: 'batch_size', 1: 'sequence_length'},
    'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
    'position_ids': {0: 'batch_size', 1: 'sequence_length'},
}

In [63]:
# # The previous model can be exported with dynamic shapes
# export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
# onnx_program = torch.onnx.dynamo_export(
#     model, 
#     *args, 
#     **kwargs, 
#     export_options=export_options
# )
# onnx_program.save("/kaggle/working/tiny_llama.onnx")

In [64]:
# Export the model

# if using cpu
torch.onnx.export(
    wrapped_model,
    (
        dummy_input_ids,
        dummy_attention_mask,
        dummy_position_ids,
    ),
    '/kaggle/working/tiny_llama.onnx',
    input_names=['input_ids', 'attention_mask', 'position_ids'],
    output_names=['logits'],
    dynamic_axes=dynamic_axes,
    opset_version=18, # DO NOT, I REPEAT DO NOT USE OPSET_VERSION 15. FULL PAIN IT ISSSS.
    do_constant_folding=False, # True, # CHECKING ANOTHER OPTION because later error
    verbose=False 
)

# if using gpu
# with torch.no_grad():
#     torch.onnx.export(
#         wrapped_model,
#         (
#             dummy_input_ids,
#             dummy_attention_mask,
#             dummy_position_ids,
#         ),
#         '/kaggle/working/tiny_llama.onnx',
#         input_names=['input_ids', 'attention_mask', 'position_ids'],
#         output_names=['logits'],
#         dynamic_axes=dynamic_axes,
#         opset_version=15,
#         do_constant_folding=True,
#         verbose=False
#     )

  if sequence_length != 1:


## **load the model**

In [65]:
for dirname, _, filenames in os.walk('/kaggle/working'):
    for filename in filenames:
        if '.onnx' in filename:
            print(os.path.join(dirname, filename))

/kaggle/working/tiny_llama.onnx


In [66]:
import onnx

onnx_model = onnx.load("/kaggle/working/tiny_llama.onnx")
# onnx_model = onnx.load("/kaggle/working/tiny_llama_basic.onnx")
# onnx_model = onnx.load("/kaggle/input/tensorrt-v3/tiny_llama.onnx")

In [67]:
import sys
sys.getsizeof(onnx_model)

96

In [68]:
# onnx.checker.check_model(onnx_model)
# Gives ValidationError: The model does not have an ir_version set properly.
# Fine only, mostly a bug

# Also might give
# ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB: 4402590179

## Run Inference using the ort InferenceSession

In [69]:
ort_session = ort.InferenceSession("/kaggle/working/tiny_llama.onnx", providers=["CPUExecutionProvider"])
# ort_session = ort.InferenceSession("/kaggle/input/tensorrt-v3/tiny_llama.onnx", providers=["CPUExecutionProvider"])

In [70]:
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

ort_inputs = {
    ort_session.get_inputs()[0].name: to_numpy(dummy_input_ids),
    ort_session.get_inputs()[1].name: to_numpy(dummy_attention_mask),
    ort_session.get_inputs()[2].name: to_numpy(dummy_position_ids)
}
ort_outs = ort_session.run(None, ort_inputs)

In [71]:
type(ort_outs), len(ort_outs), ort_outs[0].shape

(list, 1, (1, 256, 32000))

In [72]:
# check if outputs are close
np.testing.assert_allclose(
    to_numpy(wrapped_model(dummy_input_ids, dummy_attention_mask, dummy_position_ids)), 
    ort_outs[0], 
    rtol=1e-03, 
    atol=1e-04
)
# throws assertion error when atol is 1e-05
# So I think this is fair enough

## this works, so now work with text inputs

### Check the inputs

In [73]:
print(f'dummy_input_ids = {dummy_input_ids} \n dummy_attention_mask = {dummy_attention_mask} \n dummy_position_ids = {dummy_position_ids}')

dummy_input_ids = tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) 
 dummy_attention_mask = tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       

In [74]:
def prepare_inputs(text_input, tokenizer, target_length=256, device='cuda'):
    """
    Prepare inputs by padding/truncating to target length
    """
    # Tokenize the input
    encoded = tokenizer.encode(text_input, add_special_tokens=True, return_tensors="pt")
    input_ids = encoded[0]  # Remove batch dimension
    
    # Get the actual sequence length
    seq_len = len(input_ids)
    
    if seq_len > target_length:
        # Truncate if longer than target_length
        input_ids = input_ids[:target_length]
        seq_len = target_length
    
    # Pad input_ids
    padding_length = target_length - seq_len
    padding_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
    input_ids = torch.cat([
        input_ids,
        torch.full((padding_length,), padding_id, dtype=torch.long)
    ])
    
    # Create attention mask (1 for real tokens, 0 for padding)
    attention_mask = torch.cat([
        torch.ones(seq_len, dtype=torch.long),  # Explicitly set dtype to long
        torch.zeros(padding_length, dtype=torch.long)  # Explicitly set dtype to long
    ])
    
    # Create position ids
    # position_ids = torch.arange(target_length)
    position_ids = torch.arange(target_length, dtype=torch.long)  # Explicitly set dtype to long

    
    # Add batch dimension and move to device
    input_ids = input_ids.unsqueeze(0).to(device)
    attention_mask = attention_mask.unsqueeze(0).to(device)
    position_ids = position_ids.unsqueeze(0).to(device)
    
    return input_ids, attention_mask, position_ids

In [75]:
def make_both_inferences(text_input: str):

    input_ids, attention_mask, position_ids = prepare_inputs(
        text_input, 
        tokenizer, 
        target_length=256,
        device=device
    )

    wrapped_model_outputs = to_numpy(wrapped_model(input_ids, attention_mask, position_ids))

    ort_inputs = { 
        ort_session.get_inputs()[0].name: to_numpy(input_ids),
        ort_session.get_inputs()[1].name: to_numpy(attention_mask),
        ort_session.get_inputs()[2].name: to_numpy(position_ids)
    }
    ort_outs = ort_session.run(None, ort_inputs)

    return wrapped_model_outputs, ort_outs

In [76]:
wrapped_model_outputs, ort_outs = make_both_inferences(text_input ='Describe Albert Einstein')

In [77]:
wrapped_model_outputs, ort_outs

(array([[[-7.396678  , -7.3966794 ,  8.475743  , ..., -4.358731  ,
          -7.0164227 , -0.75330424],
         [-6.6988673 , -6.6988735 ,  8.643362  , ..., -4.265688  ,
          -5.4994435 , -0.25938272],
         [-7.3818617 , -7.3818617 , 10.710121  , ..., -4.5778985 ,
          -5.3154926 , -1.0604138 ],
         ...,
         [-4.5339212 , -4.5339227 , 12.72688   , ..., -2.754523  ,
          -3.4932256 ,  0.8564618 ],
         [-4.582982  , -4.582982  , 12.599232  , ..., -2.5835834 ,
          -3.1678545 ,  0.604913  ],
         [-3.1293983 , -3.1294005 , 12.985866  , ..., -1.3228996 ,
          -1.9311588 ,  1.525508  ]]], dtype=float32),
 [array([[[-7.3966737 , -7.3966737 ,  8.475735  , ..., -4.3587313 ,
           -7.01641   , -0.7532996 ],
          [-6.698857  , -6.698863  ,  8.643356  , ..., -4.2656827 ,
           -5.49943   , -0.2593714 ],
          [-7.381863  , -7.3818655 , 10.7101145 , ..., -4.5778995 ,
           -5.3154883 , -1.060413  ],
          ...,
          [

In [78]:
# NOTE: ALL ARE STILL SINGLE BATCH SIZE, IT DOESNT SUPPORT LARGER BATCH SIZE YET
# model_op_new = Out[80][0]
# model_op_new.shape

## The token ids converted back to text aren't beautiful though. 

In [79]:
 # Convert logits to token IDs (if necessary)
wrapped_token_ids = torch.argmax(torch.from_numpy(wrapped_model_outputs), dim=-1)
ort_token_ids = torch.argmax(torch.from_numpy(ort_outs[0]), dim=-1)  # Assuming first output is relevant

# Decode token IDs back to text
wrapped_text_output = tokenizer.decode(wrapped_token_ids[0], skip_special_tokens=True)
ort_text_output = tokenizer.decode(ort_token_ids[0], skip_special_tokens=True)
wrapped_text_output, ort_text_output

('amedamed The And Be.ededed And And Andeded\n A A A A\n\n\n\n And And\n\n\n M And And And And And\n\n And And And And And And And\n\n\n And And A A\n\n\n\n And Andamed\n\n\n\n\n And And And And And And And And And And And And And And\n\n The A A\n\n\n The M And And And\n\n\n\n\n\n\n\n And And And\n\n The The M A A A\n\n And And And And And And And And And And And And And And And And\n\n\n And And And And\n\n The The A A A\n\n\n\n\n\n And M M A Aed\n And A And Andamed\n The M M A A A\n And And And And And And And And And And And\n\n The A A Aed  M M And And\n\n\n\n\n\n\n\n And And And And And And And And And And And And And And And And And And And And And And And And And And And And And\n\n\n\n And And And And\n\n The   A A A M ededed And And And And And\n\n And And',
 'amedamed The And Be.ededed And And Andeded\n A A A A\n\n\n\n And And\n\n\n M And And And And And\n\n And And And And And And And\n\n\n And And A A\n\n\n\n And Andamed\n\n\n\n\n And And And And And And And And And And An

---

---

# PART 2: Decode token by token
ref https://github.com/huggingface/transformers/issues/30670

In [80]:
from transformers import StaticCache

In [81]:
torch_device = 'cpu'

In [82]:
# use staticCache from huggingface
# ref https://github.com/huggingface/transformers/issues/30670

def decode_one_tokens(model, cur_token, input_pos, attention_mask, cache_position, past_key_values):
    logits = model(
        cur_token,
        attention_mask=attention_mask,
        position_ids=input_pos,
        cache_position=cache_position,
        past_key_values=past_key_values,
        return_dict=False,
        use_cache=True
    )[0]
    new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
    return new_token


def generate(
    prompts: List[str], 
    model, 
    tokenizer, 
    num_tokens_to_generate: int = 20 # changed from 40
) ->  List[str]:
    
    global decode_one_tokens
    inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
    batch_size, seq_length = inputs["input_ids"].shape

    attention_mask = inputs["attention_mask"]
    
    with torch.no_grad():
        past_key_values = StaticCache(
            config=model.config, 
            max_batch_size=batch_size, 
            max_cache_len=256,  # changed from 4096
            device=torch_device, 
            dtype=model.dtype
        )
        
        cache_position = torch.arange(seq_length, device=torch_device)
        generated_ids = torch.zeros(
            batch_size, seq_length + num_tokens_to_generate + 1, dtype=torch.int, device=torch_device
        )

        print(f'generated pask key values ...')
        
        generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)
        print(f'generated generated_ids till cache position ...')
        
        with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # apparently this is outdated now
        # with torch.nn.attention.sdpa_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
            logits = model(
                **inputs, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
            )[0]

        print(f'got logits from model ...')
        
        next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
        generated_ids[:, seq_length] = next_token[:, 0]

        # Not using torch.compile to simplify debugging
        # decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
        cache_position = torch.tensor([seq_length + 1], device=torch_device)
        attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) # update and pass every step

        print(f'got cache position and attention mask ...')

        for _ in range(1, num_tokens_to_generate):
            with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
                next_token = decode_one_tokens(
                    model = model, 
                    cur_token = next_token.clone(), 
                    input_pos = None, 
                    attention_mask = attention_mask, 
                    cache_position = cache_position, 
                    past_key_values = past_key_values
                )
                
                print(f'generated {_} next token in loop')
                generated_ids[:, cache_position] = next_token.int()
                
            cache_position += 1
            # position_ids = position_ids[:, -1:] + 1 # variable referenced before assignment
            attention_mask = torch.cat(
                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], 
                dim=-1
            )
            print(f'\t cache posn, posn id, attn mask etc has been updated in loop ')

    return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

In [83]:
generate(
    prompts = ["Describe Albert Einstein ", "How does the sun burn? "], 
    model = model , # wrapped_model,
    tokenizer = tokenizer
)

generated pask key values ...
generated generated_ids till cache position ...
got logits from model ...
got cache position and attention mask ...
generated 1 next token in loop
	 cache posn, posn id, attn mask etc has been updated in loop 
generated 2 next token in loop
	 cache posn, posn id, attn mask etc has been updated in loop 
generated 3 next token in loop
	 cache posn, posn id, attn mask etc has been updated in loop 
generated 4 next token in loop
	 cache posn, posn id, attn mask etc has been updated in loop 
generated 5 next token in loop
	 cache posn, posn id, attn mask etc has been updated in loop 
generated 6 next token in loop
	 cache posn, posn id, attn mask etc has been updated in loop 
generated 7 next token in loop
	 cache posn, posn id, attn mask etc has been updated in loop 
generated 8 next token in loop
	 cache posn, posn id, attn mask etc has been updated in loop 
generated 9 next token in loop
	 cache posn, posn id, attn mask etc has been updated in loop 
generate

['Describe Albert Einstein  2019.\nThe 2019 edition of the annual event will be',
 "How does the sun burn? 2019\nThe 2019 edition of the World Economic Forum's"]

---

# PART 3
Try with the generate method directly

In [84]:
from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # To prevent long warnings :)

In [85]:
# Load model and tokenizer
device='cpu' # 'cuda'

model_path = "/kaggle/input/m/mambagetout/tinyllama/pytorch/default/1/"
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token # added because of generate thing

# tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
# model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")

model = AutoModelForCausalLM.from_pretrained(model_path).to(device)
model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): 

In [86]:
import torch._dynamo
torch._dynamo.config.suppress_errors = True

In [87]:
# model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
# If not done torch.compile, then we don't get the error

input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to(device)

prompt_length = input_ids.input_ids.shape[1]
model.generation_config.max_new_tokens = 16

past_key_values = StaticCache(
    config=model.config,
    batch_size=1,
    # If you plan to reuse the cache, make sure the cache length is large enough for all cases
    max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
    device=model.device,
    dtype=model.dtype
)

print(f'created the past key values ... ')
outputs = model.generate(
    **input_ids, 
    # past_key_values=past_key_values
)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
# ['The theory of special relativity states 1. The speed of light is constant in all inertial reference frames. 2']

created the past key values ... 
['The theory of special relativity states 2.0.0.0.0.0.0.0.']


**Output with and without past key values is not the same!**