In [49]:
from spektral import datasets
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from spektral.layers import GCNConv, GlobalSumPool
from spektral.data import Loader
from spektral.transforms import GCNFilter
from spektral.models.gcn import GCN
from spektral.data.loaders import SingleLoader

The Reddit dataset consists of a graph made of Reddit posts in the month of September, 2014. The label for each node is the community that a post belongs to. The graph is built by sampling 50 large communities and two nodes are connected if the same user commented on both. Node features are obtained by concatenating the average GloVe CommonCrawl vectors of the title and comments, the post's score and the number of comments.

In [5]:
data = datasets.graphsage.Reddit()

Downloading reddit dataset.


100%|██████████████████████████████████████| 1.22G/1.22G [01:50<00:00, 18.4MB/s]

Processing dataset.


100%|██████████████████████████████████████| 1.22G/1.22G [02:01<00:00, 10.8MB/s]


In [6]:
data

Reddit(n_graphs=1)

In [10]:
# Reddit posts in the month of September, 2014
reddit_graph = data[0]
reddit_graph

Graph(n_nodes=232965, n_node_features=602, n_edge_features=None, n_labels=41)

In [17]:
# Adjacency matrix (two nodes are connected if the same user commented on both)
adj_matrix = reddit_graph.a
adj_matrix.shape

(232965, 232965)

In [18]:
adj_matrix[0]

<1x232965 sparse matrix of type '<class 'numpy.float32'>'
	with 367 stored elements in Compressed Sparse Row format>

In [13]:
# Node features (Node features are obtained by concatenating the average GloVe CommonCrawl vectors of the title and comments, 
#                the post's score and the number of comments)
node_features = reddit_graph.x
node_features.shape

(232965, 602)

In [15]:
# Edge features 
edge_features = reddit_graph.e
type(edge_features)

NoneType

In [20]:
# Labels (The label for each node is the community that a post belongs to)
targets = reddit_graph.y
targets.shape

(232965, 41)

In [31]:
# Transform adj matrix for GCNConv
data.apply(GCNFilter())

In [53]:
model = GCN(n_labels=data.n_labels, n_input_channels=data.n_node_features)
model.compile(
    optimizer=Adam(learning_rate=0.01),
    loss=CategoricalCrossentropy(reduction="sum"),
    weighted_metrics=["acc"],
)

However, here's where graphs get in our way. Unlike regular data, like images or sequences, graphs cannot be stretched, cut, or reshaped so that we can fit them into tensors of pre-defined shapes. If a graph has 10 nodes and another one has 4, we have to keep them that way.

This means that iterating over a dataset in mini-batches is not trivial and we cannot simply use the model.fit() method of Keras as-is.

We have to use a data Loader

In [56]:
loader = SingleLoader(data)

In [None]:
model.fit(loader.load(), steps_per_epoch=10, epochs=10)

Epoch 1/10


2021-08-08 12:01:56.816952: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
2021-08-08 12:01:56.820571: I tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 2294685000 Hz


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10