In [2]:
import json
import os
import enum

import matplotlib.pyplot as plt
import networkx as nx
from networkx.readwrite import json_graph

import numpy as np

import torch
from torch.utils.data import DataLoader, Dataset
import zipfile


In [3]:
DATA_DIR_PATH = os.path.join(os.getcwd(), 'data')
PPI_PATH = os.path.join(DATA_DIR_PATH, 'ppi')
PPI_URL = 'https://data.dgl.ai/dataset/ppi.zip'  


# \\ constants


PPI_NUM_INPUT_FEATURES = 50
PPI_NUM_CLASSES = 121
PPI_PATH

'/kaggle/working/data/ppi'

In [4]:
def json_read(path):
    with open(path, 'r') as file:
        data = json.load(file)

    return data

In [5]:
def load_graph_data(training_config, device):
    
    dataset_name = training_config['dataset_name'].lower()

    if not os.path.exists(PPI_PATH):  # download the first time this is ran
        os.makedirs(PPI_PATH)

        zip_tmp_path = os.path.join(PPI_PATH, 'ppi.zip')
        torch.hub.download_url_to_file(PPI_URL, zip_tmp_path)

        with zipfile.ZipFile(zip_tmp_path) as zf:
            zf.extractall(path=PPI_PATH)
        print(f'Unzipping to: {PPI_PATH} finished.')

        os.remove(zip_tmp_path)
        print(f'Removing tmp file {zip_tmp_path}.')

    edge_index_list = []
    node_features_list = []
    node_labels_list = []
    num_graphs_per_split_cumulative = [0]

    splits = ['test'] if training_config['ppi_load_test_only'] else ['train', 'valid', 'test']

    for split in splits:
        node_features = np.load(os.path.join(PPI_PATH, f'{split}_feats.npy'))
        node_labels = np.load(os.path.join(PPI_PATH, f'{split}_labels.npy'))
        nodes_links_dict = json_read(os.path.join(PPI_PATH, f'{split}_graph.json'))
        
        # Convert undirected graph into directed
        collection_of_graphs = nx.DiGraph(json_graph.node_link_graph(nodes_links_dict))
        
        # For each node in the above collection, ids specify to which graph the node belongs to
        graph_ids = np.load(os.path.join(PPI_PATH, F'{split}_graph_id.npy'))
        num_graphs_per_split_cumulative.append(num_graphs_per_split_cumulative[-1] + len(np.unique(graph_ids)))

        for graph_id in range(np.min(graph_ids), np.max(graph_ids) + 1):
            mask = graph_ids == graph_id  
            graph_node_ids = np.asarray(mask).nonzero()[0]
            graph = collection_of_graphs.subgraph(graph_node_ids)  
            print(f'Loading {split} graph {graph_id} to CPU. '
                  f'It has {graph.number_of_nodes()} nodes and {graph.number_of_edges()} edges.')

            edge_index = torch.tensor(list(graph.edges), dtype=torch.long).transpose(0, 1).contiguous()
            edge_index = edge_index - edge_index.min()  
            edge_index_list.append(edge_index)

            node_features_list.append(torch.tensor(node_features[mask], dtype=torch.float))
            node_labels_list.append(torch.tensor(node_labels[mask], dtype=torch.float))

    if training_config['ppi_load_test_only']:
        data_loader_test = GraphDataLoader(
            node_features_list[num_graphs_per_split_cumulative[0]:num_graphs_per_split_cumulative[1]],
            node_labels_list[num_graphs_per_split_cumulative[0]:num_graphs_per_split_cumulative[1]],
            edge_index_list[num_graphs_per_split_cumulative[0]:num_graphs_per_split_cumulative[1]],
            batch_size=training_config['batch_size'],
            shuffle=False
        )
        return data_loader_test
    else:
        data_loader_train = GraphDataLoader(
            node_features_list[num_graphs_per_split_cumulative[0]:num_graphs_per_split_cumulative[1]],
            node_labels_list[num_graphs_per_split_cumulative[0]:num_graphs_per_split_cumulative[1]],
            edge_index_list[num_graphs_per_split_cumulative[0]:num_graphs_per_split_cumulative[1]],
            batch_size=training_config['batch_size'],
            shuffle=True
        )

        data_loader_val = GraphDataLoader(
            node_features_list[num_graphs_per_split_cumulative[1]:num_graphs_per_split_cumulative[2]],
            node_labels_list[num_graphs_per_split_cumulative[1]:num_graphs_per_split_cumulative[2]],
            edge_index_list[num_graphs_per_split_cumulative[1]:num_graphs_per_split_cumulative[2]],
            batch_size=training_config['batch_size'],
            shuffle=False  
        )

        data_loader_test = GraphDataLoader(
            node_features_list[num_graphs_per_split_cumulative[2]:num_graphs_per_split_cumulative[3]],
            node_labels_list[num_graphs_per_split_cumulative[2]:num_graphs_per_split_cumulative[3]],
            edge_index_list[num_graphs_per_split_cumulative[2]:num_graphs_per_split_cumulative[3]],
            batch_size=training_config['batch_size'],
            shuffle=False
        )

        return data_loader_train, data_loader_val, data_loader_test


In [6]:
class GraphDataLoader(DataLoader):
   
    def __init__(self, node_features_list, node_labels_list, edge_index_list, batch_size=1, shuffle=False):
        graph_dataset = GraphDataset(node_features_list, node_labels_list, edge_index_list)
        super().__init__(graph_dataset, batch_size, shuffle, collate_fn=graph_collate_fn)


class GraphDataset(Dataset):
  
    def __init__(self, node_features_list, node_labels_list, edge_index_list):
        self.node_features_list = node_features_list
        self.node_labels_list = node_labels_list
        self.edge_index_list = edge_index_list

    def __len__(self):
        return len(self.edge_index_list)

    def __getitem__(self, idx):  # we just fetch a single graph
        return self.node_features_list[idx], self.node_labels_list[idx], self.edge_index_list[idx]


def graph_collate_fn(batch):


    edge_index_list = []
    node_features_list = []
    node_labels_list = []
    num_nodes_seen = 0

    for features_labels_edge_index in batch:
        # Just collect these into separate lists
        node_features_list.append(features_labels_edge_index[0])
        node_labels_list.append(features_labels_edge_index[1])

        edge_index = features_labels_edge_index[2]  
        edge_index_list.append(edge_index + num_nodes_seen) #So if we've processed 10 nodes from previous graphs, an edge index of (1, 3) from a new graph becomes (11, 13) in the batch-wide representation.
        num_nodes_seen += len(features_labels_edge_index[1])  

    node_features = torch.cat(node_features_list, 0)
    node_labels = torch.cat(node_labels_list, 0)
    edge_index = torch.cat(edge_index_list, 1)

    return node_features, node_labels, edge_index

In [10]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # checking whether you have a GPU
config = {
    'dataset_name': "PPI",
    'batch_size': 1,
    'ppi_load_test_only': False  # small optimization for loading test graphs only, we won't use it here
}

In [11]:
data_loader_train, data_loader_val, data_loader_test = load_graph_data(config, device)
# Let's fetch a single batch from the train graph data loader
node_features, node_labels, edge_index = next(iter(data_loader_train))

print('*' * 20)
print(node_features.shape, node_features.dtype)
print(node_labels.shape, node_labels.dtype)
print(edge_index.shape, edge_index.dtype)

100%|██████████| 6.76M/6.76M [00:00<00:00, 78.4MB/s]


Unzipping to: /kaggle/working/data/ppi finished.
Removing tmp file /kaggle/working/data/ppi/ppi.zip.
Loading train graph 1 to CPU. It has 1767 nodes and 34085 edges.
Loading train graph 2 to CPU. It has 1377 nodes and 31081 edges.
Loading train graph 3 to CPU. It has 2263 nodes and 61907 edges.
Loading train graph 4 to CPU. It has 2339 nodes and 67769 edges.
Loading train graph 5 to CPU. It has 1578 nodes and 37740 edges.
Loading train graph 6 to CPU. It has 1021 nodes and 19237 edges.
Loading train graph 7 to CPU. It has 1823 nodes and 46153 edges.
Loading train graph 8 to CPU. It has 2488 nodes and 72878 edges.
Loading train graph 9 to CPU. It has 591 nodes and 8299 edges.
Loading train graph 10 to CPU. It has 3312 nodes and 109510 edges.
Loading train graph 11 to CPU. It has 2401 nodes and 66619 edges.
Loading train graph 12 to CPU. It has 1878 nodes and 48146 edges.
Loading train graph 13 to CPU. It has 1819 nodes and 47587 edges.
Loading train graph 14 to CPU. It has 3480 nodes an

In [12]:
import torch.nn as nn
from torch.optim import Adam


class GAT(torch.nn.Module):
   

    def __init__(self, num_of_layers, num_heads_per_layer, num_features_per_layer, add_skip_connection=True, bias=True,
                 dropout=0.6, log_attention_weights=False):
        super().__init__()
        assert num_of_layers == len(num_heads_per_layer) == len(num_features_per_layer) - 1, f'Enter valid  params.'

        num_heads_per_layer = [1] + num_heads_per_layer  #first layer of the GAT model, the input data typically does not have any hierarchical representations or higher-level features that require multiple attention heads to capture.
        gat_layers = []  
        for i in range(num_of_layers):
            layer = GATLayer(
                num_in_features=num_features_per_layer[i] * num_heads_per_layer[i],  
                num_out_features=num_features_per_layer[i+1],
                num_of_heads=num_heads_per_layer[i+1],
                concat=True if i < num_of_layers - 1 else False,  # last GAT layer does mean avg, the others do concat
                activation=nn.ELU() if i < num_of_layers - 1 else None,  # last layer just outputs raw scores
                dropout_prob=dropout,
                add_skip_connection=add_skip_connection,
                bias=bias,
                log_attention_weights=log_attention_weights
            )
            gat_layers.append(layer)

        self.gat_net = nn.Sequential(
            *gat_layers,
        )

    # data is just a (in_nodes_features, edge_index) tuple, I had to do it like this because of the nn.Sequential:
    # https://discuss.pytorch.org/t/forward-takes-2-positional-arguments-but-3-were-given-for-nn-sqeuential-with-linear-layers/65698
    def forward(self, data):
        return self.gat_net(data)

In [13]:
class GATLayer(torch.nn.Module):
    

    src_nodes_dim = 0  # position of source nodes in edge index
    trg_nodes_dim = 1  # position of target nodes in edge index

   
    nodes_dim = 0     
    head_dim = 1       

    def __init__(self, num_in_features, num_out_features, num_of_heads, concat=True, activation=nn.ELU(),
                 dropout_prob=0.6, add_skip_connection=True, bias=True, log_attention_weights=False):

        super().__init__()

        self.num_of_heads = num_of_heads
        self.num_out_features = num_out_features
        self.concat = concat  # whether we should concatenate or average the attention heads
        self.add_skip_connection = add_skip_connection

        
        
        

        #  num_of_heads independent W matrices
        self.linear_proj = nn.Linear(num_in_features, num_of_heads * num_out_features, bias=False)

        # In the GAT (Graph Attention Network) model, attention scores are computed between pairs of nodes in the graph to determine the importance of each neighbor node to a target node. These attention scores are crucial for aggregating information from neighboring nodes during message passing.
        #
        # In the GAT paper, attention scores are calculated using an additive attention mechanism. Specifically, for each attention head, the attention score between a target node (node i) and a source node (node j) is computed as follows:
        #
        #   score_ij = LeakyReLU( (W * hi) ⋅ (atgt)^T + (W * hj) ⋅ (asrc)^T )
        #
        # Here:
        #   - W represents the learnable parameters of the linear projection, transforming the input node features hi and hj into the desired output dimension num_out_features.
        #   - atgt and asrc are learnable parameters specific to the target and source nodes, respectively. These parameters are represented by self.scoring_fun_target and self.scoring_fun_source.
        #   - LeakyReLU is the activation function used in the scoring function.
        #
        # The self.scoring_fun_target parameter is defined as a learnable tensor of shape (1, num_of_heads, num_out_features). This tensor holds the parameters atgt used in the scoring function for targeting nodes. Each attention head has its own set of parameters, allowing the model to learn different attention patterns across heads.
        #
        # By parameterizing the scoring function in this way, the GAT model can learn to assign different importance scores to neighbors of each target node, enabling effective information aggregation in graph-structured data. The parameters in self.scoring_fun_target are updated during training via backpropagation, allowing the model to adaptively learn the attention mechanism that best suits the task at hand.

        self.scoring_fun_target = nn.Parameter(torch.Tensor(1, num_of_heads, num_out_features))
        self.scoring_fun_source = nn.Parameter(torch.Tensor(1, num_of_heads, num_out_features))

        if bias and concat:
            self.bias = nn.Parameter(torch.Tensor(num_of_heads * num_out_features))
        elif bias and not concat:
            self.bias = nn.Parameter(torch.Tensor(num_out_features))
        else:
            self.register_parameter('bias', None)

        if add_skip_connection:#skip connections in the GAT model are like shortcuts that allow the network to directly use the original input features along with the processed features from the current layer. This helps the model retain valuable information from the input, making it easier to learn and preventing the loss of important information during training.
            self.skip_proj = nn.Linear(num_in_features, num_of_heads * num_out_features, bias=False)
        else:
            self.register_parameter('skip_proj', None)

        self.leakyReLU = nn.LeakyReLU(0.2) 
        self.activation = activation
    
        self.dropout = nn.Dropout(p=dropout_prob)

        self.log_attention_weights = log_attention_weights  
        self.attention_weights =None 

        self.init_params()

    def forward(self, data):
       

        in_nodes_features, edge_index = data  # unpack data
        ''' in_nodes_features: This tensor represents the input node
        features. It has a shape of (N, F), where N is the number of 
        nodes and F is the number of features per node. Each row 
        corresponds to the feature vector of a single node in the graph.
        edge_index: This tensor represents the edge index of the 
        graph. It has a shape of (2, E), where E is the number of
        edges in the graph. Each column corresponds to an edge in 
        the graph, where the first row contains the indices of the
        source nodes and the second row contains the indices of the
        target nodes for each edge.'''
        num_of_nodes = in_nodes_features.shape[self.nodes_dim]
        assert edge_index.shape[0] == 2, f'Expected edge index with shape=(2,E) got {edge_index.shape}'

        # shape = (N, F_input) 
        # 
        in_nodes_features = self.dropout(in_nodes_features)

        # shape = (N, F_input) * (F_input, NoHeads*F_output) -> (N, NoHeads, F_output) 
        # We project the input node features into Number of Heads independent output features (one for each attention head)
        nodes_features_proj = self.linear_proj(in_nodes_features).view(-1, self.num_of_heads, self.num_out_features)

        nodes_features_proj = self.dropout(nodes_features_proj)  

        
        #  Edge attention calculation
        

        # Apply the scoring function (* represents element-wise (Hadamard) product)
        # shape = (N, NoHeads, F_output)[nodes_features_proj] * [scoring_function](1, NNoHeadsH, F_output) -> (N, NoHeads, 1) -> (N, NoHeads) becaus
        scores_source = (nodes_features_proj * self.scoring_fun_source).sum(dim=-1)
        scores_target = (nodes_features_proj * self.scoring_fun_target).sum(dim=-1)

        
        # not all the possible combinations of scores  just prepare scores that will actually be used and those are defined
        # by the edge index.
        # scores shape = (E, NoHeads), nodes_features_proj_lifted shape = (E, NoHeads, F_output), E - number of edges in the graph
        scores_source_lifted, scores_target_lifted, nodes_features_proj_lifted = self.lift(scores_source, scores_target, nodes_features_proj, edge_index)
        scores_per_edge = self.leakyReLU(scores_source_lifted + scores_target_lifted)

        # shape = (E, NoHeads , 1)
        attentions_per_edge = self.neighborhood_aware_softmax(scores_per_edge, edge_index[self.trg_nodes_dim], num_of_nodes)
        attentions_per_edge = self.dropout(attentions_per_edge)

        #
        # Step 3: Neighborhood aggregation
        #

        # Element-wise (aka Hadamard) product. Operator * does the same thing as torch.mul
        # shape = (E, NoHeads , F_output) * (E, NoHeads , 1) -> (E, NoHeads , F_output), 1 gets broadcast into F_output
        nodes_features_proj_lifted_weighted = nodes_features_proj_lifted * attentions_per_edge

        # shape = (N, NoHeads , F_output)
        out_nodes_features = self.aggregate_neighbors(nodes_features_proj_lifted_weighted, edge_index, in_nodes_features, num_of_nodes)


        out_nodes_features = self.skip_concat_bias(attentions_per_edge, in_nodes_features, out_nodes_features)
        return (out_nodes_features, edge_index)

    #
    # 
    #

    def neighborhood_aware_softmax(self, scores_per_edge, trg_index, num_of_nodes):
       
        '''ensures that all exponentials are computed with non-positive inputs, which results in values between 0 and 1 after exponentiation.'''
        scores_per_edge = scores_per_edge - scores_per_edge.max()
        exp_scores_per_edge = scores_per_edge.exp()  # softmax

        # . shape = (E, NoHeads)
        neigborhood_aware_denominator = self.sum_edge_scores_neighborhood_aware(exp_scores_per_edge, trg_index, num_of_nodes)

        
        attentions_per_edge = exp_scores_per_edge / (neigborhood_aware_denominator + 1e-16)

        # shape = (E, NoHeads) -> (E, NoHeads, 1) 
        '''eg Original attention scores:
tensor([[0.8000, 0.2000],
        [0.6000, 0.4000],
        [0.3000, 0.7000]])
Reshaped attention scores:
tensor([[[0.8000],
         [0.2000]],

        [[0.6000],
         [0.4000]],

        [[0.3000],
         [0.7000]]])'''
        return attentions_per_edge.unsqueeze(-1)

    def sum_edge_scores_neighborhood_aware(self, exp_scores_per_edge, trg_index, num_of_nodes):
        # The shape must be the same as in exp_scores_per_edge (required by scatter_add_) i.e. from E -> (E, NoHeads )
        trg_index_broadcasted = self.explicit_broadcast(trg_index, exp_scores_per_edge)

        # shape = (N, NoHeads ), where N is the number of nodes and NoHeads the number of attention heads
        size = list(exp_scores_per_edge.shape)  # convert to list otherwise assignment is not possible
        size[self.nodes_dim] = num_of_nodes
        neighborhood_sums = torch.zeros(size, dtype=exp_scores_per_edge.dtype, device=exp_scores_per_edge.device)

        # position i will contain a sum of exp scores of all the nodes that point to the node i (as dictated by the
        # target index)
        neighborhood_sums.scatter_add_(self.nodes_dim, trg_index_broadcasted, exp_scores_per_edge)

        # Expand again so that we can use it as a softmax denominator. e.g. node i's sum will be copied to
        # all the locations where the source nodes pointed to i (as dictated by the target index)
        # shape = (N, NoHeads ) -> (E, NoHeads )
        return neighborhood_sums.index_select(self.nodes_dim, trg_index)

    def aggregate_neighbors(self, nodes_features_proj_lifted_weighted, edge_index, in_nodes_features, num_of_nodes):
        size = list(nodes_features_proj_lifted_weighted.shape)  # convert to list otherwise assignment is not possible
        size[self.nodes_dim] = num_of_nodes  # shape = (N, NoHeads , F_output)
        out_nodes_features = torch.zeros(size, dtype=in_nodes_features.dtype, device=in_nodes_features.device)

        # shape = (E) -> (E, NoHeads , F_output)
        trg_index_broadcasted = self.explicit_broadcast(edge_index[self.trg_nodes_dim], nodes_features_proj_lifted_weighted)
        # shape = (E, NoHeads , F_output) -> (N, NoHeads , F_output)
        '''For each pair of indices (i, j) in trg_index_broadcasted, we take the corresponding values from
nodes_features_proj_lifted_weighted and add them to out_nodes_features at the specified positions (i, j).
This addition is done in such a way that if multiple values are added to the same position (i, j), they are accumulated (added together).'''
        out_nodes_features.scatter_add_(self.nodes_dim, trg_index_broadcasted, nodes_features_proj_lifted_weighted)

        return out_nodes_features

    def lift(self, scores_source, scores_target, nodes_features_matrix_proj, edge_index):
        '''for  the first edge (0, 1),  extracting 
        scores_source[0], scores_target[1], and nodes_features_matrix_proj[0].
        '''
        
        src_nodes_index = edge_index[self.src_nodes_dim]
        trg_nodes_index = edge_index[self.trg_nodes_dim]

        scores_source = scores_source.index_select(self.nodes_dim, src_nodes_index)
        scores_target = scores_target.index_select(self.nodes_dim, trg_nodes_index)
        nodes_features_matrix_projection_lifted = nodes_features_matrix_proj.index_select(self.nodes_dim, src_nodes_index)

        return scores_source, scores_target, nodes_features_matrix_projection_lifted

    def explicit_broadcast(self, this, other):
        for _ in range(this.dim(), other.dim()):
            this = this.unsqueeze(-1)

        # Explicitly expand so that shapes are the same
        return this.expand_as(other)

    def init_params(self):
        """
        The reason we're using Glorot (aka Xavier uniform) initialization is because it's a default TF initialization:
            https://stackoverflow.com/questions/37350131/what-is-the-default-variable-initializer-in-tensorflow

        The original repo was developed in TensorFlow (TF) and they used the default initialization.
        Feel free to experiment - there may be better initializations depending on your problem.

        """
        nn.init.xavier_uniform_(self.linear_proj.weight)
        nn.init.xavier_uniform_(self.scoring_fun_target)
        nn.init.xavier_uniform_(self.scoring_fun_source)

        if self.bias is not None:
            torch.nn.init.zeros_(self.bias)

    def skip_concat_bias(self, attention_coefficients, in_nodes_features, out_nodes_features):
       

        if self.add_skip_connection:  # add skip or residual connection
            if out_nodes_features.shape[-1] == in_nodes_features.shape[-1]:  # if F_input  == F_output
                # unsqueeze does this: (N, F_input ) -> (N, 1, F_input ), out features are (N, NoHeads , F_output) so 1 gets broadcast to NoHeads 
                # thus we're basically copying input vectors NoHeads times and adding to processed vectors
                out_nodes_features += in_nodes_features.unsqueeze(1)
            else:
                # F_input  != F_output so we need to project input feature vectors into dimension that can be added to output
                # feature vectors. skip_proj adds lots of additional capacity which may cause overfitting.
                out_nodes_features += self.skip_proj(in_nodes_features).view(-1, self.num_of_heads, self.num_out_features)

        if self.concat:
            # shape = (N, NoHeads , F_output) -> (N, NoHeads *F_output)
            out_nodes_features = out_nodes_features.view(-1, self.num_of_heads * self.num_out_features)
        else:
            # shape = (N, NoHeads , F_output) -> (N, F_output)
            out_nodes_features = out_nodes_features.mean(dim=self.head_dim)

        if self.bias is not None:
            out_nodes_features += self.bias

        return out_nodes_features if self.activation is None else self.activation(out_nodes_features)

In [22]:
from torch.utils.tensorboard import SummaryWriter


class LoopPhase(enum.Enum):
    TRAIN = 0,
    VAL = 1,
    TEST = 2


writer = SummaryWriter()  # (tensorboard) writer will output to ./runs/ directory by default



BEST_VAL_MICRO_F1 = 0
BEST_VAL_LOSS = 0
PATIENCE_CNT = 0

CHECKPOINTS_PATH = os.path.join(os.getcwd(), 'models', 'checkpoints')

# Make sure these exist as the rest of the code assumes it
os.makedirs(BINARIES_PATH, exist_ok=True)
os.makedirs(CHECKPOINTS_PATH, exist_ok=True)


In [15]:
# Testing  the GAT model
num_of_layers = 2  # Example: Number of GAT layers
num_heads_per_layer = [4, 2]  # Example: Number of attention heads per layer
num_features_per_layer = [node_features.shape[1], 32, 16]  # Example: Number of features per layer
add_skip_connection = True
bias = True
dropout = 0.6
log_attention_weights = False

gat_model = GAT(num_of_layers=num_of_layers,
                num_heads_per_layer=num_heads_per_layer,
                num_features_per_layer=num_features_per_layer,
                add_skip_connection=add_skip_connection,
                bias=bias,
                dropout=dropout,
                log_attention_weights=log_attention_weights)

# Pass a batch of graph data through the model
output = gat_model((node_features, edge_index))

# Analyze the output
out_nodes_features, edge_index = output
print("Output node features shape:", out_nodes_features.shape)
print("Edge index shape:", edge_index.shape)


Output node features shape: torch.Size([2794, 16])
Edge index shape: torch.Size([2, 88112])


In [16]:
import re  # regex


def get_training_state(training_config, model):
    training_state = {
       
        "num_of_epochs": training_config['num_of_epochs'],
        "test_perf": training_config['test_perf'],

        "num_of_layers": training_config['num_of_layers'],
        "num_heads_per_layer": training_config['num_heads_per_layer'],
        "num_features_per_layer": training_config['num_features_per_layer'],
        "add_skip_connection": training_config['add_skip_connection'],
        "bias": training_config['bias'],
        "dropout": training_config['dropout'],

        #
        "state_dict": model.state_dict()
    }

    return training_state


def print_model_metadata(training_state):
    header = f'\n{"*"*5} Model training metadata: {"*"*5}'
    print(header)

    for key, value in training_state.items():
        if key != 'state_dict':  # don't print state_dict it's a bunch of numbers...
            print(f'{key}: {value}')
    print(f'{"*" * len(header)}\n')


def get_available_binary_name(dataset_name='unknown'):
    prefix = f'gat_{dataset_name}'

    def valid_binary_name(binary_name):
        pattern = re.compile(rf'{prefix}_[0-9]{{6}}\.pth')
        return re.fullmatch(pattern, binary_name) is not None

    valid_binary_names = list(filter(valid_binary_name, os.listdir(BINARIES_PATH)))
    if len(valid_binary_names) > 0:
        last_binary_name = sorted(valid_binary_names)[-1]
        new_suffix = int(last_binary_name.split('.')[0][-6:]) + 1  # increment by 1
        return f'{prefix}_{str(new_suffix).zfill(6)}.pth'
    else:
        return f'{prefix}_000000.pth'

In [17]:
import argparse


def get_training_args():
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--num_of_epochs", type=int, help="number of training epochs", default=200)
    parser.add_argument("--patience_period", type=int, help="number of epochs with no improvement on val before terminating", default=100)
    parser.add_argument("--lr", type=float, help="model learning rate", default=5e-3)
    parser.add_argument("--weight_decay", type=float, help="L2 regularization on model weights", default=0)
    parser.add_argument("--should_test", type=bool, help='should test the model on the test dataset?', default=True)
    parser.add_argument("--force_cpu", type=bool, help='use CPU if your GPU is too small', default=False)
    parser.add_argument("--dataset_name", type=type("PPI"), help='dataset to use for training', default="PPI")

    parser.add_argument("--console_log_freq", type=int, help="log to output console (epoch) freq (None for no logging)", default=10)
    parser.add_argument("--checkpoint_freq", type=int, help="checkpoint model saving (epoch) freq (None for no logging)", default=5)

    parser.add_argument("--batch_size", type=int, help='number of graphs in a batch', default=2)

    # Logging/debugging/checkpoint related (helps a lot with experimentation)
    parser.add_argument("--enable_tensorboard", type=bool, help="enable tensorboard logging", default=False)
    args = parser.parse_args('')

    
    gat_config = {
        "num_of_layers": 3,  
        "num_heads_per_layer": [4, 4, 6],  
        "num_features_per_layer": [PPI_NUM_INPUT_FEATURES, 64, 64, PPI_NUM_CLASSES], 
        "add_skip_connection": True, 
        "bias": True,  #
        "dropout": 0.0,  
    }

    training_config = dict()
    for arg in vars(args):
        training_config[arg] = getattr(args, arg)
    training_config['ppi_load_test_only'] = False  # load both train/val/test data loaders (don't change it)

    # Add additional config information
    training_config.update(gat_config)

    return training_config

In [18]:
get_training_args()

{'num_of_epochs': 200,
 'patience_period': 100,
 'lr': 0.005,
 'weight_decay': 0,
 'should_test': True,
 'force_cpu': False,
 'dataset_name': 'PPI',
 'console_log_freq': 10,
 'checkpoint_freq': 5,
 'batch_size': 2,
 'enable_tensorboard': False,
 'ppi_load_test_only': False,
 'num_of_layers': 3,
 'num_heads_per_layer': [4, 4, 6],
 'num_features_per_layer': [50, 64, 64, 121],
 'add_skip_connection': True,
 'bias': True,
 'dropout': 0.0}

In [19]:
import time


def train_gat(config):
    
    global BEST_VAL_MICRO_F1, BEST_VAL_LOSS

    #
    device = torch.device("cuda" if torch.cuda.is_available() and not config['force_cpu'] else "cpu")

    data_loader_train, data_loader_val, data_loader_test = load_graph_data(config, device)

    gat = GAT(
        num_of_layers=config['num_of_layers'],
        num_heads_per_layer=config['num_heads_per_layer'],
        num_features_per_layer=config['num_features_per_layer'],
        add_skip_connection=config['add_skip_connection'],
        bias=config['bias'],
        dropout=config['dropout'],
        
    ).to(device)

    loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
    optimizer = Adam(gat.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])

    main_loop = get_main_loop(
        config,
        gat,
        loss_fn,
        optimizer,
        config['patience_period'],
        time.time())

    BEST_VAL_MICRO_F1, BEST_VAL_LOSS, PATIENCE_CNT = [0, 0, 0]  # reset vars used for early stopping

    for epoch in range(config['num_of_epochs']):
        main_loop(phase=LoopPhase.TRAIN, data_loader=data_loader_train, epoch=epoch)

        # Validation loop
        with torch.no_grad():
            try:
                main_loop(phase=LoopPhase.VAL, data_loader=data_loader_val, epoch=epoch)
            except Exception as e:  # "patience has run out" exception :O
                print(str(e))
                break 
   
    if config['should_test']:
        micro_f1 = main_loop(phase=LoopPhase.TEST, data_loader=data_loader_test)
        config['test_perf'] = micro_f1

        print('*' * 50)
        print(f'Test micro-F1 = {micro_f1}')
    else:
        config['test_perf'] = -1

    # Save the latest GAT in the binaries directory
    torch.save(
        get_training_state(config, gat),
        os.path.join(BINARIES_PATH, get_available_binary_name(config['dataset_name']))
    )

In [20]:
from sklearn.metrics import f1_score


def get_main_loop(config, gat, sigmoid_cross_entropy_loss, optimizer, patience_period, time_start):

    device = next(gat.parameters()).device  # fetch the device info from the model instead of passing it as a param

    def main_loop(phase, data_loader, epoch=0):
        global BEST_VAL_MICRO_F1, BEST_VAL_LOSS, PATIENCE_CNT, writer

        # 
        if phase == LoopPhase.TRAIN:
            gat.train()
        else:
            gat.eval()

        for batch_idx, (node_features, gt_node_labels, edge_index) in enumerate(data_loader):
            
            edge_index = edge_index.to(device)
            node_features = node_features.to(device)
            gt_node_labels = gt_node_labels.to(device)

            graph_data = (node_features, edge_index)

           # [0] the node_features part of the data (index 1 contains the edge_index)
            # shape = (N, C) where N is the number of nodes in the batch and C is the number of classes (121 for PPI)
            nodes_unnormalized_scores = gat(graph_data)[0]

           
            loss = sigmoid_cross_entropy_loss(nodes_unnormalized_scores, gt_node_labels)

            if phase == LoopPhase.TRAIN:
                optimizer.zero_grad() 
                loss.backward()  
                optimizer.step() 

            # C

            # If the unnormalized score is bigger than 0 that means that sigmoid would have a value higher than 0.5
            # (by sigmoid's definition) and thus we have predicted 1 for that label otherwise we have predicted 0.
            pred = (nodes_unnormalized_scores > 0).float().cpu().numpy()
            gt = gt_node_labels.cpu().numpy()
            micro_f1 = f1_score(gt, pred, average='micro')

            #
            # Logging
            #

            global_step = len(data_loader) * epoch + batch_idx
            if phase == LoopPhase.TRAIN:
                # Log metrics
                if config['enable_tensorboard']:
                    writer.add_scalar('training_loss', loss.item(), global_step)
                    writer.add_scalar('training_micro_f1', micro_f1, global_step)

                if config['console_log_freq'] is not None and epoch % config['console_log_freq'] == 0 and batch_idx == 0:
                    print(f'GAT training: time elapsed= {(time.time() - time_start):.2f} [s] |'
                          f' epoch={epoch + 1} | batch={batch_idx + 1} | train micro-F1={micro_f1}.')

                if config['checkpoint_freq'] is not None and (epoch + 1) % config['checkpoint_freq'] == 0 and batch_idx == 0:
                    ckpt_model_name = f'gat_{config["dataset_name"]}_ckpt_epoch_{epoch + 1}.pth'
                    config['test_perf'] = -1  # test perf not calculated yet, note: perf means main metric micro-F1 here
                    torch.save(get_training_state(config, gat), os.path.join(CHECKPOINTS_PATH, ckpt_model_name))

            elif phase == LoopPhase.VAL:
                if config['enable_tensorboard']:
                    writer.add_scalar('val_loss', loss.item(), global_step)
                    writer.add_scalar('val_micro_f1', micro_f1, global_step)

                # Log to console
                if config['console_log_freq'] is not None and epoch % config['console_log_freq'] == 0 and batch_idx == 0:
                    print(f'GAT validation: time elapsed= {(time.time() - time_start):.2f} [s] |'
                          f' epoch={epoch + 1} | batch={batch_idx + 1} | val micro-F1={micro_f1}')

               
                if micro_f1 > BEST_VAL_MICRO_F1 or loss.item() < BEST_VAL_LOSS:
                    BEST_VAL_MICRO_F1 = max(micro_f1, BEST_VAL_MICRO_F1)  # keep track of the best validation micro_f1 so far
                    BEST_VAL_LOSS = min(loss.item(), BEST_VAL_LOSS)  # and the minimal loss
                    PATIENCE_CNT = 0  
                else:
                    PATIENCE_CNT += 1  

                if PATIENCE_CNT >= patience_period:
                    raise Exception('Stopping the training')

            else:
                return micro_f1  # 

    return main_loop  

In [23]:
train_gat(get_training_args())

Loading train graph 1 to CPU. It has 1767 nodes and 34085 edges.
Loading train graph 2 to CPU. It has 1377 nodes and 31081 edges.
Loading train graph 3 to CPU. It has 2263 nodes and 61907 edges.
Loading train graph 4 to CPU. It has 2339 nodes and 67769 edges.
Loading train graph 5 to CPU. It has 1578 nodes and 37740 edges.
Loading train graph 6 to CPU. It has 1021 nodes and 19237 edges.
Loading train graph 7 to CPU. It has 1823 nodes and 46153 edges.
Loading train graph 8 to CPU. It has 2488 nodes and 72878 edges.
Loading train graph 9 to CPU. It has 591 nodes and 8299 edges.
Loading train graph 10 to CPU. It has 3312 nodes and 109510 edges.
Loading train graph 11 to CPU. It has 2401 nodes and 66619 edges.
Loading train graph 12 to CPU. It has 1878 nodes and 48146 edges.
Loading train graph 13 to CPU. It has 1819 nodes and 47587 edges.
Loading train graph 14 to CPU. It has 3480 nodes and 110234 edges.
Loading train graph 15 to CPU. It has 2794 nodes and 88112 edges.
Loading train graph