In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# GCN

In [4]:
# Example graph
# Edge list format: [source, destination]
edge_index = torch.tensor([
    [0, 1, 1, 2, 2, 3],  # source
    [1, 0, 2, 1, 3, 2]   # destination
], dtype=torch.long)

# Node features: 4 nodes × 3 features
x = torch.tensor([
    [1, 0, 2],
    [0, 1, 0],
    [1, 1, 0],
    [0, 0, 1]
], dtype=torch.float32)

y = torch.tensor([0, 1, 1, 0])  # node labels

In [5]:
# --- Define GCN Layer using edge list ---
class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.W = nn.Parameter(torch.randn(in_features, out_features) * 0.01)

    def forward(self, x, edge_index):
        """
        x: [N, in_features]
        edge_index: [2, E] tensor where each column is (src, dst)
        """
        num_nodes = x.size(0)
        src, dst = edge_index

        # Step 1: Add self-loops
        self_loops = torch.arange(num_nodes)
        self_loops = self_loops.unsqueeze(0).repeat(2, 1)
        edge_index = torch.cat([edge_index, self_loops], dim=1)
        src, dst = edge_index

        # Step 2: Compute degree
        deg = torch.bincount(dst, minlength=num_nodes).float()
        deg_inv_sqrt = torch.pow(deg, -0.5)
        deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0

        # Step 3: Linear transformation
        x = x @ self.W

        # Step 4: Message passing (normalized aggregation)
        messages = deg_inv_sqrt[dst].unsqueeze(1) * x[src]
        out = torch.zeros_like(x)
        out.index_add_(0, dst, messages * deg_inv_sqrt[dst].unsqueeze(1))

        # Step 5: Activation
        return F.relu(out)

In [6]:
class GCN(nn.Module):
    def __init__(self, in_features, hidden, out_features):
        super().__init__()
        self.conv1 = GCNLayer(in_features, hidden)
        self.conv2 = GCNLayer(hidden, out_features)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.conv2(x, edge_index)
        return x


In [7]:
model = GCN(in_features=3, hidden=4, out_features=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(200):
    optimizer.zero_grad()
    out = model(x, edge_index)
    loss = loss_fn(out, y)
    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        pred = out.argmax(dim=1)
        acc = (pred == y).float().mean()
        print(f"Epoch {epoch:03d}: Loss={loss.item():.4f}, Acc={acc.item():.4f}")


Epoch 000: Loss=0.6931, Acc=0.5000
Epoch 020: Loss=0.6926, Acc=0.5000
Epoch 040: Loss=0.6910, Acc=0.5000
Epoch 060: Loss=0.6892, Acc=0.5000
Epoch 080: Loss=0.6865, Acc=0.5000
Epoch 100: Loss=0.6844, Acc=0.5000
Epoch 120: Loss=0.6836, Acc=0.5000
Epoch 140: Loss=0.6835, Acc=0.5000
Epoch 160: Loss=0.6835, Acc=0.5000
Epoch 180: Loss=0.6835, Acc=0.5000


# MPNN

In [10]:
edge_index = torch.tensor([
    [0, 1, 1, 2, 2, 3],  # src
    [1, 0, 2, 1, 3, 2]   # dst
], dtype=torch.long)

x = torch.tensor([
    [1.0, 0.0, 2.0],
    [0.0, 1.0, 0.0],
    [1.0, 1.0, 0.0],
    [0.0, 0.0, 1.0],
])

y = torch.tensor([0, 1, 1, 0])


In [11]:
# --- Define MPNN layer ---
class MPNNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        # Learnable message function
        self.msg_fn = nn.Sequential(
            nn.Linear(2 * in_features, out_features),
            nn.ReLU()
        )
        # Update function
        self.update_fn = nn.Sequential(
            nn.Linear(in_features + out_features, out_features),
            nn.ReLU()
        )

    def forward(self, x, edge_index):
        """
        x: [N, F_in]
        edge_index: [2, E]
        """
        src, dst = edge_index
        num_nodes = x.size(0)

        # Step 1: Compute messages for all edges
        # Concatenate source and destination node features
        messages = self.msg_fn(torch.cat([x[src], x[dst]], dim=1))  # [E, out_features]

        # Step 2: Aggregate messages by destination node (sum)
        aggr_msg = torch.zeros(num_nodes, messages.size(1), device=x.device)
        aggr_msg.index_add_(0, dst, messages)

        # Step 3: Update node features using previous state + aggregated message
        new_x = self.update_fn(torch.cat([x, aggr_msg], dim=1))
        return new_x


In [12]:
class MPNN(nn.Module):
    def __init__(self, in_features, hidden, out_features):
        super().__init__()
        self.mp1 = MPNNLayer(in_features, hidden)
        self.mp2 = MPNNLayer(hidden, hidden)
        self.out = nn.Linear(hidden, out_features)

    def forward(self, x, edge_index):
        x = self.mp1(x, edge_index)
        x = self.mp2(x, edge_index)
        x = self.out(x)
        return x

In [13]:
model = MPNN(in_features=3, hidden=8, out_features=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(200):
    optimizer.zero_grad()
    out = model(x, edge_index)
    loss = loss_fn(out, y)
    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        pred = out.argmax(dim=1)
        acc = (pred == y).float().mean()
        print(f"Epoch {epoch:03d}: Loss={loss.item():.4f}, Acc={acc.item():.4f}")


Epoch 000: Loss=0.7062, Acc=0.5000
Epoch 020: Loss=0.0021, Acc=1.0000
Epoch 040: Loss=0.0000, Acc=1.0000
Epoch 060: Loss=0.0000, Acc=1.0000
Epoch 080: Loss=0.0000, Acc=1.0000
Epoch 100: Loss=0.0000, Acc=1.0000
Epoch 120: Loss=0.0000, Acc=1.0000
Epoch 140: Loss=0.0000, Acc=1.0000
Epoch 160: Loss=0.0000, Acc=1.0000
Epoch 180: Loss=0.0000, Acc=1.0000


In [14]:
with torch.no_grad():
    preds = model(x, edge_index).argmax(dim=1)
print("Predicted labels:", preds.tolist())
print("True labels:     ", y.tolist())

Predicted labels: [0, 1, 1, 0]
True labels:      [0, 1, 1, 0]
