In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing

In [3]:
class EdgeWeightedMPNN(MessagePassing):
    def __init__(self):
        super().__init__(aggr='add')

    def forward(self, x, edge_index, edge_attr):
        # x: [num_nodes, node_feat_dim]
        # edge_index: [2, num_edges]
        # edge_attr: [num_edges, 1]
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        # x_j: source node features
        # edge_attr: scalar weights
        return edge_attr * x_j

    def update(self, aggr_out, x):
        return F.relu(x + aggr_out)

In [4]:
"""
A -- B -- C

로 연결 되어있는 그래프

A <--> B = 0.5
B <--> C = 1.0

의 edge attr
"""

# 노드 features: A, B, C
x = torch.tensor([
    [1.0, 0.0],  # A
    [0.0, 1.0],  # B
    [1.0, 1.0],  # C
], dtype=torch.float)

# 엣지 리스트 (양방향 연결)
edge_index = torch.tensor([
    [0, 1, 1, 2],  # source
    [1, 0, 2, 1],  # target
], dtype=torch.long)

# 엣지 속성 (edge weight)
edge_attr = torch.tensor([
    [0.5],  # A → B
    [0.5],  # B → A
    [1.0],  # B → C
    [1.0],  # C → B
], dtype=torch.float)


In [5]:
mpnn = EdgeWeightedMPNN()

x1 = mpnn(x, edge_index, edge_attr)
x2 = mpnn(x1, edge_index, edge_attr)

print("Step 1:", x1)
print("Step 2:", x2)


Step 1: tensor([[1.0000, 0.5000],
        [1.5000, 2.0000],
        [1.0000, 2.0000]])
Step 2: tensor([[1.7500, 1.5000],
        [3.0000, 4.2500],
        [2.5000, 4.0000]])


Step-By-Step 손계산


In [16]:
def dprint(kwargs):
    for name, val in kwargs.items():
        print(f"{name}:\n\t{val}\n")

In [11]:
x = torch.tensor([
    [1.0, 2.0, 3.0],  # node 0 (A)
    [4.0, 5.0, 6.0],  # node 1 (B)
    [7.0, 8.0, 9.0],  # node 2 (C)
])  # shape: [3 nodes, 3 features]

edge_index = torch.tensor([
    [0, 1],  # source (A → B, B → C)
    [1, 2],  # target
])  # shape: [2, 2]

edge_attr = torch.tensor([
    [0.1, 0.2],  # edge 0: A→B
    [0.3, 0.4],  # edge 1: B→C
])  # shape: [2 edges, 2]

In [17]:
src = edge_index[0]  # [0, 1]
x_j = x[src]

dprint({
    "src": src,
    "x_j": x_j,
    "edge_attr": edge_attr
})

src:
	tensor([0, 1])

x_j:
	tensor([[1., 2., 3.],
        [4., 5., 6.]])

edge_attr:
	tensor([[0.1000, 0.2000],
        [0.3000, 0.4000]])



MPNN

In [18]:
input = torch.cat([x_j, edge_attr], dim=1)
dprint({
    "input": input
})

input:
	tensor([[1.0000, 2.0000, 3.0000, 0.1000, 0.2000],
        [4.0000, 5.0000, 6.0000, 0.3000, 0.4000]])



In [21]:
mlp = nn.Sequential(
    nn.Linear(5, 3, bias=False)
)
mlp[0].weight.data = torch.tensor([
    [1.0, 0.0, 0.0, 1.0, 0.0],  # output dim 0
    [0.0, 1.0, 0.0, 0.0, 1.0],  # output dim 1
    [0.0, 0.0, 1.0, 0.0, 0.0],  # output dim 2
])
# 수동 초기화

message = mlp(input)

dprint({
    "mlp weight": mlp[0].weight.data,
    "message": message
})


mlp weight:
	tensor([[1., 0., 0., 1., 0.],
        [0., 1., 0., 0., 1.],
        [0., 0., 1., 0., 0.]])

message:
	tensor([[1.1000, 2.2000, 3.0000],
        [4.3000, 5.4000, 6.0000]], grad_fn=<MmBackward0>)



AGG

In [25]:
dst = edge_index[1]

from torch_scatter import scatter_add
aggr_out = scatter_add(message, dst, dim=0, dim_size=3)


x_updated_manual = x + aggr_out


dprint({
    "dst": dst,
    "aggr_out": aggr_out,
    "x_updated": x_updated_manual
})

dst:
	tensor([1, 2])

aggr_out:
	tensor([[0.0000, 0.0000, 0.0000],
        [1.1000, 2.2000, 3.0000],
        [4.3000, 5.4000, 6.0000]], grad_fn=<ScatterAddBackward0>)

x_updated:
	tensor([[ 1.0000,  2.0000,  3.0000],
        [ 5.1000,  7.2000,  9.0000],
        [11.3000, 13.4000, 15.0000]], grad_fn=<AddBackward0>)



In [29]:
class MessagePassingLayerMLP(MessagePassing):
    def __init__(self):
        super().__init__(aggr='add')  # sum aggregation
        self.mlp = nn.Sequential(
            nn.Linear(5, 3, bias=False)  # 3 (x_j) + 2 (edge_attr)
        )

        self.mlp[0].weight.data = torch.tensor([
            [1.0, 0.0, 0.0, 1.0, 0.0],  # output dim 0
            [0.0, 1.0, 0.0, 0.0, 1.0],  # output dim 1
            [0.0, 0.0, 1.0, 0.0, 0.0],  # output dim 2
        ])

    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        # x_j: [num_edges, 3], edge_attr: [num_edges, 2]
        m_input = torch.cat([x_j, edge_attr], dim=1)  # [E, 5]
        return self.mlp(m_input)  # [E, 3]

    def update(self, aggr_out, x):
        return x + aggr_out

pyg_mpnn = MessagePassingLayerMLP()
x_updated_mpnn = pyg_mpnn(x, edge_index, edge_attr)

dprint({
    "x_updated_manual": x_updated_manual,
    "x_updated_mpnn": x_updated_mpnn
})

x_updated_manual:
	tensor([[ 1.0000,  2.0000,  3.0000],
        [ 5.1000,  7.2000,  9.0000],
        [11.3000, 13.4000, 15.0000]], grad_fn=<AddBackward0>)

x_updated_mpnn:
	tensor([[ 1.0000,  2.0000,  3.0000],
        [ 5.1000,  7.2000,  9.0000],
        [11.3000, 13.4000, 15.0000]], grad_fn=<AddBackward0>)



Undirected Graph?

In [30]:
from torch_geometric.utils import to_undirected

edge_index_undirected = to_undirected(edge_index)

dprint({
    "edge_index": edge_index,
    "edge_index_undirected": edge_index_undirected
})

edge_index:
	tensor([[0, 1],
        [1, 2]])

edge_index_undirected:
	tensor([[0, 1, 1, 2],
        [1, 0, 2, 1]])

