# Notebook 3: Use GraphStorm APIs for Implementing Built-in GNN Models

This notebook demonstrates how to use GraphStorm APIs to implement GraphStorm built-in GNN models such as RGAT and HGT, for different tasks.

In this notebook, we use different ``GSgnnEncoder`` modules, and set the corresponding arguments in a GNN model, hence reproducing several GraphStorm built-in GNN models, such as `RGAT`, and `HGT`. Using the same pipelines demonstratred in the **[Notebook 1: Node Classification Pipeline](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_1_NC_Pipeline.html)** and **[Notebook 2: Link Prediction Pipeline](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_2_LP_Pipeline.html)**, users can easily conduct node classification and link prediction task on the ACM dataset created by the **[Notebook 0: Data Preparation](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_0_Data_Prepare.html)**. 

### Prerequsites

- GraphStorm. Please find [more details on installation of GraphStorm](https://graphstorm.readthedocs.io/en/latest/install/env-setup.html#setup-graphstorm-with-pip-packages).
- ACM data that has been created according to the [Notebook 0: Data Preparation](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_0_Data_Prepare.html), and is stored in the `./acm_gs_1p/` folder.
- Installation of supporting libraries, e.g., matplotlib.

---

## 1. Revisit an `RGCN` model in the `demo_models.py`

The Notebook 1 and Notebook 2 both use `RGCN` models that share the same GNN model architecture defined by GraphStorm. To modify a GraphStorm GNN model, let's first revisit an RGCN model in the `demo_models.py` file. For simplicity, some document strings are removed, and code are restructured to fit in notebook cells.

In [15]:
import graphstorm as gs
from graphstorm.model import GSgnnNodeModel, GSNodeEncoderInputLayer, RelationalGCNEncoder, EntityClassifier, ClassifyLossFunc

class RgcnNCModel(GSgnnNodeModel):
    """ A simple RGCN model for node classification using Graphstorm APIs
    """
    def __init__(self, g, num_hid_layers, node_feat_field, hid_size, num_classes, multilabel=False):
        super(RgcnNCModel, self).__init__(alpha_l2norm=0.)

        # extract node feature dimensions
        feat_size = gs.get_node_feat_size(g, node_feat_field)

        # set an input encoder
        encoder = GSNodeEncoderInputLayer(g=g, feat_size=feat_size, embed_size=hid_size)
        self.set_node_input_encoder(encoder)

        # set a GNN encoder
        gnn_encoder = RelationalGCNEncoder(g=g, h_dim=hid_size, out_dim=hid_size, num_hidden_layers=num_hid_layers-1)
        self.set_gnn_encoder(gnn_encoder)

        # set a decoder specific to node classification task
        decoder = EntityClassifier(in_dim=hid_size, num_classes=num_classes, multilabel=multilabel)
        self.set_decoder(decoder)

        # classification loss function
        self.set_loss_func(ClassifyLossFunc(multilabel=multilabel))

        # initialize model's optimizer
        self.init_optimizer(lr=0.001, sparse_optimizer_lr=0.01, weight_decay=0)


### 1.1 GraphStorm built-in model architecture

A GraphStorm built-in model normally contains four modules:

- An input encoder that converts input node features to the embeddings with hidden dimensions.
- A GNN encoder that takes the embeddings from the input layer and performs message passing computation.
- A decoder that is task sepcific, e.g., the `EntityClassifier` for classification tasks.
- A loss function that matches specific tasks, e.g., the `ClassifyLossFunc`.

Besides the four modules, a GraphStorm GNN model also need to initialize its own optimizer object.

### 1.2 Model arguments

Each specific GNN model may has its own model arguments. Some arguments could be common for other models, like the input and output dimensions, while others may be model specific. For example, `RGCN` model asks for the number of bases to reduce the number of learnable parameters, and attention-based models may need to set the number of attention heads. Not only GNN models ask for arguments, GML tasks need specific arguments. For example, classification tasks may have multiple labels.

GraphStorm APIs have given default values to many arguments. For better flexibility, we can add some arguments into model initialization, such as `num_hid_layers` and `hid_size`.

### 1.3 GML task modules

Besides model-related modules, a GNN model also contains task-specific modules, including task specific decoders and loss functions. For example, to perform a node classification task, the above `RgcnNCModel` model chooses the `EntityClassifier` as its decoder and use the `ClassifyLossFunc` as its loss function.

----
## 2 Reproduce GraphStorm Built-in GNN Model Variants

With knowing the common architecture and arguments, it is easy to reproduce GraphStorm built-in GNN model variants.

### 2.1 Reproduce an `RGAT` Model for Node Classification

To turn the demo `RgcnNCModel` code into an `RgatNCModel` model, only need two modifications:

1. For the GNN encoder, replace the `RelationalGCNEncoder` with the `RelationalGATEncoder`.
2. Add some `RelationalGATEncoder` specific arguments in initialization.

Below is the simplified code of the `RgatNCModel` model. The complete code can be found in the `demo_models.py` file.

In [2]:
from graphstorm.model import RelationalGATEncoder

class RgatNCModel(GSgnnNodeModel):
    """ A simple Rgat model for node classification using Graphstorm APIs
    """
    def __init__(self, g,
                 num_heads,    # an argument specific to RelationalGATEncoder
                 num_hid_layers, node_feat_field, hid_size, num_classes, multilabel=False):
        super(RgatNCModel, self).__init__(alpha_l2norm=0.)

        # input encoder remains the same ......

        # set a GNN encoder
        gnn_encoder = RelationalGATEncoder(g=g, h_dim=hid_size, out_dim=hid_size,
                                           num_heads=num_heads,    # pass the num_heads to the RelationalGATEncoder
                                           num_hidden_layers=num_hid_layers-1)
        self.set_gnn_encoder(gnn_encoder)

        # decoder, loss function, and optimizer initialization remain the same ......

### 2.2 Reproduce an `HGT` Model with `DistMult` Decoder for Link Prediction

Similar as the `RGAT` variant, replacement of the `RelationalGCNEncoder` with the `HGTEncoder` and setting up corresponding arguments can reproduce an `HGT` model. In addition, this example also replaces the `LinkPredictDotDecoder` decoder with the `LinkPredictDistMultDecoder`, and sets its own arguments. Below is the simplified code of the `HgtLPModel` model. The complete code can be found in the `demo_models.py` file.

In [3]:
from graphstorm.model import GSgnnLinkPredictionModel, HGTEncoder, LinkPredictDistMultDecoder

class HgtLPModel(GSgnnLinkPredictionModel):
    """ A simple HGT model for link prediction using Graphstorm APIs
    """
    def __init__(self, g,
                 num_heads,    # an argument specific to HGTEncoder
                 num_hid_layers, node_feat_field, hid_size):
        super(HgtLPModel, self).__init__(alpha_l2norm=0.)

        # input encoder remains the same ......

        # set a GNN encoder
        gnn_encoder = HGTEncoder(g=g,
                                 num_heads=num_heads,    # pass the num_heads to the HGTEncoder
                                 hid_dim=hid_size, out_dim=hid_size, num_hidden_layers=num_hid_layers-1)
        self.set_gnn_encoder(gnn_encoder)

        # set a decoder specific to link prediction task
        decoder = LinkPredictDistMultDecoder(etypes=g.canonical_etypes,    # specificly added to the LinkPredictDistMultDecoder
                                             h_dim=hid_size)
        self.set_decoder(decoder)

        # loss function, and optimizer initialization remain the same ......

----

## 3. Link Prediciton Pipeline by Using the `HGT` Model 

To use the above mentioned GNN model variant, the overall GML pipeline only needs very few modifications that adapt to model specific arugments. Below example reuses the link prediction pipeline of the **Notebook 2**. For simplisity, this example combines multiple cells, and comments.

### 3.1 Training pipeline

In [16]:
import logging
import graphstorm as gs

logging.basicConfig(level=20)
gs.initialize()

nfeats_4_modeling = {'author':['feat'], 'paper':['feat'],'subject':['feat']}

acm_data = gs.dataloading.GSgnnData(part_config='./acm_gs_1p/acm.json', node_feat_field=nfeats_4_modeling)

train_dataloader = gs.dataloading.GSgnnLinkPredictionDataLoader(
    dataset=acm_data,
    target_idx=acm_data.get_edge_train_set(etypes=[('paper', 'citing', 'paper')]),
    fanout=[20, 20],
    num_negative_edges=10,
    node_feats=nfeats_4_modeling,
    batch_size=64,
    exclude_training_targets=True,
    reverse_edge_types_map={("paper", "citing", "paper"):("paper","cited","paper")},
    train_task=True)
val_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(
    dataset=acm_data,
    target_idx=acm_data.get_edge_val_set(etypes=[('paper', 'citing', 'paper')]),
    fanout=[100, 100],
    num_negative_edges=100,
    node_feats=nfeats_4_modeling,
    batch_size=256)
test_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(
    dataset=acm_data,
    target_idx=acm_data.get_edge_test_set(etypes=[('paper', 'citing', 'paper')]),
    fanout=[100, 100],
    num_negative_edges=100,
    node_feats=nfeats_4_modeling,
    batch_size=256)

from demo_models import HgtLPModel    # Import the HGT model variant

model = HgtLPModel(g=acm_data.g,
                   num_heads=8,
                   num_hid_layers=2,
                   node_feat_field=nfeats_4_modeling,
                   hid_size=128)

evaluator = gs.eval.GSgnnMrrLPEvaluator(eval_frequency=1000)

trainer = gs.trainer.GSgnnLinkPredictionTrainer(model, topk_model_to_save=1)
trainer.setup_evaluator(evaluator)
trainer.setup_device(gs.utils.get_device())

trainer.fit(train_loader=train_dataloader,
            val_loader=val_dataloader,
            test_loader=test_dataloader,
            num_epochs=50,
            save_model_path='a_save_path/',
            save_model_frequency=1000,
            use_mini_batch_infer=True)

### 3.2 Visualize Model Performance History

In [18]:
import matplotlib.pyplot as plt

val_metrics, test_metrics = [], []
for val_metric, test_metric in trainer.evaluator.history:
    val_metrics.append(val_metric['mrr'])
    test_metrics.append(test_metric['mrr'])

fig, ax = plt.subplots()
ax.plot(val_metrics, label='val')
ax.plot(test_metrics, label='test')
ax.set(xlabel='Epoch', ylabel='Mrr')
ax.legend(loc='best')

### 3.3 Inference pipeline

In [19]:
best_model_path = trainer.get_best_model_path()
print('Best model path:', best_model_path)

model.restore_model(best_model_path)

infer_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(
    dataset=acm_data,
    target_idx=acm_data.get_edge_infer_set(etypes=[('paper', 'citing', 'paper')]),
    fanout=[100, 100],
    num_negative_edges=100,
    node_feats=nfeats_4_modeling,
    batch_size=256)

infer = gs.inference.GSgnnLinkPredictionInferrer(model)

infer.infer(acm_data,
            infer_dataloader,
            save_embed_path='infer/embeddings',
            use_mini_batch_infer=True)