SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>\
SPDX-License-Identifier: Apache-2.0

# Model Training Example
---

This notebook shows how to use the Neural Graphics Model Gym to run training, with or without finetuning, of the Neural Super Sampling model.

### Enviroment Setup
**Before running the notebook:**

Check the environment prerequisites and follow the set up instructions in the [README.md](../../README.md)

In [None]:
# Import the ng_model_gym package.
import ng_model_gym as ngmg
from ng_model_gym import TrainEvalMode

### Inspect Configuration Parameters

In [None]:
from pathlib import Path
import logging

# Set training/evaluation configuration parameters using a .json file.
# Let's use the default configuration file provided in the package.
cfg_path = Path("../../assets/nss/tutorial_config.json")

# Load config object from the .json file
config = ngmg.load_config_file(cfg_path)

# Enable minimal logging for ng_model_gym.
# For more detailed logs set log_level to logging.INFO or logging.DEBUG.
ngmg.logging_config(config, log_level=logging.ERROR)

In [None]:
from rich import print as rprint

# Let's inspect the loaded config object.
rprint(config)

For more information on each of the parameters in the config object, run the following command and note their *description* fields, expected type etc:

In [None]:
ngmg.print_config_options()

The config object is mutable, so the default parameters may be overwritten as desired:

In [None]:
# Set specific train dataset
# Note: such a small dataset is only suitable for ease of demonstration, and will not achieve high quality
config.dataset.path.train = Path("../../data/nss/datasets/train")

# Overrides to allow for a smaller dataset and shorter training run
config.dataset.recurrent_samples = 4
config.train.batch_size = 4
config.train.fp32.number_of_epochs = 2

### Run Training

We can now call the do_training function to train the model on 4 frames, as per our defined config.

In [None]:
# We pass in the modified config object, and set the training mode to train in fp32.
trained_model = ngmg.do_training(config, training_mode=TrainEvalMode.FP32)

To run validation and visualise the output of the model, please see the [evaluation tutorial](model_evaluation_example.ipynb).