# Generate ONNX

Download the gen_qnn_ctx_onnx_model.py script from the onnxruntime repo

## Extract the model data from the model.bin file using the qnn sdks utility function

### Setting environment variables names

In [1]:
import os

QNN_SDK_DIR = r"C:\Qualcomm\AIStack\QAIRT\2.40.0.251030"
os.environ['QNN_ROOT']=f"{QNN_SDK_DIR}"
os.environ['PYTHONPATH']= f"{QNN_SDK_DIR}\\lib\\python"
os.environ['PATH']=os.environ['PATH']+f"{QNN_SDK_DIR}\\bin\\aarch64-windows-msvc"
os.environ['TMPDIR']=r"C:\Users\HCKTest\AppData\Local\Temp"

### Run the qnn-context-binary-utility.exe script to get the model JSONs

In [2]:
import glob

bin_files = glob.glob(os.path.join(os.getcwd(),'*serialized.bin'))
json_files = [s.replace("bin", "json") for s in bin_files]

In [3]:
bin_files

['C:\\chilukam\\Gemma\\IOT\\artifacts_from_code\\artifacts_testing\\veg.serialized.bin',
 'C:\\chilukam\\Gemma\\IOT\\artifacts_from_code\\artifacts_testing\\weight_sharing_model_1_of_4.serialized.bin',
 'C:\\chilukam\\Gemma\\IOT\\artifacts_from_code\\artifacts_testing\\weight_sharing_model_2_of_4.serialized.bin',
 'C:\\chilukam\\Gemma\\IOT\\artifacts_from_code\\artifacts_testing\\weight_sharing_model_3_of_4.serialized.bin',
 'C:\\chilukam\\Gemma\\IOT\\artifacts_from_code\\artifacts_testing\\weight_sharing_model_4_of_4.serialized.bin']

In [4]:
generator_cmd = ''
for bin_file, json_file in zip(bin_files, json_files):
    generator_cmd += "& ${QNN_SDK_ROOT}\\bin\\x86_64-windows-msvc\\qnn-context-binary-utility.exe"+f" --context_binary {bin_file} --json_file {json_file}\n"

In [5]:
import subprocess

envsetup = os.path.join(QNN_SDK_DIR, "bin", "envsetup.ps1")
command = [f"& {envsetup}", generator_cmd]
powershell_script = '\n'.join(command)

try:
    subprocess.run(["powershell.exe", "-Command", powershell_script], check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as e:
    print("Error Output:\n", e.stderr)


## Generate ONNX wrapper model

In [6]:
python_env_path = os.path.join(os.getcwd(),"Qairt_Env\\Scripts\\Activate.ps1")
gen_qnn_ctx_onnx_cmd = (
    'Get-ChildItem -Filter "weight_sharing_model*.bin" | ForEach-Object { '
    '$binFile = $_.Name; '
    '$jsonFile = \"$($binFile -replace \'.bin$\', \'.json\')\"; '
    'python gen_qnn_ctx_onnx_model.py -b $binFile -q $jsonFile --quantized_IO --disable_embed_mode '
    '}'
)
command = f"& {python_env_path}; {gen_qnn_ctx_onnx_cmd}"

result = subprocess.run(
    ["powershell.exe", "-Command", command],
    capture_output=True,
    text=True
)


In [7]:
python_env_path = os.path.join(os.getcwd(),"Qairt_Env\\Scripts\\Activate.ps1")
gen_qnn_ctx_onnx_cmd = (
    'Get-ChildItem -Filter "veg*.bin" | ForEach-Object { '
    '$binFile = $_.Name; '
    '$jsonFile = \"$($binFile -replace \'.bin$\', \'.json\')\"; '
    'python gen_qnn_ctx_onnx_model.py -b $binFile -q $jsonFile --disable_embed_mode '
    '}'
)
command = f"& {python_env_path}; {gen_qnn_ctx_onnx_cmd}"

result = subprocess.run(
    ["powershell.exe", "-Command", command],
    capture_output=True,
    text=True
)


In [8]:
import onnx

In [9]:
all_common = ['swa_position_ids_cos', 'swa_position_ids_sin', 'swa_attention_mask', 'position_ids_cos', 'position_ids_sin', 'attention_mask']

In [10]:
def combine_mods(m1, m2):
    input_output = [x.name for x in m1.graph.output][-1]
    common = [input_output, *all_common]
    m1.graph.node.extend([x for x in m2.graph.node if x.name not in common])
    m1_output_ind, m1_output_vi = next(((i, x) for i, x in enumerate(m1.graph.output) if x.name == input_output))
    m1.graph.input.extend([x for x in m2.graph.input if x.name not in common])
    m1.graph.output.extend([x for x in m2.graph.output if x.name not in common])
    return m1

In [11]:
for ar, cl in zip([1, 128], [8192, 8192]):
    m1 = onnx.load(f"ar{ar}_cl{cl}_1_of_4_qnn_ctx.onnx")
    m2 = onnx.load(f"ar{ar}_cl{cl}_2_of_4_qnn_ctx.onnx")
    m3 = onnx.load(f"ar{ar}_cl{cl}_3_of_4_qnn_ctx.onnx")
    m4 = onnx.load(f"ar{ar}_cl{cl}_4_of_4_qnn_ctx.onnx")
    for m in [m2, m3, m4]:
        m1 = combine_mods(m1, m)
    onnx.save(m1, f"ar{ar}_cl{cl}_all_of_4_qnn_ctx.onnx")

# Generate Position Processor

In [12]:
import os
import torch
import torch.nn as nn
import copy

In [13]:
from transformers.models.gemma3.modeling_gemma3 import (Gemma3TextModel as Model, Gemma3TextConfig as Config, Gemma3RotaryEmbedding as RotaryEmbedding)

In [14]:
batch_size = 1
seq_len = 128
context_length = 8192
config_file_path = os.path.join(os.getcwd(), 'config.json')
config = Config.from_pretrained(config_file_path)

In [15]:
def Qualcomm_prepare_4d_causal_sliding_window_attention_mask_with_cache_position(
        attention_mask: torch.Tensor,
        sequence_length: int,
        target_length: int,
        dtype: torch.dtype,
        cache_position: torch.Tensor,
        batch_size: int,
        sliding_window: int,
    )->torch.Tensor:
    min_dtype = -50
    causal_mask = torch.full(
        (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
    )
    if sequence_length != 1:
        causal_mask = torch.triu(causal_mask, diagonal=target_length-sequence_length)
    causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
    causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
    if attention_mask is not None:
        causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
        mask_length = attention_mask.shape[-1]
        padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
            causal_mask.device
        )
        padding_mask = padding_mask == 0
        causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
            padding_mask, min_dtype
        )
    swa_causal_mask = causal_mask[:, :, :, -sliding_window:]
    return causal_mask, swa_causal_mask

In [16]:
setattr(Model, "_prepare_4d_causal_sliding_window_attention_mask_with_cache_position", Qualcomm_prepare_4d_causal_sliding_window_attention_mask_with_cache_position)

class PositionProcessor(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.rotary_emb = RotaryEmbedding(config)
        config = copy.deepcopy(config)
        self.sliding_window = config.sliding_window
        config.rope_theta = config.rope_local_base_freq
        config.rope_scaling = {"rope_type": "default"}
        self.rotary_emb_local = RotaryEmbedding(config)

    def forward(self, attention_mask: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # head_dim = self.config.hidden_size//self.config.num_attention_heads
        head_dim = self.config.head_dim
        batch_size, seq_len = position_ids.shape
        context_length = attention_mask.shape[1]

        cache_position = torch.arange(context_length-seq_len, context_length, device='cpu')
        
        causal_mask_global, causal_mask_local = Model._prepare_4d_causal_sliding_window_attention_mask_with_cache_position(
            attention_mask,
            sequence_length=seq_len,
            target_length=context_length,
            dtype=torch.float32,
            cache_position=cache_position,
            batch_size=batch_size,
            sliding_window = self.sliding_window
        )

        # Rotary embeddings Global
        dummy_tensor = torch.zeros((batch_size, seq_len, self.config.hidden_size), device=position_ids.device)
        cos, sin = self.rotary_emb(dummy_tensor, position_ids)

        cos = cos[:1, :seq_len, :head_dim//2].unsqueeze(0)
        sin = sin[:1, :seq_len, :head_dim//2].unsqueeze(0)

        cos_local, sin_local = self.rotary_emb_local(dummy_tensor, position_ids)

        cos_local = cos_local[:1, :seq_len, :head_dim//2].unsqueeze(0)
        sin_local = sin_local[:1, :seq_len, :head_dim//2].unsqueeze(0)

        return causal_mask_global, causal_mask_local, cos, sin, cos_local, sin_local

In [17]:
config = Config.from_pretrained(config_file_path)

In [19]:
# Load config and model

model = PositionProcessor(config)
model.eval()

# Dummy inputs
attention_mask = torch.ones((batch_size, context_length), dtype=torch.int32)
position_ids = torch.arange(seq_len, dtype=torch.int32).unsqueeze(0)

outputs = model(attention_mask, position_ids)

In [20]:
# Export to ONNX
torch.onnx.export(
    model,
    (attention_mask, position_ids),
    "position-processor_swa_same_attn.onnx",
    input_names=["attention_mask_before_processor", "position_ids"],
    output_names=["attention_mask_before_quantizer", "swa_attention_mask_before_quantizer", 
                  "position_ids_cos_before_quantizer", "position_ids_sin_before_quantizer",
                 "swa_position_ids_cos_before_quantizer", "swa_position_ids_sin_before_quantizer"],
    dynamic_axes={
        "attention_mask_before_processor": {1: "ctx_len"},
        "position_ids": {1: "seq_len"},
        "attention_mask_before_quantizer": {1: "seq_len"},
        "position_ids_cos_before_quantizer": {2: "seq_len"},
        "position_ids_sin_before_quantizer": {2: "seq_len"},
    },
    opset_version=21,
    dynamo = False
)
print("ONNX model 'position-processor_swa.onnx' created successfully.")

ONNX model 'position-processor_swa.onnx' created successfully.


  torch.onnx.export(
  _export(
  if sequence_length != 1:


# Add Qunatize Layer at output of position processor

In [21]:
import json

model_json_file = json.load(open('weight_sharing_model_1_of_4.serialized.json', 'r'))

In [22]:
attn_scale = None
attn_offset = None
pos_scale = None
pos_offset = None
graph_inputs = model_json_file['info']['graphs'][0]['info']['graphInputs']
    
for input in graph_inputs:
    if input['info']['name'] == 'attention_mask':
        attn_scale = input['info']['quantizeParams']['scaleOffset']['scale']
        attn_offset = abs(input['info']['quantizeParams']['scaleOffset']['offset'])
    elif input['info']['name'] == 'position_ids_cos':
        pos_scale = input['info']['quantizeParams']['scaleOffset']['scale']
        pos_offset = abs(input['info']['quantizeParams']['scaleOffset']['offset'])
    else:
        if attn_scale and attn_offset and pos_scale and pos_offset:
            break
attn_scale, attn_offset, pos_scale, pos_offset

(0.0015259021893143654, 65535, 3.051804378628731e-05, 32768)

In [23]:
import argparse
import onnx
import numpy as np
from onnx import helper, TensorProto, checker

def _tensor_shape_from_value_info(vi):
    """Extract shape as a list of ints/strings/None from a ValueInfoProto."""
    t = vi.type.tensor_type
    if not t.HasField("shape"):
        return None
    dims = []
    for d in t.shape.dim:
        if d.HasField("dim_param") and d.dim_param:
            dims.append(d.dim_param)  # symbolic dim
        elif d.HasField("dim_value"):
            dims.append(d.dim_value)  # concrete dim
        else:
            dims.append(None)         # unknown
    return dims

def _get_opset(model, domain=""):
    for o in model.opset_import:
        if o.domain == domain:
            return o
    return None

In [24]:
input_path = "position-processor_swa_same_attn.onnx"
output_path = "position-processor_swa_same_attn_quant.onnx"

In [25]:
model = onnx.load(input_path)
graph = model.graph

In [26]:
graph_output_names = [vi.name for vi in graph.output]

In [27]:
old_output_names = graph_output_names

In [28]:
x_scale_values = [attn_scale, attn_scale, pos_scale, pos_scale, pos_scale, pos_scale] 
x_zero_point_values = [attn_offset, attn_offset, pos_offset, pos_offset, pos_offset, pos_offset]

In [29]:
new_output_names = [old_output_name.replace('_before_quantizer', '') for old_output_name in old_output_names]

In [30]:
q_node_names = [new_output_name+"_quant" for new_output_name in new_output_names]
q_input_names = [new_output_name+"_in" for new_output_name in new_output_names]
x_scale_names = [new_output_name+"_scale" for new_output_name in new_output_names]
x_zero_point_names = [new_output_name+"_zp" for new_output_name in new_output_names]

In [31]:
old_input_vis = [next((vi for vi in graph.output if vi.name == old_output_name), None) for old_output_name in old_output_names]

In [32]:
old_shapes = [_tensor_shape_from_value_info(old_input_vi) for old_input_vi in old_input_vis]

In [33]:
q_elem = TensorProto.UINT16
zp_np_dtype = np.uint16

In [34]:
def name_in_use(name: str) -> bool:
    if name in graph_output_names:
        return True
    if name in [o.name for o in graph.input]:
        return True
    if name in [init.name for init in graph.initializer]:
        return True
    for n in graph.node:
        if name in list(n.input) + list(n.output):
            return True
    return False

In [35]:
for must_be_unique in [*new_output_names, *q_input_names, *x_scale_names, *x_zero_point_names]:
    if name_in_use(must_be_unique):
        raise ValueError(f"Name '{must_be_unique}' already exists in the graph. "
                         f"Please provide a different name.")

In [36]:
new_output_vis = [helper.make_tensor_value_info(new_output_name, q_elem, old_shape) for (new_output_name, old_shape) in zip(new_output_names, old_shapes)]
graph.output.extend(new_output_vis)

In [37]:
scale_nps = [np.array([x_scale_value], dtype=np.float32) for x_scale_value in x_scale_values]
zp_nps = [np.array([x_zero_point_value], dtype=zp_np_dtype) for x_zero_point_value in x_zero_point_values]

In [38]:
x_scale_inits = [helper.make_tensor(name=x_scale_name, data_type=TensorProto.FLOAT, dims=(1,), vals=scale_np) \
                 for x_scale_name, scale_np in zip(x_scale_names, scale_nps)]

In [39]:
x_zero_point_inits = [helper.make_tensor(name=x_zero_point_name, data_type=q_elem, dims=(1,), vals=zp_np) \
                 for x_zero_point_name, zp_np in zip(x_zero_point_names, zp_nps)]

In [40]:
graph.initializer.extend([*x_scale_inits, *x_zero_point_inits])

In [41]:
q_nodes = [helper.make_node("QuantizeLinear", inputs=[q_input_name, x_scale_name, x_zero_point_name], outputs=[new_output_name], name=q_node_name) \
        for new_output_name, x_scale_name, x_zero_point_name, q_input_name, q_node_name in \
        zip(new_output_names, x_scale_names, x_zero_point_names, q_input_names, q_node_names)]

In [42]:
graph.node.extend(q_nodes)

In [43]:
for node in graph.node:
    for i, inp in enumerate(node.output):
        if inp in old_output_names:
            new_ind = old_output_names.index(inp)
            node.output[i] = q_input_names[new_ind]

In [44]:
for node in graph.node:
    for i, inp in enumerate(node.input):
        if inp in old_output_names:
            new_ind = old_output_names.index(inp)
            node.input[i] = q_input_names[new_ind]

In [45]:
old_idxs = []
for i, vi in enumerate(graph.output):
    if vi.name in old_output_names:
        old_idxs.append(i)

In [46]:
for old_idx in old_idxs[::-1]:
    del graph.output[old_idx]

In [47]:
opset = _get_opset(model, "")
opset

version: 21

In [48]:
checker.check_model(model)

In [49]:
onnx.save(model, output_path)

# Add Quantize Layer to Encoder Model

In [239]:
import json

model_json_file = json.load(open('weight_sharing_model_1_of_4.serialized.json', 'r'))

In [240]:
attn_scale = None
attn_offset = None
pos_scale = None
pos_offset = None
graph_inputs = model_json_file['info']['graphs'][0]['info']['graphInputs']
    
for input in graph_inputs:
    if input['info']['name'] == 'inputs_embeds':
        embed_scale = input['info']['quantizeParams']['scaleOffset']['scale']
        embed_offset = abs(input['info']['quantizeParams']['scaleOffset']['offset'])
        break
embed_scale, embed_offset

(0.0009816634701564908, 29193)

In [241]:
import argparse
import onnx
import numpy as np
from onnx import helper, TensorProto, checker

def _tensor_shape_from_value_info(vi):
    """Extract shape as a list of ints/strings/None from a ValueInfoProto."""
    t = vi.type.tensor_type
    if not t.HasField("shape"):
        return None
    dims = []
    for d in t.shape.dim:
        if d.HasField("dim_param") and d.dim_param:
            dims.append(d.dim_param)  # symbolic dim
        elif d.HasField("dim_value"):
            dims.append(d.dim_value)  # concrete dim
        else:
            dims.append(None)         # unknown
    return dims

def _get_opset(model, domain=""):
    for o in model.opset_import:
        if o.domain == domain:
            return o
    return None

In [242]:
input_path = "embed_fp32.onnx"
output_path = "embed_fp32_mod.onnx"

In [243]:
model = onnx.load(input_path)
graph = model.graph

In [244]:
graph_output_names = [vi.name for vi in graph.output]

In [245]:
old_output_names = graph_output_names

In [246]:
x_scale_values = [embed_scale] 
x_zero_point_values = [embed_offset]

In [247]:
new_output_names = ["inputs_embeds"]

In [248]:
q_node_names = [new_output_name+"_quant" for new_output_name in new_output_names]
q_input_names = [new_output_name+"_in" for new_output_name in new_output_names]
x_scale_names = [new_output_name+"_scale" for new_output_name in new_output_names]
x_zero_point_names = [new_output_name+"_zp" for new_output_name in new_output_names]

In [249]:
old_input_vis = [next((vi for vi in graph.output if vi.name == old_output_name), None) for old_output_name in old_output_names]

In [250]:
old_shapes = [_tensor_shape_from_value_info(old_input_vi) for old_input_vi in old_input_vis]

In [251]:
q_elem = TensorProto.UINT16
zp_np_dtype = np.uint16

In [252]:
def name_in_use(name: str) -> bool:
    if name in graph_output_names:
        return True
    if name in [o.name for o in graph.input]:
        return True
    if name in [init.name for init in graph.initializer]:
        return True
    for n in graph.node:
        if name in list(n.input) + list(n.output):
            return True
    return False

In [253]:
for must_be_unique in [*new_output_names, *q_input_names, *x_scale_names, *x_zero_point_names]:
    if name_in_use(must_be_unique):
        raise ValueError(f"Name '{must_be_unique}' already exists in the graph. "
                         f"Please provide a different name.")

In [254]:
new_output_vis = [helper.make_tensor_value_info(new_output_name, q_elem, old_shape) for (new_output_name, old_shape) in zip(new_output_names, old_shapes)]
graph.output.extend(new_output_vis)

In [255]:
scale_nps = [np.array([x_scale_value], dtype=np.float32) for x_scale_value in x_scale_values]
zp_nps = [np.array([x_zero_point_value], dtype=zp_np_dtype) for x_zero_point_value in x_zero_point_values]

In [256]:
x_scale_inits = [helper.make_tensor(name=x_scale_name, data_type=TensorProto.FLOAT, dims=(1,), vals=scale_np) \
                 for x_scale_name, scale_np in zip(x_scale_names, scale_nps)]

In [257]:
x_zero_point_inits = [helper.make_tensor(name=x_zero_point_name, data_type=q_elem, dims=(1,), vals=zp_np) \
                 for x_zero_point_name, zp_np in zip(x_zero_point_names, zp_nps)]

In [258]:
graph.initializer.extend([*x_scale_inits, *x_zero_point_inits])

In [259]:
q_nodes = [helper.make_node("QuantizeLinear", inputs=[q_input_name, x_scale_name, x_zero_point_name], outputs=[new_output_name], name=q_node_name) \
        for new_output_name, x_scale_name, x_zero_point_name, q_input_name, q_node_name in \
        zip(new_output_names, x_scale_names, x_zero_point_names, q_input_names, q_node_names)]

In [260]:
graph.node.extend(q_nodes)

In [261]:
for node in graph.node:
    for i, inp in enumerate(node.output):
        if inp in old_output_names:
            new_ind = old_output_names.index(inp)
            node.output[i] = q_input_names[new_ind]

In [262]:
for node in graph.node:
    for i, inp in enumerate(node.input):
        if inp in old_output_names:
            new_ind = old_output_names.index(inp)
            node.input[i] = q_input_names[new_ind]

In [263]:
old_idxs = []
for i, vi in enumerate(graph.output):
    if vi.name in old_output_names:
        old_idxs.append(i)

In [264]:
for old_idx in old_idxs[::-1]:
    del graph.output[old_idx]

In [265]:
opset = _get_opset(model, "")
opset

version: 19

In [266]:
opset.version = 21

In [267]:
import copy
new_input = copy.deepcopy(model.graph.input[1])

In [268]:
old_input_name = "image_features"

In [269]:
new_input.name = "image_features_dq"

In [270]:
model.graph.input.extend([new_input])

In [271]:
for node in model.graph.node:
    for i, input_name in enumerate(node.input):
        if input_name == old_input_name:
            node.input[i] = new_input.name

In [272]:
input_idx = next((i for i, vi in enumerate(model.graph.input) if vi.name == old_input_name))

In [273]:
del model.graph.input[input_idx]

In [274]:
onnx.save(model, output_path, all_tensors_to_one_file=True, save_as_external_data=True, location = "embed_fp32_mod.data")

In [275]:
checker.check_model(output_path)

# Create Ouptut Dequantizer

In [118]:
import onnx
import numpy as np
import onnx.helper as helper
from onnx import helper, TensorProto
import onnx.numpy_helper as numpy_helper

In [119]:
def create_dequantizer_onnx()->None:
    # Define the input tensor
    input_tensor = helper.make_tensor_value_info(
        'logits', TensorProto.UINT16, [1, 'sequence_length', 'vocab_size']
    )

    # Define the output tensor
    output_tensor = helper.make_tensor_value_info(
        'logits_dequantized', TensorProto.FLOAT, [1, 'sequence_length', 'vocab_size']
    )

    # Create the Cast node to convert uint16 to float32
    cast_node = helper.make_node(
        'Cast',
        inputs=['logits'],
        outputs=['logits_dequantized'],
        to=TensorProto.FLOAT,
        name='logits_dequantizer'
    )

    # Create the graph
    graph_def = helper.make_graph(
        nodes=[cast_node],
        name='dequantizer',
        inputs=[input_tensor],
        outputs=[output_tensor]
    )

    # Create the model
    model_def = helper.make_model(graph_def, producer_name='onnx-dequantize-example')
    model_def.ir_version = 10
    model_def.opset_import[0].version = 21
    # Save the model
    onnx.save(model_def, 'dequantizer.onnx')

    print("ONNX model 'dequantize.onnx' created successfully.")

In [120]:
create_dequantizer_onnx()

ONNX model 'dequantize.onnx' created successfully.


# Input Ops for SWA to Past

In [121]:
import torch

In [122]:
class InputSWAMod(torch.nn.Module):
    def __init__(self, swa_dim, past_dim):
        super().__init__()
        self.swa_dim = swa_dim
        self.past_dim = past_dim

    def forward(self, past_key_ins, past_value_ins):
        swa_past_key_ins = []
        swa_past_value_ins = []
        
        for ind, past_key_in in enumerate(past_key_ins):
            # past_key_in = past_key_in.transpose(0, 1)
            past_key_in = past_key_in.view((1, 4, 256, self.past_dim))
            if (ind+1)%6 != 0: 
                swa_past_key_ins.append(past_key_in[:, :, :, self.past_dim-self.swa_dim:self.past_dim])
            else: 
                swa_past_key_ins.append(past_key_in)
                
        for ind, past_value_in in enumerate(past_value_ins):
            # past_value_in = past_value_in.transpose(0, 1)
            past_value_in = past_value_in.view((1, 4, self.past_dim, 256))
            if (ind+1)%6 != 0: 
                swa_past_value_ins.append(past_value_in[:, :, self.past_dim-self.swa_dim:past_dim, :])
            else:
                swa_past_value_ins.append(past_value_in)

        return swa_past_key_ins, swa_past_value_ins

In [124]:
context_length = 8192
past_dim = context_length - 128
swa_dim = 896
swa_input_mod = InputSWAMod(swa_dim, past_dim)

past_key_shape = (4, 1, 256, past_dim)
past_value_shape = (4, 1, past_dim, 256)

past_key_inputs = [torch.zeros(past_key_shape, dtype = torch.uint8) for i in range(34)]
past_value_inputs = [torch.zeros(past_value_shape, dtype = torch.uint8) for i in range(34)]

output_key, output_value = swa_input_mod(past_key_inputs, past_value_inputs)

own_past_key_pattern = "own_past_key_%d_in"
own_past_value_pattern = "own_past_value_%d_in"

past_key_pattern = "past_key_%d_in"
past_value_pattern = "past_value_%d_in"
swa_key_pattern = "swa_key_%d_in"
swa_value_pattern = "swa_value_%d_in"

input_names = [own_past_key_pattern%(i) for i in range(34)] + [own_past_value_pattern%(i) for i in range(34)]
output_names = [swa_key_pattern%(i) if (i+1)%6 != 0 else past_key_pattern%(i) for i in range(34)] + \
               [swa_value_pattern%(i) if (i+1)%6 != 0 else past_value_pattern%(i)  for i in range(34)]

# Export to ONNX
torch.onnx.export(
    swa_input_mod,
    (past_key_inputs, past_value_inputs),
    "input-processor_swa_cl128_oga_rs.onnx",
    input_names=input_names,
    output_names=output_names,
    external_data = False,
    opset_version=21
)
print("ONNX model 'input-processor_swa.onnx' created successfully.")

  torch.onnx.export(
W0212 10:51:22.848000 17108 Lib\site-packages\torch\onnx\_internal\exporter\_registration.py:110] torchvision is not installed. Skipping torchvision::nms
W0212 10:51:22.849000 17108 Lib\site-packages\torch\onnx\_internal\exporter\_registration.py:110] torchvision is not installed. Skipping torchvision::roi_align
W0212 10:51:22.850000 17108 Lib\site-packages\torch\onnx\_internal\exporter\_registration.py:110] torchvision is not installed. Skipping torchvision::roi_pool


[torch.onnx] Obtain model graph for `InputSWAMod()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `InputSWAMod()` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...


  return cls.__new__(cls, *args)


[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
ONNX model 'input-processor_swa.onnx' created successfully.


In [125]:
past_dim = context_length - 1
swa_dim = 1023
swa_input_mod = InputSWAMod(swa_dim, past_dim)

past_key_shape = (4, 1, 256, past_dim)
past_value_shape = (4, 1, past_dim, 256)

past_key_inputs = [torch.zeros(past_key_shape, dtype = torch.uint8) for i in range(34)]
past_value_inputs = [torch.zeros(past_value_shape, dtype = torch.uint8) for i in range(34)]

output_key, output_value = swa_input_mod(past_key_inputs, past_value_inputs)

own_past_key_pattern = "own_past_key_%d_in"
own_past_value_pattern = "own_past_value_%d_in"

past_key_pattern = "past_key_%d_in"
past_value_pattern = "past_value_%d_in"
swa_key_pattern = "swa_key_%d_in"
swa_value_pattern = "swa_value_%d_in"

input_names = [own_past_key_pattern%(i) for i in range(34)] + [own_past_value_pattern%(i) for i in range(34)]
output_names = [swa_key_pattern%(i) if (i+1)%6 != 0 else past_key_pattern%(i) for i in range(34)] + \
               [swa_value_pattern%(i) if (i+1)%6 != 0 else past_value_pattern%(i)  for i in range(34)]

# Export to ONNX
torch.onnx.export(
    swa_input_mod,
    (past_key_inputs, past_value_inputs),
    "input-processor_swa_cl1_oga_rs.onnx",
    input_names=input_names,
    output_names=output_names,
    external_data = False,
    opset_version=21
)
print("ONNX model 'input-processor_swa.onnx' created successfully.")

  torch.onnx.export(
W0212 10:51:26.021000 17108 Lib\site-packages\torch\onnx\_internal\exporter\_registration.py:110] torchvision is not installed. Skipping torchvision::nms
W0212 10:51:26.023000 17108 Lib\site-packages\torch\onnx\_internal\exporter\_registration.py:110] torchvision is not installed. Skipping torchvision::roi_align
W0212 10:51:26.024000 17108 Lib\site-packages\torch\onnx\_internal\exporter\_registration.py:110] torchvision is not installed. Skipping torchvision::roi_pool


[torch.onnx] Obtain model graph for `InputSWAMod()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `InputSWAMod()` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...


  return cls.__new__(cls, *args)


[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
ONNX model 'input-processor_swa.onnx' created successfully.


# Output Processor Identity

In [126]:
import torch

In [127]:
class TransposeChangeName(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, inputs):
        outputs = []
        for input in inputs:
            output = input.transpose(0, 1)
            outputs.append(output)
        return outputs

In [128]:
changeNameMod = TransposeChangeName()

In [129]:
inputs = [torch.zeros((1, 4, 256, 128), dtype = torch.uint8) for i in range(34)] + \
            [torch.zeros((1, 4, 128, 256), dtype = torch.uint8) for i in range(34)]

In [130]:
outputs = changeNameMod(inputs)

In [131]:
input_names=["past_key_%d_out"%(i) if (i+1)%6 == 0 else "swa_key_%d_out"%(i) for i in range(34)] + \
                ["past_value_%d_out"%(i)  if (i+1)%6 == 0 else "swa_value_%d_out"%(i) for i in range(34)]

In [132]:
output_names=["own_past_key_%d_out"%(i) for i in range(34)] + ["own_past_value_%d_out"%(i) for i in range(34)]

In [133]:
dynamic_axes = {}
dynamic_axes.update({x: {3: "kv_dim"} if 'key' in x else {2: "kv_dim"} for x in input_names})
dynamic_axes.update({x: {3: "kv_dim"} if 'key' in x else {2: "kv_dim"} for x in output_names})

In [134]:
# Export to ONNX
torch.onnx.export(
    changeNameMod,
    inputs,
    "transpose_outputs_cl_all.onnx",
    input_names = input_names,
    output_names = output_names,
    dynamic_axes = dynamic_axes,
    opset_version=21,
    dynamo = False
)

  torch.onnx.export(
  _export(


# Merge Output Processors Models

In [135]:
import onnx

In [136]:
op_model = onnx.load("transpose_outputs_cl_all.onnx")
dq_model = onnx.load("dequantizer.onnx")

In [137]:
merged_model = onnx.compose.merge_models(op_model, dq_model, {})

In [138]:
onnx.save(merged_model, "output_processor.onnx")

# OGA Inference

In [178]:
import onnxruntime_genai as og
import json

In [179]:
model_path = r"./"

In [180]:
config = og.Config(model_path)

In [181]:
model = og.Model(config)

In [182]:
tokenizer = og.Tokenizer(model)

In [183]:
processor = model.create_multimodal_processor()

In [184]:
stream = processor.create_stream()

In [185]:
image_paths = []
images = None
text = "What are transformers?"

In [186]:
messages = []

In [187]:
content_list = [{"type": "image"} for _ in image_paths]
content_list.append({"type": "text", "text": text})
messages.append({"role": "user", "content": content_list})

In [188]:
message_json = json.dumps(messages)

In [189]:
prompt = tokenizer.apply_chat_template(message_json, add_generation_prompt=True)

In [190]:
inputs = processor(prompt, images=images)

In [191]:
params = og.GeneratorParams(model)
params.set_search_options(max_length=1024)

In [192]:
generator = og.Generator(model, params)

In [193]:
generator.set_inputs(inputs)

In [194]:
generator.generate_next_token()

In [195]:
new_token = generator.get_next_tokens()[0]

In [196]:
print(stream.decode(new_token), end="", flush=True)

Okay

In [197]:
for _ in range(100):
    if generator.is_done(): break
    generator.generate_next_token()
    new_token = generator.get_next_tokens()[0]
    print(stream.decode(new_token), end="", flush=True)

, let's break down what "transformers" are in the context of modern AI – specifically, large language models. They're a complex topic, but I'll explain it in a way that's both detailed and accessible.

**Transformers: The Core Idea**

At their heart, transformers are a type of computer program designed to understand and generate text.▁▁They've become *the* tool for many AI applications because they've proven exceptionally good at this.▁▁Let