# Notebook 6: Running Custom Model with GraphStorm CLIs 

Notebook 1 to 5 provides examples about how to use GraphStorm APIs to implement various GNN components and models. These notebooks can run in the GraphStrom Standalone mode, i.e., on a single CPU or GPU of a single machine. To fully leverage GraphStorm's distributed model training and inference capability, however, we need to convert code implemented on these notebook into Python scripts that can be launched with GraphStorm Command Line Interfaces (CLIs).

This notebook introduces the method of conversion, and explain the key components of the example Python scripts. For this notebook, we use the custom model developed in the [Notebook 4: Use GraphStorm APIs for Customizing Model Components](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_4_Customized_Models.html) as an example.

----

### Prerequisites

- GraphStorm. Please find [more details on installation of GraphStorm](https://graphstorm.readthedocs.io/en/latest/install/env-setup.html#setup-graphstorm-with-pip-packages).
- ACM data that has been created according to [Notebook 0: Data Preparation](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_0_Data_Prepare.html), and is stored in the `./acm_gs_1p/` folder.

## Brief Introduction and Run CLIs on a Single Machine

In order to use GraphStorm CLIs, we need to put the custom model into a Python file, which can be called in the [Task-agnostic CLI for model training and inference](https://graphstorm.readthedocs.io/en/latest/cli/model-training-inference/single-machine-training-inference.html#task-agnostic-cli-for-model-training-and-inference) as an argument. We build two files for model training and inference separately.

We can reuse most of the code about the customized `RGAT` module in Notebook 4, , i.e., `Ara_GatLayer`, `Ara_GatEncoder`, and `RgatNCModel`, in the training and inference files.

For the training file, we can copy and paste the code of the `4.1 Training pipeline` section in Notebook 4, and enclose them in a `fit()` function. Similarly, for the inference file, we can copy and paste the code of the `4.3 Inference pipeline` section in Notebook 4, and enclose them in a `infer()` function.

We have provided the two files, named `demo_run_train.py` and `demo_run_infer.py` under the [GraphStorm API documentation folder](https://github.com/awslabs/graphstorm/tree/main/docs/source/api/notebooks). With the two files, we can call GraphStorm's task-agnostic CLI to run our custom model as shown below.

In [None]:
# download the example yaml configuration file
!wget -O acm_nc.yaml https://github.com/awslabs/graphstorm/raw/main/examples/use_your_own_data/acm_nc.yaml

# CLI for the custom RGAT model training
!python -m graphstorm.run.launch \
           --part-config ./acm_gs_1p/acm.json \
           --num-trainers 4 \
           --num-servers 1 \
           --num-samplers 0 \
           demo_run_train.py --cf acm_nc.yaml \
                             --save-model-path models/ \
                             --node-feat-name paper:feat author:feat subject:feat \
                             --num-epochs 5 \
                             --rgat-encoder-type ara

# CLI for the custom RGAT model inference
!python -m graphstorm.run.launch \
           --part-config ./acm_gs_1p/acm.json \
           --num-trainers 4 \
           --num-servers 1 \
           --num-samplers 0 \
           demo_run_infer.py --cf acm_nc.yaml \
                             --restore-model-path models/epoch-4 \
                             --save-prediction-path predictions/ \
                             --save-embed-path embeddings/ \
                             --node-feat-name paper:feat author:feat subject:feat \
                             --rgat-encoder-type ara

## CLI argument processing explanation
Compared to the code in [Notebook 4](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_4_Customized_Models.html), the majority of modifications in the two Python files is related to how to collect and parse GraphStorm CLI configurations. Unlike hard-coding some variables, e.g., `nfeats_4_modeling`, or setting fix input values, e.g., `label_field='label',` or `encoder_type='ara'`, we will need to provide these values via CLI configurations.

As shown in the above commands, there are three types of configurations passed to the GraphStorm task-agnostic command.

- [Launch CLI arguments](https://graphstorm.readthedocs.io/en/latest/cli/model-training-inference/configuration-run.html#launch-cli-arguments), which direclty follow the `graphstom.run.launch`.
- [Model training and inference configurations](https://graphstorm.readthedocs.io/en/latest/cli/model-training-inference/configuration-run.html#model-training-and-inference-configurations), which are predefined in GraphStorm. These configurations can be put into a yaml file which will be the value of `--cf` argument following the training or inference Python file name. You can also set them as arguments too, which will overwrite the same configurations set in the yaml file.
- Configurations specified for custom modules, which are not predefined in GraphStorm, but are used only for the custom modules should be defined as input arguments of training or inference Python files.

Below we show the main entrance function of the `demo_run_train.py` file.

In [3]:
import argparse
from graphstorm.config import get_argument_parser

......

if __name__ == '__main__':
    # Leverage GraphStorm's argument parser to accept configuratioin yaml file
    arg_parser = get_argument_parser()

    # parse all arguments and split GraphStorm's built-in arguments from the custom ones
    gs_args, unknown_args = arg_parser.parse_known_args()
    print(f'GS arguments: {gs_args}')

    # create a new argument parser dedicated for custom arguments
    cust_parser = argparse.ArgumentParser(description="Customized Arguments")
    # add custom arguments
    cust_parser.add_argument('--rgat-encoder-type', type=str, default="ara")
    cust_args = cust_parser.parse_args(unknown_args)
    print(f'Customized arguments: {cust_args}')

    # use both argument sets in our main function
    fit(gs_args, cust_args)

GraphStorm's config module provides a `get_argument_parser` method, which can create a argument parser, e.g., `arg_parser`, dedicated to process GraphStorm launch CLI arguments and model training and inference configurations. Using the `parse_known_args()` method, the argument parser can extract all GraphStorm built-in configurations, and also return custom arguments, which can be processed by another argument parse, e.g., the `cust_parser`. We can then pass these arguments to the corresponding methods. Please refer to [get_argument_parser API document](https://graphstorm.readthedocs.io/en/latest/api/generated/graphstorm.config.get_argument_parser.html#graphstorm.config.get_argument_parser) for more details about this method.

## GraphStorm `GSConfig` object explanation
Once obtained these arguments, we can use them to create a `GSConfig` object and then pass the object to different modules to get related configurations. The `GSConfig` object checks every argument's format and value to ensure compliance with GraphStorm specifications. Below cells show the code of creating the `GSConfig` object and examples of how to use it to pass configurations. For example, we can pass the IP list file, GraphStorm backend, and the local rank configurations to GraphStorm distributed context initialization function, i.e., `gs.initialize()`, to start GraphStorm distributed context.

For more details of `GSConfig`, please refer to the [GSConfig API documentation page](https://graphstorm.readthedocs.io/en/latest/api/generated/graphstorm.config.GSConfig.html#graphstorm.config.GSConfig) .

In [20]:
# in demo_run_train.py file

from graphstorm.config import GSConfig

......

def fit(gs_args, cust_args):
    # Utilize GraphStorm's GSConfig class to accept arguments
    config = GSConfig(gs_args)

    # Initialize distributed training and inference context
    gs.initialize(ip_config=config.ip_config, backend=config.backend, local_rank=config.local_rank)
    acm_data = gs.dataloading.GSgnnData(part_config=config.part_config)

    ......

    model = RgatNCModel(g=acm_data.g,
                        num_heads=config.num_heads,
                        num_hid_layers=config.num_layers,
                        node_feat_field=config.node_feat_name,
                        hid_size=config.hidden_size,
                        num_classes=config.num_classes,
                        encoder_type=cust_args.rgat_encoder_type)   # here use the custom argument instead of GSConfig

    ......

In [None]:
# in demo_run_infer.py file

from graphstorm.config import GSConfig

......

def infer(gs_args, cust_args):
    # Utilize GraphStorm's GSConfig class to accept arguments
    config = GSConfig(gs_args)

    ......

    model = RgatNCModel(g=acm_data.g,
                        num_heads=config.num_heads,
                        num_hid_layers=config.num_layers,
                        node_feat_field=config.node_feat_name,
                        hid_size=config.hidden_size,
                        num_classes=config.num_classes,
                        encoder_type=cust_args.rgat_encoder_type)   # here use the custom argument instead of GSConfig

    model.restore_model(config.restore_model_path)

    ......

## Run CLIs on a Distributed Cluster  

It is easy to modify the command in the above cell to run them on a [Distributed clusters](https://graphstorm.readthedocs.io/en/latest/cli/model-training-inference/distributed/cluster.html). We need conduct three additional operations:

1. As demonstrated in [User Your Own Data tutorial](https://graphstorm.readthedocs.io/en/latest/tutorials/own-data.html#run-graph-construction), partition the ACM data in multiple partitions, e.g., 2 partitions by setting the argument `--num-parts 2`, and record its JSON file path, e.g., `./acm_gs_2p/acm.json`.
2. Follow the [tutorial of creating a GraphStorm cluster](https://graphstorm.readthedocs.io/en/latest/cli/model-training-inference/distributed/cluster.html#create-a-graphstorm-cluster) to prepare a cluster with 2 machines.
3. Prepare an IP list file, e.g., `ip_list.txt` on the cluster, and record its file path, e.g., `./ip_list.txt`.

Then we just add two addition CLI launch arguments, and run the CLI below on the clusters within a running docker container.

In [None]:
# CLI for the custom RGAT model training
!python -m graphstorm.run.launch \
           --part-config ./acm_gs_2p/acm.json \
           --num-trainers 4 \
           --num-servers 1 \
           --num-samplers 0 \
           --ip-config ./ip_list.txt \
           --ssh-port 2222 \
           demo_run_train.py --cf acm_nc.yaml \
                             --save-model-path models/ \
                             --node-feat-name paper:feat author:feat subject:feat \
                             --num-epochs 5 \
                             --rgat-encoder-type ara

# CLI for the custom RGAT model inference
!python -m graphstorm.run.launch \
           --part-config ./acm_gs_2p/acm.json \
           --num-trainers 4 \
           --num-servers 1 \
           --num-samplers 0 \
           --ip-config ./ip_list.txt \
           --ssh-port 2222 \
           demo_run_infer.py --cf acm_nc.yaml \
                             --restore-model-path models/epoch-4 \
                             --save-prediction-path predictions/ \
                             --save-embed-path embeddings/ \
                             --node-feat-name paper:feat author:feat subject:feat \
                             --rgat-encoder-type ara

## Run CLIs on an Amazon SageMaker Cluster

In order to run the custom models on an Amazon SageMaker cluster, we need to conduct four steps:

1. Partition the ACM data in multiple partitions, e.g., 2 partition, and upload them to an Amazon S3 location, e.g., `s3://<PATH_TO_DATA>/acm_gs_2p`.
2. Upload the configuration yaml file to an Amazon S3 location, e.g., `s3://<PATH_TO_TRAINING_CONFIG>/acm_nc.yaml`.
3. Git clone [GraphStorm source code](https://github.com/awslabs/graphstorm), and move the `demo_run_train.py` and `demo_run_infer.py` files from the `graphstorm/docs/source/api/notebooks/` folder to the `graphstorm/python/graphstorm/` folder.
4. Follow the [Setup GraphStorm SageMaker Docker Image](https://graphstorm.readthedocs.io/en/latest/cli/model-training-inference/distributed/sagemaker.html#step-1-build-a-sagemaker-compatible-docker-image) tutorial to create a docker image.

Then use the following SageMaker CLIs to run custom model on an Amazon SageMaker cluster. Please refer to the [GraphStorm Model Training and Inference on on SageMaker](https://graphstorm.readthedocs.io/en/latest/cli/model-training-inference/distributed/sagemaker.html#) for more details.

In [None]:
# SageMaker CLIs should be run under the graphstorm/sagemaker folder
!cd /<path-to-graphstorm>/sagemaker/

# SageMaker CLI for the customized RGAT model training
!python launch/launch_train.py \
        --image-url <AMAZON_ECR_IMAGE_URI> \
        --region <REGION> \
        --entry-point run/train_entry.py \
        --role <ROLE_ARN> \
        --instance-count 2 \
        --graph-data-s3 s3://<PATH_TO_DATA>/acm_gs_2p \
        --yaml-s3 s3://<PATH_TO_TRAINING_CONFIG>/acm_nc.yaml \
        --model-artifact-s3 s3://<PATH_TO_SAVE_TRAINED_MODEL> \
        --graph-name acm \
        --task-type node_classification \
        --custom-script graphstorm/python/graphstorm/demo_run_train.py \
        --node-feat-name paper:feat author:feat subject:feat \
        --num-epochs 5 \
        --rgat-encoder-type ara

# SageMaker CLI for the customized RGAT model inference
!python launch/launch_infer.py \
        --image-url <AMAZON_ECR_IMAGE_URI> \
        --region <REGION> \
        --entry-point run/infer_entry.py \
        --role <ROLE_ARN> \
        --instance-count 2 \
        --graph-data-s3 s3://<PATH_TO_DATA>/acm_gs_2p \
        --yaml-s3 s3://<PATH_TO_TRAINING_CONFIG>/acm_nc.yaml \
        --model-artifact-s3 s3://<PATH_TO_SAVE_BEST_TRAINED_MODEL> \
        --raw-node-mappings-s3 s3://<PATH_TO_DATA>/acm_gs_2p/raw_id_mappings \
        --output-emb-s3 s3://<PATH_TO_SAVE_GENERATED_NODE_EMBEDDING>/ \
        --output-prediction-s3 s3://<PATH_TO_SAVE_PREDICTION_RESULTS> \
        --graph-name acm \
        --task-type node_classification \
        --custom-script graphstorm/python/graphstorm/demo_run_infer.py \
        --node-feat-name paper:feat author:feat subject:feat \
        --rgat-encoder-type ara