## How to train your GraphSAGE

So far we learned 'How to build your own graph', and 'How to build a graph neural network', and 'How to perform forward propagation with the GNN".
In this tutorial, we will learn "How to load/use famous graph benchmark dataset", and "How to train graph neural network model".

The content of this tutorial is orignally written by dgl team. The original documentation can be found from [here](https://docs.dgl.ai/guide/training-node.html#guide-training-node-classification).

In [17]:
import torch
import dgl
import torch.nn as nn
import dgl.nn as dglnn
import torch.nn.functional as F

## MNIST of GNNs : `CiteseerGraphDataset`

`CiteseerGraphDataset` is a the famous graph benchmark datset. `CiteseerGraphDataset` contains a  scientific publications citetation graph.
The node feature are predefined vectors with dimensions of 3703. The target task is to predict interger-valued labels of nodes. `CiteseerGraphDataset` contains
six different node labels. 

Out of 3327 nodes in the orignal graph, total 1620 nodes are selected to be used during training, validating, and testing your model.
1. 120 nodes can be used as the training nodes where you can directly compute the cross-entropy losses.
2. 500 nodes are reserved for validation
3. 1000 nodes are reserved for testing

`dgl` provides `CiteseerGraphDataset` with everything is preprocesed as mentioned above.

In [22]:
dataset = dgl.data.CiteseerGraphDataset()
graph = dataset[0] # since it only has one graph :)

Loading from cache failed, re-processing.


  r_inv = np.power(rowsum, -1).flatten()


Finished data loading and preprocessing.
  NumNodes: 3327
  NumEdges: 9228
  NumFeats: 3703
  NumClasses: 6
  NumTrainingSamples: 120
  NumValidationSamples: 500
  NumTestSamples: 1000
Done saving data into cached files.


## Simple Graph SAGE model

Let's build simple Graph SAGE model with `dgl` and `pytorch`. The simple GraphSAGE model defined as follows:
$$h=\text{GraphSAGE}^{(2)}(\text{ReLU}((\text{GraphSAGE}^{(1)}(\mathcal{G}, X))))$$

With the `dgl.nn` you can code up the above model with few lines.

In [12]:
class SAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super().__init__()
        self.conv1 = dglnn.SAGEConv(in_feats=in_feats, 
                                    out_feats=hid_feats, 
                                    aggregator_type='mean')
        self.conv2 = dglnn.SAGEConv(in_feats=hid_feats, 
                                    out_feats=out_feats, 
                                    aggregator_type='mean')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = F.relu(h)
        h = self.conv2(graph, h)
        return h

## Feature preparation

To train the neural network models in supervised fashion, we need to have input feature and the corresponding target. 
Don't foreget that we must use the subset of node feature and corresponding the labels for training!
Similary, we have to use only the subsets of feature and labels for validation and testing. Here, we will see
how can we get the right subset of features and labels for the purposes.

In [14]:
node_features = graph.ndata['feat']
node_labels = graph.ndata['label']
train_mask = graph.ndata['train_mask']
valid_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
n_features = node_features.shape[1]
n_labels = int(node_labels.max().item() + 1)

In [18]:
def evaluate(model, graph, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(graph, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

In [19]:
model = SAGE(in_feats=n_features, hid_feats=100, out_feats=n_labels)
opt = torch.optim.Adam(model.parameters())

for epoch in range(10):
    model.train()
    # forward propagation by using all nodes
    logits = model(graph, node_features)
    # compute loss
    loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])
    # compute validation accuracy
    acc = evaluate(model, graph, node_features, node_labels, valid_mask)
    # backward propagation
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

    # Save model if necessary.  Omitted in this example.

1.7948145866394043
1.7812025547027588
1.7674504518508911
1.7531793117523193
1.7383815050125122
1.7231699228286743
1.7076141834259033
1.6916249990463257
1.675097107887268
1.6580389738082886
