In [2]:
import torch
import torch.nn as nn
import torch.optim as optim 
import numpy as np
import random
from tqdm import tqdm
import torch_scatter
import torch_geometric
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.utils import to_dense_adj
from torch_geometric.datasets import Planetoid

# Prepare Data

In [3]:
dataset = Planetoid('./data', 'Cora')
data = dataset[0]

# GCN Layer and GCN Model

$$
\begin{aligned}
h_{v}^k = \sigma\left(W_k\sum_{u \in \mathcal{N}(v)\cup \{v\}} \dfrac{h_{u}^{k-1}}{\sqrt{|N(u)||N(v)|}} \right) = \sigma\left(W_k\sum_{u } h_{u}^{k-1}\hat{A}_{uv} \right)
\end{aligned}
$$
其中$\hat{A}_{uv}$是归一化的邻接矩阵，即$\hat{A} = D^{-1/2}AD^{-1/2}$，$D$是度数对角矩阵，$A$是增加了对角元后的邻接矩阵

In [4]:
class GCNLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNLayer, self).__init__()
        self.W = nn.Linear(in_channels, out_channels)
        self.act = nn.Sigmoid()
    
    def forward(self, A, X):
        '''
        A: adjacency matrix, shape: (N, N)
        X: feature matrix, shape: (N, in_channels)
        
        return: shape: (N, out_channels)
        '''
        A = A + torch.eye(A.shape[0])
        degree = torch.sum(A, dim=1)
        D = torch.diag(1 / torch.sqrt(degree))
        A_hat = D @ A @ D
        X_hat = A_hat @ X
        return self.act(self.W(X_hat))

In [5]:
class SingleLayerGCN(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SingleLayerGCN, self).__init__()
        self.gcn = GCNLayer(in_channels, out_channels)
    
    def forward(self, A, X):
        return self.gcn(A, X)

class MultiLayerGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(MultiLayerGCN, self).__init__()
        self.gcn1 = GCNLayer(in_channels, hidden_channels)
        self.gcn2 = GCNLayer(hidden_channels, out_channels)
    
    def forward(self, A, X):
        X = self.gcn1(A, X)
        X = self.gcn2(A, X)
        return X

# Re-split train test data by 9:1

In [6]:
N_train = data.x.shape[0] * 9 // 10
all_samples = list(range(data.x.shape[0]))
random.shuffle(all_samples)
train_samples = all_samples[:N_train]
test_samples = all_samples[N_train:]
len(train_samples), len(test_samples)

(2437, 271)

# Training

In [7]:
n_classes = data.y.max() + 1
n_features = data.x.shape[1]
model_single = SingleLayerGCN(n_features, n_classes)
model_multiple = MultiLayerGCN(n_features, 64, n_classes)

In [8]:
def train(A, features, labels, train_idx, model, optimizer, epochs=100):
    
    for epoch in tqdm(range(epochs)):
        model.train()
        optimizer.zero_grad()
        logits = model(A, features)
        loss_train = torch.nn.CrossEntropyLoss()(logits[train_idx], labels[train_idx])
        loss_train.backward()
        optimizer.step()

def test(A, features, labels, test_idx, model):
    model.eval()
    with torch.no_grad():
        logits = model(A, features)
        pred = logits.argmax(dim=1)
        acc_test = int((pred[test_idx] == labels[test_idx]).sum()) / len(test_idx)
        return acc_test

In [9]:
model_single = SingleLayerGCN(n_features, n_classes)
model_multiple = MultiLayerGCN(n_features, 128, n_classes)
optim_single = torch.optim.Adam(model_single.parameters(), lr=0.01)
optim_multiple = torch.optim.Adam(model_multiple.parameters(), lr=0.01)

In [10]:
A = to_dense_adj(data.edge_index)[0]
features = data.x

In [11]:
train(A, features, data.y, train_samples, model_single, optim_single)

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [01:16<00:00,  1.31it/s]


In [12]:
test(A, features, data.y, test_samples, model_single)

0.8634686346863468

In [13]:
train(A, features, data.y, train_samples, model_multiple, optim_multiple)

100%|██████████| 100/100 [02:24<00:00,  1.45s/it]


In [14]:
test(A, features, data.y, train_samples, model_multiple)

0.9396799343455068

# Bonus: GCN with scatter

In [15]:
class GCNLayer_scatter(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNLayer_scatter, self).__init__()
        self.W = nn.Linear(in_channels, out_channels)
        self.act = nn.Sigmoid()
    
    def forward(self, edge_index, X):
        '''
        edge_index: [2, E]
        x: input features, shape [N, in_channels],

        return: shape [N, out_channels]
        '''
        edge_index, _ = add_self_loops(edge_index, num_nodes=X.size(0))
        row, col = edge_index

        deg = degree(col, X.size(0), dtype=X.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]    # (E,)

        X = self.W(X)                                   # (N, out_channels)
        X = X[edge_index[0]] * norm.unsqueeze(1)        # (E, out_channels)
        
        target = edge_index[1]                          # (E,)

        out = torch_scatter.scatter(X, target, dim=0, reduce='sum')   # (N, out_channels)
        return self.act(out)


class SingleLayerGCN_scatter(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SingleLayerGCN_scatter, self).__init__()
        self.gcn = GCNLayer_scatter(in_channels, out_channels)
    
    def forward(self, edge_index, X):
        return self.gcn(edge_index, X)

class MultiLayerGCN_scatter(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(MultiLayerGCN_scatter, self).__init__()
        self.gcn1 = GCNLayer_scatter(in_channels, hidden_channels)
        self.gcn2 = GCNLayer_scatter(hidden_channels, out_channels)
    
    def forward(self, edge_index, X):
        X = self.gcn1(edge_index, X)
        X = self.gcn2(edge_index, X)
        return X        

# Comparison of results

In [16]:
model_scatter_single = SingleLayerGCN_scatter(in_channels=n_features, out_channels=n_classes)

In [17]:
model_scatter_single.eval()
model_single.eval()
with torch.no_grad():
    model_scatter_single.gcn.W.weight = model_single.gcn.W.weight
    print(test(data.edge_index, data.x, data.y, test_samples, model_scatter_single))
    print(test(A, data.x, data.y, test_samples, model_single))

0.8634686346863468
0.8634686346863468


# Comparison of time and memory

In [17]:
import time
def comp_scatter(num_nodes, edge_prob, n_features=128, out_channels=20):
    model_scatter = SingleLayerGCN_scatter(n_features, out_channels)
    X = torch.randn(num_nodes, n_features)
    edge_index = torch_geometric.utils.erdos_renyi_graph(num_nodes, edge_prob)
    start_time = time.time()
    logits = model_scatter(edge_index, X)
    print("Time taken: ", time.time() - start_time)
def comp(num_nodes, edge_prob, n_features=128, out_channels=20):
    model = SingleLayerGCN(n_features, out_channels)
    X = torch.randn(num_nodes, n_features)
    edge_index = torch_geometric.utils.erdos_renyi_graph(num_nodes, edge_prob)
    A = torch_geometric.utils.to_dense_adj(edge_index)[0]
    start_time = time.time()
    logits = model(A, X)
    print("Time taken: ", time.time() - start_time)


In [18]:
%load_ext memory_profiler

节点太多会爆内存且用时太久，因此选取的节点和边数比作业要求少

## 使用scatter的实验

In [19]:
%%memit 
comp_scatter(20000, 0.1)

Time taken:  5.369081735610962
peak memory: 10287.58 MiB, increment: 10025.90 MiB


In [25]:
%%memit 
comp_scatter(10000, 0.1)

Time taken:  1.2652032375335693
peak memory: 2819.27 MiB, increment: 2639.00 MiB


In [27]:
%%memit 
comp_scatter(10000, 0.3)

Time taken:  3.7466859817504883
peak memory: 8087.11 MiB, increment: 7906.42 MiB


In [28]:
%%memit 
comp_scatter(10000, 0.7)

Time taken:  47.16345977783203
peak memory: 9739.15 MiB, increment: 9558.45 MiB


In [29]:
%%memit 
comp_scatter(5000, 0.1)

Time taken:  0.2821488380432129
peak memory: 706.67 MiB, increment: 593.97 MiB


In [30]:
%%memit 
comp_scatter(5000, 0.3)

Time taken:  0.8651759624481201
peak memory: 1959.32 MiB, increment: 1839.07 MiB


In [31]:
%%memit 
comp_scatter(5000, 0.7)

Time taken:  2.0101325511932373
peak memory: 4729.50 MiB, increment: 4609.20 MiB


## 使用邻接矩阵的实验

In [20]:
%%memit 
comp(20000, 0.1)

Time taken:  205.634783744812
peak memory: 9642.47 MiB, increment: 9531.35 MiB


In [34]:
%%memit 
comp(10000, 0.1)

Time taken:  29.314781665802002
peak memory: 2469.82 MiB, increment: 2331.17 MiB


In [35]:
%%memit 
comp(10000, 0.3)

Time taken:  29.426659107208252
peak memory: 2511.20 MiB, increment: 2370.11 MiB


In [36]:
%%memit 
comp(10000, 0.7)

Time taken:  27.021151542663574
peak memory: 4998.66 MiB, increment: 4857.57 MiB


In [37]:
%%memit 
comp(5000, 0.1)

Time taken:  3.8847031593322754
peak memory: 658.40 MiB, increment: 517.30 MiB


In [38]:
%%memit 
comp(5000, 0.3)

Time taken:  3.6381733417510986
peak memory: 734.85 MiB, increment: 593.72 MiB


In [39]:
%%memit 
comp(5000, 0.7)

Time taken:  3.6242895126342773
peak memory: 1250.27 MiB, increment: 1109.14 MiB


结果

运行时间上
- 基于邻接矩阵的GCN时间只与节点数有关，大致为平方关系
- 基于scatter的GCN时间与边数有关，大致为线性关系

内存占用上
- 基于scatter的GCN内存与边数有关，大致为线性关系
- 基于邻接矩阵的GCN内存与节点数和边数均有关