SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>\
SPDX-License-Identifier: Apache-2.0

# Quantization Aware Training Example
---
This notebook shows how to use the Neural Graphics Model Gym to run quantization aware training of the Neural Super Sampling model, as well as how to export to VGF.

### Environment Setup
**Before running the notebook:**

Check the environment prerequisites and follow the set up instructions in the [README.md](../../README.md)

In [None]:
# Import the ng_model_gym package.
import ng_model_gym as ngmg

### Inspect Configuration Parameters

In [None]:
import logging
from pathlib import Path

# Set training/evaluation configuration parameters using a .json file.
# Let's use the default configuration file provided in the package.
cfg_path = Path("../../assets/nss/tutorial_config.json")

# Load config object from the .json file
config = ngmg.load_config_file(cfg_path)

# Enable minimal logging for ng_model_gym.
# For more detailed logs set log_level to logging.INFO or logging.DEBUG.
ngmg.logging_config(config, log_level=logging.ERROR)

In [None]:
from rich import print as rprint

# Let's inspect the loaded config object.
rprint(config)

For more information on each of the parameters in the config object, run the following command and note their *description* fields, expected type etc:

In [None]:
ngmg.print_config_options()

The config object is mutable, so the default parameters may be overwritten as desired:

In [None]:
# Set specific train dataset
# Note: such a small dataset is only suitable for ease of demonstration, and will not achieve high quality
config.dataset.path.train = Path("../../data/nss/datasets/train")

# Overrides to allow for a smaller dataset and shorter training run
config.dataset.recurrent_samples = 4
config.train.batch_size = 4
config.train.qat.number_of_epochs = 2

# Enable fine-tuning from an existing model file
config.train.finetune = True

# Path to previously trained .pt model file. For this demo, we will use the provided pretrained .pt file
config.train.pretrained_weights = "../../data/nss/weights/nss_v0.1.0_fp32.pt"

**Note:** depending on your machine, you may need to reduce the train batch size when running QAT training compared to fp32 training. ExecuTorch requires additional GPU overhead which could result in OOM errors if the training dataset configuration approaches your GPU capacity.

### Run Quantization Aware Training

We can now call the do_training function to quantize aware train the model for 2 epochs on 4 reccurrent frames, as per our defined config.

In [None]:
from ng_model_gym import TrainEvalMode

# We pass in the modified config object, and set the training mode to perform QAT fine-tuning on the pretrained weights.
qat_ckpt_path = ngmg.do_training(config, training_mode=TrainEvalMode.QAT_INT8)

### Export to VGF

We can now export our model using Executorch to a VGF file ready to be used in your game engine of choice.

In [None]:
from ng_model_gym import ExportType

ngmg.do_export(config, qat_ckpt_path, export_type=ExportType.QAT_INT8)

The exported VGF file should now be visible in the `tutorials/nss/output/vgf` directory, by default.