In [66]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch_geometric.datasets import Planetoid
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

In [48]:
S_1 = 25
S_2 = 10
K = 2
OUT_DIM = 256
HID_DIM = 512
EPOCHS = 10

Download Dataset and create mini-batching of training data

In [85]:
#Download PROTEINS dataset and save in data
dataset = Planetoid(root="data", name="CiteSeer")
data = dataset[0]

feat_data = {'train' : data.x[data.train_mask],
        'val' : data.x[data.val_mask],
        'test' : data.x[data.test_mask]}
label = {'train' : data.y[data.train_mask],
        'val' : data.y[data.val_mask],
        'test' : data.y[data.test_mask]}
#mini-batching of training data
#train_batches = DataLoader(TensorDataset(feat_data['train'], label['train']), batch_size=4)

adj_list = [[] for _ in range(data.num_nodes)]
for edge in data.edge_index.T:
        list.append(adj_list[edge[0].item()], edge[1].item())


In [29]:
#Print information about the dataset
print(f'Dataset: {dataset}')
print('-------------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

# Print information about the graph
print(f'\nGraph:')
print('------')
print(f'Training nodes: {sum(data.train_mask).item()}')
#print(f'Training nodes: {next(iter(train_batches))[1].shape}')
print(f'Evaluation nodes: {sum(data.val_mask).item()}')
#print(f'Evaluation nodes: {len(label["val"])}')
print(f'Testing nodes: {sum(data.test_mask).item()}')
#print(f'Testing nodes: {len(label["test"])}')


Dataset: CiteSeer()
-------------------
Number of graphs: 1
Number of nodes: 3327
Number of features: 3703
Number of classes: 6

Graph:
------
Training nodes: 120
Evaluation nodes: 500
Testing nodes: 1000


Neighborhood Sampler

In [49]:
import random
def sample(node, sample_size):
    if len(adj_list[node]) < sample_size:
        return adj_list[node]
    else:
        return random.sample(adj_list[node], k=sample_size) 

Aggregator function

In [51]:
class MaxPoolingAggregator(torch.nn.Module):
    """ 
    AGGREGATOR: Max Pooling
    Params:
        in_channels feature size of each input sample
        out_channel feature size of each output sample
    """
    def __init__(self, in_channels, out_channels):
        super(MaxPoolingAggregator, self).__init__()
        #fully connected layer with learnable weights
        self.fc_layer = torch.nn.Linear(in_channels, out_channels, bias=True)
        #non-linearity -> ReLu
        self.non_lin = F.relu()
    """
    Forward Propagation
    Params:
        neighborhood neigborhood sample of imput node to be aggregated
    feed neighborhood of node through fully connected layer and non linearity
    return maximum of all neighbors
    """
    def forward(self, neigborhood):
        out = np.array([])
        for h in neigborhood:
            h = self.fc_layer(h)
            h = self.non_lin(h)
            np.append(out, h)
        return np.max(out)

Graph SAGE model

In [None]:
class GraphSAGE(torch.nn.Module):
    """ 
    GraphSAGE model
    Params:
        in_channels feature size of each input sample
        hidden_channels feature size of each hidden sample
        out_channel feature size of each output sample
        num_layers =K number of message passing layers
    """
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=K):
        super(GraphSAGE, self).__init__()
        #Use Adam optimaizer vgl Experimental Setup
        self.optimizer = torch.optim.Adam(self.parameters())
        #Cross entropy loss for supervised learning 
        self.loss_fn = F.cross_entropy()
        self.aggregate = MaxPoolingAggregator(in_channels, hidden_channels)
        self.fc_layer1 = torch.nn.Linear(in_channels + hidden_channels, out_channels)
        self.fc_layer2 = torch.nn.Linear(in_channels + hidden_channels, out_channels)

    """
    Forward Propagation
    Params:
        x input vector 
        edge_index adjacency matrix 
    """
    