# Notebook 4: Use GraphStorm APIs for Customizing Model Components

This notebook provides an example about how to customize components of GML models to fit specific requirements. The customized models should extend GraphStorm higher-level APIs, which enable them to not only implement their own functionalities, but also to easily integrate into GraphStorm training and inference pipelines.

----

### An Example of a Customized Model

A widely used GNN model is the `RGAT` model, proposed by [Relational Graph Attention Networks](https://arxiv.org/abs/1904.05811). The original `RGAT` model considers the different importance of neighbors for a node and leverages attention mechanism to aggregate messages from neighbors within the same relation type. Then, aggregations of neighbors from different relation types are added together as the output representations of a node,

$$h_i = \sum_{r\in \mathscr{R}}\sum_{j\in \mathcal{N}^{r}_{(i)}} \alpha^{r}_{i,j} W^{r} h_j^{r}.$$

An alternative way to aggregate representations across different relation types is to use attention instead of summation. We can use an additional weight set to compute the attention coefficients for different relation types,
$$\beta^r_i = \dfrac{exp(h^{r}_i \cdot \phi)}{\sum_{k \in \{1, \dots, \mathscr{R}\}} exp(h^{k}_i \cdot \phi)},$$
and then compute the weighted sum of aggregations across relation types,
$$h_i = \sum_{r \in \mathscr{R}}{\beta^r_i \times h_i^{r}},$$
$$h_i^{r} = LeakyReLU(\sum_{j\in \mathcal{N}^{r}_{(i)}} \alpha^{r}_{i,j} W^{r} h_j^{r}).$$

In this notebook, we will implement this `Across Relation Attention GAT (ARA_GAT)`, fit it into GraphStorm model architecture, and run training and inference using existing pipelines.

### Prerequsites

- GraphStorm. Please find [more details on installation of GraphStorm](https://graphstorm.readthedocs.io/en/latest/install/env-setup.html#setup-graphstorm-with-pip-packages).
- ACM data that has been created according to **[Notebook 0: Data Preparation](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_0_Data_Prepare.html)**, and is stored in the `./acm_gs_1p/` folder.
- Installation of supporting libraries, e.g., matplotlib.

## 1. Recap GraphStorm Model Architecture

As explained in **[Notebook 3: Use GraphStorm APIs for Implementing Built-in GNN Models](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_3_Model_Variants.html)**, a GraphStorm model normally contains four modules:

- An input encoder that converts input node features to the embeddings with hidden dimensions.
- A GNN encoder that takes the embeddings from the input encoder and performs message passing computation.
- A decoder that is task sepcific, e.g., the `EntityClassifier` for classification tasks.
- A loss function that matches specific tasks, e.g., the `ClassifyLossFunc`.

Given this architecture, it is clear that we only need to build a **customized GNN encoder** that implements the `ARA_GAT` variant and leave the other modules untouched.

## 2. ARA_GAT Variant Encoder Implementation

To build a customized GNN encoder, we can refer to implementation of GraphStorm's GNN encoders, such as `graphstorm.model.RelationalGATEncoder`, which extends the `graphstorm.model.GraphConvEncoder` and implements the required method.

The code in the cells below includes a layer module named `Ara_GatLayer`, which fulfills the ARA_GAT functions in one layer of GNN, and an encoder module named `Ara_GatEncoder`, which extends from `graphstorm.model.GSgnnNodeModel`. 

### 2.1 `Ara_GatLaye` Implementation

In [5]:
import dgl
import torch as th
import torch.nn as nn

class Ara_GatLayer(nn.Module):
    """ One layer of ARA_GAT
    """
    def __init__(self, in_dim, out_dim, num_heads, rel_names, bias=True,
                 activation=None, self_loop=False, dropout=0.0, norm=None):
        super(Ara_GatLayer, self).__init__()
        self.bias = bias
        self.activation = activation
        self.self_loop = self_loop
        self.leaky_relu = nn.LeakyReLU(0.2)

        # GAT module for each relation type
        self.rel_gats = nn.ModuleDict()
        for rel in rel_names:
            self.rel_gats[str(rel)] = dgl.nn.GATConv(in_dim, out_dim//num_heads,    # should be divible
                                                     num_heads, allow_zero_in_degree=True)

        # across-relation attention weight set
        self.acr_attn_weights = nn.Parameter(th.Tensor(out_dim, 1))
        nn.init.normal_(self.acr_attn_weights)

        # bias
        if bias:
            self.h_bias = nn.Parameter(th.Tensor(out_dim))
            nn.init.zeros_(self.h_bias)

        # weight for self loop
        if self.self_loop:
            self.loop_weight = nn.Parameter(th.Tensor(in_dim, out_dim))
            nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain('relu'))

        # dropout
        self.dropout = nn.Dropout(dropout)

        # normalization for each node type
        ntypes = set()
        for rel in rel_names:
            ntypes.add(rel[0])
            ntypes.add(rel[2])

        if norm == "batch":
            self.norm = nn.ParameterDict({ntype:nn.BatchNorm1d(out_dim) for ntype in ntypes})
        elif norm == "layer":
            self.norm = nn.ParameterDict({ntype:nn.LayerNorm(out_dim) for ntype in ntypes})
        else:
            self.norm = None

    def forward(self, g, inputs):
        """
        g: DGL.block
            A DGL block
        inputs : dict[str, torch.Tensor]
            Node feature for each node type.

        Returns
        -------
        dict[str, torch.Tensor]
            New node features for each node type.
        """
        g = g.local_var()

        # loop each edge type to fulfill GAT computation within each edge type
        for src_type, e_type, dst_type in g.canonical_etypes:

            # extract subgraph of each edge type
            sub_graph = g[src_type, e_type, dst_type]

            # check if no edges exist for this edge type
            if sub_graph.num_edges() == 0:
                continue

            # extract source and destination node features
            src_feat = inputs[src_type]
            dst_feat = inputs[dst_type][ :sub_graph.num_dst_nodes()]

            # GAT in one relation type
            agg = self.rel_gats[str((src_type, e_type, dst_type))](sub_graph, (src_feat, dst_feat))
            agg = agg.view(agg.shape[0], -1)

            # store aggregations in destination nodes
            sub_graph.dstdata['agg_' + str((src_type, e_type, dst_type))] = self.leaky_relu(agg)

        h = {}
        for n_type in g.dsttypes:
            if g.num_dst_nodes(n_type) == 0:
                continue

            # cross relation attention enhancement as outputs
            agg_list = []
            for k, v in g.dstnodes[n_type].data.items():
                if k.startswith('agg_'):
                    agg_list.append(v)

            # cross-relation attention
            if agg_list:
                acr_agg = th.stack(agg_list, dim=1)

                acr_att = th.matmul(acr_agg, self.acr_attn_weights)
                acr_sfm = th.softmax(acr_att, dim=1)

                # cross-relation weighted aggregation
                acr_sum = (acr_agg * acr_sfm).sum(dim=1)
            elif not self.self_loop:
                raise ValueError(f'Some nodes in the {n_type} type have no in-degree.' + \
                                 'Please check the data or set \"self_loop=True\"')

            # process new features
            if self.self_loop:
                if agg_list:
                    h_n = acr_sum + th.matmul(inputs[n_type][ :g.num_dst_nodes(n_type)], self.loop_weight)
                else:
                    h_n = th.matmul(inputs[n_type][ :g.num_dst_nodes(n_type)], self.loop_weight)
            if self.bias:
                h_n = h_n + self.h_bias
            if self.activation:
                h_n = self.activation(h_n)
            if self.norm:
                h_n = self.norm[n_type](h_n)
            h_n = self.dropout(h_n)

            h[n_type] = h_n

        return h


### 2.2 `Ara_GatEncoder` Implementation

Here, we implement the `Ara_GatEncoder` by extending the base GraphStorm encoder, `graphstorm.model.gnn_encoder_base.GraphConvEncoder`, and implementing the `forward(self, blocks, h)` funciton to make this class compatible with GraphStorm model architecture. The forward() function takes a DGL block list and a dictionary of node representations as input arguments, and returns a dictionary that contains the new node representations. This forward() function will be called by GraphStorm model classes within their own forward() function.

In [6]:
from graphstorm.model.gnn_encoder_base import GraphConvEncoder
import torch.nn.functional as F

class Ara_GatEncoder(GraphConvEncoder):
    """ Across Relation Attention GAT Encoder by extending Graphstorm APIs
    """
    def __init__(self, g, h_dim, out_dim, num_heads, num_hidden_layers=1,
                 dropout=0, use_self_loop=True, norm='batch'):
        super(Ara_GatEncoder, self).__init__(h_dim, out_dim, num_hidden_layers)

        # h2h
        for _ in range(num_hidden_layers):
            self.layers.append(Ara_GatLayer(h_dim, h_dim, num_heads, g.canonical_etypes,
                                            activation=F.relu, self_loop=use_self_loop, dropout=dropout, norm=norm))
        # h2o
        self.layers.append(Ara_GatLayer(h_dim, out_dim, num_heads, g.canonical_etypes,
                                        activation=F.relu, self_loop=use_self_loop, norm=norm))

    def forward(self, blocks, h):
        """ accept block list and feature dictionary as inputs
        """
        for layer, block in zip(self.layers, blocks):
            h = layer(block, h)
        return h

## 3. Build a Node Classification Model based on the `Ara_GatEncoder`

The `RgatNCModel` below follows the same node classification model architecture used in **[Notebook 1: Use GraphStorm APIs for Building a Node Classification Pipeline](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_1_NC_Pipeline.html)**. For the GNN encoder components, this model provides the option to use either the `Ara_GatEncoder` or the built-in `RelationalGATEncoder` from GraphStorm.

In [20]:
from graphstorm.model import GSgnnNodeModel, GSNodeEncoderInputLayer, RelationalGATEncoder, EntityClassifier, ClassifyLossFunc

class RgatNCModel(GSgnnNodeModel):
    """ A customized RGAT model for node classification using Graphstorm APIs
    """
    def __init__(self, g, num_heads, num_hid_layers, node_feat_field, hid_size, num_classes, multilabel=False,
                 encoder_type='ara'    # option for different rgat encoders
                ):
        super(RgatNCModel, self).__init__(alpha_l2norm=0.)

        # extract feature size
        feat_size = gs.get_node_feat_size(g, node_feat_field)

        # set an input layer encoder
        encoder = GSNodeEncoderInputLayer(g=g, feat_size=feat_size, embed_size=hid_size)
        self.set_node_input_encoder(encoder)

        # set the option of using either customized RGAT or built-in RGAT encoder
        if encoder_type == 'ara':
            gnn_encoder = Ara_GatEncoder(g, hid_size, hid_size, num_heads,
                                         num_hidden_layers=num_hid_layers-1)
        elif encoder_type == 'rgat':
            gnn_encoder = RelationalGATEncoder(g, hid_size, hid_size, num_heads,
                                               num_hidden_layers=num_hid_layers-1)
        else:
            raise Exception(f'Not supported encoders \"{encoder_type}\".')
        self.set_gnn_encoder(gnn_encoder)

        # set a decoder specific to node classification task
        decoder = EntityClassifier(in_dim=hid_size, num_classes=num_classes, multilabel=multilabel)
        self.set_decoder(decoder)

        # classification loss function
        self.set_loss_func(ClassifyLossFunc(multilabel=multilabel))

        # initialize model's optimizer
        self.init_optimizer(lr=0.001, sparse_optimizer_lr=0.01, weight_decay=0)

## 4. Node Classification Pipeline Using the `Ara_GatNCModel` Model 

The overall pipeline for using the customized model for node classification tasks is identical to those in **[Notebook 1: Use GraphStorm APIs for Building a Node Classification Pipeline](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_1_NC_Pipeline.html)**.

### 4.1 Training pipeline

In [1]:
import logging
logging.basicConfig(level=20)
import graphstorm as gs
gs.initialize()

acm_data = gs.dataloading.GSgnnData(part_config='./acm_gs_1p/acm.json')

nfeats_4_modeling = {'author':['feat'], 'paper':['feat'],'subject':['feat']}

train_dataloader = gs.dataloading.GSgnnNodeDataLoader(
    dataset=acm_data,
    target_idx=acm_data.get_node_train_set(ntypes=['paper']),
    node_feats=nfeats_4_modeling,
    label_field='label',
    fanout=[20, 20],
    batch_size=64,
    train_task=True)
val_dataloader = gs.dataloading.GSgnnNodeDataLoader(
    dataset=acm_data,
    target_idx=acm_data.get_node_val_set(ntypes=['paper']),
    node_feats=nfeats_4_modeling,
    label_field='label',
    fanout=[100, 100],
    batch_size=256,
    train_task=False)
test_dataloader = gs.dataloading.GSgnnNodeDataLoader(
    dataset=acm_data,
    target_idx=acm_data.get_node_test_set(ntypes=['paper']),
    node_feats=nfeats_4_modeling,
    label_field='label',
    fanout=[100, 100],
    batch_size=256,
    train_task=False)

model = RgatNCModel(g=acm_data.g, num_heads=8, num_hid_layers=2, node_feat_field=nfeats_4_modeling,
                    hid_size=128, num_classes=14, encoder_type='ara')

evaluator = gs.eval.GSgnnClassificationEvaluator(eval_frequency=100)

trainer = gs.trainer.GSgnnNodePredictionTrainer(model)
trainer.setup_evaluator(evaluator)
trainer.setup_device(gs.utils.get_device())

trainer.fit(train_loader=train_dataloader,
            val_loader=val_dataloader,
            test_loader=test_dataloader,
            num_epochs=50,
            save_model_path='a_save_path/')

### 4.2 Visualize Model Performance History

In [11]:
import matplotlib.pyplot as plt

val_metrics, test_metrics = [], []
for val_metric, test_metric in trainer.evaluator.history:
    val_metrics.append(val_metric['accuracy'])
    test_metrics.append(test_metric['accuracy'])

fig, ax = plt.subplots()
ax.plot(val_metrics, label='val')
ax.plot(test_metrics, label='test')
ax.set(xlabel='Epoch', ylabel='Accuracy')
ax.legend(loc='best')

### 4.3 Inference pipeline

In [12]:
best_model_path = trainer.get_best_model_path()
print('Best model path:', best_model_path)

model.restore_model(best_model_path)

infer_dataloader = gs.dataloading.GSgnnNodeDataLoader(dataset=acm_data,
                                                      target_idx=acm_data.get_node_test_set(ntypes=['paper']),
                                                      node_feats=nfeats_4_modeling,
                                                      label_field='label',
                                                      fanout=[100, 100],
                                                      batch_size=256,
                                                      train_task=False)

infer = gs.inference.GSgnnNodePredictionInferrer(model)

infer.infer(infer_dataloader,
            save_embed_path='infer/embeddings',
            save_prediction_path='infer/predictions',
            use_mini_batch_infer=True)