# Install dependencies and import libs

In [2]:
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [2]:
import torch
print(torch.__version__, torch.version.cuda)

2.0.1+cu118 11.8


In [15]:
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu118.html

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in links: https://data.pyg.org/whl/torch-2.0.0+cu118.html
Collecting pyg_lib
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcu118/pyg_lib-0.3.1%2Bpt20cu118-cp38-cp38-linux_x86_64.whl (2.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m0m
[?25hCollecting torch_scatter
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcu118/torch_scatter-2.1.2%2Bpt20cu118-cp38-cp38-linux_x86_64.whl (10.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.1/10.1 MB[0m [31m49.0 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25hCollecting torch_sparse
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcu118/torch_sparse-0.6.18%2Bpt20cu118-cp38-cp38-linux_x86_64.whl (4.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.9/4.9 MB[0m [31m22.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00

In [42]:
import copy
import random
import os.path as osp
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.sparse as sp

import torch.nn.functional as F
from tqdm import tqdm

from torch_geometric.datasets import Reddit
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import degree
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
import networkx as nx

from networkx.algorithms import community

# Load data

In [55]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

path = osp.join('.', 'Reddit')
dataset = Reddit(path)

In [5]:
device

device(type='cuda')

# Exploratory data analysis

In [56]:
data = dataset[0]

In [20]:
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

Number of features: 602
Number of classes: 41


In [15]:
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')  
print(f'Has self-loops: {data.has_self_loops()}')  
print(f'Is undirected: {data.is_undirected()}')  

Number of nodes: 232965
Number of edges: 114615892
Has isolated nodes: False
Has self-loops: False
Is undirected: True


In [37]:
edge_index = data.edge_index.numpy()
print(edge_index.shape)
edge_example = edge_index[:, np.where(edge_index[0]==45)[0]]
edge_example

(2, 114615892)


array([[    45,     45,     45, ...,     45,     45,     45],
       [   160,    175,    258, ..., 232623, 232625, 232648]])

# Data loaders

In [65]:
# Already send node features/labels to GPU for faster access during sampling:
data = dataset[0].to(device, 'x', 'y')

kwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True}
train_loader = NeighborLoader(data, input_nodes=data.train_mask,
                              num_neighbors=[25, 10], shuffle=True, **kwargs)

subgraph_loader = NeighborLoader(copy.copy(data), input_nodes=None,
                                 num_neighbors=[-1], shuffle=False, **kwargs)

# No need to maintain these features during evaluation:
del subgraph_loader.data.x, subgraph_loader.data.y
# Add global node index information.
subgraph_loader.data.num_nodes = data.num_nodes
subgraph_loader.data.n_id = torch.arange(data.num_nodes)

In [78]:
subgraphs_4 = []
for i, subgraph in enumerate(train_loader):
    print(f'Subgraph {i}: {subgraph}')
    subgraphs_4.append(subgraph)
    if i == 4:
        break

Subgraph 0: Data(x=[105702, 602], edge_index=[2, 233317], y=[105702], train_mask=[105702], val_mask=[105702], test_mask=[105702], n_id=[105702], e_id=[233317], num_sampled_nodes=[3], num_sampled_edges=[2], input_id=[1024], batch_size=1024)
Subgraph 1: Data(x=[106813, 602], edge_index=[2, 236520], y=[106813], train_mask=[106813], val_mask=[106813], test_mask=[106813], n_id=[106813], e_id=[236520], num_sampled_nodes=[3], num_sampled_edges=[2], input_id=[1024], batch_size=1024)
Subgraph 2: Data(x=[105892, 602], edge_index=[2, 232976], y=[105892], train_mask=[105892], val_mask=[105892], test_mask=[105892], n_id=[105892], e_id=[232976], num_sampled_nodes=[3], num_sampled_edges=[2], input_id=[1024], batch_size=1024)
Subgraph 3: Data(x=[106564, 602], edge_index=[2, 236272], y=[106564], train_mask=[106564], val_mask=[106564], test_mask=[106564], n_id=[106564], e_id=[236272], num_sampled_nodes=[3], num_sampled_edges=[2], input_id=[1024], batch_size=1024)
Subgraph 4: Data(x=[105411, 602], edge_i

In [None]:
# G = to_networkx(subgraphs_4[0], to_undirected=True) # takes too long to load data
# nx.draw_networkx(G, with_labels=False,
#                 node_size=200, 
#                 node_color=sub_graph.y,
#                 cmap="cool",
#                 font_size=10)

# GraphSAGE

In [7]:
class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = x.relu_()
                x = F.dropout(x, p=0.5, training=self.training)
        return x

    @torch.no_grad()
    def inference(self, x_all, subgraph_loader):
        pbar = tqdm(total=len(subgraph_loader.dataset) * len(self.convs))
        pbar.set_description('Evaluating')

        # Compute representations of nodes layer by layer, using *all*
        # available edges. This leads to faster computation in contrast to
        # immediately computing the final representations of each batch:
        for i, conv in enumerate(self.convs):
            xs = []
            for batch in subgraph_loader:
                x = x_all[batch.n_id.to(x_all.device)].to(device)
                x = conv(x, batch.edge_index.to(device))
                if i < len(self.convs) - 1:
                    x = x.relu_()
                xs.append(x[:batch.batch_size].cpu())
                pbar.update(batch.batch_size)
            x_all = torch.cat(xs, dim=0)
        pbar.close()
        return x_all

In [8]:
model = SAGE(dataset.num_features, 256, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


# Main methods

In [9]:
def train(epoch):
    model.train()

    pbar = tqdm(total=int(len(train_loader.dataset)))
    pbar.set_description(f'Epoch {epoch:02d}')

    total_loss = total_correct = total_examples = 0
    for batch in train_loader:
        optimizer.zero_grad()
        y = batch.y[:batch.batch_size]
        y_hat = model(batch.x, batch.edge_index.to(device))[:batch.batch_size]
        loss = F.cross_entropy(y_hat, y)
        loss.backward()
        optimizer.step()

        total_loss += float(loss) * batch.batch_size
        total_correct += int((y_hat.argmax(dim=-1) == y).sum())
        total_examples += batch.batch_size
        pbar.update(batch.batch_size)
    pbar.close()

    return total_loss / total_examples, total_correct / total_examples


@torch.no_grad()
def test():
    model.eval()
    y_hat = model.inference(data.x, subgraph_loader).argmax(dim=-1)
    y = data.y.to(y_hat.device)

    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        accs.append(int((y_hat[mask] == y[mask]).sum()) / int(mask.sum()))
    return accs

# Train and evaluate

In [10]:
times = []
for epoch in range(1, 11):
    start = time.time()
    loss, acc = train(epoch)
    print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train: {acc:.4f}')
    train_acc, val_acc, test_acc = test()
    print(f'Epoch: {epoch:02d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
          f'Test: {test_acc:.4f}')
    times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")

Epoch 01: 100%|██████████| 153431/153431 [00:06<00:00, 22376.98it/s]


Epoch 01, Loss: 0.5071, Approx. Train: 0.8945


Evaluating: 100%|██████████| 465930/465930 [00:48<00:00, 9583.01it/s] 


Epoch: 01, Train: 0.9532, Val: 0.9522, Test: 0.9501


Epoch 02: 100%|██████████| 153431/153431 [00:05<00:00, 27000.50it/s]


Epoch 02, Loss: 0.5186, Approx. Train: 0.9225


Evaluating: 100%|██████████| 465930/465930 [00:48<00:00, 9613.82it/s] 


Epoch: 02, Train: 0.9478, Val: 0.9450, Test: 0.9450


Epoch 03: 100%|██████████| 153431/153431 [00:05<00:00, 27148.53it/s]


Epoch 03, Loss: 0.5010, Approx. Train: 0.9241


Evaluating: 100%|██████████| 465930/465930 [00:48<00:00, 9618.08it/s] 


Epoch: 03, Train: 0.9567, Val: 0.9516, Test: 0.9504


Epoch 04: 100%|██████████| 153431/153431 [00:05<00:00, 27012.79it/s]


Epoch 04, Loss: 0.4907, Approx. Train: 0.9276


Evaluating: 100%|██████████| 465930/465930 [00:48<00:00, 9631.18it/s] 


Epoch: 04, Train: 0.9589, Val: 0.9530, Test: 0.9514


Epoch 05: 100%|██████████| 153431/153431 [00:05<00:00, 26323.75it/s]


Epoch 05, Loss: 0.5361, Approx. Train: 0.9285


Evaluating: 100%|██████████| 465930/465930 [00:48<00:00, 9705.44it/s] 


Epoch: 05, Train: 0.9603, Val: 0.9503, Test: 0.9516


Epoch 06: 100%|██████████| 153431/153431 [00:05<00:00, 25751.12it/s]


Epoch 06, Loss: 0.5411, Approx. Train: 0.9289


Evaluating: 100%|██████████| 465930/465930 [00:47<00:00, 9825.06it/s] 


Epoch: 06, Train: 0.9606, Val: 0.9512, Test: 0.9513


Epoch 07: 100%|██████████| 153431/153431 [00:05<00:00, 26536.47it/s]


Epoch 07, Loss: 0.5213, Approx. Train: 0.9298


Evaluating: 100%|██████████| 465930/465930 [00:47<00:00, 9808.63it/s] 


Epoch: 07, Train: 0.9640, Val: 0.9509, Test: 0.9522


Epoch 08: 100%|██████████| 153431/153431 [00:05<00:00, 26746.90it/s]


Epoch 08, Loss: 0.5036, Approx. Train: 0.9322


Evaluating: 100%|██████████| 465930/465930 [00:48<00:00, 9648.99it/s] 


Epoch: 08, Train: 0.9636, Val: 0.9519, Test: 0.9521


Epoch 09: 100%|██████████| 153431/153431 [00:05<00:00, 26738.01it/s]


Epoch 09, Loss: 0.5297, Approx. Train: 0.9329


Evaluating: 100%|██████████| 465930/465930 [00:48<00:00, 9581.31it/s] 


Epoch: 09, Train: 0.9635, Val: 0.9507, Test: 0.9506


Epoch 10: 100%|██████████| 153431/153431 [00:05<00:00, 27099.67it/s]


Epoch 10, Loss: 0.5687, Approx. Train: 0.9320


Evaluating: 100%|██████████| 465930/465930 [00:48<00:00, 9684.86it/s] 


Epoch: 10, Train: 0.9626, Val: 0.9515, Test: 0.9505
Median time per epoch: 54.0749s
