## GraphStorm Model Training 


After constructing the binary partitioned graph using GConstruct, you can run GraphStorm locally to train a model for detecting fraud transactions. 

GraphStorm supports model training across multiple environments, including local single-instance training, and distributed training on Sagemaker. 

For quick proof-of-concept and experiments with smaller graphs, you can train a model locally. Production deployments with enterprise-scale graph data can leverage cluster environments, including Amazon SageMaker, AWS Batch, and Amazon EC2 clusters. This notebook demonstrates local model training on a single instance.

----

### Model training configuration YAML file

GraphStorm enables [model training and inference configurations](https://graphstorm.readthedocs.io/en/latest/cli/model-training-inference/configuration-run.html#model-training-and-inference-configurations) through YAML files and CLI arguments. 

This allows you to define baseline configurations in a YAML file and use CLI arguments to modify or extend these settings. For this node classification model, we have prepared a baseline configuration file named `ieee_nc.yaml`.

In [None]:
!cat ieee_nc.yaml

The YAML configuration file defines several key configuration categories, including:

- GNN Model parameters specify model architecture parameters like ``model_encoder_type``, ``num_layers``, and ``hidden_size``. In this configuration, we define a two-layer [RGCN model](https://arxiv.org/pdf/1703.06103) with a 128-dimensional hidden layer. The ``node_feat_name`` parameter identifies the input features for the model.
- Hyperparameter Configurations control the training pipeline's core settings. These include learning rate(``lr``), batch size(``batch_size``), and training epochs(``num_epochs``). Specifically, we set the learning rate to 0.001 , 1024 samples per batch, and provide a CLI argument to train the model for 2 epochs.
- Node Classification Configurations define required classification parameters such as target node type (``target_ntype``), label field(``label_field``), and number of classes(``num_classes``). These settings are necessary to set up the classification model.

To get started with training your own model, you can use one of the [GraphStorm training configuration examples](https://github.com/awslabs/graphstorm/tree/main/training_scripts) as a starting point, then run a [SageMaker HPO](https://graphstorm.readthedocs.io/en/latest/cli/model-training-inference/distributed/sagemaker.html#launch-hyper-parameter-optimization-task) job to find the best performing set of hyperparameters for your particular problem.

### Model Training Commands

With the baseline configuration established, you can execute the node classification command to evaluate the initial model performance. Training for 2 epochs should take around 4 minutes on an `ml.m5.4xlarge` notebook instance.

In [None]:
# Use sys.executable to refer to the gsf kernel's python binary
import sys

PYTHON = sys.executable

In [None]:
!{PYTHON} -m graphstorm.run.gs_node_classification \
           --workspace ./ \
           --part-config ieee_gs/ieee-cis.json \
           --num-trainers 1 \
           --cf ieee_nc.yaml \
           --eval-metric roc_auc \
           --save-model-path ./model-simple/ \
           --topk-model-to-save 1 \
           --num-epochs 2

The best-performing model has been saved in the `./model-simple` folder, identified by its epoch number (`epoch-01`). In addition, two additional files are saved under this folder, which are needed during the deployment of real-time inference endpoint, `data_transform_new.json` which contains information about the graph data structure, and `GRAPHSTORM_RUNTIME_UPDATED_TRAINING_CONFIG.yaml` which contains information about the GNN model.

During inference we use the combined information from these files to re-create the model and re-apply any data transformations online.

Finally, we will save a small JSON file here that will tell us where we saved the model, we will re-use this in the following notebook when deploying the model.

In [None]:
import json
import os

with open("task_config.json", "w", encoding="utf-8") as f:
    task_config = {"MODEL_PATH": os.path.abspath("./model-simple")}
    json.dump(task_config, f)

----

#### [Optional] Addressing Label Imbalance in training

Remember for our dataset, only 3.5% of transaction labels are positive (1), while the majority are negative (0). So using an evaluation metric like simple accuracy would have been misleading, a simple majority classifier can achieve accuracy of ~97%.

To ensure fair evaluation, we used the Area Under the Receiver Operating Characteristic Curve [(AUC of ROC)](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) as the evaluation metric. This setting however only affects the evaluation and not the training process.

To take the label imbalance into consideration during the training process, we can also assign separate weights to each class, over-weighting the positive class to give more importance to the fraudulent cases in the dataset.

The updated command below incorporates new configuration arguments, enabling more robust model training and saving the top-1 performing model for potential real-time inference deployment. 

While you have a model that is ready to be deployed, you can let the more accurate model to train in the background while you work on deploying the initial model. If you want, you can come back and deploy the more accurate model later.

<div style="background-color: #fff8e6; color: #994d00; padding: 10px; border-left: 4px solid #994d00; margin-bottom: 10px;">
<strong>Note:</strong> Depending on the instance type, this training process may take minutes to hours.</div>

In [None]:
# [Optional] Start a long-running training command and move on to the next notebook
!{PYTHON} -m graphstorm.run.gs_node_classification \
           --workspace ./ \
           --part-config ieee_gs/ieee-cis.json \
           --num-trainers 1 \
           --cf ieee_nc.yaml \
           --model-encoder-type hgt \
           --num-ffn-layers-in-gnn 2 \
           --num-heads 16 \
           --hidden-size 128 \
           --eval-metric roc_auc \
           --imbalance-class-weights 0.1,1.0 \
           --fanout 10,10 \
           --save-model-path model-advanced/ \
           --topk-model-to-save 1 \
           --num-epochs 50

## Next steps

With the artifacts from the simple model available under `./model-simple` you can proceed to notebook `3-GraphStorm-Endpoint-Deployment.ipynb`, where you will deploy the trained model and prepare for online GNN inference.