SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>\
SPDX-License-Identifier: Apache-2.0

# Model Fine-Tuning Example
---

This notebook shows how to use the Neural Graphics Model Gym to run fine-tuning of the Neural Super Sampling model.

### Prerequisites
Familiarity with the concepts in
- [model training tutorial](model_training_example.ipynb)
- [model evaluation tutorial](model_evaluation_example.ipynb)

### Environment Setup
**Before running the notebook:**

Check the environment prerequisites and follow the set up instructions in the [README.md](../../README.md)

In [None]:
# Import the ng_model_gym package.
import ng_model_gym as ngmg

### Inspect Configuration Parameters

In [None]:
import logging
from pathlib import Path

# Load the fine-tuning config (some parameters are set differently)
cfg_path = Path("../../assets/nss/finetune_config.json")
config = ngmg.load_config_file(cfg_path)

# Enable minimal logging for ng_model_gym.
# For more detailed logs set log_level to logging.INFO or logging.DEBUG.
ngmg.logging_config(config, log_level=logging.ERROR)

In [None]:
from rich import print as rprint

# Let's inspect the loaded config object.
rprint(config)

### Run Fine-Tuning

To configure fine-tuning, update the respective config parameters

In [None]:
# Indicate we will start from pretrained weights and fine-tune the model to the sequence of frames
config.train.finetune = True

# Run validation at the end of each epoch to estimate the image quality
config.train.perform_validate = True

# Path to previously trained .pt model file. For this demo, we will use the provided pretrained .pt file
config.train.pretrained_weights = "../../data/nss/weights/nss_v0.1.0_fp32.pt"

Then as before, we can execute training using the do_training function.

In [None]:
from ng_model_gym import TrainEvalMode

finetuned_model, final_ckpt_path = ngmg.do_training(
    config, training_mode=TrainEvalMode.FP32
)

### Compare Evaluation Metrics

Now that we have a model which is fine-tuned to our specific training sequence, we can compare its evaluation metrics with those of the pretrained weights to confirm that the fine-tuned model performs better on this scene.

First, get the evaluation metrics for the pretrained weights:

In [None]:
ngmg.do_evaluate(
    config, Path("../../data/nss/weights/nss_v0.1.0_fp32.pt"), TrainEvalMode.FP32
)

And now compare the final checkpoint from fine-tuning with the evaluation results for the pretrained weights:

In [None]:
ngmg.do_evaluate(config, final_ckpt_path, TrainEvalMode.FP32)

For even better results, we can increase `config.dataset.recurrent_samples` to 16 before running fine-tuning. 

Note: this will significantly increase the execution time of the do_training function.