# Tutorial for creating and training simple Graph Convolutional Networks (GCNs) in PyTorch.

In [None]:
# NumPy for data processing and handling
import numpy as np

# PyTorch for neural networks
import torch

# MatPlotLib for plotting
import matplotlib.pyplot as plt
%matplotlib inline  

In [None]:
# Arbitrarily define the necessarily variables needed to create a Graph Convolutional Network.

# Define an adjacency matrix for the graph
A_np = np.array([[[[0, 1, 1, 1],
                   [1, 0, 1, 1],
                   [1, 1, 0, 1],
                   [1, 1, 1, 0]]]])

# Define a feature matrix (normaly this would come from some dataset and there would be many samples)
# Can have arbitrary width but must have the same length as A
X_np = np.array([[[[0.5, 0.0],
                   [0.0, 1.0],
                   [0.3, 0.2],
                   [1.0, 0.7]]]])

# Define a label (normally this would come from the same dataset as before)
y_np = np.array([[0.5]])

# Cast numpy arrays as PyTorch tensors
A = torch.FloatTensor(A_np)
X = torch.FloatTensor(X_np)
y = torch.FloatTensor(y_np)

In [None]:
# Define our network as a sub class of torch.nn.Module
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()

        # Number of input nodes must be equal to the width of the feature matrix.
        self.l_1 = torch.nn.Linear(2, 128, bias=False)
        self.l_2 = torch.nn.Linear(128, 128, bias=False)
        self.flat = torch.nn.Flatten()
        # Number of output nodes must be equal to the length of the label.
        self.l_3 = torch.nn.Linear(4*128, 1)

        self.act_1 = torch.nn.ReLU()
        self.act_2 = torch.nn.ReLU()
        self.act_3 = torch.nn.Tanh()

        self.loss_fct = torch.nn.MSELoss()

    def forward(self, X, A):
        # Graph convolution layer
        f = torch.matmul(A, X)
        f = self.l_1(f)
        f = self.act_1(f)

        # Graph convolution layer
        f = torch.matmul(A, f)
        f = self.l_2(f)
        f = self.act_2(f)

        # Fully connected layer
        f = self.flat(f)
        f = self.l_3(f)
        f = self.act_3(f)

        return f

    def loss(self, y_pred, y_true):
        return self.loss_fct(y_pred, y_true)

In [None]:
net = GCN()

# Reset epoch count
epoch = -1

# Array to keep track of the loss after epoch
loss_log = []

In [None]:
#Initialize optimizer
optimizer = torch.optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)

n_epoch = 10

# Loop over training algorithm for n_epochs
for e in range(n_epoch):
    # Increase epoch count
    epoch += 1

    # Initialize variable to keep track of average epoch loss
    epoch_loss = 0.0
    
    # Put network in training mode
    net.train()

    X_th = torch.autograd.Variable(X)
    A_th = torch.autograd.Variable(A)
    y_true = torch.autograd.Variable(y)

    # Set gradients to zero
    optimizer.zero_grad()

    # Propogate the training data through the network
    y_pred = net(X_th, A_th)

    # Calculate the loss from the network outputs
    loss = net.loss(y_pred, y_true)
        
    # Perform backpropagation
    loss.backward()
    # Update weights using backpropagation
    optimizer.step()
    # Add batch loss to epoch loss
    epoch_loss += loss.item()
    
    # Record the average loss for the epoch
    loss_log.append(epoch_loss)

    print('\n' + 'Epoch: ' + str(epoch) + '\n' + 'Training Loss: ' + str(loss_log[epoch]))

print('\n' + 'Training finished.')

In [None]:
# Plot the training and validation loss curves.
plt.figure()
plt.plot(np.arange(len(loss_log)), loss_log, label='Training')
plt.xlabel('Epoch')
plt.ylabel('Mean Squared Error')
plt.xlim([0, len(loss_log)-1])
plt.ylim([0, 0.03])
plt.legend()
plt.show()

In [None]:
# Save model
torch.save(net.state_dict(), './GCN_model.pkl')

In [None]:
# Load model
loaded_net = GCN()

loaded_net.load_state_dict(torch.load('./GCN_model.pkl'))

# Put network in inference mode
loaded_net.eval()