In [11]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid, Reddit
from torch_geometric.nn import GCNConv
# from torch_geometric.data import ClusterData, ClusterLoader
from torch_geometric.loader import ClusterData, ClusterLoader



In [12]:

# Load a dataset
dataset = Planetoid(name='Cora', root='/tmp/Cora')


In [13]:
dataset = Reddit(root='/tmp/Reddit')


In [3]:
dataset[0]

Data(x=[232965, 602], edge_index=[2, 114615892], y=[232965], train_mask=[232965], val_mask=[232965], test_mask=[232965])

In [14]:
# Define a simple GCN model
class GCN(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, 16)
        self.conv2 = GCNConv(16, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


# Prepare Cluster-GCN data
cluster_data = ClusterData(dataset[0], num_parts=100, recursive=False)
train_loader = ClusterLoader(cluster_data, batch_size=1, shuffle=True)


Computing METIS partitioning...
Done!


In [16]:
# Initialize the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(dataset.num_node_features, dataset.num_classes).to(device)


In [18]:
# Train the model
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train():
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

# Training loop
for epoch in range(300):
    loss = train()
    print(f'Epoch {epoch+1}, Loss: {loss:.4f}')

Epoch 1, Loss: 0.7370
Epoch 2, Loss: 0.7605
Epoch 3, Loss: 0.7210
Epoch 4, Loss: 0.6236
Epoch 5, Loss: 0.6082
Epoch 6, Loss: 0.6062
Epoch 7, Loss: 0.6452
Epoch 8, Loss: 0.6437
Epoch 9, Loss: 0.6813
Epoch 10, Loss: 0.7765
Epoch 11, Loss: 0.6802
Epoch 12, Loss: 0.6044
Epoch 13, Loss: 0.5936
Epoch 14, Loss: 0.5930
Epoch 15, Loss: 0.5817
Epoch 16, Loss: 0.5784
Epoch 17, Loss: 0.5729
Epoch 18, Loss: 0.5669
Epoch 19, Loss: 0.5746
Epoch 20, Loss: 0.5777
Epoch 21, Loss: 0.5751
Epoch 22, Loss: 0.6403
Epoch 23, Loss: 0.6000
Epoch 24, Loss: 1.1584
Epoch 25, Loss: 1.6815
Epoch 26, Loss: 1.1583
Epoch 27, Loss: 0.9076
Epoch 28, Loss: 1.0434
Epoch 29, Loss: 0.7563
Epoch 30, Loss: 0.6823
Epoch 31, Loss: 0.6425
Epoch 32, Loss: 0.6150
Epoch 33, Loss: 0.6100
Epoch 34, Loss: 0.5933
Epoch 35, Loss: 0.5880
Epoch 36, Loss: 0.5749
Epoch 37, Loss: 0.5724
Epoch 38, Loss: 0.5642
Epoch 39, Loss: 0.5695
Epoch 40, Loss: 0.5674
Epoch 41, Loss: 0.5664
Epoch 42, Loss: 0.5656
Epoch 43, Loss: 0.5601
Epoch 44, Loss: 0.56