In [1]:
!pip install -q transformers onnx onnxruntime-gpu torch


In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import onnx
import onnxruntime as ort
import numpy as np
import time
import os

In [3]:
print("\nLoading GPT-2 model from Hugging Face...")

model_name = "gpt2"  # Small model: 124M parameters
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # Set pad token

model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()

print(f"Model loaded: {model_name}")
print(f"Parameters: ~124M")

# Prepare input
PROMPT = "Explain quantum computing"
SEQ_LEN = 32
BATCH_SIZE = 1

inputs = tokenizer(
    PROMPT,
    return_tensors="pt",
    padding="max_length",
    max_length=SEQ_LEN,
    truncation=True
)
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask

print(f"Input: '{PROMPT}'")
print(f"Input shape: {input_ids.shape}")
print(" Complete\n")


Loading GPT-2 model from Hugging Face...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Model loaded: gpt2
Parameters: ~124M
Input: 'Explain quantum computing'
Input shape: torch.Size([1, 32])
 Complete



In [4]:
print(" Measuring PyTorch inference time...")

# Warmup
for _ in range(5):
    with torch.no_grad():
        _ = model(input_ids=input_ids, attention_mask=attention_mask)

# Benchmark
pytorch_times = []
for _ in range(20):
    start = time.time()
    with torch.no_grad():
        pytorch_output = model(input_ids=input_ids, attention_mask=attention_mask)
    pytorch_times.append(time.time() - start)

pytorch_logits = pytorch_output.logits
pytorch_mean_time = np.mean(pytorch_times) * 1000
pytorch_std_time = np.std(pytorch_times) * 1000

print(f"PyTorch inference time: {pytorch_mean_time:.2f} ms ± {pytorch_std_time:.2f} ms")
print(f"Output shape: {pytorch_logits.shape}")
print("Complete\n")


 Measuring PyTorch inference time...
PyTorch inference time: 354.79 ms ± 115.15 ms
Output shape: torch.Size([1, 32, 50257])
Complete



In [5]:
!pip install -q transformers onnx onnxruntime

In [6]:
import os
import time
import numpy as np
import torch
import onnx
import onnxruntime as ort

from transformers import GPT2Tokenizer, GPT2LMHeadModel


In [7]:
model_name = "gpt2"

tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

model.eval()
model = model.cpu()


In [8]:
class GPT2ONNXWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            use_cache=False,
            return_dict=True
        )
        return outputs.logits


In [9]:
onnx_model = GPT2ONNXWrapper(model)


In [10]:
BATCH_SIZE = 1
SEQ_LEN = 16

dummy_input_ids = torch.ones(
    (BATCH_SIZE, SEQ_LEN),
    dtype=torch.long
)

dummy_attention_mask = torch.ones(
    (BATCH_SIZE, SEQ_LEN),
    dtype=torch.long
)


In [11]:
!pip install -q onnxscript


In [13]:
onnx_path = "gpt2_simple.onnx"

with torch.no_grad():
    torch.onnx.export(
        onnx_model,
        (dummy_input_ids, dummy_attention_mask),
        onnx_path,
        input_names=["input_ids", "attention_mask"],
        output_names=["logits"],
        opset_version=14,
        do_constant_folding=True
    )

print("ONNX export completed")


W0121 10:49:36.376000 1914 torch/onnx/_internal/exporter/_compat.py:114] Setting ONNX exporter to use operator set version 18 because the requested opset_version 14 is a lower version than we have implementations for. Automatic version conversion will be performed, which may not be successful at converting to the requested version. If version conversion is unsuccessful, the opset version of the exported model will be kept at 18. Please consider setting opset_version >=18 to leverage latest ONNX features


[torch.onnx] Obtain model graph for `GPT2ONNXWrapper([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `GPT2ONNXWrapper([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...




[torch.onnx] Translate the graph into ONNX... ✅


Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/__init__.py", line 127, in call
    converted_proto = _c_api_utils.call_onnx_api(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/_c_api_utils.py", line 65, in call_onnx_api
    result = func(proto)
             ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/__init__.py", line 122, in _partial_convert_version
    return onnx.version_converter.convert_version(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnx/version_converter.py", line 39, in convert_version
    converted_model_str = C.convert_version(model_str, target_version)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: /github/workspace/onnx/version_converter/adapters/no_previous_version.h:26: adapt: Assertion `

Applied 14 of general pattern rewrite rules.
ONNX export completed


In [14]:
onnx_model_proto = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model_proto)

model_size_mb = os.path.getsize(onnx_path) / (1024 * 1024)
print(f" ONNX model verified")
print(f" Model size: {model_size_mb:.2f} MB")


 ONNX model verified
 Model size: 1.05 MB


In [16]:
text = "Hello, how are you today?"

tokenizer.pad_token = tokenizer.eos_token

encoded = tokenizer(
    text,
    return_tensors="pt",
    padding="max_length",
    truncation=True,
    max_length=SEQ_LEN
)

input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]


In [17]:
session = ort.InferenceSession(
    onnx_path,
    providers=["CPUExecutionProvider"]
)

onnx_inputs = {
    "input_ids": input_ids.numpy(),
    "attention_mask": attention_mask.numpy()
}

outputs = session.run(None, onnx_inputs)
logits = outputs[0]

print("ONNX inference successful")
print("Logits shape:", logits.shape)


ONNX inference successful
Logits shape: (1, 16, 50257)


In [18]:
# Warm-up
for _ in range(5):
    _ = session.run(None, onnx_inputs)

# Timing
times = []
for _ in range(20):
    start = time.time()
    _ = session.run(None, onnx_inputs)
    times.append(time.time() - start)

mean_time = np.mean(times) * 1000
std_time = np.std(times) * 1000

print(f"Mean inference time: {mean_time:.2f} ms ± {std_time:.2f} ms")


Mean inference time: 77.01 ms ± 2.67 ms


In [32]:
class GPT2StaticKVONNX(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(
        self,
        input_ids,
        attention_mask,
        past_key_0, past_value_0,
        past_key_1, past_value_1,
        past_key_2, past_value_2,
        past_key_3, past_value_3,
        past_key_4, past_value_4,
        past_key_5, past_value_5,
        past_key_6, past_value_6,
        past_key_7, past_value_7,
        past_key_8, past_value_8,
        past_key_9, past_value_9,
        past_key_10, past_value_10,
        past_key_11, past_value_11,
    ):
        past = [
            (past_key_0, past_value_0),
            (past_key_1, past_value_1),
            (past_key_2, past_value_2),
            (past_key_3, past_value_3),
            (past_key_4, past_value_4),
            (past_key_5, past_value_5),
            (past_key_6, past_value_6),
            (past_key_7, past_value_7),
            (past_key_8, past_value_8),
            (past_key_9, past_value_9),
            (past_key_10, past_value_10),
            (past_key_11, past_value_11),
        ]
        past_key_values_tuple = tuple(past)

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values_tuple,
            use_cache=True,
            return_dict=True
        )

        logits = outputs.logits
        new_past = outputs.past_key_values

        return (
            logits,
            new_past[0][0], new_past[0][1],
            new_past[1][0], new_past[1][1],
            new_past[2][0], new_past[2][1],
            new_past[3][0], new_past[3][1],
            new_past[4][0], new_past[4][1],
            new_past[5][0], new_past[5][1],
            new_past[6][0], new_past[6][1],
            new_past[7][0], new_past[7][1],
            new_past[8][0], new_past[8][1],
            new_past[9][0], new_past[9][1],
            new_past[10][0], new_past[10][1],
            new_past[11][0], new_past[11][1],
        )

In [29]:
NUM_LAYERS = model.config.num_hidden_layers
NUM_HEADS = model.config.num_attention_heads
HEAD_DIM = model.config.hidden_size // model.config.num_attention_heads
MAX_SEQ_LEN = SEQ_LEN # Using previously defined SEQ_LEN for the past KV length

def make_past():
    return torch.zeros(
        (BATCH_SIZE, NUM_HEADS, MAX_SEQ_LEN, HEAD_DIM),
        dtype=torch.float32
    )

past = [make_past() for _ in range(NUM_LAYERS * 2)]

In [30]:
input_ids = torch.ones((BATCH_SIZE, 1), dtype=torch.long)
attention_mask = torch.ones((BATCH_SIZE, MAX_SEQ_LEN), dtype=torch.long)


In [33]:
onnx_path = "gpt2_static_kv.onnx"

inputs = [input_ids, attention_mask] + past
input_names = ["input_ids", "attention_mask"] + [
    f"past_key_{i}" if i % 2 == 0 else f"past_value_{i//2}"
    for i in range(len(past))
]

output_names = ["logits"] + [
    f"present_key_{i}" if i % 2 == 0 else f"present_value_{i//2}"
    for i in range(len(past))
]

with torch.no_grad():
    torch.onnx.export(
        GPT2StaticKVONNX(model),
        tuple(inputs),
        onnx_path,
        input_names=input_names,
        output_names=output_names,
        opset_version=18,
        do_constant_folding=True
    )

print("✅ Static KV cache ONNX exported")

[torch.onnx] Obtain model graph for `GPT2StaticKVONNX([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `GPT2StaticKVONNX([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 14 of general pattern rewrite rules.
✅ Static KV cache ONNX exported


In [37]:
import numpy as np

class GPT2StaticKVONNXInference:
    def __init__(self, onnx_path, tokenizer, model_config, seq_len=SEQ_LEN):
        self.tokenizer = tokenizer
        self.model_config = model_config
        self.seq_len = seq_len
        self.session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])

        self.num_layers = model_config.num_hidden_layers
        self.num_heads = model_config.num_attention_heads
        self.head_dim = model_config.hidden_size // model_config.num_attention_heads

    def generate(self, prompt, max_new_tokens=10):
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # Tokenize the prompt without padding to get its actual length
        initial_encoded_prompt = self.tokenizer(prompt, return_tensors="np", truncation=True)
        prompt_ids = initial_encoded_prompt.input_ids # Shape (1, actual_prompt_len)
        actual_prompt_len = prompt_ids.shape[1]

        generated_ids = prompt_ids.tolist()[0]

        # Initialize `current_past_key_values` and `current_attention_mask` dynamically.
        # These represent the *actual* history accumulated so far.
        current_past_key_values = [
            np.zeros((BATCH_SIZE, self.num_heads, 0, self.head_dim), dtype=np.float32)
            for _ in range(self.num_layers * 2)
        ]
        current_attention_mask = np.array([], dtype=np.int64).reshape(BATCH_SIZE, 0) # Initially empty

        # Process prompt tokens one by one to build up the initial KV cache and attention mask
        for token_idx in range(actual_prompt_len):
            input_id_token = prompt_ids[:, token_idx:token_idx+1] # Shape (1, 1)

            # Update the dynamic attention mask by concatenating a '1' for the current token
            current_attention_mask = np.concatenate([current_attention_mask, np.array([[1]])], axis=-1)

            # --- Prepare ONNX inputs, ensuring fixed shapes for the ONNX model inputs ---
            onnx_input_ids = input_id_token # Already (1,1)

            # Create ONNX-compatible attention mask by padding/truncating `current_attention_mask` to `self.seq_len`
            onnx_attention_mask = np.zeros((BATCH_SIZE, self.seq_len), dtype=np.int64)
            current_mask_len = current_attention_mask.shape[1]
            if current_mask_len > self.seq_len:
                # If current mask is longer than SEQ_LEN, take the last SEQ_LEN elements (sliding window)
                onnx_attention_mask[:, :] = current_attention_mask[:, -self.seq_len:]
            else:
                # Pad with zeros at the beginning if shorter
                onnx_attention_mask[:, self.seq_len - current_mask_len:] = current_attention_mask

            # Create ONNX-compatible past_key_values by padding/truncating `current_past_key_values` to `self.seq_len`
            onnx_past_key_values = []
            for j in range(self.num_layers):
                current_k = current_past_key_values[2*j]
                current_v = current_past_key_values[2*j+1]

                onnx_k = np.zeros((BATCH_SIZE, self.num_heads, self.seq_len, self.head_dim), dtype=np.float32)
                onnx_v = np.zeros((BATCH_SIZE, self.num_heads, self.seq_len, self.head_dim), dtype=np.float32)

                current_kv_len = current_k.shape[2]
                if current_kv_len > self.seq_len:
                    onnx_k[:, :, :, :] = current_k[:, :, -self.seq_len:, :]
                    onnx_v[:, :, :, :] = current_v[:, :, -self.seq_len:, :]
                else:
                    onnx_k[:, :, self.seq_len - current_kv_len:, :] = current_k
                    onnx_v[:, :, self.seq_len - current_kv_len:, :] = current_v

                onnx_past_key_values.extend([onnx_k, onnx_v])

            onnx_inputs = {
                "input_ids": onnx_input_ids,
                "attention_mask": onnx_attention_mask
            }
            for j in range(self.num_layers):
                onnx_inputs[f"past_key_{2*j}"] = onnx_past_key_values[2*j]
                onnx_inputs[f"past_value_{j}"] = onnx_past_key_values[2*j + 1]

            outputs = self.session.run(None, onnx_inputs)
            new_past_key_values = outputs[1:]

            # Update `current_past_key_values` with the new output, these dynamically grow
            current_past_key_values = new_past_key_values

        # After processing the prompt, `current_past_key_values` holds the KV cache for the full prompt,
        # and `current_attention_mask` is the mask for the prompt (actual length).

        # Now, generate `max_new_tokens`.
        # `input_ids` for the first generation step is the *last token processed from prompt*.
        input_ids = onnx_input_ids # This is the last input_id_token from the loop

        for i in range(max_new_tokens):
            # --- Prepare ONNX inputs, adhering to fixed shapes ---
            # Update the dynamic attention mask (for the *actual* history, then pad for ONNX input)
            current_attention_mask = np.concatenate([current_attention_mask, np.array([[1]])], axis=-1)

            onnx_input_ids = input_ids # (1,1)

            onnx_attention_mask = np.zeros((BATCH_SIZE, self.seq_len), dtype=np.int64)
            current_mask_len = current_attention_mask.shape[1]
            if current_mask_len > self.seq_len:
                onnx_attention_mask[:, :] = current_attention_mask[:, -self.seq_len:]
            else:
                onnx_attention_mask[:, self.seq_len - current_mask_len:] = current_attention_mask

            onnx_past_key_values = []
            for j in range(self.num_layers):
                current_k = current_past_key_values[2*j]
                current_v = current_past_key_values[2*j+1]

                onnx_k = np.zeros((BATCH_SIZE, self.num_heads, self.seq_len, self.head_dim), dtype=np.float32)
                onnx_v = np.zeros((BATCH_SIZE, self.num_heads, self.seq_len, self.head_dim), dtype=np.float32)

                current_kv_len = current_k.shape[2]
                if current_kv_len > self.seq_len:
                    onnx_k[:, :, :, :] = current_k[:, :, -self.seq_len:, :]
                    onnx_v[:, :, :, :] = current_v[:, :, -self.seq_len:, :]
                else:
                    onnx_k[:, :, self.seq_len - current_kv_len:, :] = current_k
                    onnx_v[:, :, self.seq_len - current_kv_len:, :] = current_v

                onnx_past_key_values.extend([onnx_k, onnx_v])

            onnx_inputs = {
                "input_ids": onnx_input_ids,
                "attention_mask": onnx_attention_mask
            }
            for j in range(self.num_layers):
                onnx_inputs[f"past_key_{2*j}"] = onnx_past_key_values[2*j]
                onnx_inputs[f"past_value_{j}"] = onnx_past_key_values[2*j + 1]

            outputs = self.session.run(None, onnx_inputs)
            logits = outputs[0]
            new_past_key_values = outputs[1:]

            next_token_logits = logits[:, -1, :]
            next_token = np.argmax(next_token_logits, axis=-1)

            generated_ids.append(next_token[0].item())

            if next_token[0].item() == self.tokenizer.eos_token_id:
                break

            input_ids = next_token[:, np.newaxis] # New token for next iteration
            current_past_key_values = new_past_key_values # Accumulate KV cache

        return self.tokenizer.decode(generated_ids, skip_special_tokens=True)

# Instantiate the inference class
print("\nInitializing ONNX Runtime inference...")
onnx_inference_model = GPT2StaticKVONNXInference(
    onnx_path="gpt2_static_kv.onnx",
    tokenizer=tokenizer,
    model_config=model.config,
    seq_len=SEQ_LEN
)

# Generate text
initial_prompt = "Explain quantum computing in a few sentences:"
print(f"Generating text with ONNX model from prompt: '{initial_prompt}'")
start_time_generate = time.time()
onnx_generated_text = onnx_inference_model.generate(
    prompt=initial_prompt,
    max_new_tokens=50
)
end_time_generate = time.time()
print(f"Generated text: {onnx_generated_text}")
print(f"ONNX text generation time: {(end_time_generate - start_time_generate) * 1000:.2f} ms")
print("Complete\n")



Initializing ONNX Runtime inference...
Generating text with ONNX model from prompt: 'Explain quantum computing in a few sentences:'
Generated text: Explain quantum computing in a few sentences: Quantum computing in a few sentences: Quantum computing in a few sentences: Quantum computing in a few sentences sentences: Quantum computing in a few sentences sentences sentences in a quantum computing quantum quantum quantum quantum quantum quantum quantum quantum quantum quantum quantum quantum quantum quantum quantum quantum
ONNX text generation time: 2914.28 ms
Complete



In [41]:
import numpy as np

# Helper function to prepare ONNX inputs with padding/truncation
def prepare_onnx_inputs_for_step(input_id_token, dynamic_attention_mask, dynamic_past_kv, session_seq_len, num_layers, num_heads, head_dim):
    # ONNX input_ids is always (BATCH_SIZE, 1)
    onnx_input_ids = input_id_token

    # ONNX attention_mask is always (BATCH_SIZE, session_seq_len)
    onnx_attention_mask = np.zeros((BATCH_SIZE, session_seq_len), dtype=np.int64)
    current_mask_len = dynamic_attention_mask.shape[1]
    if current_mask_len > session_seq_len:
        onnx_attention_mask[:, :] = dynamic_attention_mask[:, -session_seq_len:]
    else:
        # Pad with zeros at the beginning
        onnx_attention_mask[:, session_seq_len - current_mask_len:] = dynamic_attention_mask

    # ONNX past_key_values are always (BATCH_SIZE, NUM_HEADS, session_seq_len, HEAD_DIM)
    onnx_past_key_values = []
    for j in range(num_layers):
        current_k = dynamic_past_kv[2*j]
        current_v = dynamic_past_kv[2*j+1]

        onnx_k = np.zeros((BATCH_SIZE, num_heads, session_seq_len, head_dim), dtype=np.float32)
        onnx_v = np.zeros((BATCH_SIZE, num_heads, session_seq_len, head_dim), dtype=np.float32)

        current_kv_len = current_k.shape[2]
        if current_kv_len > session_seq_len:
            onnx_k[:, :, :, :] = current_k[:, :, -session_seq_len:, :]
            onnx_v[:, :, :, :] = current_v[:, :, -session_seq_len:, :]
        else:
            # Pad with zeros at the beginning
            onnx_k[:, :, session_seq_len - current_kv_len:, :] = current_k
            onnx_v[:, :, session_seq_len - current_kv_len:, :] = current_v

        onnx_past_key_values.extend([onnx_k, onnx_v])

    inputs_dict = {
        "input_ids": onnx_input_ids,
        "attention_mask": onnx_attention_mask
    }
    for j in range(num_layers):
        inputs_dict[f"past_key_{2*j}"] = onnx_past_key_values[2*j]
        inputs_dict[f"past_value_{j}"] = onnx_past_key_values[2*j + 1]
    return inputs_dict


print("\nBenchmarking ONNX Runtime static KV cache inference...")

initial_prompt_benchmark = "The quick brown fox"
encoded_input_benchmark = tokenizer(
    initial_prompt_benchmark,
    return_tensors="np",
    padding="max_length", # We'll handle this dynamically for ONNX inputs
    truncation=True,
    max_length=SEQ_LEN # Use SEQ_LEN for max prompt length
)

# Initialize dynamic states for the benchmarking loop, similar to `generate` method
# These states will dynamically grow (current_attention_mask, current_past_kv)
# The `input_id_token` for ONNX will always be (1,1)
current_input_ids_dynamic = None # Will hold the single input token for each step
current_attention_mask_dynamic = np.array([], dtype=np.int64).reshape(BATCH_SIZE, 0)
current_past_kv_dynamic = [
    np.zeros((BATCH_SIZE, NUM_HEADS, 0, HEAD_DIM), dtype=np.float32)
    for _ in range(NUM_LAYERS * 2)
]


# --- Process prompt tokens first to build initial KV cache ---
prompt_ids = encoded_input_benchmark.input_ids
prompt_len = prompt_ids.shape[1]

for token_idx in range(prompt_len):
    input_id_token = prompt_ids[:, token_idx:token_idx+1]

    current_attention_mask_dynamic = np.concatenate([current_attention_mask_dynamic, np.array([[1]])], axis=-1)

    temp_onnx_inputs = prepare_onnx_inputs_for_step(
        input_id_token,
        current_attention_mask_dynamic,
        current_past_kv_dynamic,
        SEQ_LEN, NUM_LAYERS, NUM_HEADS, HEAD_DIM
    )

    outputs = onnx_inference_model.session.run(None, temp_onnx_inputs)
    new_past_kv = outputs[1:]

    current_input_ids_dynamic = input_id_token # Keep track of the last token processed
    current_past_kv_dynamic = new_past_kv # Update dynamic KV cache


# After processing prompt, `current_input_ids_dynamic`, `current_attention_mask_dynamic`, `current_past_kv_dynamic`
# hold the state *after* the prompt. `current_input_ids_dynamic` is the last token of the prompt.


# --- Warm-up for new token generation ---
for _ in range(5):
    temp_onnx_inputs = prepare_onnx_inputs_for_step(
        current_input_ids_dynamic,
        current_attention_mask_dynamic,
        current_past_kv_dynamic,
        SEQ_LEN, NUM_LAYERS, NUM_HEADS, HEAD_DIM
    )

    outputs_warmup = onnx_inference_model.session.run(None, temp_onnx_inputs)
    next_token_logits_warmup = outputs_warmup[0][:, -1, :]
    next_token_warmup = np.argmax(next_token_logits_warmup, axis=-1)

    current_input_ids_dynamic = next_token_warmup[:, np.newaxis]
    current_attention_mask_dynamic = np.concatenate([current_attention_mask_dynamic, np.array([[1]])], axis=-1)
    current_past_kv_dynamic = outputs_warmup[1:]


# --- Reset for actual timing (re-initialize dynamic states and re-process prompt) ---
current_input_ids_dynamic = None
current_attention_mask_dynamic = np.array([], dtype=np.int64).reshape(BATCH_SIZE, 0)
current_past_kv_dynamic = [
    np.zeros((BATCH_SIZE, NUM_HEADS, 0, HEAD_DIM), dtype=np.float32)
    for _ in range(NUM_LAYERS * 2)
]

for token_idx in range(prompt_len):
    input_id_token = prompt_ids[:, token_idx:token_idx+1]

    current_attention_mask_dynamic = np.concatenate([current_attention_mask_dynamic, np.array([[1]])], axis=-1)

    temp_onnx_inputs = prepare_onnx_inputs_for_step(
        input_id_token,
        current_attention_mask_dynamic,
        current_past_kv_dynamic,
        SEQ_LEN, NUM_LAYERS, NUM_HEADS, HEAD_DIM
    )

    outputs_reset = onnx_inference_model.session.run(None, temp_onnx_inputs)
    new_past_kv = outputs_reset[1:]

    current_input_ids_dynamic = input_id_token
    current_past_kv_dynamic = new_past_kv


# --- Actual timing loop for single token generation ---
onnx_kv_times = []
for _ in range(20):
    start = time.time()

    temp_onnx_inputs = prepare_onnx_inputs_for_step(
        current_input_ids_dynamic,
        current_attention_mask_dynamic,
        current_past_kv_dynamic,
        SEQ_LEN, NUM_LAYERS, NUM_HEADS, HEAD_DIM
    )

    outputs_timed = onnx_inference_model.session.run(None, temp_onnx_inputs)
    end = time.time()
    onnx_kv_times.append(end - start)

    next_token_logits_timed = outputs_timed[0][:, -1, :]
    next_token_timed = np.argmax(next_token_logits_timed, axis=-1)

    current_input_ids_dynamic = next_token_timed[:, np.newaxis]
    current_attention_mask_dynamic = np.concatenate([current_attention_mask_dynamic, np.array([[1]])], axis=-1)
    current_past_kv_dynamic = outputs_timed[1:]

onnx_kv_mean_time = np.mean(onnx_kv_times) * 1000
onnx_kv_std_time = np.std(onnx_kv_times) * 1000

print(f"ONNX static KV cache inference time (per token): {onnx_kv_mean_time:.2f} ms ± {onnx_kv_std_time:.2f} ms")
print("Complete\n")

print("\n--- Performance Comparison ---")
print(f"PyTorch (full sequence): {pytorch_mean_time:.2f} ms ± {pytorch_std_time:.2f} ms")
print(f"ONNX (simple inference): {mean_time:.2f} ms ± {std_time:.2f} ms")
print(f"ONNX (static KV cache, per token): {onnx_kv_mean_time:.2f} ms ± {onnx_kv_std_time:.2f} ms")



Benchmarking ONNX Runtime static KV cache inference...
ONNX static KV cache inference time (per token): 43.72 ms ± 1.36 ms
Complete


--- Performance Comparison ---
PyTorch (full sequence): 354.79 ms ± 115.15 ms
ONNX (simple inference): 77.01 ms ± 2.67 ms
ONNX (static KV cache, per token): 43.72 ms ± 1.36 ms
