In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing

In [5]:
def dprint(kwargs):
    for name, val in kwargs.items():
        print(f"{name}:\n\t{val}\n")

In [2]:
x = torch.tensor([
    [1.0, 2.0, 3.0],  # node 0 (A)
    [4.0, 5.0, 6.0],  # node 1 (B)
    [7.0, 8.0, 9.0],  # node 2 (C)
])  # shape: [3 nodes, 3 features]

edge_index = torch.tensor([
    [0, 1],  # source (A → B, B → C)
    [1, 2],  # target
])  # shape: [2, 2]

edge_attr = torch.tensor([
    [0.1, 0.2],  # edge 0: A→B
    [0.3, 0.4],  # edge 1: B→C
])  # shape: [2 edges, 2]

In [9]:
from torch_geometric.nn import GCNConv, NNConv
import torch.nn.functional as F

class BasicGCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        return x
    


class EdgeAwareGNN(nn.Module):
    def __init__(self, in_channels, edge_dim, hidden_channels, out_channels):
        super().__init__()
        self.edge_mlp1 = nn.Sequential(
            nn.Linear(edge_dim, in_channels * hidden_channels),
            nn.ReLU(),
            nn.Linear(in_channels * hidden_channels, in_channels * hidden_channels)
        )
        self.edge_mlp2 = nn.Sequential(
            nn.Linear(edge_dim, hidden_channels * out_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels * out_channels, hidden_channels * out_channels)
        )

        self.conv1 = NNConv(in_channels, hidden_channels, self.edge_mlp1, aggr='add')
        self.conv2 = NNConv(hidden_channels, out_channels, self.edge_mlp2, aggr='add')

    def forward(self, x, edge_index, edge_attr):
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = self.conv2(x, edge_index, edge_attr)
        return x
    
class WeightedGCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index, edge_weight):
        x = self.conv1(x, edge_index, edge_weight=edge_weight)
        x = torch.relu(x)
        x = self.conv2(x, edge_index, edge_weight=edge_weight)
        return x

Implement with MessagePassing

In [21]:
class BasicGCN_MP(MessagePassing):
    def __init__(self, in_channels=0, out_channels=0):
        super().__init__(aggr='add')
        self.lin = nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.lin(x)  # W * x
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        return x_j

    def update(self, aggr_out):
        return aggr_out

class EdgeAware_MP(MessagePassing):
    def __init__(self, in_channels=0, edge_dim=0, out_channels=0):
        super().__init__(aggr='add')
        self.mlp = nn.Sequential(
            nn.Linear(in_channels + edge_dim, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )

    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        # x_j: source node feature
        input = torch.cat([x_j, edge_attr], dim=1)
        return self.mlp(input)

    def update(self, aggr_out):
        return aggr_out
    
class WeightedGCN_MP(MessagePassing):
    def __init__(self, in_channels=0, out_channels=0):
        super().__init__(aggr='add')
        self.lin = nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index, edge_weight):
        x = self.lin(x)
        return self.propagate(edge_index, x=x, edge_weight=edge_weight)

    def message(self, x_j, edge_weight):
        return edge_weight.view(-1, 1) * x_j

    def update(self, aggr_out):
        return aggr_out

class TwoLayerGNN(nn.Module):
    def __init__(self, LayerClass, in_channels, hidden_channels, out_channels, **kwargs):
        super().__init__()
        self.layer1 = LayerClass(in_channels=in_channels, out_channels=hidden_channels, **kwargs)
        self.layer2 = LayerClass(in_channels=hidden_channels, out_channels=out_channels, **kwargs)

    def forward(self, x, edge_index, **kwargs):
        x = self.layer1(x, edge_index, **kwargs)
        x = torch.relu(x)
        x = self.layer2(x, edge_index, **kwargs)
        return x

In [22]:
IN_CHANNELS = 3
OUT_CHANNELS = 2
HIDDEN_CHANNELS = 8
EDGE_DIM = 2
edge_weight = torch.ones(edge_attr.size(0), dtype=torch.float)

# 1. basic
model1 = BasicGCN(in_channels=IN_CHANNELS, hidden_channels=HIDDEN_CHANNELS, out_channels=OUT_CHANNELS)
out1 = model1(x, edge_index)

# 2. edge-aware
model2 = EdgeAwareGNN(in_channels=IN_CHANNELS, edge_dim=EDGE_DIM, hidden_channels=HIDDEN_CHANNELS, out_channels=OUT_CHANNELS)
out2 = model2(x, edge_index, edge_attr)

# 3. weighted
model3 = WeightedGCN(in_channels=IN_CHANNELS, hidden_channels=HIDDEN_CHANNELS, out_channels=OUT_CHANNELS)
out3 = model3(x, edge_index, edge_weight)

#==============================================================================================
# 1. basic
MP_model1 = TwoLayerGNN(BasicGCN_MP, in_channels=IN_CHANNELS, hidden_channels=HIDDEN_CHANNELS, out_channels=OUT_CHANNELS)
                        
MP_out1 = model1(x, edge_index)

# 2. edge-aware
MP_model2 = TwoLayerGNN(EdgeAware_MP, in_channels=IN_CHANNELS, edge_dim=EDGE_DIM, hidden_channels=HIDDEN_CHANNELS, out_channels=OUT_CHANNELS)
MP_out2 = model2(x, edge_index, edge_attr)

# 3. weighted
MP_model3 = TwoLayerGNN(WeightedGCN_MP, in_channels=IN_CHANNELS, hidden_channels=HIDDEN_CHANNELS, out_channels=OUT_CHANNELS)
MP_out3 = model3(x, edge_index, edge_weight)

dprint({
    "x": x,
    "out1": out1,
    "MP_out1": MP_out1,
    "out2": out2,
    "MP_out2": MP_out2,
    "out3": out3,
    "MP_out3": MP_out3
})

x:
	tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])

out1:
	tensor([[1.2605, 1.4746],
        [2.0351, 2.4804],
        [2.9233, 3.8178]], grad_fn=<AddBackward0>)

MP_out1:
	tensor([[1.2605, 1.4746],
        [2.0351, 2.4804],
        [2.9233, 3.8178]], grad_fn=<AddBackward0>)

out2:
	tensor([[ 0.7438,  0.1126],
        [ 1.8267, -1.1452],
        [ 3.3220, -3.1283]], grad_fn=<AddBackward0>)

MP_out2:
	tensor([[ 0.7438,  0.1126],
        [ 1.8267, -1.1452],
        [ 3.3220, -3.1283]], grad_fn=<AddBackward0>)

out3:
	tensor([[0.4914, 0.7102],
        [0.7558, 0.9888],
        [0.9134, 1.0068]], grad_fn=<AddBackward0>)

MP_out3:
	tensor([[0.4914, 0.7102],
        [0.7558, 0.9888],
        [0.9134, 1.0068]], grad_fn=<AddBackward0>)

