## 超大图的学习
在超大图上进行图神经网络的训练仍然具有挑战。普通的基于SGD的图神经网络的训练方法，要么面临着随着图神经网络层数增加，计算成本呈指数增长的问题，要么面临着保存整个图的信息和每一层每个节点的表征到内存（显存）而消耗巨大内存（显存）空间的问题。虽然已经有一些论文提出了无需保存整个图的信息和每一层每个节点的表征到GPU内存（显存）的方法，但这些方法可能会损失预测精度或者对提高内存的利用率并不明显。于是论文[Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Network](https://arxiv.org/abs/1905.07953)提出了一种新的图神经网络的训练方法。

In [1]:
from torch_geometric.datasets import Reddit
from torch_geometric.loader import ClusterData, ClusterLoader, NeighborSampler
from torch_geometric.nn import SAGEConv
import torch
from tqdm import tqdm

### 载入数据

In [2]:
dataset = Reddit('../datasets/Reddit')
data = dataset[0]
print(dataset.num_classes)
print(data.num_nodes)
print(data.num_edges)
print(data.num_features)

clustered_data = ClusterData(data, num_parts=1500, recursive=False, save_dir=dataset.processed_dir)
train_loader = ClusterLoader(clustered_data, batch_size=16, shuffle=True)
subgraph_loader = NeighborSampler(data.edge_index, sizes=[-1], batch_size=1024, shuffle=False)

41
232965
114615892
602


In [3]:
data = next(train_loader._get_iterator())
print(data.num_nodes)
print(data.num_edges)
print(data.train_mask.sum())
print(data.y.unique())

2482
139364
tensor(1528)
tensor([ 0,  1,  3,  4,  5,  6,  7,  8,  9, 10, 11, 13, 15, 16, 18, 19, 21, 22,
        23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 35, 36, 37, 38, 39, 40])


### 定义网络

In [4]:
class Net(torch.nn.Module):
    def __init__(self, in_channels, out_channels, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.convs = torch.nn.ModuleList(
            [SAGEConv(in_channels, 128),
             SAGEConv(128, out_channels)]
        )
    
    def forward(self, x, edge_index):
        x = torch.nn.functional.relu(self.convs[0](x, edge_index))
        x = torch.nn.functional.dropout(x, p=0.5, training=self.training)
        return torch.nn.functional.log_softmax(self.convs[1](x, edge_index), dim=-1)
    
    def inference(self, x_all, subgraph_loader, device):
        pbar = tqdm(total=x_all.size(0))
        pbar.set_description('Evaluation')
        
        for i, conv in enumerate(self.convs):
            xs = []
            for batch_size, n_id, adj in subgraph_loader:
                edge_index, _, size = adj.to(device)
                x = x_all[n_id].to(device)
                x_target = x[:size[1]]
                x = conv((x, x_target), edge_index)
                if i != len(self.convs) - 1:
                    x = torch.nn.functional.relu(x)
                xs.append(x.cpu())
                pbar.update(batch_size // 2)
            x_all = torch.cat(xs, dim=0)
        pbar.close()
        return x_all

### 训练函数

In [5]:
def train(model, train_loader, optimizer):
    model.train()
    total_loss = total_nodes = 0
    for batch in train_loader:
        batch = batch.cuda()
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        loss = torch.nn.functional.nll_loss(out[batch.train_mask], batch.y[batch.train_mask])
        loss.backward()
        optimizer.step()
        
        nodes = batch.train_mask.sum().item()
        total_loss += loss.item() * nodes
        total_nodes += nodes
    return total_loss / total_nodes

def test(model, data, device):
    model.eval()
    data = data.to(device)
    with torch.no_grad():
        out = model.inference(data.x, subgraph_loader, device)
        y_pred = out.argmax(dim=-1)
    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        correct = y_pred[mask].eq(data.y[mask]).sum().item()
        accs.append(correct / mask.sum().item())
    return accs

### 模型训练

In [6]:
device = torch.device('cuda:0')
model = Net(dataset.num_features, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

for epoch in range(30):
    loss = train(model, train_loader, optimizer)
    # if epoch % 5 == 0:
    #     train_acc, val_acc, test_acc = test(model, data, device)
    #     print(f'Epoch: {epoch + 1:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
    # else:
    print(f'Epoch: {epoch + 1:03d}, Loss: {loss:.4f}')

Epoch: 001, Loss: 1.0917
Epoch: 002, Loss: 0.4706
Epoch: 003, Loss: 0.4019
Epoch: 004, Loss: 0.3824
Epoch: 005, Loss: 0.3514
Epoch: 006, Loss: 0.3272
Epoch: 007, Loss: 0.3233
Epoch: 008, Loss: 0.3291
Epoch: 009, Loss: 0.3125
Epoch: 010, Loss: 0.2991
Epoch: 011, Loss: 0.2925
Epoch: 012, Loss: 0.2908
Epoch: 013, Loss: 0.2996
Epoch: 014, Loss: 0.3067
Epoch: 015, Loss: 0.2975
Epoch: 016, Loss: 0.2847
Epoch: 017, Loss: 0.2721
Epoch: 018, Loss: 0.2719
Epoch: 019, Loss: 0.2832
Epoch: 020, Loss: 0.2860
Epoch: 021, Loss: 0.2684
Epoch: 022, Loss: 0.2728
Epoch: 023, Loss: 0.2783
Epoch: 024, Loss: 0.2618
Epoch: 025, Loss: 0.2610
Epoch: 026, Loss: 0.2980
Epoch: 027, Loss: 0.2545
Epoch: 028, Loss: 0.2459
Epoch: 029, Loss: 0.2423
Epoch: 030, Loss: 0.2435
