In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

from torch_geometric.data import DataLoader, Data
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import OneHotDegree
from torch_scatter import scatter_max

cwd = os.getcwd()

In [28]:
class MLP(nn.Module):
    """A class implementing the MLP layer of a GIN model."""
    
    def __init__(self, num_layers, input_dim, hidden_dim):
        """ 
        Args:
            num_layers: An integer indicating the number of layers. Input layer is not counted.
            input_dim: An integer indicating the input dimension.
            hiddem_dim: An integer indicating the size of the hidden layers. Also indicates the output dimension.
        """
        
        super(MLP, self).__init__()
        self.num_layers = num_layers
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.linears = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        
        if self.num_layers == 0:
            self.linears.extend([nn.Linear(input_dim, hidden_dim)])
            self.batch_norms.extend([nn.BatchNorm1d(hidden_dim)])
        else:
            self.linears = nn.ModuleList([nn.Linear(input_dim, hidden_dim)] + 
                                         [nn.Linear(hidden_dim, hidden_dim) for _ in range(self.num_layers - 1)])
            self.batch_norms = nn.ModuleList([nn.BatchNorm1d(hidden_dim) for _ in range(self.num_layers)])
            
    def forward(self, x):
        """ Forward pass of the model on input.
        
        Args:
            x: A Tensor representing the input.
        
        Returns: 
            x: A Tensor representing the output.
        """
        
        for i in range(max(1, self.num_layers)):
            x = self.linears[i](x)
            x = self.batch_norms[i](x)
            x = F.relu(x)
        return x

    def reset_parameters(self):
        """ Reset the parameters of the model. """
        
        for i in range(max(1, self.num_layers)):
            self.linears[i].reset_parameters()
            self.batch_norms[i].reset_parameters()
                

class GIN(nn.Module):
    """A class implementing the GIN model."""
    
    def __init__(self, num_layers, num_mlps, input_dim, hidden_dim, output_dim, dropout_rate, 
                 nbh_agg, graph_agg, learn_eps, random):
        """ 
        Args:
            num_layers: An integer indicating the number of layers. Input layer is counted.
            num_mlps: An integer indicating the number of MLP layers. If is 0, the model will be a GIN 1-LAYER.
            input_dim: An integer indicating the input dimension.
            output_dim: An integer indicating the input dimension.
            hiddem_dim: An integer indicating the size of the hidden layers.
            dropout_rate: A float between 0. and 1., indicating the dropout rate, applied at the final layer.
            nbh_agg: A string in ['sum', 'mean', 'max', 'lstm'] indicating the type of node neighborhood aggregation.
            graph_agg: A string in ['sum'] indicating the readout.
            learn_eps: A boolean indicating whether the epsilon parameter is learnable or is fixed as a zero tensor.
            random: A boolean indicating whether random initialization is used. If True, the input dimension is double,
                the input features random features from a normal distribution are concatenated.
        """
        
        super(GIN, self).__init__()
        self.num_layers = num_layers
        self.num_mlps = num_mlps
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.dropout_rate = dropout_rate
        self.nbh_agg = nbh_agg
        self.graph_agg = graph_agg
        self.learn_eps = learn_eps
        self.random = random
        if self.learn_eps:
            self.eps = nn.Parameter(torch.zeros(self.num_layers - 1))
        else:
            self.eps = torch.zeros(self.num_layers - 1)
            
        if self.random:
            self.input_dim *= 2

        self.mlps = nn.ModuleList([MLP(self.num_mlps, self.input_dim, self.hidden_dim)] + 
                                  [MLP(self.num_mlps, self.hidden_dim, self.hidden_dim) for _ in range(
                                      self.num_layers - 2)])
        
        self.linears_graph = nn.ModuleList([nn.Linear(self.input_dim, self.output_dim)] + 
                                           [nn.Linear(self.hidden_dim, self.output_dim) for _ in range(
                                               self.num_layers - 1)])
        
        if nbh_agg == 'lstm':
            self.lstm_input = nn.LSTM(input_dim, input_dim, batch_first = True)
            self.lstm_hidden = nn.LSTM(hidden_dim, hidden_dim, batch_first = True)
        
    def __adj_matrix(self, batch):
        """ Obtain the node adjacency matrix of a graph. An adjacency matrix adj_matrix satisfies: 
        adj_matrix[i][j] = 1. if nodes i and j are adjacen; 0. otherwise.
        
        Args:
            batch: A batch corresponding to a graph.
            
        Returns: A sparse tensor representing the adjacency matrix.
        """
        
        num_nodes = batch.x.shape[0]
        return torch.sparse_coo_tensor(batch.edge_index, [1.] * batch.edge_index.shape[1], (num_nodes, num_nodes))
    
    def __nbh_indices(self, adj_matrix):
        """ Obtain a list containing the tensor of neighbor indices for each node. nbh_indices satisfies: 
        nbh_indices[i] = torch.tensor([j | j neighbour of i]).
        
        Args: 
            adj_matrix: A tensor representing the adjacency matrix of a graph.
            
        Returns:
            nbh_indices: A list of tensors representing the neighbors indices.
        """
        
        nbh_indices = []
        for i in range(adj_matrix.shape[0]):
            nbh_indices.append(adj_matrix[i].coalesce().indices()[0])
        return nbh_indices
        
    def __graphs_matrix(self, batch):
        """ Obtain the graphs matrix given a batch which is a graph formed from other graphs. A graphs_matrix satisfies: 
        graphs_matrix[g][n] = 1. if graph g has node n; 0. otherwise.
        
        Args:
            batch: A batch corresponding to a graph.
        
        Returns:
            graphs_matrix: A tensor representing the graphs matrix.
        """
        
        num_graphs = batch.y.shape[0]
        num_nodes = batch.x.shape[0]
        
        graphs_matrix = torch.zeros((num_graphs, num_nodes))
        for node in range(num_nodes):
            graphs_matrix[batch.batch[node]][node] = 1.
        return graphs_matrix
        
    def __nbh_agg(self, x, adj_matrix, nbh_indices = None, k = None):
        """ Obtain neighborhood aggregation tensor.
        
        Args:
            x: A tensor containing node features.
            adj_matrix: adj_matrix: A tensor representing the adjacency matrix.
            nbh_indices: A list of tensors representing the neighbors indices, used in max and LSTM aggregation.
            k: An integer indicating the layer index, used in LSTM aggregation.
        
        Returns: A tensor representing the neighborhoods aggregations.
        
        """
        
        if self.nbh_agg == 'sum':
            return torch.spmm(adj_matrix, x)

        if self.nbh_agg == 'mean':
            ones = torch.ones(adj_matrix.shape[0], 1)
            num_neighbours = torch.spmm(adj_matrix, ones)
            return torch.spmm(adj_matrix, x) / num_neighbours

        if self.nbh_agg == 'max':
            max_agg = []
            for i in range(x.shape[0]):
                if nbh_indices[i].shape[0] != 0:
                    index = nbh_indices[i].expand((x.shape[1], nbh_indices[i].shape[0])).t()
                    nbhs = x.gather(0, index)
                    index = torch.arange(nbhs.shape[1]).expand(nbhs.shape[0], nbhs.shape[1])
                    max_agg.append(scatter_max(nbhs.flatten(), index.flatten())[0])
                else:
                    max_agg.append(torch.zeros(x.shape[1]))
            return torch.stack(max_agg)

        if self.nbh_agg == 'lstm':
            lstm_agg = []
            for i in range(x.shape[0]):
                if nbh_indices[i].shape[0] != 0:
                    index = nbh_indices[i].expand((x.shape[1], nbh_indices[i].shape[0])).t()
                    nbhs = x.gather(0, index)
                    nbhs = nbhs[torch.randperm(nbhs.shape[0])]
                    nbhs = nbhs.reshape(1, nbhs.shape[0], nbhs.shape[1])
                    if k == 0:
                        output = self.lstm_input(nbhs)[0]
                    else:
                        output = self.lstm_hidden(nbhs)[0]
                    lstm_agg.append(output[0][-1])
                else:
                    lstm_agg.append(torch.zeros(x.shape[1]))
            return torch.stack(lstm_agg)
            
    def __graph_agg(self, x, graphs_matrix):
        """ Obtain graph aggregation tensor.
        
        Args:
            x: A tensor containing node features.
            graphs_matrix: A tensor representing the graphs matrix.
            
        Returns: A tensor representing the graph readout.
        """
        
        if self.graph_agg == 'sum':
            return torch.spmm(graphs_matrix, x)
        
    def forward(self, batch):  
        """ Forward pass of the model on input.
        
        Args:
            batch: A batch, representing a  graph containing all the graphs part of the same batch.
    
        Returns: 
            output: A Tensor representing the output.
        """
        
        x = batch.x
        x = x.type(torch.FloatTensor)
        if self.random:
            x = torch.cat((x, torch.empty((x.shape[0], x.shape[1])).normal_(mean=0.,std=1.)), axis = 1)
            
        num_graphs = batch.y.shape[0]
        graphs_matrix = self.__graphs_matrix(batch)
        adj_matrix = self.__adj_matrix(batch)
        nbh_indices = None
        if self.nbh_agg in ['max', 'lstm']:
            nbh_indices = self.__nbh_indices(adj_matrix)
        
        output = torch.zeros((num_graphs, self.output_dim))
        graph_agg = self.__graph_agg(x, graphs_matrix)
        output += F.dropout(self.linears_graph[0](graph_agg), self.dropout_rate, training = self.training)
            
        for k in range(self.num_layers - 1):
            x = (1 + self.eps[k]) * x + self.__nbh_agg(x, adj_matrix, nbh_indices, k)
            x = self.mlps[k](x)       
            graph_agg = self.__graph_agg(x, graphs_matrix)
            output += F.dropout(self.linears_graph[k + 1](graph_agg), self.dropout_rate, training = self.training)
            
        return output
        
    def reset_parameters(self):
        """ Reset the parameters of the model. """
        
        for i in range(self.num_layers - 1):
            self.mlps[i].reset_parameters()
        for i in range(self.num_layers):
            self.linears_graph[i].reset_parameters()                

In [29]:
def k_fold_cross_validation(k, model, dataset, epochs, batch_size, output_dim = None):
    """ Run k fold cross validation. 
    
    Args:
        k: An integer indicating the number of folds.
        model: A pytorch module.
        dataset: A dataset.
        epochs: An integer indicating the number of epochs.
        batch_size: An integer indicating the size of batches.
        output_dim: An integer indicating the output dimension.
    
    """
    
    N = len(dataset)
    accuracies = []
    for i in range(k):
        model.reset_parameters()
        train_data_loader = DataLoader(dataset[:i * N // k] + dataset[(i + 1) * N // k:], batch_size = batch_size, 
                                       shuffle = True)
        test_data_loader = DataLoader(dataset[i * N // k: (i + 1) * N // k], batch_size = 1, shuffle = False)
        
        train(model, train_data_loader, epochs) 
        print(model.eps)
        acc = test(model, test_data_loader, output_dim)
        print('Fold ' + str(i + 1) + ' accuracy: ' + str(acc * 100) + '%')
        accuracies.append(acc)
    print()
    accuracies = torch.tensor(accuracies)
    print('Global accuracy: ' + str(torch.mean(accuracies).item() * 100) + '%')
    print('Standard deviation: ' + str(torch.std(accuracies).item() * 100) + '%')
    
def train(model, data_loader, epochs):
    """ Train the model on data for a given number of epochs.
    
    Args:
        model: A pytorch module.
        data_loader: A data loader corresponding to a dataset.
        epochs: An integer indicating the number of epochs.
    """
    
    model.train()
    loss =  nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 50, gamma = 0.5)

    for epoch in range(epochs):
        for batch in data_loader:
            optimizer.zero_grad()
            cur_loss = loss(model(batch), batch.y)
            cur_loss.backward()
            optimizer.step()
        scheduler.step()
            
def test(model, data_loader, output_dim = None):
    """ Evaluate the model on data.
    
    Args:
        model: A pytorch module.
        data_loader: A data loader corresponding to a dataset.
        output_dim: An integer indicating the output dimension. If is None, then it's considered to be 2.
    """
    
    if output_dim is None: # binary by default
        output_dim = 2
    model.eval()
    labels = []
    predictions = []
    for batch in data_loader:
        labels.append(batch.y.item())
        predictions.append(torch.argmax(model(batch)[0]).item())
    print('Predictions:' , end = ' ')
    for i in range(output_dim):
        print(str(predictions.count(i)), end = ' ')
    print()
    labels = torch.tensor(labels)
    predictions = torch.tensor(predictions)
    return (torch.sum(labels == predictions) / labels.shape[0]).item()            

In [30]:
def shuffle_dataset(dataset):
    """  
    Args:
        dataset: A dataset.
    
    Returns: The dataset after it's permuted.
    """
    
    index = list(range(len(dataset)))
    random.shuffle(index)
    return [dataset[i] for i in index]

def add_node_features(dataset, num_node_features):
    """ Add zero tensors as node features to a dataset.
    
    Args:
        dataset: A dataset.
        num_node_features: An integer indicating the number of node features.
    """
    
    new_dataset = []
    for i in range(len(dataset)):
        num_nodes = torch.max(dataset[i].edge_index) + 1
        new_dataset.append(Data(x = torch.zeros((num_nodes,  num_node_features)), edge_index = dataset[i].edge_index, 
                                y = dataset[i].y))
    return new_dataset

In [31]:
dataset_mutag = TUDataset(os.path.join(cwd, 'Datasets'), 'MUTAG') # input_dim = 7

dataset_proteins = TUDataset(os.path.join(cwd, 'Datasets'), 'PROTEINS') # input_dim = 3

dataset_imdb_b = TUDataset(os.path.join(cwd, 'Datasets'), 'IMDB-BINARY', transform = OneHotDegree(max_degree = 540))

dataset_imdb_m = TUDataset(os.path.join(cwd, 'Datasets'), 'IMDB-MULTI', transform = OneHotDegree(max_degree = 352)) 

dataset_rdt_b = TUDataset(os.path.join(cwd, 'Datasets'), 'REDDIT-BINARY') 
dataset_rdt_b = add_node_features(dataset_rdt_b, 64)

dataset_rdt_m = TUDataset(os.path.join(cwd, 'Datasets'), 'REDDIT-MULTI-5K') 
dataset_rdt_m = add_node_features(dataset_rdt_b, 64)

dataset_collab = TUDataset(os.path.join(cwd, 'Datasets'), 'COLLAB') 
dataset_collab = add_node_features(dataset_collab, 64)

dataset_nci1 = TUDataset(os.path.join(cwd, 'Datasets'), 'NCI1') # input_dim = 37

dataset_ptc = TUDataset(os.path.join(cwd, 'Datasets'), 'PTC_MR') # input_dim = 18

dataset_syntheticnew = TUDataset(os.path.join(cwd, 'Datasets'), 'SYNTHETICnew', 
                                 transform = OneHotDegree(max_degree = 18))

dataset_cuneiform = TUDataset(os.path.join(cwd, 'Datasets'), 'Cuneiform') # input_dim = 3

num_node_features = 64
datasets = {'MUTAG': (dataset_mutag, 7, 2),
            'PROTEINS': (dataset_proteins, 3, 2), 
            'IMDB-BINARY': (dataset_imdb_b, 541, 2),
            'IMDB-MULTI': (dataset_imdb_m, 353, 3),
            'REDDIT-BINARY': (dataset_rdt_b, num_node_features, 2),
            'REDDIT-MULTI-5K': (dataset_rdt_b, num_node_features, 5),
            'COLLAB': (dataset_collab, num_node_features, 3),
            'NCI1': (dataset_nci1, 37, 2),
            'PTC_MR': (dataset_ptc, 18, 2),
            'SYNTHETICnew': (dataset_syntheticnew, 19, 2),
            'Cuneiform': (dataset_cuneiform, 3, 30)
           }

In [26]:
(dataset, input_dim, output_dim) = datasets['MUTAG']
model = GIN(5, 2, input_dim, 32, output_dim, 0.5, 'sum', 'sum', True, False)
k_fold_cross_validation(10, model, dataset, 350, 32, output_dim)

Parameter containing:
tensor([-0.3252, -0.2164, -0.0030,  0.0417], requires_grad=True)
Predictions: 6 12 
Fold 1 accuracy: 83.33333134651184%
Parameter containing:
tensor([-0.5209, -0.2829, -0.1203,  0.4685], requires_grad=True)
Predictions: 4 15 
Fold 2 accuracy: 84.21052694320679%
Parameter containing:
tensor([-0.7635, -0.2579, -0.0443,  0.9134], requires_grad=True)
Predictions: 6 13 
Fold 3 accuracy: 84.21052694320679%
Parameter containing:
tensor([-1.1456, -0.2362, -0.0925,  1.1386], requires_grad=True)
Predictions: 3 16 
Fold 4 accuracy: 89.47368264198303%
Parameter containing:
tensor([-1.4640, -0.2330, -0.1603,  1.3404], requires_grad=True)
Predictions: 8 11 
Fold 5 accuracy: 89.47368264198303%
Parameter containing:
tensor([-1.5799, -0.3840,  0.0194,  1.8077], requires_grad=True)
Predictions: 4 14 
Fold 6 accuracy: 88.88888955116272%
Parameter containing:
tensor([-1.6540, -0.4309, -0.6354,  2.1429], requires_grad=True)
Predictions: 9 10 
Fold 7 accuracy: 94.73684430122375%
Parame