<a href="https://colab.research.google.com/github/Papa-Panda/Paper_reading/blob/main/GNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#
# Below is an example implementation of a simple Graph Neural Network (GNN) using PyTorch,
# inspired by the ideas in "A Gentle Introduction to Graph Neural Networks".
#
# https://www.bilibili.com/video/BV1iT4y1d7zP/?spm_id_from=333.337.search-card.all.click&vd_source=83baba81780fd95e96c22e9346057527

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import networkx as nx
import numpy as np

# Example graph (Adjacency matrix)
def create_graph():
    G = nx.karate_club_graph()
    adj_matrix = nx.adjacency_matrix(G).todense()
    return torch.tensor(adj_matrix, dtype=torch.float), G

# Graph Neural Network
class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, adj):
        # Layer 1: Input -> Hidden
        x = torch.mm(adj, x)  # Graph convolution step
        x = self.layer1(x)
        x = F.relu(x)         # Activation

        # Layer 2: Hidden -> Output
        x = torch.mm(adj, x)  # Graph convolution step
        x = self.layer2(x)
        return x

# Prepare the graph data
adj, G = create_graph()
n_nodes = adj.shape[0]
input_dim = 5   # Number of features per node
hidden_dim = 8
output_dim = 2  # Number of classes

# Generate random node features
node_features = torch.rand((n_nodes, input_dim))

# Initialize the model
model = GCN(input_dim, hidden_dim, output_dim)

# Forward pass
output = model(node_features, adj)

print("Output node representations:")
print(output)


Output node representations:
tensor([[ 97.1879, -37.2893],
        [ 91.6643, -32.8119],
        [111.8865, -38.8511],
        [ 62.0588, -20.8771],
        [ 27.3404,  -8.8308],
        [ 35.0437, -11.6627],
        [ 31.9782, -10.6693],
        [ 59.7893, -20.6988],
        [ 97.2249, -33.6876],
        [ 22.0759,  -7.5188],
        [ 21.2678,  -7.2313],
        [ 19.5109,  -5.7710],
        [ 14.9175,  -5.5441],
        [ 90.6638, -30.2042],
        [ 35.1552, -12.4398],
        [ 52.0130, -17.8774],
        [ 10.4388,  -4.3482],
        [ 17.3232,  -5.3005],
        [ 22.9296,  -7.8783],
        [ 30.0817,  -9.4191],
        [ 26.7264,  -9.7209],
        [ 21.6528,  -6.7003],
        [ 37.4713, -12.8779],
        [ 89.1182, -31.4153],
        [ 17.2390,  -6.9914],
        [ 46.6149, -21.8273],
        [ 25.2817,  -9.4963],
        [ 63.1515, -23.8775],
        [ 34.5234, -12.6705],
        [ 57.6430, -22.3724],
        [ 59.9191, -21.0433],
        [ 90.6966, -29.3623],
        [12