In [3]:
import torch
from torch_geometric.datasets import TUDataset
import numpy as np
import matplotlib.pyplot as plt

In [9]:
#Download PROTEINS dataset and save in data
dataset = TUDataset(root="data", name="PROTEINS", use_node_attr=True)
data = dataset[0]

#Print information about the dataset
print(f'Dataset: {dataset}')
print('-------------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

# Print information about the graph
print(f'\nGraph:')
print('------')
print(f'Edges are directed: {data.is_directed()}')
print(f'Graph has isolated nodes: {data.has_isolated_nodes()}')
print(f'Graph has loops: {data.has_self_loops()}')

Dataset: PROTEINS(1113)
-------------------
Number of graphs: 1113
Number of nodes: 42
Number of features: 4
Number of classes: 2

Graph:
------
Edges are directed: False
Graph has isolated nodes: False
Graph has loops: False


In [None]:
#TODO: implement Mini Batching

In [10]:
from torch_geometric.nn import SAGEConv
import torch.nn.functional as F
from torch.nn import Dropout, Linear

class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, aggregator_type="mean"):
        super().__init__()
        self.num_layers = num_layers
        #SAGEConv layers
        self.sage_in = SAGEConv(in_channels, hidden_channels)
        self.sage_hid = SAGEConv(hidden_channels, hidden_channels)
        self.sage_out = SAGEConv(hidden_channels, out_channels)
        #Use Adam optimaizer vgl Experimental Setup
        self.optimizer = torch.optim.Adam(self.parameters)
    
    """
    Forward Propagation
    Params:
        x input vector 
        edge_index adjacency matrix 
    """
    def forward(self, x, edge_index):
        h = self.sage_in(x, edge_index)
        h = torch.nn.ReLU(h)
        #Number of hidden layers
        for _ in range(self.num_layers - 2):
            h = self.sage_hid(h, edge_index)
            h = torch.nn.ReLU(h)
        h = self.sage_out(h, edge_index)
        return h
