# Use GraphStorm CLIs for Multi-task Learning

This notebook demonstrates how to use GraphStorm Command Line Interfaces (CLIs) to run multi-task GNN model training and inference. By playing with this nodebook, users will be able to get familiar with GraphStom CLIs, hence furhter using them on their own tasks and models.

In this notebook, we will train a RGCN model on the ACM dataset with two training supervisions, i.e., link prediction and node feature reconstruction.

**Note:** For more details about multi-task learning please refer to [Multi-task Learning in GraphStorm](https://graphstorm.readthedocs.io/en/latest/advanced/multi-task-learning.html)

## 0. Setup environment
First let's install GraphStorm and its dependencies, PyTorch and DGL.

In [1]:
!pip install scikit-learn==1.4.2
!pip install scipy==1.13.0
!pip install pandas==1.3.5
!pip install pyarrow==14.0.0
!pip install graphstorm
!pip install torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
!pip install dgl==1.1.3 -f https://data.dgl.ai/wheels-internal/repo.html

## 1. Create the example ACM graph data
This notebook uses the ACM graph as an example. We use the following script to create the ACM graph data.

In [2]:
!mkdir example
!wget -O /example/acm_data.py https://github.com/awslabs/graphstorm/raw/main/examples/acm_data.py
!python /example/acm_data.py --output-path /example/acm_raw

The ACM graph data includes node files and edge files. It also includes a JSON configuration file describing how to construct a graph for model training. More details can be found in [Use Your Own Data (ACM data example)](https://graphstorm.readthedocs.io/en/latest/tutorials/own-data.html#use-your-own-data).

In [3]:
!ls -al /example/acm_raw/

## 2. Construct and Partition ACM Graph
Since GraphStorm is designed naturally for distributed GNN training, we need to construct a graph and split it into multiple partitions. In this example, for simplicity, we create a graph with one partition (no actual splitting).

In [4]:
!python -m graphstorm.gconstruct.construct_graph \
           --conf-file /example/acm_raw/config.json \
           --output-dir /example/acm_gs \
           --num-parts 1 \
           --graph-name acm

The generated ACM graph contains all the information required for GNN model training. For more details of preparing data for multi-task learning, please refer to [Preparing multi-task learning data](https://graphstorm.readthedocs.io/en/latest/advanced/multi-task-learning.html#preparing-the-training-data).

## 3. GNN Model Training 
Once the graph constucted, we can call the GraphStorm multi-task learning CLI to run model training. Before kicking off the model training, we need to create a YAML configuration file for the CLI.

In [5]:
!wget -O /example/acm_mt.yaml https://github.com/awslabs/graphstorm/raw/main/examples/use_your_own_data/acm_mt.yaml

In [6]:
!cat /example/acm_mt.yaml

The YAML configuration file defines two training tasks: 
 * A link prediction task on the `<paper, citing, paper>` edges. The task specific settings are under the`gsf::multi_task_learning::link_prediction` configuration block.
 * A node feature reconstruction task on the `paper` nodes with the node feature `label` to be reconstructed. The task specific settings are under the`gsf::multi_task_learning::reconstruct_node_feat` configuration block.
 
For more details of multi-task YAML configuration, please refer to [Define Multi-task for training](https://graphstorm.readthedocs.io/en/latest/advanced/multi-task-learning.html#define-multi-task-for-training).

#### Launch the training job

In [7]:
!python -m graphstorm.run.gs_multi_task_learning \
           --workspace /example \
           --part-config /example/acm_gs/acm.json \
           --num-trainers 1 \
           --cf /example/acm_mt.yaml \
           --num-epochs 4

The saved model is under `/example/acm_lp/models/`.

In [8]:
!ls -a /example/acm_lp/models/

## 4. GNN Model Inference 
Once the model is trained, we can do model inference with the trained model artifacts by using the GraphStorm multi-task learning CLI. We can use the same YAML configuration file for model inference.

#### Launch the inference job
We load the model checkpoint of epoch-2 in the example to do inference. The inference command will report the test scores for both link prediction task and node feature reconstruction task.

In [9]:
!python -m graphstorm.run.gs_multi_task_learning \
           --inference \
           --workspace /example \
           --part-config /example/acm_gs/acm.json \
           --restore-model-path /example/acm_lp/models/epoch-2 \
           --num-trainers 1 \
           --cf /example/acm_mt.yaml

#### Launch the embedding generation inference job

You can also use the GraphStorm `gs_gen_node_embedding` CLI to generate node embeddings with the trained GNN model on the ACM graph. The saved node embeddings are under `/example/acm_lp/emb/`.

In [10]:
!python -m graphstorm.run.gs_gen_node_embedding \
           --inference \
           --workspace /example \
           --part-config /example/acm_gs/acm.json \
           --restore-model-path /example/acm_lp/models/epoch-2 \
           --save-embed-path /example/acm_lp/emb/ \
           --restore-model-layers "embed,gnn" \
           --num-trainers 1 \
           --cf /example/acm_mt.yaml

In [11]:
!ls -al /example/acm_lp/emb/