SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>\
SPDX-License-Identifier: Apache-2.0

# Adding a Custom Model Example
---

This notebook shows how to add your own custom model to use within the Neural Graphics Model Gym.

### Environment Setup
**Before running the notebook:**

Check the environment prerequisites and follow the set up instructions in the [README.md](../README.md)

In [None]:
# Import the ng_model_gym package.
import ng_model_gym as ngmg

### Inspect Configuration Parameters

In [None]:
import logging
from pathlib import Path

# Set training/evaluation configuration parameters using a .json file.
# Let's use the default configuration file provided in the package.
cfg_path = Path("./assets/nss/tutorial_config.json")

# Load config object from the .json file
config = ngmg.load_config_file(cfg_path)

# Enable minimal logging for ng_model_gym.
# For more detailed logs set log_level to logging.INFO or logging.DEBUG.
ngmg.logging_config(config, log_level=logging.ERROR)

In [None]:
from rich import print as rprint

# Let's inspect the default model configuration
rprint(config.model)

### Add and register a new model

Add a new model within your project, making sure it is marked with the register_model() decorator.

We have added an example model for this notebook under assets/nss/custom_model_ex.py

In [None]:
# Show the custom model code
from IPython.display import Code, display

with open("./assets/custom_model_ex.py", "r", encoding="utf-8") as model_file:
    model_code = model_file.read()

display(Code(model_code, language="python"))

Now we have a new model, we need to import the file containing it in order for the model to be registered and usable within our workflows.

In [None]:
# Import the file containing the model
from assets import custom_model_ex  # noqa: F401 pylint: disable=unused-import

In [None]:
# List the registered models
# We can see our new model 'custom_model-v1' has been registered
from ng_model_gym.core.model.model_registry import MODEL_REGISTRY

MODEL_REGISTRY.list_registered()

### Using the model

#### Set the custom model in the config json

To use this new model, we need to update our config and set the model name and optional version which were used when registering it.

In [None]:
# Set the name and optional version which were used when registering
# e.g. @register_model(name="custom_model", version="1")
config.model.name = "custom_model"
config.model.version = "1"

# We can see the model configuration has been updated to our new model
rprint(config.model)

In [None]:
# We'll override some other configs for a smaller dataset and shorter training run
config.dataset.recurrent_samples = 4
config.train.batch_size = 4
config.train.fp32.number_of_epochs = 2

We can now run any of our workflows, such as training, using the custom model we have just added.

In [None]:
from ng_model_gym import TrainEvalMode

# We pass in the modified config object containing our model name and version
trained_model_path = ngmg.do_training(config, training_mode=TrainEvalMode.FP32)