# MessagePassing in PyTorch Geometric

[Reference](https://zqfang.github.io/2021-08-07-graph-pyg/)

In [27]:
import torch
from torch_scatter import scatter_add

num_nodes = 4
embed_size = 5

src = torch.randint(0, num_nodes, (num_nodes, embed_size)) # (num_nodes, embed_size)
src_index = torch.tensor([0,0,0,1,1,2,3,3]) # edges
tmp = torch.index_select(src, 0, src_index) # shape [num_edges, embed_size ]
print("input: ")
print(tmp)

target_index = torch.tensor([1,2,3,3,0,0,0,2])
aggr = scatter_add(tmp, target_index, 0) # shape [num_nodes, embed_size] 
print("agg out:")
print(aggr)

# behind the sence, torch.scatter_add is used
# repeat the edge_index
index2 = target_index.expand((embed_size, target_index.size(0))).T
# same result by using torch.scatter_add
aggr2 = torch.zeros(num_nodes, embed_size, dtype=tmp.dtype).scatter_add(0, index2, tmp)


input: 
tensor([[1, 1, 0, 1, 2],
        [1, 1, 0, 1, 2],
        [1, 1, 0, 1, 2],
        [1, 2, 1, 2, 3],
        [1, 2, 1, 2, 3],
        [2, 3, 3, 0, 1],
        [3, 3, 1, 3, 2],
        [3, 3, 1, 3, 2]])
agg out:
tensor([[6, 8, 5, 5, 6],
        [1, 1, 0, 1, 2],
        [4, 4, 1, 4, 4],
        [2, 3, 1, 3, 5]])


In [28]:
print('src\n',src) # source nodes: (num_nodes, embed_size)
print('src_index\n',src_index) # Indexes of the source nodes
res = torch.index_select(src, 0, src_index) # Selected source nodes as in edge format: (num_edges, embed_size)
print('res\n',res)

src
 tensor([[1, 1, 0, 1, 2],
        [1, 2, 1, 2, 3],
        [2, 3, 3, 0, 1],
        [3, 3, 1, 3, 2]])
src_index
 tensor([0, 0, 0, 1, 1, 2, 3, 3])
res
 tensor([[1, 1, 0, 1, 2],
        [1, 1, 0, 1, 2],
        [1, 1, 0, 1, 2],
        [1, 2, 1, 2, 3],
        [1, 2, 1, 2, 3],
        [2, 3, 3, 0, 1],
        [3, 3, 1, 3, 2],
        [3, 3, 1, 3, 2]])


Select the vectors indexed from the src matrix according to the index vector: src_index

In [29]:
print('target node indexes',target_index)
aggreated_node_feats = scatter_add(res, target_index, 0) # shape [num_nodes, embed_size] 
print('aggreated node features\n', aggreated_node_feats)

target node indexes tensor([1, 2, 3, 3, 0, 0, 0, 2])
aggreated node features
 tensor([[6, 8, 5, 5, 6],
        [1, 1, 0, 1, 2],
        [4, 4, 1, 4, 4],
        [2, 3, 1, 3, 5]])


In [36]:
target_index.expand((embed_size, target_index.size(0))).T

tensor([[1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3],
        [3, 3, 3, 3, 3],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [2, 2, 2, 2, 2]])

In [38]:
torch.zeros(num_nodes, embed_size, dtype=tmp.dtype).scatter_add(0, index2, tmp)

tensor([[6, 8, 5, 5, 6],
        [1, 1, 0, 1, 2],
        [4, 4, 1, 4, 4],
        [2, 3, 1, 3, 5]])

In [37]:
aggr == aggr2

tensor([[True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True]])