In [1]:
#import libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
import matplotlib.pyplot as plt
import networkx as nx

In [4]:
# Define the GCN model from torch nn
class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

In [5]:
# Load the Cora dataset from torch datasets 
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]  # Get the first graph object

In [6]:
# Initialize model, optimizer, and loss function
model = GCN(in_channels=dataset.num_node_features, hidden_channels=16, out_channels=dataset.num_classes)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()



In [7]:
# Training loop
def train():
    model.train()
    optimizer.zero_grad()  # Clear gradients
    out = model(data.x, data.edge_index)  # Forward pass
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute loss
    loss.backward()  # Backpropagation
    optimizer.step()  # Update weights
    return loss.item()

In [8]:
# Evaluation function
def test():
    model.eval()
    with torch.no_grad():  # No gradients needed for evaluation
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)  # Get the predicted classes
        test_accuracy = (pred[data.test_mask] == data.y[data.test_mask]).sum() / data.test_mask.sum()
        return test_accuracy.item()

In [9]:
# Run the training process
for epoch in range(200):  # Number of epochs
    loss = train()
    if epoch % 10 == 0:
        test_acc = test()
        print(f'Epoch: {epoch}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')

print("Training completed.")

Epoch: 0, Loss: 1.9454, Test Accuracy: 0.4850
Epoch: 10, Loss: 0.6218, Test Accuracy: 0.7480
Epoch: 20, Loss: 0.1132, Test Accuracy: 0.7720
Epoch: 30, Loss: 0.0253, Test Accuracy: 0.7660
Epoch: 40, Loss: 0.0094, Test Accuracy: 0.7610
Epoch: 50, Loss: 0.0051, Test Accuracy: 0.7610
Epoch: 60, Loss: 0.0036, Test Accuracy: 0.7600
Epoch: 70, Loss: 0.0029, Test Accuracy: 0.7600
Epoch: 80, Loss: 0.0024, Test Accuracy: 0.7610
Epoch: 90, Loss: 0.0021, Test Accuracy: 0.7600
Epoch: 100, Loss: 0.0018, Test Accuracy: 0.7600
Epoch: 110, Loss: 0.0017, Test Accuracy: 0.7610
Epoch: 120, Loss: 0.0015, Test Accuracy: 0.7630
Epoch: 130, Loss: 0.0013, Test Accuracy: 0.7650
Epoch: 140, Loss: 0.0012, Test Accuracy: 0.7650
Epoch: 150, Loss: 0.0011, Test Accuracy: 0.7650
Epoch: 160, Loss: 0.0010, Test Accuracy: 0.7660
Epoch: 170, Loss: 0.0009, Test Accuracy: 0.7660
Epoch: 180, Loss: 0.0009, Test Accuracy: 0.7660
Epoch: 190, Loss: 0.0008, Test Accuracy: 0.7660
Training completed.


In [10]:
import numpy as np
import plotly.graph_objs as go
import plotly.offline as pyo
import networkx as nx

def plot_simple_graph():
    # Convert PyG edge_index -> networkx graph
    edge_index = data.edge_index.cpu().numpy()
    G = nx.Graph()
    for i in range(edge_index.shape[1]):
        G.add_edge(edge_index[0, i], edge_index[1, i])

    # Use only first 200 nodes to keep it light
    nodes_to_plot = list(G.nodes)[:200]
    H = G.subgraph(nodes_to_plot).copy()

    # Fast layout
    pos = nx.kamada_kawai_layout(H)

    # Edge coordinates
    edge_x, edge_y = [], []
    for edge in H.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x += [x0, x1, None]
        edge_y += [y0, y1, None]

    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color="#888"),
        hoverinfo="none",
        mode="lines"
    )

    # Node coordinates
    node_x, node_y = [], []
    for node in H.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode="markers",
        hoverinfo="text",
        marker=dict(size=8, color="blue"),
        text=[f"Node {n}" for n in H.nodes()]
    )

    # Plot
    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        title="Simple Graph Visualization (subgraph of 200 nodes)",
                        showlegend=False,
                        hovermode="closest",
                        margin=dict(b=0,l=0,r=0,t=40),
                        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
                    ))

    pyo.iplot(fig)  # use pyo.plot(fig) if outside Jupyter

# Call it
plot_simple_graph()
