# Notebook 6: Converting Customized Model Notebooks to Using GraphStorm CLIs 

Notebook 1 to 5 provides examples about how to use GraphStorm APIs to implement various GNN components and models. These notebooks can run in the GraphStrom Standalone mode, i.e., on a single CPU or GPU of a single machine. GraphStorm Standalone mode is good for quick model development and debugging on a relatively small graph dataset. To fully leverage GraphStorm's distributed model training and inference capacity, however, we need to convert code implemented on these notebook into Python scripts that can be launched with GraphStorm Command Line Interfaces (CLIs) that can handle extreme large graphs.

This notebook introduces the method of conversion, and explain the key components of the example Python scripts. For this notebook, we use the customized model developed in the [Notebook 4: Use GraphStorm APIs for Customizing Model Components](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_4_Customized_Models.html) as an example.

----

### Prerequisites

- GraphStorm. Please find [more details on installation of GraphStorm](https://graphstorm.readthedocs.io/en/latest/install/env-setup.html#setup-graphstorm-with-pip-packages).
- ACM data that has been created according to [Notebook 0: Data Preparation](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_0_Data_Prepare.html), and is stored in the `./acm_gs_1p/` folder.

## Brief Introduction and Run

In order to use GraphStorm CLIs, we need to put the customized model into a Python file, which can be called in the [Task-agnostic CLI for model training and inference](https://graphstorm.readthedocs.io/en/latest/cli/model-training-inference/single-machine-training-inference.html#task-agnostic-cli-for-model-training-and-inference) as an argument. We could build two files for model training and inference separately.

We can reuse most of the code about the customized `RGAT` module in Notebook 4 in the training and inference files, i.e., copy and paste the code of class `Ara_GatLayer`, `Ara_GatEncoder`, and `RgatNCModel` into the two files. 

For the training file, we can copy and paste the code of training pipeline, and enclose them in a `fit()` function. Similarly, for the inference file, we can copy and paste the code of infernece pipeline, and enclose them in a `infer()` function.

We have created the two files, named `demo_run_train.py` and `demo_run_infer.py` under the [GraphStorm API documentation folder](https://github.com/awslabs/graphstorm/tree/main/docs/source/api). With the two files, we can call GraphStorm's task-agnostic CLIs to run our customized model as shown below.

In [None]:
# download the example yaml configuration file
!wget -O acm_nc.yaml https://github.com/awslabs/graphstorm/raw/main/examples/use_your_own_data/acm_nc.yaml

# CLI for the customized RGAT model training
!python -m graphstorm.run.launch \
           --part-config ./acm_gs_1p/acm.json \
           --num-trainers 4 \
           --num-servers 1 \
           --num-samplers 0 \
           demo_run_train.py --cf acm_nc.yaml \
                             --save-model-path model_path/ \
                             --node-feat-name paper:feat author:feat subject:feat \
                             --num-epochs 5 \
                             --rgat-encoder-type ara

# CLI for the customized RGAT model inference
!python -m graphstorm.run.launch \
           --part-config ./acm_gs_1p/acm.json \
           --num-trainers 4 \
           --num-servers 1 \
           --num-samplers 0 \
           demo_run_infer.py --cf acm_nc.yaml \
                             --restore-model-path models/epoch-4 \
                             --save-prediction-path predictions/ \
                             --save-embed-path embeddings/ \
                             --node-feat-name paper:feat author:feat subject:feat \
                             --rgat-encoder-type ara

## CLI argument processing explanation
Compared to the code in [Notebook 4](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_4_Customized_Models.html), the majority of modifications in the two Python files is related to how to collect and parse GraphStorm CLI configurations. Unlike hard-coding some variables, e.g., `nfeats_4_modeling`, or setting fix input values, e.g., `label_field='label',` or `encoder_type='ara'`, we will need to provide these values via CLI configurations.

As shown in the above commands, there are three types of configurations passed to the GraphStorm task-agnostic command.

- [Launch CLI arguments](https://graphstorm.readthedocs.io/en/latest/cli/model-training-inference/configuration-run.html#launch-cli-arguments), which direclty follow the `graphstom.run.launch`.
- [Model training and inference configurations](https://graphstorm.readthedocs.io/en/latest/cli/model-training-inference/configuration-run.html#model-training-and-inference-configurations), which are predefined in GraphStorm. These configurations can be put into a yaml file and become the value of `--cf` argument follow the training or inference Python file. You can also set them as arguments too, which will overwrite the same configurations set in the yaml file.
- Configurations specified for customized modules, which are not predefined by GraphStorm, but are used only for the customized modules. These configurations should be defined as the arguments for training or inference Python files.

If following the argument placement convension, it is easy to collect and parse them. Below we show the main entrance method of the `demo_run_train.py` file.

In [3]:
import argparse
from graphstorm.config import get_argument_parser

......

if __name__ == '__main__':
    # Leverage GraphStorm's argument parser to accept configuratioin yaml file
    arg_parser = get_argument_parser()

    # parse all arguments and split GraphStorm's built-in arguments from the customized ones
    gs_args, unknown_args = arg_parser.parse_known_args()
    print(f'GS arguments: {gs_args}')

    # create a new argument parser dedicated for customized arguments
    cust_parser = argparse.ArgumentParser(description="Customized Arguments")
    # add customized arguments
    cust_parser.add_argument('--rgat-encoder-type', type=str, default="ara")
    cust_args = cust_parser.parse_args(unknown_args)
    print(f'Customized arguments: {cust_args}')

    # use both argument sets in our main function
    fit(gs_args, cust_args)

GraphStorm's config module provides a `get_argument_parser` method, which can create a argument parser, e.g., `arg_parser`, dedicated to process GraphStorm launch CLI arguments and model training and inference configurations. Using its `parse_known_args()` method, the argument parser can extract all GraphStorm built-in configurations, and also return other customized arguments, which can be processed by another argument parse, e.g., the `cust_parser`. We can then pass these arguments to the corresponding methods. Please refer to [get_argument_parser API document](https://graphstorm.readthedocs.io/en/latest/api/generated/graphstorm.config.get_argument_parser.html#graphstorm.config.get_argument_parser) for more details about this method.

### GraphStorm `GSConfig` object explanation
Once obtained these arguments, we can use them to create a `GSConfig` object and then pass the object to different modules to get related configurations. The `GSConfig` object can check all arguments' format and values to ensure compliance with GraphStorm specifications. Below shows the code of creating the `GSConfig` object and examples of how to use it. Please refer to the [GSConfig API doc](https://graphstorm.readthedocs.io/en/latest/api/generated/graphstorm.config.GSConfig.html#graphstorm.config.GSConfig) for more details of this class.

In [20]:
from graphstorm.config import GSConfig

......

def fit(gs_args, cust_args):
    # Utilize GraphStorm's GSConfig class to accept arguments
    config = GSConfig(gs_args)

    gs.initialize(ip_config=config.ip_config, backend=config.backend, local_rank=config.local_rank)
    acm_data = gs.dataloading.GSgnnData(part_config=config.part_config)

    ......

    model = RgatNCModel(g=acm_data.g,
                    num_heads=config.num_heads, 
                    num_hid_layers=config.num_layers,
                    node_feat_field=config.node_feat_name,
                    hid_size=config.hidden_size,
                    num_classes=config.num_classes,
                    encoder_type=cust_args.rgat_encoder_type)   # here use the customized argument instead of GSConfig