# AIMET Quantization workflow for Gemma3-4B Language Model

This notebook shows a working code example of how to use AIMET to quantize Gemma3-4B Language Model

---
### Required packages
The notebook assumes AIMET and Gemma3 related packages are already installed.

In [None]:
if __name__ != '__main__':
    raise Exception("Killing multiprocessing spawn started by Converter during model preparation.")

In [None]:
# Install packages only if running in jupyter notebook mode
if hasattr(__builtins__,'__IPYTHON__'):
    !sudo -H apt-get -qq update
    !sudo -H apt-get -qq install libc++-dev
    !sudo -H pip install --quiet --upgrade --root-user-action=ignore --no-cache-dir transformers==4.50.0
    !sudo -H pip install --quiet --upgrade --root-user-action=ignore --no-cache-dir tokenizers==0.21.4

### Overall flow
This notebook covers the following
1. Setting QNN SDK and NSP target
2. Instantiate and evaluate HuggingFace model
3. Instantiate and adapt FP32 model
4. Model Sample Input
5. Prepare model using QAIRT model preparer pro
6. Evaluation of prepared model
7. Quantization 
8. Export


### What this notebook is not 
* This notebook is not intended to show the full scope of optimization. For example, the flow will not use QAT, KD-QAT as deliberate choice to have the notebook execute more quickly.

### 1.1 Setting QNN SDK

In [None]:
import sys
import os
import copy

QNN_SDK_ROOT = "/tmp/qnn"
assert QNN_SDK_ROOT is not None, 'Please point the QNN_SDK_ROOT variable to your QNN SDK'
assert os.path.exists(QNN_SDK_ROOT), "QNN_SDK_ROOT doesn't exist!"
sys.path.insert(0, QNN_SDK_ROOT + '/lib/python')

lib_clang_path = os.path.join(QNN_SDK_ROOT, 'lib', 'x86_64-linux-clang')
LD_LIBRARY_PATH = os.getenv('LD_LIBRARY_PATH', None)
os.environ['LD_LIBRARY_PATH'] = lib_clang_path + ':' + LD_LIBRARY_PATH if LD_LIBRARY_PATH is not None else lib_clang_path


### 1.2 Setting NSP Target

In [None]:
# Select quantsim config based on target
htp_config_file =  "htp_quantsim_config_v81.json"

### 2. Instantiate and evaluate HuggingFace model

In [None]:
from genai_lib.common.debug.recipe_logger import recipe_dump_init
from genai_lib.common.debug.recipe_logger import llm_lib_log_env_info
from transformers import AutoConfig, AutoTokenizer, AutoProcessor

#======================Configurable setting by users================================
run_ppl_eval = True
load_optimized_weights = True   # load spinquant checkpoint from spinquant notebook for optimized quantization accuracy
cache_dir='/tmp/cache_dir'
output_dir = '/tmp/output_dir'  # point to where the export artifacts of this notebook to be saved

model_name = 'gemma_4b'
model_id="google/gemma-3-4b-it"
if load_optimized_weights:
    optimized_model_id = "/tmp/output_dir/spinquant"

lmm_config = AutoConfig.from_pretrained(model_id, cache_dir=cache_dir, trust_remote_code=True)

context_length = 8192
# To help with debugging num_hidden_layers could be set to 6 to quickly verify the pipeline
num_hidden_layers = int(os.getenv("NUM_HIDDEN_LAYERS", 0))
lmm_config.text_config.num_hidden_layers = num_hidden_layers if num_hidden_layers > 0 else lmm_config.text_config.num_hidden_layers

print(f'num_layer: {lmm_config.text_config.num_hidden_layers}, context_length: {context_length},'
      f'num_attention_heads :{lmm_config.text_config.num_attention_heads},  num_kv_heads: {lmm_config.text_config.num_key_value_heads}')

# Recipe_logger: Initialize the logger and log environment details 
os.makedirs(output_dir, exist_ok=True)
recipe_dump_init(output_dir)

llm_lib_log_env_info()

#### 2.1 Instantiate the HuggingFace model

In [None]:
import torch
from transformers.models.gemma3 import modeling_gemma3
from genai_lib.common.debug.profiler import event_marker

with event_marker('Load FP model'):
    model = modeling_gemma3.Gemma3ForConditionalGeneration.from_pretrained(model_id, config=lmm_config, cache_dir=cache_dir)

    os.environ['TOKENIZERS_PARALLELISM'] = '0'
    tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir, use_fast=True, trust_remote_code=True)
    processor = AutoProcessor.from_pretrained(model_id, cache_dir=cache_dir, trust_remote_code=True)
    ## Adjust the tokenizer to limit to context_length
    tokenizer.model_max_length = context_length

#### 2.2 Instantiate Dataloaders

In [None]:
from llm_utils.wikitext_dataloader import get_wiki_dataset
from llm_utils.llava_dataloader import get_llava_dataset

train_dataloader, test_dataloader, _ = get_wiki_dataset(context_length, tokenizer, cache_dir=cache_dir)

dataset_path = "<path to folder containing the coco dataset root folder>"
data_files = "llm_utils/llava_dataset/llava_v1_5_mix665k_300.json"
llava_dataset = get_llava_dataset(tokenizer, processor, data_files=data_files, dataset_path=dataset_path, cache_dir=cache_dir)

#### 2.3 HuggingFace FP model eval

In [None]:
from aimet_torch.utils import place_model
from genai_lib.llm.evaluation_utils import llm_evaluate_ppl_with_dataloader
from genai_lib.common.debug.recipe_logger import llm_lib_log_property, Property
from genai_lib.common.debug.recipe_logger import llm_lib_log_metric, ModelType, Metric


# Recipe_logger: Log the context_length property and the metrics.
llm_lib_log_property({Property.context_length : context_length})

if run_ppl_eval:
    with event_marker("HuggingFace FP model eval"):
        with place_model(model, torch.device('cuda')):
            orig_ppl = llm_evaluate_ppl_with_dataloader(model=model.language_model, dataloader=test_dataloader)

    llm_lib_log_metric(ModelType.hf_model, Metric.ppl, orig_ppl, model_name="base")
    print(f"PPL score of HuggingFace FP model = {orig_ppl}")

# Remove the HuggingFace model from memory
del model

### 3. Instantiate and adapt FP32 model

#### 3.1 Adapt FP32 model definition for inference on HTP.
- The following adaptations are done to replace default attention module with attention definition that compatible with NSP backend
  * use conv instead of linear for Q,K,V,O projections
  * bypass attention and causal mask generation and replace with pre-generated 2D-mask input
  * output only newly created V and transposed K instead of entire augmented KV sequence
  * input pre-calculated positional embedding instead of position ids, thus bypass the embedding generation in the model
  


In [None]:
from transformers import cache_utils
from transformers.models.gemma3 import modeling_gemma3
from gemma3.adaptation import (
    Gemma3Attention,
    Gemma3DecoderLayer,
    Gemma3TextModel,
    Gemma3ForCausalLM,
    adapted_update_causal_mask,
    DynamicCache_update,
    DynamicCache_get_seq_length,
    update_attr
)

with event_marker("FP model adaptation configuration"):
    modeling_gemma3.Gemma3TextModel = Gemma3TextModel
    modeling_gemma3.Gemma3ForCausalLM = Gemma3ForCausalLM
    modeling_gemma3.Gemma3DecoderLayer = Gemma3DecoderLayer
    modeling_gemma3.Gemma3Attention = Gemma3Attention

    # Bypass attention_mask preparation
    assert hasattr(modeling_gemma3.Gemma3TextModel, '_update_causal_mask'), \
        "GaussModel does not have _update_causal_mask as attribute"
    modeling_gemma3.Gemma3TextModel._update_causal_mask = adapted_update_causal_mask
    
    # Adapting KV$ management
    assert update_attr(cache_utils.DynamicCache, 'update', DynamicCache_update), \
        f"Unknown DynamicCache definition: {cache_utils.DynamicCache}"
    assert update_attr(cache_utils.DynamicCache, 'get_seq_length', DynamicCache_get_seq_length), \
        f"Unknown DynamicCache definition: {cache_utils.DynamicCache}"

#### 3.2 Instantiate adapted FP32 model definition

In [None]:
import types

In [None]:
#======================Fixed setting that should not be changed by users==============
# Auto-regression length: number of tokens to consume and number of logits to produce.
# This value should NOT be changed due to downstream consumption requirements
ARN = int(os.getenv("ARN", 473))

enable_right_padding = False   # right-pad causes error in model prepare step, only support left-pad currently
pad_to_left = not enable_right_padding
num_slices = None

setattr(lmm_config.text_config, 'return_new_key_value_only', True)
setattr(lmm_config.text_config, 'transposed_key_cache', True)
setattr(lmm_config.text_config, 'use_combined_mask_input', True)
setattr(lmm_config.text_config, 'use_position_embedding_input', True)
setattr(lmm_config.text_config, '_attn_implementation', 'eager')
setattr(lmm_config.text_config, '_attn_implementation_internal', 'eager')
setattr(lmm_config.text_config, 'return_dict', False)
setattr(lmm_config.text_config, 'logits_to_keep', 0)
setattr(lmm_config.text_config, 'input_tokens_per_inference', ARN)

lmm_config.save_pretrained(output_dir)

with event_marker('Adapted FP model creation'):
    model = modeling_gemma3.Gemma3ForConditionalGeneration.from_pretrained(model_id, config=lmm_config, cache_dir=cache_dir)
    # Gemma3ForConditionalGeneration uses AutoModelForCausalLM to initialize language model, 
    # so we need below line to make sure we use the forward function of adapted Gemma3ForCausalLM
    model.language_model.forward = types.MethodType(Gemma3ForCausalLM.forward, model.language_model)
    
    # Load and replace the original language model if an optimized language model is provided
    if load_optimized_weights:
        optimized_language_model = modeling_gemma3.Gemma3ForCausalLM.from_pretrained(optimized_model_id, config=lmm_config.text_config, cache_dir=cache_dir)
        model.language_model = optimized_language_model
        del optimized_language_model

#### 3.3 Complete the last step(s) of Model Adaptation
The following model adaptation are enabled for inference:
- apply linear to conv in attention, MLP and lmhead and arrange linear weights properly for conv

In [None]:
from genai_lib.common.dev.model_adaptation.linear_to_conv import replace_linears_with_convs

with event_marker('FP model adaptation for NSP backend completion'):
    model.language_model = replace_linears_with_convs(model.language_model)

#### 3.4 Instantiate VEG (vision tower + projector) and embedding layer

In [None]:
class VisualEmbeddingGenerator(torch.nn.Module):
    def __init__(self, vision_tower, multi_modal_projector):
        super().__init__()
        self.multi_modal_projector = multi_modal_projector
        self.vision_tower = vision_tower
        self.device = vision_tower.device

    # this forwrad gets the image pixel values that we get from the AutoProcessor when we pass the image and text (text -> input ids, and image-> pixel values)
    # input shape is [1,3,896,896], output shape is [1,256,2560]
    def forward(self, pixel_values):
        image_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
        image_features = self.multi_modal_projector(image_outputs)
        return image_features

embedding_layer = copy.deepcopy(model.language_model.model.embed_tokens)
vision_model = VisualEmbeddingGenerator(model.vision_tower, model.multi_modal_projector)

#### 3.5 Changes to HuggingFace model to work with the Adapted Model and Prepared Model
- As a result of adapting the model we introduce changes to the types of the model inputs.
- As a result of model preparation, we make the shapes of the inputs static.
- adapted_model_forward works with either adapted model dynamic input or prepared model static input model through flag static_shape.
- Override the 'forward' function and the function 'prepare_inputs_for_generation'. With these overrides, we make the adapted model or prepared model work just like the original model.
- adapted_model_prepare_inputs_for_dynamic_shapes is utility function for forward pass of adapted model with dynamic shapes.
- adapted_model_prepare_inputs_for_static_shapes is utility function for forward pass of prepared model with static shapes.

##### 3.5.1 Define prepare inputs function for adapted model with dynamic shapes
The inputs in case of dynamic shape is sent in the same way as we would have in the origial HF model but with the adaptation/ pre-computation logic. This means we do not need to perform any padding to our inputs.

In [None]:
from transformers.cache_utils import DynamicCache
from genai_lib.llm.static_graph_utils import llm_create_1d_attn_mask, llm_pad_position_ids
from gemma3.utils import llm_update_causal_mask, llm_create_position_embeddings, lmm_preprocess_inputs, llm_get_kv_length

llm_config = lmm_config.text_config
global_layer_idx = llm_config.sliding_window_pattern - 1

# Creating a list to store which layers to sliding eviction
layer_indices_to_perform_sliding_eviction = [layer_idx for layer_idx in range(llm_config.num_hidden_layers) if bool((layer_idx + 1) % llm_config.sliding_window_pattern)]


def adapted_model_prepare_inputs_for_dynamic_shapes(self,inputs_embeds_slice, attn_mask_slice, position_ids_slice, outputs, token_type_ids=None, **kwargs):
    device = inputs_embeds_slice.device
    batch_size = inputs_embeds_slice.shape[0]
    input_len = inputs_embeds_slice.shape[1]
    kv_length, kv_length_sliding = llm_get_kv_length(outputs, global_layer_idx, layer_indices_to_perform_sliding_eviction)
    
    if outputs['past_key_values'] is None:
        outputs['past_key_values'] = DynamicCache()

    ########### Causal Mask preparation #######
    # 1. Global Layer
    past_kv_attn_mask = torch.ones((batch_size, kv_length), dtype=torch.long, device=device)

    prepared_1d_attention_mask = llm_create_1d_attn_mask(attn_mask_past_kv=past_kv_attn_mask,
                                                         attn_mask_input=attn_mask_slice)

    # During dynamic mask creation, the KV$ update inside the Attention block will only perform the concat since our pastKV$ 
    # will not have space to scatter new KV$. As such do not pass pad_to_left = True or cache_index to the llm_update_causal_mask API.
    prepared_causal_mask = llm_update_causal_mask(prepared_1d_attn_mask=prepared_1d_attention_mask,
                                                  input_tensor=inputs_embeds_slice,
                                                  max_input_tokens=input_len,
                                                  model_context_len=kv_length+input_len,
                                                  model_id_or_path=model_id,
                                                  token_type_ids=token_type_ids)

    # 2. Sliding Layer
    past_kv_attn_mask_sliding = torch.ones((batch_size, kv_length_sliding), dtype=torch.long, device=device)

    prepared_1d_attention_mask_sliding = llm_create_1d_attn_mask(attn_mask_past_kv=past_kv_attn_mask_sliding,
                                                                 attn_mask_input=attn_mask_slice)

    # Here we pass sliding_window to get the strided attention mask to ensure each token in the current input only looks at the most
    # recent sliding window worth of KV$
    swa_attention_mask  = llm_update_causal_mask(prepared_1d_attn_mask=prepared_1d_attention_mask_sliding,
                                                 input_tensor=inputs_embeds_slice,
                                                 max_input_tokens=input_len,
                                                 model_context_len=kv_length_sliding+input_len,
                                                 model_id_or_path=model_id,
                                                 sliding_window=llm_config.sliding_window,
                                                 token_type_ids=token_type_ids)
    
    ########### Position ID preparation #######

    padded_position_ids = llm_pad_position_ids(position_ids_slice=position_ids_slice,
                                               max_input_tokens=input_len,
                                               pad_to_left=pad_to_left)

    # Global RoPE
    prepared_position_embeddings = llm_create_position_embeddings(config=llm_config,
                                                                  position_ids=padded_position_ids)

    # Local RoPE
    config_local = copy.deepcopy(llm_config)
    config_local.rope_theta = llm_config.rope_local_base_freq
    config_local.rope_scaling = {"rope_type": "default"}
    swa_position_embeddings = llm_create_position_embeddings(config=config_local,
                                                            position_ids=padded_position_ids)

    prepared_inputs = {
        'attention_mask': prepared_causal_mask,
        'position_ids': prepared_position_embeddings,
        'past_key_values': copy.deepcopy(outputs['past_key_values']),
        'inputs_embeds': inputs_embeds_slice,
        'swa_attention_mask': swa_attention_mask,
        'swa_position_ids': swa_position_embeddings,
    }

    return prepared_inputs

##### 3.5.2 Define prepare inputs function for adapted model with static shapes

In [None]:
from genai_lib.llm.static_graph_utils import llm_pad_inputs, llm_pad_past_kv, llm_pad_input_attn_mask, llm_create_kv_attn_mask, llm_get_dummy_kv
from genai_lib.llm.dev.model_adaptation.common.utils import KEY_CONCAT_AXIS, VALUE_CONCAT_AXIS

def adapted_model_prepare_inputs_for_static_shapes(self, inputs_embeds_slice, attn_mask_slice, position_ids_slice, outputs, token_type_ids=None, **kwargs):
    batch_size = inputs_embeds_slice.shape[0]
    pad_token = tokenizer.eos_token_id
    device = inputs_embeds_slice.device
    head_dim = llm_config.head_dim if hasattr(llm_config, 'head_dim') else llm_config.hidden_size // llm_config.num_attention_heads
    kv_length, kv_length_sliding = llm_get_kv_length(outputs, global_layer_idx, layer_indices_to_perform_sliding_eviction)

    ####### input id preparation #######
    pad_input_embeds = llm_pad_inputs(pad_token=pad_token,
                                      max_input_tokens=ARN,
                                      inputs_embeds_slice=inputs_embeds_slice,
                                      pad_to_left=pad_to_left)

    ####### KV input preparation #######
    global_dummy_kv = llm_get_dummy_kv(batch_size=batch_size,
                                       num_key_value_heads=llm_config.num_key_value_heads,
                                       head_dim=head_dim,
                                       key_concat_axis=KEY_CONCAT_AXIS,
                                       device=device,
                                       cache_len=context_length-ARN if pad_to_left else context_length)
    
    sliding_dummy_kv = llm_get_dummy_kv(batch_size=batch_size,
                                        num_key_value_heads=llm_config.num_key_value_heads,
                                        head_dim=head_dim,
                                        key_concat_axis=KEY_CONCAT_AXIS,
                                        device=device,
                                        cache_len=llm_config.sliding_window-ARN if pad_to_left else llm_config.sliding_window)
    
    dummy_kv = [sliding_dummy_kv if ((layer_idx + 1) % llm_config.sliding_window_pattern) != 0 else global_dummy_kv
                for layer_idx in range(llm_config.num_hidden_layers)]

    padded_past_kv_in = llm_pad_past_kv(dummy_past_kv=dummy_kv,
                                        unpadded_past_kv=outputs['past_key_values'],
                                        num_hidden_layers=llm_config.num_hidden_layers,
                                        key_concat_axis=KEY_CONCAT_AXIS,
                                        value_concat_axis=VALUE_CONCAT_AXIS,
                                        pad_to_left=pad_to_left)
    
    cache_index = None
    swa_cache_index = None
    if enable_right_padding:
        cache_index = torch.tensor([kv_length], dtype=torch.int64, device=device)
        swa_cache_index = torch.tensor([kv_length_sliding], dtype=torch.int64, device=device)  
    
    ######### Attention mask Input preparation #######
    inp_attn_mask = llm_pad_input_attn_mask(attn_mask_slice=attn_mask_slice,
                                            max_input_tokens=ARN,
                                            pad_to_left=pad_to_left)
    
    # 1. Global Layer
    past_kv_attn_mask = llm_create_kv_attn_mask(unpadded_past_kv=outputs['past_key_values'],
                                                model_context_len=context_length,
                                                max_input_tokens=ARN,
                                                batch_size=batch_size,
                                                device=device,
                                                pad_to_left=pad_to_left,
                                                global_layer_idx=global_layer_idx)
    
    prepared_1d_attention_mask = llm_create_1d_attn_mask(attn_mask_past_kv=past_kv_attn_mask,
                                                         attn_mask_input=inp_attn_mask,
                                                         cache_index=cache_index)

    prepared_causal_mask = llm_update_causal_mask(prepared_1d_attn_mask=prepared_1d_attention_mask,
                                                  input_tensor=pad_input_embeds,
                                                  max_input_tokens=ARN,
                                                  model_context_len=context_length,
                                                  model_id_or_path=model_id,
                                                  cache_index=cache_index,
                                                  pad_to_left=pad_to_left,
                                                  token_type_ids=token_type_ids)

    ########### Position ID preparation #######
    padded_position_ids = llm_pad_position_ids(position_ids_slice=position_ids_slice,
                                               max_input_tokens=ARN, 
                                               pad_to_left=pad_to_left)

    # Global RoPE
    prepared_position_embeddings = llm_create_position_embeddings(config=llm_config,
                                                                  position_ids=padded_position_ids)

    # Local RoPE
    config_local = copy.deepcopy(llm_config)
    config_local.rope_theta = llm_config.rope_local_base_freq
    config_local.rope_scaling = {"rope_type": "default"}
    swa_position_embeddings = llm_create_position_embeddings(config=config_local,
                                                                  position_ids=padded_position_ids)

    # Computing the sliding_cache_indices
    if enable_right_padding:
        offset = max(0, cache_index.item() + ARN - llm_config.sliding_window)
        prefix_kv_length = getattr(llm_config, "prefix_kv_length", 0)
        sliding_cache_indices_left = torch.arange(0, prefix_kv_length)
        sliding_cache_indices_right = torch.arange(offset+prefix_kv_length, offset+llm_config.sliding_window)
    else:
        # for left padding, during quantization the prefix_kv_len should be 0.
        prefix_kv_length = getattr(llm_config, "prefix_kv_length", 0)
        sliding_cache_indices_left = torch.arange(context_length-kv_length-prefix_kv_length, context_length-kv_length)

        # for the right indices, choose the sliding window length worth of positions, shifting by prefix kv if present.
        sliding_cache_indices_right = torch.arange(context_length-llm_config.sliding_window+prefix_kv_length, context_length)
    sliding_cache_indices = torch.cat([sliding_cache_indices_left, sliding_cache_indices_right], dim=-1).to(device)
    
    swa_attention_mask = prepared_causal_mask[..., sliding_cache_indices]
    
    prepared_inputs = {
        'attention_mask': prepared_causal_mask,
        'position_ids': prepared_position_embeddings,
        'past_key_values': padded_past_kv_in,
        'inputs_embeds': pad_input_embeds,
    }

    if enable_right_padding:
        prepared_inputs.update({'cache_index': cache_index})
        prepared_inputs.update({'swa_cache_index': swa_cache_index})
    prepared_inputs.update({'swa_attention_mask': swa_attention_mask})
    prepared_inputs.update({'swa_position_ids': swa_position_embeddings})

    return prepared_inputs

##### 3.5.3 Define forward function for adapted model

In [None]:
from transformers.modeling_outputs import CausalLMOutputWithPast
from genai_lib.llm.static_graph_utils import llm_slice_inputs_for_inference, llm_trim_pad_logits
from genai_lib.llm.dev.model_adaptation.common.utils import llm_update_kv_cache, trim_current_kv
from genai_lib.llm.long_context_utils import llm_scatter_exceeded_kv_using_rotating_eviction, replenish_rotating_index_cache
from gemma3.utils import slice_token_type_ids

# Redefinition of the forward function to work with model I/O adaptations and static shapes of the tensors that the model consumes as input
def adapted_model_forward(
    self,
    input_ids=None,
    pixel_values=None,
    attention_mask=None,
    past_key_values=None,
    inputs_embeds=None,
    token_type_ids=None,
    return_dict=False,
    output_hidden_states=False,
    **kwargs
):
    head_dim = llm_config.head_dim if hasattr(llm_config, 'head_dim') else llm_config.hidden_size // llm_config.num_attention_heads
    static_shape = hasattr(self, 'num_logits_to_return')
    num_slices = kwargs.get('num_slices', None)
    embedding_layer = kwargs.get("embedding_layer", None)
    vision_model = kwargs.get("vision_model", None)

    # dictionary to store the running output which contains the logits and the useful past kv cache until that execution
    outputs = {'past_key_values': past_key_values}

    kv_length, _ = llm_get_kv_length(outputs, global_layer_idx, layer_indices_to_perform_sliding_eviction)
    if kv_length == 0:
        self.rotating_eviction_cache = None
    
    # generate text + (vision) embeddings
    inputs_embeds = lmm_preprocess_inputs(input_ids=input_ids, pixel_values=pixel_values, inputs_embeds=inputs_embeds,
                                          past_key_values=past_key_values, image_token_index=lmm_config.image_token_index,
                                          embedding_layer=embedding_layer, vision_model=vision_model)
    
    # create the generator which slices input into chunks of AR (and pads if necessary)
    slice_inputs_gen_obj = llm_slice_inputs_for_inference(max_input_tokens=ARN if static_shape else input_ids.shape[-1],
                                                          model_context_len=context_length,
                                                          inputs_embeds=inputs_embeds,
                                                          past_seen_tokens=kv_length)
    
    # create token type id slices for bidirectional attention mask
    token_type_ids_slices = None
    if token_type_ids is not None and kv_length == 0:
        token_type_ids_slices = slice_token_type_ids(token_type_ids, max_input_tokens=ARN if static_shape else input_ids.shape[-1])
    
    for i, inputs in enumerate(slice_inputs_gen_obj):
        inputs_embeds_slice = inputs['inputs_embeds_slice']
        attn_mask_slice = inputs['attn_mask_slice']
        position_ids_slice = inputs['position_ids_slice']
        token_type_ids_slice = token_type_ids_slices[i] if token_type_ids_slices is not None else None
        
        if num_slices is not None and i >= num_slices:
            break
        
        if static_shape:
            prepared_inputs = adapted_model_prepare_inputs_for_static_shapes(self,inputs_embeds_slice=inputs_embeds_slice,
                                                                             attn_mask_slice=attn_mask_slice, 
                                                                             position_ids_slice=position_ids_slice,
                                                                             outputs=outputs, token_type_ids=token_type_ids_slice)
        else:
            prepared_inputs = adapted_model_prepare_inputs_for_dynamic_shapes(self, inputs_embeds_slice=inputs_embeds_slice,
                                                                              attn_mask_slice=attn_mask_slice, 
                                                                              position_ids_slice=position_ids_slice,
                                                                              outputs=outputs, token_type_ids=token_type_ids_slice)
        
        cur_outputs = self.model(**prepared_inputs)
        if not static_shape:
            cur_outputs = (self.lm_head(cur_outputs[0]),) + cur_outputs[1:]

        ############# KV$ management outside the self.model #####################
        outputs['past_key_values'] = llm_update_kv_cache(unpadded_past_kv=outputs['past_key_values'],
                                                         current_key_values=cur_outputs[-1],
                                                         key_concat_axis=KEY_CONCAT_AXIS,
                                                         value_concat_axis=VALUE_CONCAT_AXIS,
                                                         inputs_embeds_slice=inputs_embeds_slice,
                                                         pad_to_left=pad_to_left)
        
        if static_shape:
            # Replenish the rotating eviction cache as needed
            self.rotating_eviction_cache = replenish_rotating_index_cache(cache_length=llm_config.sliding_window - ARN,
                                                                          num_kv_heads=llm_config.num_key_value_heads,
                                                                          head_dim=head_dim,
                                                                          rotating_eviction_cache=self.rotating_eviction_cache)

            num_sliding_extra_kvs = outputs['past_key_values'][layer_indices_to_perform_sliding_eviction[0]][1].shape[2] - (llm_config.sliding_window - ARN)
            if num_sliding_extra_kvs > 0:
                # if exceeding KV$ for sliding layer, evict the oldest KV$ out
                outputs['past_key_values'], self.rotating_eviction_cache = llm_scatter_exceeded_kv_using_rotating_eviction(self.rotating_eviction_cache,
                                                                                                                           outputs['past_key_values'],
                                                                                                                           num_sliding_extra_kvs,
                                                                                                                           KEY_CONCAT_AXIS,
                                                                                                                           VALUE_CONCAT_AXIS,
                                                                                                                           layer_indices_to_perform_sliding_eviction)
        else:
            # assuming old KV$ is on the left side, we want to remove the older KV from the left, the size left after trimming should be sliding_window-1
            input_tensor = torch.randn((1, llm_config.sliding_window-1))
            outputs['past_key_values'] = trim_current_kv(outputs['past_key_values'], input_tensor, KEY_CONCAT_AXIS, VALUE_CONCAT_AXIS,
                                                         layer_indices_to_perform_trimming=layer_indices_to_perform_sliding_eviction)
        
        ############# Logit management outside the self.model #####################

        lm_logits = llm_trim_pad_logits(cur_logits=cur_outputs[0],
                                        inputs_embeds_slice=inputs_embeds_slice,
                                        pad_to_left=pad_to_left)
        bsz, _, dim = lm_logits.shape
        outputs['logits'] = torch.cat(
                (outputs.get('logits', torch.zeros((bsz, 0, dim), device=lm_logits.device)), lm_logits),
                dim=1)

    if return_dict:
        return CausalLMOutputWithPast(
            loss=outputs.get('loss', None),
            logits=outputs.get('logits', None),
            past_key_values=outputs.get('past_key_values', None),
            hidden_states=None,
            attentions=None,
        )
    return outputs['logits'], outputs['past_key_values']

##### 3.6 Adapted FP model eval

In [None]:
import types

if run_ppl_eval:
    model.language_model.forward = types.MethodType(adapted_model_forward, model.language_model)
    with event_marker(f"Adapted FP model eval", flush_ram=True):
        with place_model(model.language_model, torch.device('cuda')), place_model(vision_model, torch.device("cuda")), place_model(embedding_layer, torch.device("cuda")):
            adapted_ppl = llm_evaluate_ppl_with_dataloader(model=model.language_model, dataloader=test_dataloader,
                                                           model_forward_kwargs={"embedding_layer": embedding_layer, "vision_model": vision_model})
        llm_lib_log_metric(ModelType.adapted_model, Metric.ppl, adapted_ppl, model_name="base")
        print(f"PPL score of Adapted HF FP model = {adapted_ppl}")

    # Revert forward passes for model preparation
    model.language_model.forward = types.MethodType(Gemma3ForCausalLM.forward, model.language_model)

### 4. Model Sample Input

In [None]:
from aimet_torch.utils import change_tensor_device_placement

def get_dummy_data(device="cuda", dtype=torch.float32, return_dict=False):
    inputs_embeds = torch.rand(1, ARN, llm_config.hidden_size, device=device, dtype=dtype)
    attn_mask = torch.ones((1, ARN), device=device, dtype=dtype)
    position_ids = torch.cumsum(attn_mask, dim=1) - 1
    outputs = {'past_key_values': None}
    dummy_input = adapted_model_prepare_inputs_for_static_shapes(model.language_model, inputs_embeds, attn_mask, position_ids, outputs)
    for val in dummy_input:
        dummy_input[val] = change_tensor_device_placement(dummy_input[val], device)
    if not return_dict:
        dummy_input = tuple(dummy_input.values())

    return dummy_input

### 5. Prepare model using QAIRT model preparer pro

#### 5.1 KVCache MHA model preparation

In [None]:
import time
from qti.aisw.preparer_api import prepare_model
from qti.aisw.emitter.utils.torch_utils import load_torch_model_using_safetensors

from genai_lib.llm.model_preparation_utils import llm_build_preparer_converter_args
from genai_lib.llm.utils import llm_model_input_output_names

model.language_model.num_logits_to_return = ARN # configuring the model for KVCache mode

skip_prepare = False  # whether to skip model prepare and use existing prepared model
if skip_prepare:
    prepare_path = "<path to existing prepared model folder when skip_prepare=True>"
else:
    prepare_path = os.path.join(output_dir, 'prepare')
os.makedirs(prepare_path, exist_ok=True)
prepare_filename = f'{model_name}_kvcache_{llm_config.num_hidden_layers}_layer'

if skip_prepare:
    with event_marker(f"KVCache load pre-prepared {prepare_filename}", flush_ram=True):
        prepared_model_path = os.path.join(prepare_path, f'{prepare_filename}.py')
        if not os.path.exists(prepared_model_path):
            raise ValueError(f"prepared artifacts not found in {prepare_path}")
        else:
            print(f'WARNING: preparation skipped for model={prepare_filename}, prepared at {time.ctime(os.path.getmtime(prepared_model_path))}')
            prepared_model = load_torch_model_using_safetensors(path=prepare_path, filename=prepare_filename, model_name=prepare_filename)

else:
    dummy_input = get_dummy_data(device=model.language_model.model.device, dtype=model.language_model.dtype, return_dict=True)
    input_names, output_names = llm_model_input_output_names(llm_config.num_hidden_layers, use_input_embedding = True)
    if enable_right_padding:
        input_names += ["cache_index"]
        input_names += ["swa_cache_index"]
    input_names += ['swa_attention_mask']
    input_names += ['swa_position_ids']  # sliding window attention RoPE embeddings
    # Build converter args
    # TODO: temporary fix - SWA RoPE embeddings (swa_position_ids) manually added to llm_build_preparer_converter_args in genai-lib (get the fix from the new commits)
    converter_args = llm_build_preparer_converter_args(llm_config.num_hidden_layers, input_names, use_qairt_mpp=True)
    with event_marker("KVCache prepare model", flush_ram=True):
        if __name__ == '__main__': # We use the main guard to prevent child processes from re-running the top-level code
            prepared_model = prepare_model(model.language_model,
                                          dummy_input,
                                          model_name=prepare_filename,
                                          filename=prepare_filename,
                                          path=prepare_path,
                                          input_names=input_names,
                                          output_names=output_names,
                                          onnx_export_args={"opset_version":17},
                                          keep_original_model_structure=False, # Flatten the model to enable weight-sharing by setting `keep_original_model_structure = False\n",
                                          converter_args=converter_args,
                                          order_inputs=True,
                                          order_outputs=True,
                                          skipped_optimizers=['eliminate_common_subexpression',
                                                              'eliminate_nop_with_unit', 
                                                              'eliminate_duplicate_initializer'
                                                             ],
                                           return_prepare_model=True,
                                          )
        else:
            raise Exception("Killing multiprocessing spawn started by Converter during model preparation.")

### 6. Evaluation of prepared models

#### 6.1 Changes to HuggingFace model to work with the prepared model
Replace the model inside the HuggingFace model with the prepared model.
Note that the prepared model already fuses model.model and model.lm_head 
into one, so here we simply set model.lm_head to None

In [None]:
del model.language_model.model
del model.language_model.lm_head

model.language_model.model = prepared_model
model.lm_head = None

model.language_model.forward = types.MethodType(adapted_model_forward, model.language_model)

#### 6.2 Evaluation of perplexity score using prepared model

In [None]:
if run_ppl_eval:
    with event_marker("KVcache prepared FP eval", flush_ram=True):
        with place_model(prepared_model, torch.device("cuda")), place_model(vision_model, torch.device("cuda")), place_model(embedding_layer, torch.device("cuda")):
            model.language_model.model = prepared_model
            prepared_kvcache_ppl = llm_evaluate_ppl_with_dataloader(model=model.language_model, dataloader=test_dataloader,
                                                                    model_forward_kwargs={"embedding_layer": embedding_layer, "vision_model": vision_model})
        llm_lib_log_metric(ModelType.prepared_model, Metric.ppl, prepared_kvcache_ppl, model_name="base")
        print(f"ppl score of KVCACHE prepared fp model = {prepared_kvcache_ppl}")

### 7. Quantization

The _Quantization_ step is the primary focus of this notebook, this section could be modified to execute various quantization experiments.


In [None]:
# Config for quantization
apply_lm_head_seqmse = False
apply_decoder_seqmse = True
apply_lm_head_lpbq = False
apply_decoder_lpbq = False
num_seqmse_candidates = 60
num_seqmse_batches = 20
num_calibration_batches = 20
embedding_table_bitwidth = 16

#### 7.1 Create quantsim configured for QNN HTP target 


In [None]:
import inspect
from copy import deepcopy
from aimet_common.defs import QuantScheme
from aimet_torch.v2.quantsim import QuantizationSimModel

if apply_lm_head_seqmse or apply_decoder_seqmse:
    import functools

    def copy_model_with_shared_weights(source_model):
        target_model = deepcopy(source_model)
        for name, source_parameter in source_model.named_parameters():
            pre, _, post = name.rpartition('.')
            pre_obj = functools.reduce(getattr, [target_model] + pre.split('.')) if pre else target_model
            setattr(pre_obj, post, source_parameter)
        return target_model

    # Create copy of fp model defintion for SeqMSE and/or LoRA
    fp_prepared_model = copy_model_with_shared_weights(prepared_model)

dummy_input = get_dummy_data(device="cuda", dtype=model.language_model.dtype, return_dict=True)

dummy_input_sorted = {}
sig = inspect.signature(prepared_model.forward)
for key in list(sig.parameters.keys()):
    dummy_input_sorted[key] = dummy_input[key]
dummy_input = tuple(dummy_input_sorted.values())

with event_marker("create KVCache Quantsim"):
    with place_model(prepared_model, "cuda"):
        quantsim = QuantizationSimModel(model=prepared_model,
                                        quant_scheme=QuantScheme.post_training_tf,
                                        dummy_input=dummy_input,
                                        default_output_bw=16,
                                        default_param_bw=4,
                                        in_place=True,
                                        config_file=htp_config_file)

#### 7.2 Setting 16bit x 8bit matmuls
To keep key and value tensors as 8 bits, reducing data I/O costs associated with KV-cache orchestration.

In [None]:
from aimet_torch.v2.experimental.quantsim_utils import set_matmul_second_input_producer_to_8bit_symmetric

set_matmul_second_input_producer_to_8bit_symmetric(quantsim)

#### 7.3 Concat encoding unification
configuring concat ops to have shared encoding on input and output activations.

In [None]:
from aimet_torch.v2.experimental import propagate_output_encodings
from aimet_torch.nn.modules import custom as aimet_ops

propagate_output_encodings(quantsim, aimet_ops.Concat)

#### 7.4 Manual Mixed Precision
applying mixed precision configuration to ops 

In [None]:
import json
from llm_utils.mixed_precision_overrides import ManualQuantsimMixedPrecisionConfig
from aimet_torch.v2.nn.modules.custom import QuantizedRmsNorm
from aimet_torch.v2.quantization.affine import QuantizeDequantize

def apply_manual_mixed_precision(sim):
    with open("./config/mixed_precision_config/exceptions.json", "r") as f_in:
        mixed_precision_config = json.load(f_in)

    # Customize mixed precision llm_config based on user parameters
    for entry in mixed_precision_config['name_list']:
        if "model_embed_tokens_Gather" in entry['module_name']:
            entry['exceptions']['param_exceptions']['bitwidth'] = embedding_table_bitwidth
            break

    quantsim_adjuster = ManualQuantsimMixedPrecisionConfig(mixed_precision_config_file=mixed_precision_config)
    quantsim_adjuster.apply_exceptions(sim)

    # Make RMSNorm encodings per-tensor (they default to per-channel)
    for name, qmodule in sim.named_qmodules():
        if isinstance(qmodule, QuantizedRmsNorm):
            qmodule.param_quantizers['weight'] = QuantizeDequantize(shape=(), bitwidth=16, symmetric=False).to(qmodule.weight.device)

apply_manual_mixed_precision(quantsim)

#### 7.5 Apply Block Quantization

Swapping needed modules' weight quantizers to LPBQ quantizers

In [None]:
from aimet_torch.v2.nn.true_quant import QuantizedConv2d
from aimet_torch.v2.quantsim.config_utils import set_grouped_blockwise_quantization_for_weights

def apply_lpbq(sim):
    arg = None

    if apply_decoder_lpbq and apply_lm_head_lpbq:
        arg = lambda module: isinstance(module, QuantizedConv2d)
    elif apply_decoder_lpbq:
        arg = lambda module: isinstance(module, QuantizedConv2d) and module.param_quantizers['weight'].bitwidth == 4
    elif apply_lm_head_lpbq:
        lm_head_modules = [qmodule for name, qmodule in sim.named_qmodules() if "lm_head" in name]
        arg = lambda module: module in lm_head_modules and isinstance(module, QuantizedConv2d)

    if arg:
        BLOCK_QUANT_SIZE = 64
        set_grouped_blockwise_quantization_for_weights(sim=sim,
                                                       arg=arg,
                                                       bitwidth=4,
                                                       symmetric=True,
                                                       decompressed_bw=8,
                                                       block_size=BLOCK_QUANT_SIZE,
                                                       block_grouping=-1)


if apply_decoder_lpbq or apply_lm_head_lpbq:
    apply_lpbq(quantsim)

#### 7.7 Sequential MSE
applying sequential MSE technique to optimize parameter encodings

In [None]:
import math
from aimet_torch.v2.seq_mse import apply_seq_mse, SeqMseParams

def perform_seqmse(sim, fp_model):
    def _seq_mse_forward_fn(_model, inputs):
        model.language_model.model = _model
        inputs.update({"embedding_layer": embedding_layer, "vision_model": vision_model})
        model.language_model(**inputs)

    lm_head_fp_modules = [module
                          for module_name, module in fp_model.named_modules()
                          if isinstance(module, torch.nn.Conv2d) and 'lm_head' in module_name]
    decoder_fp_modules = [module
                          for module_name, module in fp_model.named_modules()
                          if isinstance(module, torch.nn.Conv2d) and 'lm_head' not in module_name]

    if apply_decoder_seqmse and apply_lm_head_seqmse:
        modules_to_exclude = []
    elif apply_decoder_seqmse:
        modules_to_exclude = lm_head_fp_modules
    elif apply_lm_head_seqmse:
        modules_to_exclude = decoder_fp_modules
    
    seqmse_params = SeqMseParams(num_batches=num_seqmse_batches,
                                 inp_symmetry='symqt',
                                 num_candidates=num_seqmse_candidates,
                                 loss_fn='mse',
                                 forward_fn=_seq_mse_forward_fn)

    with place_model(sim.model, torch.device("cuda")), place_model(fp_model, torch.device("cuda")), \
         place_model(embedding_layer, torch.device("cuda")), place_model(vision_model, torch.device("cuda")):
        with torch.no_grad():
            apply_seq_mse(fp_model, sim, llava_dataset, seqmse_params, modules_to_exclude=modules_to_exclude)
            # apply_seq_mse(fp_model, sim, train_dataloader, seqmse_params, modules_to_exclude=modules_to_exclude)


if apply_decoder_seqmse or apply_lm_head_seqmse:
    with event_marker("SeqMSE for base model"):
        perform_seqmse(quantsim, fp_prepared_model)

    del fp_prepared_model

#### 7.8 Calibration


In [None]:
from tqdm import tqdm
from aimet_torch.v2.experimental.quantsim_utils import clip_weights_to_7f7f

def perform_calibration(sim, calibration_dataloader, num_batches=200):
    def _calibration_forward_fn(sim_model, kwargs):
        model.language_model.model = sim_model
        data_loader = kwargs['data_loader']
        max_iterations = kwargs['num_batches']
        for batch_id, batch in enumerate(tqdm(data_loader, total=max_iterations)):
            if batch_id < max_iterations:
                inputs = change_tensor_device_placement(batch, device="cuda")
                inputs.update({"embedding_layer": embedding_layer, "vision_model": vision_model})
                _ = model.language_model(**inputs)
            else:
                break

    kwargs = {
        'data_loader': calibration_dataloader,
        'num_batches': num_batches
    }

    with place_model(sim.model, "cuda"), place_model(embedding_layer, torch.device("cuda")), place_model(vision_model, torch.device("cuda")):
        with torch.no_grad():
            sim.compute_encodings(_calibration_forward_fn, kwargs)

    clip_weights_to_7f7f(sim)


with event_marker("compute encoding for base model", flush_ram=True):
    perform_calibration(quantsim, llava_dataset, num_batches=num_calibration_batches)

#### 7.9 Eval KV Cache quantsim model

In [None]:
if run_ppl_eval:
    with event_marker("KV cache quantsim model eval", flush_ram=True):
        with place_model(quantsim.model, torch.device("cuda")), place_model(vision_model, torch.device("cuda")), place_model(embedding_layer, torch.device("cuda")):
            model.language_model.model = quantsim.model
            sim_ppl = llm_evaluate_ppl_with_dataloader(model=model.language_model, dataloader=test_dataloader,
                                                       model_forward_kwargs={"embedding_layer": embedding_layer, "vision_model": vision_model})
        llm_lib_log_metric(ModelType.qsim_model, Metric.ppl, sim_ppl, model_name="base")
        print(f"ppl score of KVCACHE quantsim model = {sim_ppl}")

### 8. Export
the pipeline call below would export onnx model, encoding and test vector for KVCache models.

#### 8.1 Export Onnx and Encodings

In [None]:
from aimet_torch import onnx_utils
from aimet_torch.onnx_utils import OnnxExportApiArgs

def export_onnx_and_encodings(sim, onnx_dir, filename_prefix, generate_updatable_tensors = False):
    # Get input names and output names. This is different from the input names and output names we created for model preparation. 
    # The reason for this difference stems from the fact that we want the prepared model to have inputs and outputs named similar to original HF model
    # ONNX does not allow tupling the inputs or outputs and we want to give meaningful names to the input and output tensors in the ONNX graph
    input_names, output_names = llm_model_input_output_names(llm_config.num_hidden_layers,
                                                             use_position_embedding_input=True,
                                                             separate_tuple_input_output=True,
                                                             use_input_embedding=True)
    if enable_right_padding:
        input_names += ["cache_index"]
        input_names += ["swa_cache_index"]
    input_names += ["swa_attention_mask"]
    input_names += ['swa_position_ids_cos', 'swa_position_ids_sin']  # sliding window attention RoPE embeddings

    # Replace past_ with swa_ in the input and output names of swa layers
    for index, name in enumerate(input_names + output_names):
        if name.startswith("past_key") or name.startswith("past_value"):
            kv_index = int([i for i in name.split("_") if i.isdigit()][0])
            if ((kv_index + 1) % llm_config.sliding_window_pattern) != 0:
                if index < len(input_names):
                    input_names[index] = name.replace(f"past_", f"swa_")
                else:
                    output_names[index - len(input_names)] = name.replace(f"past_", f"swa_")

    # Setting this flag to False means that the prepared model will be flattened
    onnx_utils.EXPORT_TO_ONNX_DIRECT = True
    
    onnx_api_args = OnnxExportApiArgs(input_names=input_names, output_names=output_names, opset_version=17)
    onnx_utils.RESTORE_ONNX_MODEL_INITIALIZERS = True

    dummy_input = get_dummy_data(device="cpu", dtype=model.language_model.dtype, return_dict=True)
    dummy_input_sorted = {}
    sig = inspect.signature(sim.model.forward)
    for key in list(sig.parameters.keys()):
        dummy_input_sorted[key] = dummy_input[key]
    dummy_input = tuple(dummy_input_sorted.values())

    os.makedirs(onnx_dir, exist_ok=True)
    with torch.no_grad():
        with place_model(sim.model, torch.device("cpu")):
            sim.export(onnx_dir, filename_prefix, dummy_input, onnx_export_args=onnx_api_args,
                       export_model=True, filename_prefix_encodings=filename_prefix)


with event_marker(f"KVCache export onnx and encodings for base generation model", flush_ram=True):
    base_onnx_dir = os.path.join(output_dir, 'export', 'onnx')
    base_filename_prefix = f"{model_name}"
    export_onnx_and_encodings(quantsim, base_onnx_dir, base_filename_prefix)

# Exporting Tokenizer
tokenizer_dir = os.path.join(output_dir, 'tokenizer')
os.makedirs(tokenizer_dir, exist_ok=True)
tokenizer.save_pretrained(tokenizer_dir)

# Export embedding bin
inputs_embeds = embedding_layer.weight * embedding_layer.embed_scale
lut_array= inputs_embeds.detach().cpu().numpy()
lut_array.tofile(os.path.join(output_dir,"embedding_fp32.bin"))

#### 8.2 Generating test vectors

In [None]:
from itertools import islice
from genai_lib.llm.test_vectors import generate_test_vectors

def generate_test_vectors_for_usecase(sim, output_dir, num_test_vectors = 1, slice_num = 0):
    test_vector_layers = [
        "model_embed_tokens_Gather",
        "model_layers_\\d+_Add_1"
    ]

    device = torch.device('cuda')
    with place_model(sim.model, device), place_model(vision_model, torch.device("cuda")), place_model(embedding_layer, torch.device("cuda")):
        for index, batch in enumerate(train_dataloader):
            if index >= num_test_vectors:
                break
            
            # Consider LLM data (wikitext) here
            inputs_embeds_slice = embedding_layer(batch['input_ids'][..., :ARN].to(device=torch.device('cuda')))
            attention_mask_slice = torch.ones((inputs_embeds_slice.shape[0], ARN), dtype = torch.long, device=torch.device('cuda'))
            position_ids_slice = (torch.cumsum(attention_mask_slice, dim=1) - 1).to(device=torch.device('cuda'))
            outputs = {'past_key_values': None}
            model_inputs = adapted_model_prepare_inputs_for_static_shapes(model.language_model, 
                                                                          inputs_embeds_slice=inputs_embeds_slice, 
                                                                          attn_mask_slice=attention_mask_slice, 
                                                                          position_ids_slice=position_ids_slice,
                                                                          outputs=outputs)
            
            generate_test_vectors(sim=sim, model_inputs=model_inputs, output_dir=output_dir,
                                  batch_index=index, test_vector_layers=test_vector_layers)

with event_marker("generate base model test vectors"):
    test_vec_dir = os.path.join(output_dir, 'export', 'test_vectors')
    generate_test_vectors_for_usecase(quantsim, os.path.dirname(test_vec_dir),
                                      slice_num=1 + (llm_config.sliding_window + ARN - 1) // ARN)

#### 8.3 Save Quantsim Model

In [None]:
import pickle as pkl

# Increase recursion depth limit to save full model
sys.setrecursionlimit(100000)

# base_dir = os.path.join(output_dir, 'quantsim')
with event_marker("save quantsim model"), open(f"{output_dir}/{prepare_filename}.pkl", 'wb') as file:
    pkl.dump(quantsim, file)
    

### Summary

In [None]:
from genai_lib.common.debug.profiler import EventProfiler
from genai_lib.common.debug.recipe_logger import dump_logs_to_json

EventProfiler().report()
EventProfiler().json_dump(os.path.join(output_dir, 'profiling_stats.json'))
dump_logs_to_json()

Copyright (c) 2024 Qualcomm Technologies, Inc. and/or its subsidiaries.