# SpinQuant-R2 Optimization for Gemma3-4B language model

---
### Required packages
The notebook assumes AIMET and Gemma3 related packages are already installed.

In [None]:
# Install packages only if running in jupyter notebook mode
if hasattr(__builtins__,'__IPYTHON__'):
    !sudo -H apt-get -qq update
    !sudo -H apt-get -qq install libc++-dev
    !sudo -H pip install --quiet --upgrade --root-user-action=ignore --no-cache-dir transformers==4.50.0
    !sudo -H pip install --quiet --upgrade --root-user-action=ignore --no-cache-dir tokenizers==0.21.4

### Overall flow
This notebook covers the following
1. Instantiate model and dataloaders
2. Apply SpinQuant-R2
3. Save optimized model

### 1. Instantiate model and dataloaders

In [None]:
import sys, os
import copy
from tqdm import tqdm
import torch

from genai_lib.common.debug.recipe_logger import recipe_dump_init
from genai_lib.common.debug.recipe_logger import llm_lib_log_env_info
from transformers import AutoConfig, AutoTokenizer, AutoProcessor

from transformers import set_seed
set_seed(0)

#======================Configurable setting by users================================
context_length = 8192
run_ppl_eval = True
run_spinquant_r2 = True

cache_dir='/tmp/cache_dir'
output_dir = '/tmp/output_dir'  # point to where the export artifacts of this notebook to be saved
os.makedirs(output_dir, exist_ok=True)

# HF configs
model_name = 'gemma_4b'
model_id="google/gemma-3-4b-it"  # HF checkpoint

llm_config = AutoConfig.from_pretrained(model_id, cache_dir=cache_dir, trust_remote_code=True)


# Recipe_logger: Initialize the logger and log environment details 
recipe_dump_init(output_dir)
llm_lib_log_env_info()

#### 1.1 Instantiate the HuggingFace model

In [None]:
import torch
from transformers.models.gemma3 import modeling_gemma3
from genai_lib.common.debug.profiler import event_marker

with event_marker('Load FP model'):
    model = modeling_gemma3.Gemma3ForConditionalGeneration.from_pretrained(model_id, config=llm_config, cache_dir=cache_dir)

    os.environ['TOKENIZERS_PARALLELISM'] = '0'
    tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir, use_fast=True, trust_remote_code=True)
    processor = AutoProcessor.from_pretrained(model_id, cache_dir=cache_dir, trust_remote_code=True)
    ## Adjust the tokenizer to limit to context_length
    tokenizer.model_max_length = context_length

#### 1.2 Instantiate Dataloaders

In [None]:
from llm_utils.wikitext_dataloader import get_wiki_dataset

# Wikitext dataloader
train_dataloader, test_dataloader, _ = get_wiki_dataset(context_length, tokenizer, cache_dir=cache_dir)

#### 1.3 HuggingFace FP model eval

In [None]:
from aimet_torch.utils import place_model
from genai_lib.llm.evaluation_utils import llm_evaluate_ppl_with_dataloader
from genai_lib.common.debug.recipe_logger import llm_lib_log_property, Property
from genai_lib.common.debug.recipe_logger import llm_lib_log_metric, ModelType, Metric


# Recipe_logger: Log the context_length property and the metrics.
llm_lib_log_property({Property.context_length : context_length})

if run_ppl_eval:
    with event_marker("HuggingFace FP model eval"):
        with place_model(model, torch.device('cuda')):
            orig_ppl = llm_evaluate_ppl_with_dataloader(model=model.language_model, dataloader=test_dataloader)

    llm_lib_log_metric(ModelType.hf_model, Metric.ppl, orig_ppl, model_name="base")
    print(f"PPL score of HuggingFace FP model: {orig_ppl}")

### 2. Apply SpinQuant-R2

#### 2.1 Apply Rotation

Apply SpinQuant-R2 to v/o-proj weights (the weights are updated in place). This is to make the V-proj activations more quantization friendly since we are using 8-bits KV-cache on target and this causes large accuracy drop on this model. We do not apply R1 here since Gemma3 has post-RMSNorm for both attention and MLP blocks, which makes R1 non-mergable (i.e., can only be online rotation).

In [None]:
from llm_utils.spinquant.spinquant_utils import apply_spinquant_r2

if run_spinquant_r2:
    with event_marker('Apply SpinQuant R2'):
        with place_model(model, torch.device('cuda')):
            apply_spinquant_r2(model=model.language_model, config=llm_config.text_config)

#### 2.2 Evaluate SpinQuant FP Model

The SpinQuant model (i.e., with R2 merged to v/o-proj) should give the same FP PPL as the original model.

In [None]:
if run_ppl_eval:
    with event_marker("SpinQuant FP model eval"):
        with place_model(model, torch.device('cuda')):
            spin_ppl = llm_evaluate_ppl_with_dataloader(model=model.language_model, dataloader=test_dataloader)
    print(f"PPL score of SpinQuant FP model: {spin_ppl}")

### 3. Save optimized model

This notebook exports new model weights that are more quantization friendly, and stores them in the specified spinquant directory. To use the new weights for the quantization pipeline notebook (`gemma3_4b.ipynb`), pass in the spinquant directory path as the Model ID in the notebook.

In [None]:
spinquant_dir = os.path.join(output_dir, 'spinquant')
os.makedirs(spinquant_dir, exist_ok=True)

with event_marker("save optimized model", flush_ram=True):
    model.language_model.save_pretrained(spinquant_dir)
    tokenizer.save_pretrained(spinquant_dir)

### Summary

In [None]:
from genai_lib.common.debug.profiler import EventProfiler
EventProfiler().report()
EventProfiler().json_dump(os.path.join(output_dir, 'profiling_stats.json'))

Copyright (c) 2024 Qualcomm Technologies, Inc. and/or its subsidiaries.