# Graph Attention Networks

Graph Attention Network, or GAT for short, is a Graph Neural Network (GNN). In this notebook we will see a the implementation of the model along with training and accuracy

## 1 Importing the Libraries
We will import the required libraries to build the model
* [pickle](https://docs.python.org/3/library/pickle.html): The pickle module implements binary protocols for serializing and de-serializing a Python object structure. This is used to load the data
* [scipy](https://docs.scipy.org/doc/scipy/reference/sparse.html): It provides functions to deal with sparse data
* [numpy](https://numpy.org/): It is a library consisting of multidimensional array objects and a collection of routines for processing of array
* [torch](https://pytorch.org/): PyTorch is a Python package that provides two high-level features. Tensor computation (like NumPy) with strong GPU acceleration and Deep neural networks built on a tape-based autograd system
* [os](https://docs.python.org/3/library/os.html): It is used for operating system dependent functionality.
* [enum](https://docs.python.org/3/library/enum.html) : It is used to implement Enumerations
* [git](https://gitpython.readthedocs.io/en/stable/intro.html#):  is a python library used to interact with git repositories
* [re](https://docs.python.org/3/library/re.html): This module provides regular expression matching operations
* [argparse](https://docs.python.org/3/library/argparse.html): The module makes it easy to write user-friendly command-line interfaces

In [23]:
# For Loading data
import pickle

# Main computation libraries
import scipy.sparse as sp
import numpy as np

# Deep learning related imports
import torch

# To define constants and access directories
import os
import enum

# To implement GAT inner workings
import torch.nn as nn
from torch.optim import Adam

import git # To store the checkpoint
import re  # regex


import argparse


## 2 Constants for the GAT Model

Here we will define the constants that we will use in the model

In [24]:
class DatasetType(enum.Enum):
    CORA = 0

# We'll be dumping and reading the data from this directory
DATA_DIR_PATH = os.path.join(os.getcwd(), 'data')
CORA_PATH = os.path.join(DATA_DIR_PATH, 'cora')  # this is checked-in no need to make a directory
    
CORA_TRAIN_RANGE = [0, 140]  # we're using the first 140 nodes as the training nodes
CORA_VAL_RANGE = [140, 140+500]
CORA_TEST_RANGE = [1708, 1708+1000]
CORA_NUM_INPUT_FEATURES = 1433
CORA_NUM_CLASSES = 7

## 2 Helper Functions

This section will contain a bunch of helper functions that will be used to build the model

### 2.1 Loading and Saving Data
First let's define these simple functions for loading/saving Pickle files - we need them for Cora. All Cora data is stored as pickle

In [25]:
def pickle_read(path):
    with open(path, 'rb') as file:
        data = pickle.load(file)

    return data

def pickle_save(path, data):
    with open(path, 'wb') as file:
        pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)

### 2.2 Normalizing the features

This function takes the node features as input and normalises them. We do multiplication with inverse sum of features.
* **Step 1** : We first Calculate sum features for every node feature vector. This has a shape (N, FIN) -> (N, 1), where N number of nodes and FIN number of input features
* **Step 2** : We make it inverse because * by 1/x is better (faster) then / by x. This has a shape = (N, 1) -> (N)
* **Step 3** : Again certain sums will be 0 so 1/0 will give us inf so we replace those by 1 which is a neutral element for mul
* **Step 4** : Create a diagonal matrix whose values on the diagonal come from node_features_inv_sum
* **Step 5** : We return the normalized features.

In [26]:
def normalize_features_sparse(node_features_sparse):
    assert sp.issparse(node_features_sparse), f'Expected a sparse matrix, got {node_features_sparse}.'

    node_features_sum = np.array(node_features_sparse.sum(-1))  # Step 1

    node_features_inv_sum = np.power(node_features_sum, -1).squeeze() # Step 2

    node_features_inv_sum[np.isinf(node_features_inv_sum)] = 1. # Step 3

    diagonal_inv_features_sum_matrix = sp.diags(node_features_inv_sum) # Step 4

    return diagonal_inv_features_sum_matrix.dot(node_features_sparse) # Step 5

### 2.3 Building Edge Index
This function build the edge index for each of the edges in the graph
* **Step 1** : It iterates through all the nodes
* **Step 2** : For each neighbour we check if we have visited that edge
* **Step 3** : We assign an edge index to a non visited edge
* **Step 4** : We stack the edge index and the nodes. This is of shape = (2, E), where E is the number of edges in the graph

In [27]:
def build_edge_index(adjacency_list_dict, num_of_nodes, add_self_edges=True):
    source_nodes_ids, target_nodes_ids = [], []
    seen_edges = set()

    for src_node, neighboring_nodes in adjacency_list_dict.items(): # Step 1
        for trg_node in neighboring_nodes: # Step 2
            if (src_node, trg_node) not in seen_edges:  
                source_nodes_ids.append(src_node)
                target_nodes_ids.append(trg_node)

                seen_edges.add((src_node, trg_node)) # Step 3

    if add_self_edges:
        source_nodes_ids.extend(np.arange(num_of_nodes))
        target_nodes_ids.extend(np.arange(num_of_nodes))

    edge_index = np.row_stack((source_nodes_ids, target_nodes_ids)) # Step 4

    return edge_index

### 2.4 Loading Graph Data
This function creates the graph Data from the Cora Dataset. The predefined datasets include
1. node_features.csr - This contains the feautes of the node. The shape of is (N, FIN), where N is the number of nodes and FIN is the number of input features
2. node_labels.npy - This contains the labels of the respective nodes. The shape of this is (N, 1), where N is the number of nodes 
3. adjacency_list.dict - This defins the nodes and the corresponding neighbours. The shape of this is (N, number of neighboring nodes). This is a dictionary not a matrix.

* **Step 1** : We normalize the data. This helps with training
* **Step 2** : We get edge index based on the graph. This is of shape = (2, E), where E is the number of edges, and 2 for source and target nodes. Basically edge index contains tuples of the format S->T, e.g. 0->3 means that node with id 0 points to a node with id 3.
* **Step 3** : Convert to dense PyTorch tensors. Needs to be long int type because later functions like PyTorch's index_select expect it
* **Step 4** : Get Indices that help us extract nodes that belong to the train/val and test splits

In [28]:
def load_graph_data(training_config, device):
    dataset_name = training_config['dataset_name'].lower()

    if dataset_name == DatasetType.CORA.name.lower():

        
        node_features_csr = pickle_read(os.path.join(CORA_PATH, 'node_features.csr'))

        node_labels_npy = pickle_read(os.path.join(CORA_PATH, 'node_labels.npy'))

        adjacency_list_dict = pickle_read(os.path.join(CORA_PATH, 'adjacency_list.dict'))

        node_features_csr = normalize_features_sparse(node_features_csr) # Step 1
        num_of_nodes = len(node_labels_npy)

        topology = build_edge_index(adjacency_list_dict, num_of_nodes, add_self_edges=True) # Step 2

        
        # Step 3
        topology = torch.tensor(topology, dtype=torch.long, device=device)
        node_labels = torch.tensor(node_labels_npy, dtype=torch.long, device=device)  
        node_features = torch.tensor(node_features_csr.todense(), device=device)

        # Step 4
        train_indices = torch.arange(CORA_TRAIN_RANGE[0], CORA_TRAIN_RANGE[1], dtype=torch.long, device=device)
        val_indices = torch.arange(CORA_VAL_RANGE[0], CORA_VAL_RANGE[1], dtype=torch.long, device=device)
        test_indices = torch.arange(CORA_TEST_RANGE[0], CORA_TEST_RANGE[1], dtype=torch.long, device=device)

        return node_features, node_labels, topology, train_indices, val_indices, test_indices
    else:
        raise Exception(f'{dataset_name} not yet supported.')

### 2.5 Final Data Precossing

We now finaly create the data that will be used to train the model
* **Step 1** : Checking whether you have a GPU
* **Step 2** : Loading Graph Data

In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

config = {
    'dataset_name': 'CORA',
}

node_features, node_labels, edge_index, train_indices, val_indices, test_indices = load_graph_data(config, device)

print(node_features.shape, node_features.dtype)
print(node_labels.shape, node_labels.dtype)
print(edge_index.shape, edge_index.dtype)
print(train_indices.shape, train_indices.dtype)
print(val_indices.shape, val_indices.dtype)
print(test_indices.shape, test_indices.dtype)

torch.Size([2708, 1433]) torch.float32
torch.Size([2708]) torch.int64
torch.Size([2, 13264]) torch.int64
torch.Size([140]) torch.int64
torch.Size([500]) torch.int64
torch.Size([1000]) torch.int64


## 3 GAT Inner Workings

This section explains the inner workings of the GAT Model

### 3.1 GAT Layer
We will define functions that are required by a single layer in this sections

### 3.1.1 Initialisation

* **Step 1** : Initiliase constants that will be used in the model
* **Step 2** : Create the trainable weights linear projection matrix (denoted as "W" in the paper), attention target/source(denoted as "a" in the paper) and bias (not mentioned in the paper but present in the official GAT repo).
* **Step 3** : After we concatenate target node (node i) and source node (node j) we apply the "additive" scoring function which gives us un-normalized score "e". Here we split the "a" vector - but the semantics remain the same. Basically instead of doing [x, y] (concatenation, x/y are node feature vectors) and dot product with "a" we instead do a dot product between x and "a_left" and y and "a_right" and we sum them up.
* **Step 4** : Initialise the Bias if nessecary
* **Step 5** : Initialise the Add Slip connection if necessary
* **Step 6** : Initialise the Leaky ReLu, using 0.2 as in the paper

### 3.1.2 Forward Implementation
* **Step 1** : Linear Projection + regularization
    * **Step 1.1** We apply the dropout to all of the input node features.
    * **Step 1.2** We project the input node features into NH independent output features (one for each attention head).
    
* **Step 2** : Edge attention calculation
    * **Step 2.1** Apply the scoring function (* represents element-wise (a.k.a. Hadamard) product)
    * **Step 2.2** We simply copy (lift) the scores for source/target nodes based on the edge index. Instead of preparing all the possible combinations of scores we &emsp;just prepare those that will actually be used and those are defined by the edge index.
    * **Step 2.3** Add stochasticity to neighborhood aggregation

* **Step 3** : Neighborhood aggregation
    * **Step 3.1** Element-wise (aka Hadamard) product
    * **Step 3.2** Sum up weighted and projected neighborhood feature vectors for every target node

* **Step 4** : Residual/skip connections, concat and bias

### 3.1.2 Neighbourhood Aware Softmax
 As the fn name suggest it does softmax over the neighborhoods. Example: say we have 5 nodes in a graph. Two of them 1, 2 are connected to node 3. If we want to calculate the representation for node 3 we should take into account feature vectors of 1, 2 and 3 itself. Since we have scores for edges 1-3, 2-3 and 3-3 in scores_per_edge variable, this function will calculate attention scores like this: 1-3/(1-3+2-3+3-3) (where 1-3 is overloaded notation it represents the edge 1-3 and its (exp) score) and similarly for 2-3 and 3-3 i.e. for this neighborhood we don't care about other edge scores that include nodes 4 and 5.
 
* **Step 1** : Calculate the numerator. Make logits <= 0 so that e^logit <= 1 (this will improve the numerical stability)
* **Step 2** : Calculate the denominator. shape = (E, NH)<
* **Step 3** : 1e-16 is theoretically not needed but is only there for numerical stability (avoid div by 0) - due to the possibility of the computer rounding a very small number all the way to 0.
* **Step 4** : reshape = (E, NH) -> (E, NH, 1) so that we can do element-wise multiplication with projected node features


In [30]:
class GATLayer(torch.nn.Module):
    
    # We'll use these constants in many functions 
    src_nodes_dim = 0  # position of source nodes in edge index
    trg_nodes_dim = 1  # position of target nodes in edge index

    nodes_dim = 0      # node dimension (axis is maybe a more familiar term nodes_dim is the position of "N" in tensor)
    head_dim = 1       # attention head dim

    def __init__(self, num_in_features, num_out_features, num_of_heads, concat=True, activation=nn.ELU(),
                 dropout_prob=0.6, add_skip_connection=True, bias=True, log_attention_weights=False):

        super().__init__()

        # Step 1
        self.num_of_heads = num_of_heads
        self.num_out_features = num_out_features
        self.concat = concat  # whether we should concatenate or average the attention heads
        self.add_skip_connection = add_skip_connection

        
        # Step 2
        self.linear_proj = nn.Linear(num_in_features, num_of_heads * num_out_features, bias=False)

        # Step 3
        self.scoring_fn_target = nn.Parameter(torch.Tensor(1, num_of_heads, num_out_features))
        self.scoring_fn_source = nn.Parameter(torch.Tensor(1, num_of_heads, num_out_features))

        # Step 4
        if bias and concat:
            self.bias = nn.Parameter(torch.Tensor(num_of_heads * num_out_features))
        elif bias and not concat:
            self.bias = nn.Parameter(torch.Tensor(num_out_features))
        else:
            self.register_parameter('bias', None)
            
        # Step 5
        if add_skip_connection:
            self.skip_proj = nn.Linear(num_in_features, num_of_heads * num_out_features, bias=False)
        else:
            self.register_parameter('skip_proj', None)

        #
        # End of trainable weights
        #

        self.leakyReLU = nn.LeakyReLU(0.2)  # Step 3.1.
        self.activation = activation
        
        self.dropout = nn.Dropout(p=dropout_prob)

       
        self.init_params()
        
    def forward(self, data):
        #
        # Step 1: Linear Projection + regularization
        #

        in_nodes_features, edge_index = data  # unpack data
        num_of_nodes = in_nodes_features.shape[self.nodes_dim]
        assert edge_index.shape[0] == 2, f'Expected edge index with shape=(2,E) got {edge_index.shape}'
        
        # Step 1.1
        in_nodes_features = self.dropout(in_nodes_features)

        # Step 1.2
        nodes_features_proj = self.linear_proj(in_nodes_features).view(-1, self.num_of_heads, self.num_out_features)

        nodes_features_proj = self.dropout(nodes_features_proj)  # in the official GAT imp they did dropout here as well

        #
        # Step 2: Edge attention calculation
        #
        
        # Step 2.1
        scores_source = (nodes_features_proj * self.scoring_fn_source).sum(dim=-1)
        scores_target = (nodes_features_proj * self.scoring_fn_target).sum(dim=-1)

        
        # Step 2.2
        scores_source_lifted, scores_target_lifted, nodes_features_proj_lifted = self.lift(scores_source, scores_target, nodes_features_proj, edge_index)
        scores_per_edge = self.leakyReLU(scores_source_lifted + scores_target_lifted)

        attentions_per_edge = self.neighborhood_aware_softmax(scores_per_edge, edge_index[self.trg_nodes_dim], num_of_nodes)
        
        # Step 2.3
        attentions_per_edge = self.dropout(attentions_per_edge)

        #
        # Step 3: Neighborhood aggregation
        #

        # Step 3.1
        nodes_features_proj_lifted_weighted = nodes_features_proj_lifted * attentions_per_edge

        # Step 3.2
        out_nodes_features = self.aggregate_neighbors(nodes_features_proj_lifted_weighted, edge_index, in_nodes_features, num_of_nodes)

        #
        # Step 4: Residual/skip connections, concat and bias
        #

        out_nodes_features = self.skip_concat_bias(attentions_per_edge, in_nodes_features, out_nodes_features)
        return (out_nodes_features, edge_index)

    #
    # Helper functions (without comments there is very little code so don't be scared!)
    #

    def neighborhood_aware_softmax(self, scores_per_edge, trg_index, num_of_nodes):
        
        # Step 1
        scores_per_edge = scores_per_edge - scores_per_edge.max()
        exp_scores_per_edge = scores_per_edge.exp()  # softmax

        # Step 2
        neigborhood_aware_denominator = self.sum_edge_scores_neighborhood_aware(exp_scores_per_edge, trg_index, num_of_nodes)

        # Step 3
        attentions_per_edge = exp_scores_per_edge / (neigborhood_aware_denominator + 1e-16)

        # Step 4
        return attentions_per_edge.unsqueeze(-1)

    def sum_edge_scores_neighborhood_aware(self, exp_scores_per_edge, trg_index, num_of_nodes):
        
        # The shape must be the same as in exp_scores_per_edge (required by scatter_add_) i.e. from E -> (E, NH)
        trg_index_broadcasted = self.explicit_broadcast(trg_index, exp_scores_per_edge)

        # shape = (N, NH), where N is the number of nodes and NH the number of attention heads
        size = list(exp_scores_per_edge.shape)  # convert to list otherwise assignment is not possible
        size[self.nodes_dim] = num_of_nodes
        neighborhood_sums = torch.zeros(size, dtype=exp_scores_per_edge.dtype, device=exp_scores_per_edge.device)

        # position i will contain a sum of exp scores of all the nodes that point to the node i (as dictated by the
        # target index)
        neighborhood_sums.scatter_add_(self.nodes_dim, trg_index_broadcasted, exp_scores_per_edge)

        # Expand again so that we can use it as a softmax denominator. e.g. node i's sum will be copied to
        # all the locations where the source nodes pointed to i (as dictated by the target index)
        # shape = (N, NH) -> (E, NH)
        return neighborhood_sums.index_select(self.nodes_dim, trg_index)

    def aggregate_neighbors(self, nodes_features_proj_lifted_weighted, edge_index, in_nodes_features, num_of_nodes):
        
        size = list(nodes_features_proj_lifted_weighted.shape)  # convert to list otherwise assignment is not possible
        size[self.nodes_dim] = num_of_nodes  # shape = (N, NH, FOUT)
        out_nodes_features = torch.zeros(size, dtype=in_nodes_features.dtype, device=in_nodes_features.device)

        # shape = (E) -> (E, NH, FOUT)
        trg_index_broadcasted = self.explicit_broadcast(edge_index[self.trg_nodes_dim], nodes_features_proj_lifted_weighted)
        # aggregation step - we accumulate projected, weighted node features for all the attention heads
        # shape = (E, NH, FOUT) -> (N, NH, FOUT)
        out_nodes_features.scatter_add_(self.nodes_dim, trg_index_broadcasted, nodes_features_proj_lifted_weighted)

        return out_nodes_features

    def lift(self, scores_source, scores_target, nodes_features_matrix_proj, edge_index):
        """
        Lifts i.e. duplicates certain vectors depending on the edge index.
        One of the tensor dims goes from N -> E (that's where the "lift" comes from).

        """
        src_nodes_index = edge_index[self.src_nodes_dim]
        trg_nodes_index = edge_index[self.trg_nodes_dim]

        # Using index_select is faster than "normal" indexing (scores_source[src_nodes_index]) in PyTorch!
        scores_source = scores_source.index_select(self.nodes_dim, src_nodes_index)
        scores_target = scores_target.index_select(self.nodes_dim, trg_nodes_index)
        nodes_features_matrix_proj_lifted = nodes_features_matrix_proj.index_select(self.nodes_dim, src_nodes_index)

        return scores_source, scores_target, nodes_features_matrix_proj_lifted

    def explicit_broadcast(self, this, other):
        # Append singleton dimensions until this.dim() == other.dim()
        for _ in range(this.dim(), other.dim()):
            this = this.unsqueeze(-1)

        # Explicitly expand so that shapes are the same
        return this.expand_as(other)

    def init_params(self):
        
        nn.init.xavier_uniform_(self.linear_proj.weight)
        nn.init.xavier_uniform_(self.scoring_fn_target)
        nn.init.xavier_uniform_(self.scoring_fn_source)

        if self.bias is not None:
            torch.nn.init.zeros_(self.bias)

    def skip_concat_bias(self, attention_coefficients, in_nodes_features, out_nodes_features):

        if self.add_skip_connection:  # add skip or residual connection
            if out_nodes_features.shape[-1] == in_nodes_features.shape[-1]:  # if FIN == FOUT
                # unsqueeze does this: (N, FIN) -> (N, 1, FIN), out features are (N, NH, FOUT) so 1 gets broadcast to NH
                # thus we're basically copying input vectors NH times and adding to processed vectors
                out_nodes_features += in_nodes_features.unsqueeze(1)
            else:
                # FIN != FOUT so we need to project input feature vectors into dimension that can be added to output
                # feature vectors. skip_proj adds lots of additional capacity which may cause overfitting.
                out_nodes_features += self.skip_proj(in_nodes_features).view(-1, self.num_of_heads, self.num_out_features)

        if self.concat:
            # shape = (N, NH, FOUT) -> (N, NH*FOUT)
            out_nodes_features = out_nodes_features.view(-1, self.num_of_heads * self.num_out_features)
        else:
            # shape = (N, NH, FOUT) -> (N, FOUT)
            out_nodes_features = out_nodes_features.mean(dim=self.head_dim)

        if self.bias is not None:
            out_nodes_features += self.bias

        return out_nodes_features if self.activation is None else self.activation(out_nodes_features)

### 3.2 GAT Layer
This combines all the layers to form a multilayer GAT

In [31]:
class GAT(torch.nn.Module):

    def __init__(self, num_of_layers, num_heads_per_layer, num_features_per_layer, add_skip_connection=True, bias=True,
                 dropout=0.6):
        super().__init__()
        assert num_of_layers == len(num_heads_per_layer) == len(num_features_per_layer) - 1, f'Enter valid arch params.'

        num_heads_per_layer = [1] + num_heads_per_layer  # trick - so that I can nicely create GAT layers below

        gat_layers = []  # collect GAT layers
        for i in range(num_of_layers):
            layer = GATLayer(
                num_in_features=num_features_per_layer[i] * num_heads_per_layer[i],  # consequence of concatenation
                num_out_features=num_features_per_layer[i+1],
                num_of_heads=num_heads_per_layer[i+1],
                concat=True if i < num_of_layers - 1 else False,  # last GAT layer does mean avg, the others do concat
                activation=nn.ELU() if i < num_of_layers - 1 else None,  # last layer just outputs raw scores
                dropout_prob=dropout,
                add_skip_connection=add_skip_connection,
                bias=bias
            )
            gat_layers.append(layer)

        self.gat_net = nn.Sequential(
            *gat_layers,
        )

    # data is just a (in_nodes_features, edge_index) tuple, I had to do it like this because of the nn.Sequential:
    # https://discuss.pytorch.org/t/forward-takes-2-positional-arguments-but-3-were-given-for-nn-sqeuential-with-linear-layers/65698
    def forward(self, data):
        return self.gat_net(data)

## 4 Training GAT
We will finally start training the GAT

### 4.1 Initialisation
* **Step 1** : We will define three different model phases to depict training, validation and testing.
* **Step 2** : We will define Global vars used for early stopping. After some number of epochs (as defined by the patience_period var) without any improvement on the validation dataset (measured via accuracy metric), we'll break out from the training loop.


In [32]:
# Step 1
class LoopPhase(enum.Enum):
    TRAIN = 0,
    VAL = 1,
    TEST = 2

# Step 2
BEST_VAL_ACC = 0
BEST_VAL_LOSS = 0
PATIENCE_CNT = 0

### 4.2 Get Training State

This function gets the model parameter and required implementation to print various stages of the model

In [33]:
def get_training_state(training_config, model):
    training_state = {
        # Training details
        "dataset_name": training_config['dataset_name'],
        "num_of_epochs": training_config['num_of_epochs'],
        "test_acc": training_config['test_acc'],

        # Model structure
        "num_of_layers": training_config['num_of_layers'],
        "num_heads_per_layer": training_config['num_heads_per_layer'],
        "num_features_per_layer": training_config['num_features_per_layer'],
        "add_skip_connection": training_config['add_skip_connection'],
        "bias": training_config['bias'],
        "dropout": training_config['dropout'],

        # Model state
        "state_dict": model.state_dict()
    }

    return training_state


def print_model_metadata(training_state):
    header = f'\n{"*"*5} Model training metadata: {"*"*5}'
    print(header)

    for key, value in training_state.items():
        if key != 'state_dict':  # don't print state_dict just a bunch of numbers...
            print(f'{key}: {value}')
    print(f'{"*" * len(header)}\n')



### 4.3 Get Training Args

This function creates the various arguments that will contain the model parameters and other information. It finally returns the training configuration

In [34]:


def get_training_args():
    parser = argparse.ArgumentParser()

    # Training related
    parser.add_argument("--num_of_epochs", type=int, help="number of training epochs", default=10000)
    parser.add_argument("--patience_period", type=int, help="number of epochs with no improvement on val before terminating", default=1000)
    parser.add_argument("--lr", type=float, help="model learning rate", default=5e-3)
    parser.add_argument("--weight_decay", type=float, help="L2 regularization on model weights", default=5e-4)
    parser.add_argument("--should_test", type=bool, help='should test the model on the test dataset?', default=True)

    # Dataset related
    parser.add_argument("--dataset_name", choices=[el.name for el in DatasetType], help='dataset to use for training', default=DatasetType.CORA.name)
    parser.add_argument("--should_visualize", type=bool, help='should visualize the dataset?', default=False)

    # Logging/debugging/checkpoint related (helps a lot with experimentation)
    parser.add_argument("--enable_tensorboard", type=bool, help="enable tensorboard logging", default=False)
    parser.add_argument("--console_log_freq", type=int, help="log to output console (epoch) freq (None for no logging)", default=100)
    parser.add_argument("--checkpoint_freq", type=int, help="checkpoint model saving (epoch) freq (None for no logging)", default=1000)
    args = parser.parse_args("")

    # Model architecture related - this is the architecture as defined in the official paper (for Cora classification)
    gat_config = {
        "num_of_layers": 2,  # GNNs, contrary to CNNs, are often shallow (it ultimately depends on the graph properties)
        "num_heads_per_layer": [8, 1],
        "num_features_per_layer": [CORA_NUM_INPUT_FEATURES, 8, CORA_NUM_CLASSES],
        "add_skip_connection": False,  # hurts perf on Cora
        "bias": True,  # result is not so sensitive to bias
        "dropout": 0.6,  # result is sensitive to dropout
    }

    # Wrapping training configuration into a dictionary
    training_config = dict()
    for arg in vars(args):
        training_config[arg] = getattr(args, arg)

    # Add additional config information
    training_config.update(gat_config)

    return training_config

### 4.4 Main Loop

We will now define a simple decorator function so that we don't have to pass arguments that don't change from epoch to epoch

* **Step 1** : Certain modules behave differently depending on whether we're training the model or not. e.g. nn.Dropout - we only want to drop model weights during the training.
* **Step 2** : Do a forwards pass and extract only the relevant node scores (train/val or test ones). Note: [0] just extracts the node_features part of the data (index 1 contains the edge_index)
* **Step 3** : Calculater the cross entropy. Example: let's take an output for a single node on Cora - it's a vector of size 7 and it contains unnormalized scores like: V = [-1.393,  3.0765, -2.4445,  9.6219,  2.1658, -5.5243, -4.6247] What PyTorch's cross entropy loss does is for every such vector it first applies a softmax, and so we'll have the V transformed into: [1.6421e-05, 1.4338e-03, 5.7378e-06, 0.99797, 5.7673e-04, 2.6376e-07, 6.4848e-07] Secondly, whatever the correct class is (say it's 3), it will then take the element at position 3, 0.99797 in this case, and the loss will be -log(0.99797). It does this for every node and applies a mean. You can see that as the probability of the correct class for most nodes approaches 1 we get to 0 loss!
* **Step 4** : If the model is training, do the dradient descent.
* **Step 5** : Accuracy metric. Finds the index of maximum (unnormalized) score for every node and that's the class prediction for that node. Compare those to true (ground truth) labels and find the fraction of correct predictions

In [None]:
def get_main_loop(config, gat, cross_entropy_loss, optimizer, node_features, node_labels, edge_index, train_indices, val_indices, test_indices, patience_period, time_start):

    node_dim = 0  

    train_labels = node_labels.index_select(node_dim, train_indices)
    val_labels = node_labels.index_select(node_dim, val_indices)
    test_labels = node_labels.index_select(node_dim, test_indices)

    # node_features shape = (N, FIN), edge_index shape = (2, E)
    graph_data = (node_features, edge_index)  # I pack data into tuples because GAT uses nn.Sequential which requires it

    def get_node_indices(phase):
        if phase == LoopPhase.TRAIN:
            return train_indices
        elif phase == LoopPhase.VAL:
            return val_indices
        else:
            return test_indices

    def get_node_labels(phase):
        if phase == LoopPhase.TRAIN:
            return train_labels
        elif phase == LoopPhase.VAL:
            return val_labels
        else:
            return test_labels

    def main_loop(phase, epoch=0):
        global BEST_VAL_ACC, BEST_VAL_LOSS, PATIENCE_CNT

        # Step 1
        if phase == LoopPhase.TRAIN:
            gat.train()
        else:
            gat.eval()

        node_indices = get_node_indices(phase)
        gt_node_labels = get_node_labels(phase)  # gt stands for ground truth

        # Step 2
        nodes_unnormalized_scores = gat(graph_data)[0].index_select(node_dim, node_indices)

        # Step 3
        loss = cross_entropy_loss(nodes_unnormalized_scores, gt_node_labels)

        # Step 4
        if phase == LoopPhase.TRAIN:
            optimizer.zero_grad()  # clean the trainable weights gradients in the computational graph (.grad fields)
            loss.backward()  # compute the gradients for every trainable weight in the computational graph
            optimizer.step()  # apply the gradients to weights

        # Step 5
        class_predictions = torch.argmax(nodes_unnormalized_scores, dim=-1)
        accuracy = torch.sum(torch.eq(class_predictions, gt_node_labels).long()).item() / len(gt_node_labels)

        #
        # Logging
        #

        

        if phase == LoopPhase.VAL:

            # Log to console
            if config['console_log_freq'] is not None and epoch % config['console_log_freq'] == 0:
                print(f'GAT training: time elapsed= {(time.time() - time_start):.2f} [s] | epoch={epoch + 1} | val acc={accuracy}')

            # The "patience" logic - should we break out from the training loop? If either validation acc keeps going up
            # or the val loss keeps going down we won't stop
            if accuracy > BEST_VAL_ACC or loss.item() < BEST_VAL_LOSS:
                BEST_VAL_ACC = max(accuracy, BEST_VAL_ACC)  # keep track of the best validation accuracy so far
                BEST_VAL_LOSS = min(loss.item(), BEST_VAL_LOSS)
                PATIENCE_CNT = 0  # reset the counter every time we encounter new best accuracy
            else:
                PATIENCE_CNT += 1  # otherwise keep counting

            if PATIENCE_CNT >= patience_period:
                raise Exception('Stopping the training')

        else:
            return accuracy  # in the case of test phase we just report back the test accuracy

    return main_loop  # return the decorated function

### 4.5 Final Training the GAT
* **Step 1** : Load the graph data
* **Step 2** : Prepare the model
* **Step 3** : Prepare other training related utilities (loss & optimizer and decorator function)
* **Step 4** : Start the training procedure
* **Step 5** : Potentially test the model

In [36]:
import time


def train_gat(config):
    global BEST_VAL_ACC, BEST_VAL_LOSS

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # checking whether you have a GPU, I hope so!

    # Step 1
    node_features, node_labels, edge_index, train_indices, val_indices, test_indices = load_graph_data(config, device)

    # Step 2
    gat = GAT(
        num_of_layers=config['num_of_layers'],
        num_heads_per_layer=config['num_heads_per_layer'],
        num_features_per_layer=config['num_features_per_layer'],
        add_skip_connection=config['add_skip_connection'],
        bias=config['bias'],
        dropout=config['dropout']
    ).to(device)

    # Step 3
    loss_fn = nn.CrossEntropyLoss(reduction='mean')
    optimizer = Adam(gat.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])

    
    main_loop = get_main_loop(
        config,
        gat,
        loss_fn,
        optimizer,
        node_features,
        node_labels,
        edge_index,
        train_indices,
        val_indices,
        test_indices,
        config['patience_period'],
        time.time())

    BEST_VAL_ACC, BEST_VAL_LOSS, PATIENCE_CNT = [0, 0, 0]  # reset vars used for early stopping

    # Step 4
    for epoch in range(config['num_of_epochs']):
        # Training loop
        main_loop(phase=LoopPhase.TRAIN, epoch=epoch)

        # Validation loop
        with torch.no_grad():
            try:
                main_loop(phase=LoopPhase.VAL, epoch=epoch)
            except Exception as e:  # "patience has run out" exception :O
                print(str(e))
                break  # break out from the training loop

    # Step 5
    if config['should_test']:
        test_acc = main_loop(phase=LoopPhase.TEST)
        config['test_acc'] = test_acc
        print(f'Test accuracy = {test_acc}')
    else:
        config['test_acc'] = -1


In [39]:
train_gat(get_training_args())

GAT training: time elapsed= 0.10 [s] | epoch=1 | val acc=0.302
GAT training: time elapsed= 8.21 [s] | epoch=101 | val acc=0.756
GAT training: time elapsed= 16.38 [s] | epoch=201 | val acc=0.784
GAT training: time elapsed= 24.46 [s] | epoch=301 | val acc=0.806
GAT training: time elapsed= 32.44 [s] | epoch=401 | val acc=0.812
GAT training: time elapsed= 40.49 [s] | epoch=501 | val acc=0.804
GAT training: time elapsed= 48.97 [s] | epoch=601 | val acc=0.79
GAT training: time elapsed= 57.36 [s] | epoch=701 | val acc=0.794
GAT training: time elapsed= 65.75 [s] | epoch=801 | val acc=0.806
GAT training: time elapsed= 74.09 [s] | epoch=901 | val acc=0.814
GAT training: time elapsed= 82.20 [s] | epoch=1001 | val acc=0.796
GAT training: time elapsed= 90.19 [s] | epoch=1101 | val acc=0.794
GAT training: time elapsed= 98.19 [s] | epoch=1201 | val acc=0.808
GAT training: time elapsed= 106.38 [s] | epoch=1301 | val acc=0.808
GAT training: time elapsed= 114.72 [s] | epoch=1401 | val acc=0.808
GAT trai