In [551]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

In [552]:
S = (25, 10)
K = 2
OUT_DIM = 256
HID_DIM = 1024
EPOCHS = 10

Download Dataset and create mini-batching of training data

In [553]:
#Download PROTEINS dataset and save in data
dataset = Planetoid(root="data", name="CiteSeer", transform=T.NormalizeFeatures())
data = dataset[0]

feat_data = {'train' : data.x[data.train_mask],
        'val' : data.x[data.val_mask],
        'test' : data.x[data.test_mask]}
label = {'train' : data.y[data.train_mask],
        'val' : data.y[data.val_mask],
        'test' : data.y[data.test_mask]}

#mini-batching of training data
#train_batches = DataLoader(TensorDataset(feat_data['train'], label['train']), batch_size=4)

adj_list = [[] for _ in range(data.num_nodes)]
for edge in data.edge_index.T:
        list.append(adj_list[edge[0].item()], edge[1].item())



In [554]:
#Print information about the dataset
print(f'Dataset: {dataset}')
print('-------------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

# Print information about the graph
print(f'\nGraph:')
print('------')
print(f'Training nodes: {sum(data.train_mask).item()}')
#print(f'Training nodes: {next(iter(train_batches))[1].shape}')
print(f'Evaluation nodes: {sum(data.val_mask).item()}')
#print(f'Evaluation nodes: {len(label["val"])}')
print(f'Testing nodes: {sum(data.test_mask).item()}')
#print(f'Testing nodes: {len(label["test"])}')


Dataset: CiteSeer()
-------------------
Number of graphs: 1
Number of nodes: 3327
Number of features: 3703
Number of classes: 6

Graph:
------
Training nodes: 120
Evaluation nodes: 500
Testing nodes: 1000


Neighborhood Sampler

In [555]:
import random
""" 
Neighborhood Sampler
Params:
    node index of node the neighborhood should be sampled of
    sample_size number of sampled nodes
Returns:
    index list of nodes in neigborhood of node
"""
def sample(node, sample_size):
    if len(adj_list[node]) < sample_size:
        return adj_list[node]
    else:
        return random.sample(adj_list[node], k=sample_size) 

Aggregator function

In [556]:
class MaxPoolingAggregator(torch.nn.Module):
    """ 
    AGGREGATOR: Max Pooling
    Params:
        in_channels feature size of each input sample
        out_channel feature size of each output sample
    """
    def __init__(self, in_channels, out_channels):
        super(MaxPoolingAggregator, self).__init__()
        #fully connected layer with learnable weights
        self.fc_layer = torch.nn.Linear(in_channels, out_channels, bias=True)
    """
    Forward Propagation
    Params:
        neighborhood neigborhood sample of imput node to be aggregated
    feed neighborhood of node through fully connected layer and non linearity
    return maximum of all neighbors
    """
    def forward(self, neigborhood):
        out = []
        for h in neigborhood:
            h = self.fc_layer(h).relu()
            list.append(out, h)
            #print(out.size)
        return torch.amax(torch.stack(out), 0)

Graph SAGE model

In [557]:
class GraphSAGE(torch.nn.Module):
    """ 
    GraphSAGE model
    Params:
        in_channels feature size of each input sample
        hidden_channels feature size of each hidden sample
        out_channel feature size of each output sample
        num_layers =K number of message passing layers
    """
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.aggregate = MaxPoolingAggregator(in_channels, hidden_channels)
        self.linears = torch.nn.ModuleList([torch.nn.Linear(in_channels + hidden_channels, out_channels), 
                                            torch.nn.Linear(out_channels + hidden_channels, out_channels)])
        #Use Adam optimizer vgl Experimental Setup
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.1)
        #Cross entropy loss for supervised learning 
        self.loss_fn = F.cross_entropy

    """
    Forward Propagation
    Params:
        data feature vectors
        input_idx index of input vectors in original ds
    """
    def forward(self, d, input_idx):
        x = d
        y = torch.empty(len(input_idx), OUT_DIM)
        for k in range(K):
            for i, idx in enumerate(input_idx):
                neighborhood_features  = data.x[sample(idx, S[k])]
                h_n = self.aggregate(neighborhood_features)
                h = torch.cat((x[i], h_n))
                h = self.linears[k](h)
                y[i]= F.normalize(h, dim=-1)
            x = y
        return x

In [558]:
def train(model, nodes):
    model.train()
    model.optimizer.zero_grad()
    output = model(feat_data['train'], nodes)
    loss = model.loss_fn(output, label['train'])
    loss.backward()
    model.optimizer.step()
    return loss

@torch.no_grad()
def evaluate(model, nodes):
    model.eval()
    output = model(feat_data['val'], nodes)
    loss = model.loss_fn(output, label['val'])
    return loss
    

In [559]:
model = GraphSAGE(data.num_node_features, HID_DIM, OUT_DIM)
print(model)

for epoch in range(EPOCHS):
    loss = train(model, tuple([i for i in range(120)]))
    print(loss.item())

GraphSAGE(
  (aggregate): MaxPoolingAggregator(
    (fc_layer): Linear(in_features=3703, out_features=1024, bias=True)
  )
  (linears): ModuleList(
    (0): Linear(in_features=4727, out_features=256, bias=True)
    (1): Linear(in_features=1280, out_features=256, bias=True)
  )
)
5.564339637756348
5.421643257141113
5.405557155609131
5.38843297958374
5.362799644470215
5.320326805114746
5.244482517242432
5.150076389312744
5.203665256500244
5.19434928894043
