# Notebook Details

- Understand the functioning of GNN through node prediction task

In [None]:
### Install dependencies
pip install torch torch_geometric

Note: you may need to restart the kernel to use updated packages.


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import global_mean_pool, GCNConv
import numpy as np
import networkx as nx

In [11]:
# Function to create random graphs with binary labels
def generate_random_graphs(num_graphs, num_nodes_range, num_edges_range):
    graphs = []
    labels = []
    for _ in range(num_graphs):
        num_nodes = np.random.randint(*num_nodes_range)
        num_edges = np.random.randint(*num_edges_range)
        
        G = nx.gnm_random_graph(num_nodes, num_edges)
        edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous()
        
        # Create node features (e.g., random or constant features)
        x = torch.randn((num_nodes, 10))  # 16 features per node
        
        # Assign a binary label to the entire graph
        y = torch.tensor([np.random.choice([0, 1])], dtype=torch.long)
        
        # Append to list as Data object
        data = Data(x=x, edge_index=edge_index, y=y)
        graphs.append(data)
        labels.append(y.item())
    
    return graphs, labels

In [12]:
# Generate random graphs
num_graphs = 100
num_nodes_range = (20, 50)
num_edges_range = (30, 75)
graphs, labels = generate_random_graphs(num_graphs, num_nodes_range, num_edges_range)


In [13]:
# Create a DataLoader
loader = DataLoader(graphs, batch_size=10, shuffle=True)


In [14]:
# Define a GNN model
class GraphGNN(nn.Module):
    def __init__(self, num_node_features, hidden_channels, num_classes):
        super(GraphGNN, self).__init__()
        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = nn.Linear(hidden_channels, num_classes)

    def forward(self, x, edge_index, batch):
        # Apply graph convolution layers
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        
        # Apply global pooling to obtain graph-level representation
        x = global_mean_pool(x, batch)
        
        # Apply final linear layer
        out = self.lin(x)
        return out

# Training function
def train():
    model.train()
    for data in loader:
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()

# Testing function
def test(loader):
    model.eval()
    correct = 0
    for data in loader:
        with torch.no_grad():
            out = model(data.x, data.edge_index, data.batch)
            pred = out.argmax(dim=1)
            correct += pred.eq(data.y).sum().item()
    return correct / len(loader.dataset)


In [15]:
# Initialize model, optimizer, and loss function
model = GraphGNN(num_node_features=10, hidden_channels=32, num_classes=2)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

In [16]:
# Train the model
for epoch in range(1, 201):
    train()
    if epoch % 10 == 0:
        acc = test(loader)
        print(f'Epoch: {epoch:03d}, Test Accuracy: {acc:.4f}')

Epoch: 010, Test Accuracy: 0.8000
Epoch: 020, Test Accuracy: 0.9200
Epoch: 030, Test Accuracy: 0.9700
Epoch: 040, Test Accuracy: 1.0000
Epoch: 050, Test Accuracy: 1.0000
Epoch: 060, Test Accuracy: 0.9900
Epoch: 070, Test Accuracy: 1.0000
Epoch: 080, Test Accuracy: 1.0000
Epoch: 090, Test Accuracy: 1.0000
Epoch: 100, Test Accuracy: 1.0000
Epoch: 110, Test Accuracy: 1.0000
Epoch: 120, Test Accuracy: 1.0000
Epoch: 130, Test Accuracy: 1.0000
Epoch: 140, Test Accuracy: 1.0000
Epoch: 150, Test Accuracy: 1.0000
Epoch: 160, Test Accuracy: 1.0000
Epoch: 170, Test Accuracy: 1.0000
Epoch: 180, Test Accuracy: 1.0000
Epoch: 190, Test Accuracy: 1.0000
Epoch: 200, Test Accuracy: 1.0000
