# Node classification with SAGE

Authored by:
* Emil Riis Hansen 
* Jonas Brusokas
* Kashif Rabbani

References:
* General experiment setup - StellarGraph SAGE demo: https://stellargraph.readthedocs.io/en/stable/demos/node-classification/graphsage-node-classification.html
* NN setup with DGL - https://docs.dgl.ai/guide/training-node.html

### Installing Python preliminaries
* StellarGraph (hosts the preprocessed graph representation of the original dataset)
* Pandas, numpy, ... - default PyData libraries
* SKLearn - data preprocessing
* DGL - for GNN model definition, training and etc. 


In [1]:
# install StellarGraph if running on Google Colab
import sys
if 'google.colab' in sys.modules:
  %pip install -q stellargraph[demos]==1.2.1

[K     |████████████████████████████████| 440kB 4.7MB/s 
[K     |████████████████████████████████| 235kB 11.3MB/s 
[K     |████████████████████████████████| 51kB 7.2MB/s 
[?25h  Building wheel for mplleaflet (setup.py) ... [?25l[?25hdone


In [2]:
# verify that we're using the correct version of StellarGraph for this notebook
import stellargraph as sg

try:
    sg.utils.validate_notebook_version("1.2.1")
except AttributeError:
    raise ValueError(
        f"This notebook requires StellarGraph version 1.2.1, but a different version {sg.__version__} is installed.  Please see <https://github.com/stellargraph/stellargraph/issues/1172>."
    ) from None

In [3]:
import networkx as nx
import pandas as pd
import itertools
import json
import os

import numpy as np

from networkx.readwrite import json_graph

from sklearn.preprocessing import StandardScaler

import stellargraph as sg
from stellargraph.mapper import ClusterNodeGenerator
from stellargraph.layer import GCN
from stellargraph import globalvar

from tensorflow.keras import backend as K

from tensorflow.keras import layers, optimizers, losses, metrics, Model
from sklearn import preprocessing, feature_extraction, model_selection
from stellargraph import datasets
from IPython.display import display, HTML
from IPython.display import display, HTML
import matplotlib.pyplot as plt
%matplotlib inline

In [4]:
!pip install dgl

# Specific dependencies for DGL with PyTorch acting as a back-end
# NOTE: PyTorch is the default back-end 
import dgl
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F

import torch.utils.data
from torch.utils.data import Dataset, DataLoader

Collecting dgl
[?25l  Downloading https://files.pythonhosted.org/packages/4d/05/9627fd225854f9ab77984f79405e78def50eb673a962940ed30fc07e9ac6/dgl-0.5.2-cp36-cp36m-manylinux1_x86_64.whl (3.5MB)
[K     |████████████████████████████████| 3.5MB 4.6MB/s 
Installing collected packages: dgl
Successfully installed dgl-0.5.2
Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


DGL backend not selected or invalid.  Assuming PyTorch for now.
Using backend: pytorch


## Loading the dataset

This notebook demonstrates the DGL SAGE implementation for classification. It uses the Cora dataset of classified scientific publications.

In [5]:
display(HTML(datasets.Cora().description))

In [6]:
G, labels = datasets.Cora().load()

print(G.info())

StellarGraph: Undirected multigraph
 Nodes: 2708, Edges: 5429

 Node types:
  paper: [2708]
    Features: float32 vector, length 1433
    Edge types: paper-cites->paper

 Edge types:
    paper-cites->paper: [5429]
        Weights: all 1 (default)
        Features: none


As seen from the G.info output, the Graph has 2708 nodes, 5429 edges. Each node has 1433 features (corresponding to words from dictionary)


In [7]:
set(labels)

{'Case_Based',
 'Genetic_Algorithms',
 'Neural_Networks',
 'Probabilistic_Methods',
 'Reinforcement_Learning',
 'Rule_Learning',
 'Theory'}

There are 7 distinct target classes, classifying type of scientific publication

In [8]:
from collections import Counter
Counter(labels)

Counter({'Case_Based': 298,
         'Genetic_Algorithms': 418,
         'Neural_Networks': 818,
         'Probabilistic_Methods': 426,
         'Reinforcement_Learning': 217,
         'Rule_Learning': 180,
         'Theory': 351})

There is a imbalance of classes within the dataset that might lead to sub-optimal predictive performance, but for the purposes of demonstrating how GNNs work, it will suffice

## Graph preprocessing for DGL | PyTorch

In [9]:
print(f"Number of nodes in graph: {len(G.nodes())}")

# Create a tensor of shape [nodes, features] where for each node (corresponding to a paper), identify which words from dictionary it contains
node_features = torch.from_numpy(G.node_features())
print(f"Extracted node features tensor shape: {node_features.shape}")

# Create one-hot encoded vectors for labels (targets | paper classes) 
target_encoding = preprocessing.LabelBinarizer()
one_hot_node_labels = target_encoding.fit_transform(labels)
print(f"Extracted node one-hot encoded labels tensor shape: {one_hot_node_labels.shape}")

# For the defined model we want to have labels as integers (position of '1' within one-hot encoded vector)
node_labels = torch.from_numpy(np.argmax(one_hot_node_labels, axis=1))
print(f"Extracted node labels tensor shape: {node_labels.shape}")

Number of nodes in graph: 2708
Extracted node features tensor shape: torch.Size([2708, 1433])
Extracted node one-hot encoded labels tensor shape: (2708, 7)
Extracted node labels tensor shape: torch.Size([2708])


## Constructing a SAGE NN (2-layer) for node classification

In [10]:
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F

class SAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super().__init__()
        print(f"SAGE Model constructed: in_feats={in_feats}, hid_feats={hid_feats}, out_feats={out_feats}")
        self.conv1 = dglnn.SAGEConv(
            in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
        self.conv2 = dglnn.SAGEConv(
            in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = F.relu(h)
        h = self.conv2(graph, h)
        # Model returns a tensor of probabilities - shape [node_count, output_features (categories)]
        return h

### Defining necessary preliminaries (node, feature, label counts and converting graph to DGL graph)

In [12]:
n_nodes = node_features.shape[0]
n_features = node_features.shape[1]
n_labels = int(node_labels.max().item() + 1)

print(f"Number of nodes: {n_nodes}")
print(f"Number of features: {n_features}")
print(f"Number of labels (classes): {n_labels}")

nx_graph = G.to_networkx()
graph = dgl.from_networkx(nx_graph)

Number of nodes: 2708
Number of features: 1433
Number of labels (classes): 7


### Defining which nodes will be used for training, validation and testing

In [13]:
n_train_nodes = 100
n_valid_nodes = 500
n_test_nodes = 1000

# We divide the dataset (by using bool masks as identifiers) as follows:
# [100 training nodes] [500 validation nodes] [1000 test nodes] [all other (1108) nodes]
train_mask = torch.tensor([True] * n_train_nodes + [False] * (n_nodes - n_train_nodes))
valid_mask = torch.tensor([False] * n_train_nodes + [True] * n_valid_nodes + [False] * (n_nodes - n_train_nodes - n_valid_nodes))
test_mask = torch.tensor([False] * n_train_nodes + [False] * n_valid_nodes + [True] * (n_test_nodes) + [False] * (n_nodes - n_train_nodes - n_valid_nodes - n_test_nodes))

### Defining our evaluation metric (correct/all predictions)

In [14]:
def evaluate(model, graph, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(graph, features)
        logits = logits[mask]
        labels = labels[mask]
        
        # Our evaluation metric is the ratio of correct_predictions/all_predictions
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

### Training code (model, optimizer instantiation, training loop)

In [15]:
model = SAGE(in_feats=n_features, hid_feats=100, out_feats=n_labels)

opt = torch.optim.Adam(model.parameters())
epoch_count = 20

print("\n--- [Training started] --- ")
for epoch in range(epoch_count):
    model.train()
    # Forward propagation by using all nodes - generates prediction
    logits = model(graph, node_features)

    # Compute loss from the training data
    # NOTE: mask defines which subset of the data to use for loss calculation
    loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])

    # Compute validation accuracy
    acc = evaluate(model, graph, node_features, node_labels, valid_mask)

    # Backward propagation
    opt.zero_grad()
    loss.backward()
    opt.step()

    print(f"Epoch: {epoch+1} | training loss: {loss.item()}, validation accuracy: {acc}")

SAGE Model constructed: in_feats=1433, hid_feats=100, out_feats=7

--- [Training started] --- 
Epoch: 1 | training loss: 2.013437271118164, validation accuracy: 0.116
Epoch: 2 | training loss: 1.8030956983566284, validation accuracy: 0.138
Epoch: 3 | training loss: 1.6141239404678345, validation accuracy: 0.168
Epoch: 4 | training loss: 1.443945288658142, validation accuracy: 0.21
Epoch: 5 | training loss: 1.290073037147522, validation accuracy: 0.244
Epoch: 6 | training loss: 1.1508654356002808, validation accuracy: 0.266
Epoch: 7 | training loss: 1.0245959758758545, validation accuracy: 0.292
Epoch: 8 | training loss: 0.9097762107849121, validation accuracy: 0.308
Epoch: 9 | training loss: 0.8051037788391113, validation accuracy: 0.326
Epoch: 10 | training loss: 0.7101536393165588, validation accuracy: 0.342
Epoch: 11 | training loss: 0.6242868900299072, validation accuracy: 0.356
Epoch: 12 | training loss: 0.54667067527771, validation accuracy: 0.36
Epoch: 13 | training loss: 0.4772