# Chapter 5: Training Graph Neural Networks

### Overview

The following assumes that the graphs(s) and node/edge features are already prepared. 

In [None]:
import dgl 
dataset = dgl.data.CiteseerGraphDataset()
graph = dataset[0]

### Heterogeneous Graphs

Sometimes we would like to work with heterogeneous graphs. Here we take a synthetic heterogeneous graph as an example for demonstrating node classification, edge classification, and link prediction tasks. 

The synthetic heterogeneous graph `hetero_graph` has these edge types:

- `('user', 'follow', 'user')`
- `('user', 'followed_by', 'user')`
- `('user', 'click', 'item')`
- `('item', 'clicked-by', 'user')`
- `('user', 'dislike', 'item')`
- `('item', 'disliked-by', 'user)`

In [None]:
import numpy as np
import torch

# Set seeds 
np.random.seed(1775)
torch.manual_seed(1775)

n_users = 1000
n_items = 500
n_follows = 3000
n_clicks = 5000
n_dislikes = 500
n_hetero_features = 10
n_user_classes = 5
n_max_clicks = 10

follow_src = np.random.randint(0, n_users, n_follows)
follow_dst = np.random.randint(0, n_users, n_follows)
click_src = np.random.randint(0, n_users, n_clicks)
click_dst = np.random.randint(0, n_items, n_clicks)
dislike_src = np.random.randint(0, n_users, n_dislikes)
dislike_dst = np.random.randint(0, n_items, n_dislikes)

hetero_graph = dgl.heterograph({
    ('user', 'follow', 'user'): (follow_src, follow_dst),
    ('user', 'followed-by', 'user'): (follow_dst, follow_src), 
    ('user', 'click', 'item'): (click_src, click_dst), 
    ('item', 'clicked-by', 'user'): (click_dst, click_src),
    ('user', 'dislike', 'item'): (dislike_src, dislike_dst), 
    ('item', 'disliked-by', 'user'): (dislike_dst, dislike_src),
})

## 5.1 Node Classification/Regression

### Writing NN Model

DGL provides a few builtin graph convolution modules that can perform one round of message passing. We choose `dgl.nn.pytorch.SAGEConv`, the graph convolution module for GraphSAGE. 

For deep learning models on graphs we need a multi-layer graph neural network, where we do multiple rounds of message passing. This is achieved by stacking graph convolution modules:

In [None]:
# Construct a two-layer GNN model
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
class SAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super().__init__()
        self.conv1 = dglnn.SAGEConv(
            in_feats = in_feats, out_feats=hid_feats, aggregator_type='mean',
        ) 
        self.conv2 = dglnn.SAGEConv(
            in_feats = hid_feats, out_feats = out_feats, aggregator_type='mean'
        )
        
    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = F.relu(h)
        h = self.conv2(graph, h)
        return h
    

### Training Loop

Training on the full graph simply involves a forward propagation of the model defined above, and computing the loss by comparing the prediction against ground truth labels on the training nodes. 

The section uses a DGL built-in dataset `dgl.data.CiteseerGraphDataset` to show a training loop. The node features and labels are stored on its graph instance, and the training-validation-test split are also stored on the graph as boolean masks.

In [None]:
node_features = graph.ndata['feat']
node_labels = graph.ndata['label']
train_mask = graph.ndata['train_mask']
valid_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
n_features = node_features.shape[1]
n_labels = int(node_labels.max().item()+1)

In [None]:
def evaluate(model, graph, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(graph, features)
        logtis = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logtis, dim=1)
        correct = torch.sum(indices==labels)
        return correct.item() * 1.0 / len(labels)

In [None]:
# We then write the training loop as:

model = SAGE(in_feat=n_features, hid_feats =100, out_feats=n_labels)
opt = torch.optim.Adam(model.parameters())

for epoch in range(10):
    model.train()
    # forward propagation by using all nodes
    logits = model(graph, node_features)
    # Compute loss
    loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])
    # Compute validation accuracy
    acc = evaluate(model, graph, node_features, node_labels, valid_mask)
    # Back propagation
    opt.zero_grad()
    loss.backwards()
    opt.step()
    
    #Optionally here we could save the model


GraphSAGE provides an end-to-end homogeneous graph node classification example. You could see the corresponding model implementation is in the `GraphSAGE` class in the the example with adjustable number of layers, dropout probabilities, and customizable aggregation functions and nonlinearities. 

### Heterogeneous Graph

If our graph is heterogeneous, we may want to gather messages from neighbours along all edge types. We can use the module `dgl.nn.pytorch.HeteroGraphConv` to perform message passing on all edge types, then combining different graph convolution modules for each type. 

The following code will define a heterogeneous graph convolution module that first performs a separate graph convolution on each edge type, then sums the message aggregations on each type as the final result for all node types.

In [None]:
# define a heterograph conv model

class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()
        
        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats) for rel in rel_names
        }, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats) for rel in rel_names
        }, aggregate='sum')
        
    def forward(self, graph, inputs): 
        # Inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

`dgl.nn.HeteroGraphConv` takes in a dictionary of node types and node feature tensors as input, and returns another dictionary of node types and node features. 

In [None]:
# Using the user-item heterograph example
model = RGCN(n_hetero_features, 20, n_user_classes, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
labels = hetero_graph.nodes['user'].data['label']
train_mask = hetero_graph.nodes['user'].data['train_mask']

In [None]:
# Forward propagation
node_features = {'user': user_feats, 'item': item_feats}
h_dict = model(hetero_graph, {'users': user_feats, 'item':item_feats})
h_user = h_dict['user']
h_item = h_dict['item']

Training loop is the same as for the homogeneous graph, except now we have a dictionary of node representations from which you compute the predictions. For instance if we are only predicting the `user` nodes, we can just extract the `user` node embeddings from the returned dictionary:

In [None]:
opt = torch.optim.Adam(model.parameters())

for epoch in range(5):
    model.train()
    # forward propagation using all nodes and extracting the user embs
    logits = model(hetero_graph, node_features)['user']
    # compute loss
    loss = F.cross_entropy(logits[train_mask], labels[train_mask])
    # Compute the validation acc (omitted here)
    # back propagation
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

## 5.2 Edge Classification

In [None]:
# Create a random graph for edge prediction
import numpy as np
import dgl
import torch
import torch.nn as nn
np.random.seed(1775)
torch.manual_seed(1775)

src = np.random.randint(0, 100, 500)
dst = np.random.randint(0, 100, 500)
# Make the graph undirected
edge_pred_graph = dgl.graph((np.concatenate([src,dst]), np.concatenate([dst,src])))
# synthetic node and edge features, as well as edge labels
edge_pred_graph.ndata['feature'] = torch.randn(100,10)
edge_pred_graph.edata['feature'] = torch.randn(1000,10)
edge_pred_graph.edata['label'] = torch.randn(1000)
# synthetic train-validation-test splits
edge_pred_graph.edata['train_mask'] = torch.zeros(1000, dtype=torch.bool).bernoulli(0.6)

Similarly to how node classification is done with a multilayer GNN. The same technique can be applied for computing a hidden representation of any node. The prediction on edges can be derived from the representation of their incident nodes. 

The most common case of computing the prediction on  an edge is to express it as a parameterized function of the representation of its incident nodes, and optionally the features on the edge itself. 

### Model Implementation Difference from Node Classification

Here we compute the node representation with the model from the previous section on Node Classification, we need only to write another component that computes the edge prediction with `apply_edges()` method. 

For instance if we want to compute a score for each edge for edge regression, the following code computes the dot product of incident node representations on each edge. 

In [None]:
import dgl.function as fn
class DotProductPredictor(nn.Module):
    def forward(self, graph, h):
        # h contains the node representations computed from the GNN defined
        # in the node classification section 3.2
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            return graph.edata['score']

We may also write a prediction function that predicts a vector for each edge with a MLP. Such a vector can be used in further downstream tasks, for example as logits of a categorical distribution. 

In [None]:
class MLPPredictor(nn.Module):
    def __init__(self, in_features, out_classes):
        super().__init__()
        self.W = nn.Linear(in_features *2, out_classes)
        
    def apply_edges(self, edges):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        score = self.W(torch.cat([h_u, h_v], 1))
        
    def forward(self, graph, h): 
        # h contains the node representations computed from the GNN defined
        # in the node classification section (Section 5.1)
        with graph.local_scope():
            graph.ndata['h']
            graph.apply_edges(self.apply_edges)
            return graph.edata['score']
        

### Training Loop

Given the node representation computation model and an edge predictor model, we can easily write a full-graph training loop where we compute the prediction on all edges.

The following example takes `SAGE` in the previous section as the node representation computation model and `DotPredictor` as an edge predictor model. 

In [None]:
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.sage = SAGE(in_features, hidden_features, out_features)
        self.pred = DotProductPredictor()
    def forward(self, g, x):
        h = self.sage(g, x)
        return self.pred(g,h)

In this example, we also assume that the training/validation/test edge sets are identified by boolean masks on edges. This example does not include early stopping or model saving. 

In [None]:
node_features = edge_pred_graph.ndata['feature']
edge_label = edge_pred_graph.edata['label']
train_mask = edge_pred_graph.edata['train_mask']
model = Model(10,20,5)
opt = torch.optim.Adam(model.paramters())
for epoch in range(10):
    pred = model(edge_pred_graph, node_features)
    loss = ((pred[train_mask] - edge_label[train_mask]) ** 2).mean()
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())


### Heterogeneous Graph

Edge classification on heterogeneous graphs is not very different from that on homogeneous graphs. If you wish to perform edge classification on one edge type, you only need to compute the node representation for all node types, and predict on that edge type with the `apply_edges()` method.

For example to make `DotProductPredictor` work on one edge type of a heterogeneous graph, you only need to specify the edge type in the `apply_edges()` method. 

In [None]:
class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h, e_type):
        # h contains the node rep for each edge type computed from the GNN
        # for heterogeneous graphs defined in the node classification section 5.1 
        with graph.local_scope():
            graph.ndata['h'] = h # assigns 'h' of all node types in one shot
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']

In [None]:
# In a similar fashion we write HeteroMLPPredictor
class HeteroMLPPredictor(nn.Module):
    def __init__(self, in_features, out_classes):
        super().__init__()
        self.W = nn.Linear(in_features * 2, out_classes)
        
    def apply_edges(self, edges):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        score = self.W(torch.cat([h_u, h_v], 1))
        return {'score': score}
    
    def forward(self, graph, h, etype):
        # h contains the node reps for each edge type computed from 
        # the GNN in 5.1
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(self.apply_edgesm etype=etype)
            return graph.edges[etype].data['score']

The end-to-end model that predicts a score for each edge on a single edge type will look like this:

In [None]:
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroDotProductPredictor
    def forward(self, g, x, etype):
        h = self.sage(g, x)
        return self.pred(g, h, etype)

Using the model simply involves feeding the model a dictionary of node types and features.

In [None]:
model = Model(10, 20, 5, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
label = hetero_graph.edges['click'].data['label']
train_mask = hetero_graph.edges['click.data'].data['train_mask']
node_features = {'user': user_feats, 'item': item_feats}

The training loop looks almost the same as that in the homogeneous graph. For instance, if you wish to predict the edge labels on edge type `click`, then you can simply do:

In [None]:
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
    pred = model(hetero_graph, node_features, 'click')
    loss = ((pred[train_mask] - label[train_mask]) ** 2).mean()
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

### Predicting Edge Type of an Existing Edge on a Heterogeneous Graph

This problem arises when we want to predict which type an existing edge belongs to.

This is a simplified version of rating prediction, which is common in recommeder tasks on graphs.

We can use a heterogeneous graph convolution network to obtain the node representations. For instance, we can still use the RGCN defined ealier in chapter 5. 

To predict the type of an edge, we can simply repurpose `HeteroDotProductPredictor` above so that it takes another graph with only one edge type that 'merges' all the edge types to be predicted, and emits the score of each type for ever edge. 

In the example here, we will need a graph that has two node types `user` and `item`, and one single edge type that 'merges' all the edge types from `user` and `item`. That is it simplifies `click` and `dislike`. This can be conveniently created usint the following syntax:


In [None]:
dec_graph = hetero_graph['user', :, 'item']

Which returns a heterogeneous graph with node type `user` and `item` as well as a single edge type combining all edge types in between, in effect `click` and `dislike`. 

Since the statement above also returns the original edge types as a feature named `dgl.ETYPE`, we can use that as labels. 

In [None]:
edge_label = dec_graph.edata[dgl.ETYPE]

Given the graph above as input to the edge type predictor module, you can write your predictor module as follows:

In [None]:
class HeteroMLPPredictor(nn.Module):
    def __init__(self, in_dims, n_classes):
        super().__init__()
        self.W = nn.Linear(in_dims * 2, n_classes)
    
    def apply_edges(self, edges):
        x = torch.cat([edges.src['h'], edge.dst['h']], 1)
        y = self.W(x)
        return {'score': y}
    
    def forward(self, graph, h):
        # h contains the node representations for each edge type computed from
        # THe GNN for heterogeneous graphs defined in the node classification section 5.1
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(self.apply_edges)
            return graph.edata['score']

This model combines the node representation module and the edge type predictor module is the following:

In [None]:
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroDotProductPredictor(out_features, len(rel_names))
    def forward(self, g, x, dec_graph):
        h = self.sage(g, x)
        return self.pred(dec_graph, h)
    
# Associated training loop
model =  Model(10,20,5,hetero_graph.etypes)
user_feats = hetero_graph['user'].data['feature']
item_feats = hetero_graph['item'].data['feature']
node_features = {'user': user_feats, 'item': item_feats}

opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
    logits = model(hetero_graph, node_features, dec_graph)
    loss = F.cross_entropy(logits, edge_label)
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())
    
    
    

DGL provides Graph Convolutional Matrix Completion as an example of rating prediction, which is formulated by predicting the type of an existing edge on a heterogeneous graph. The node representation module in the model implementation file is called `GCMCLayer`. The edge predictor module is called `BiDecoder`. Both of them are more complicated than the setting described here. 

## 5.3 Link Prediction

https://docs.dgl.ai/en/1.0.x/guide/training-link.html