In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import plotly.graph_objs as go
import plotly.offline as pyo

class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

coraData = Planetoid(root='/tmp/Cora', name='Cora')
data = coraData[0]  

model = GCN(in_channels=coraData.num_node_features, hidden_channels=16, out_channels=coraData.num_classes)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

def train():
    model.train()
    optimizer.zero_grad()  
    out = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()  
    optimizer.step()  
    return loss.item()

def test():
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1) 
        test_accuracy = (pred[data.test_mask] == data.y[data.test_mask]).sum() / data.test_mask.sum()
        return test_accuracy.item()

for epoch in range(300): 
    loss = train()
    if epoch % 10 == 0:
        test_acc = test()
        print(f'Epoch: {epoch}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')

print("Training completed.")

Epoch: 0, Loss: 1.9512, Test Accuracy: 0.4300
Epoch: 10, Loss: 0.7663, Test Accuracy: 0.7200
Epoch: 20, Loss: 0.1661, Test Accuracy: 0.7820
Epoch: 30, Loss: 0.0330, Test Accuracy: 0.7860
Epoch: 40, Loss: 0.0112, Test Accuracy: 0.7890
Epoch: 50, Loss: 0.0058, Test Accuracy: 0.7890
Epoch: 60, Loss: 0.0040, Test Accuracy: 0.7830
Epoch: 70, Loss: 0.0031, Test Accuracy: 0.7820
Epoch: 80, Loss: 0.0026, Test Accuracy: 0.7820
Epoch: 90, Loss: 0.0023, Test Accuracy: 0.7800
Epoch: 100, Loss: 0.0020, Test Accuracy: 0.7780
Epoch: 110, Loss: 0.0018, Test Accuracy: 0.7780
Epoch: 120, Loss: 0.0016, Test Accuracy: 0.7790
Epoch: 130, Loss: 0.0015, Test Accuracy: 0.7790
Epoch: 140, Loss: 0.0013, Test Accuracy: 0.7780
Epoch: 150, Loss: 0.0012, Test Accuracy: 0.7780
Epoch: 160, Loss: 0.0011, Test Accuracy: 0.7770
Epoch: 170, Loss: 0.0010, Test Accuracy: 0.7760
Epoch: 180, Loss: 0.0010, Test Accuracy: 0.7770
Epoch: 190, Loss: 0.0009, Test Accuracy: 0.7770
Epoch: 200, Loss: 0.0008, Test Accuracy: 0.7770
Epo

In [4]:
def plot_interactive_graph():
    edge_index = data.edge_index.cpu().numpy()
    G = nx.Graph()

    for i in range(edge_index.shape[1]):
        G.add_edge(edge_index[0, i], edge_index[1, i])

    labels = {i: coraData[0].y[i].item() for i in range(data.num_nodes)}
    colors = np.array([coraData[0].y[i].item() for i in range(data.num_nodes)])

    pos = nx.spring_layout(G, k=0.5, iterations=50)

    edge_x = []
    edge_y = []
    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.append(x0)
        edge_x.append(x1)
        edge_x.append(None) 
        edge_y.append(y0)
        edge_y.append(y1)
        edge_y.append(None)

    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color='#888'),
        hoverinfo='none',
        mode='lines'
    )

    node_x = []
    node_y = []
    node_color = []
    for i in range(data.num_nodes):
        x, y = pos[i]
        node_x.append(x)
        node_y.append(y)
        node_color.append(f'rgba(0, 0, {255 - (colors[i] * 25)}, 0.7)')  # Blue shades based on class

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers+text',
        hoverinfo='text',
        marker=dict(showscale=True,
                    colorscale='Blues',
                    size=10,
                    color=node_color,
                    line=dict(width=2)),
        text=[f'Node {i}: {labels[i]}' for i in range(data.num_nodes)]
    )

    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        title='Interactive Graph Visualization of the Cora coraData',
                        titlefont=dict(size=16),
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=0,l=0,r=0,t=40),
                        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
                    ))

    pyo.iplot(fig)

plot_interactive_graph()
