# Pytorch Geometric Workflow...

- `__check_input__(**kwargs)`: check **SparseTensor** or not
- `__collect__(**kwargs)`: Contruct the message of **node i**, for every node in the graph
- `message(**kwargs)`: construct the **node i**'s message
- `aggregate(**kwargs)`: message aggregation: max, min, mean, add
- `update(**kwargs)`: update the features with aggregated message + current one

In [1]:
import torch
import torch.nn.functional as F
from torch.nn import Linear, Parameter
from torch.nn.init import xavier_uniform_, zeros_
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, remove_self_loops, degree, softmax

In [2]:
num_nodes = 4 # number of nodes in the graph
embed_size = 5 # initial node features

# Creating node features
node_feat = torch.rand((num_nodes, embed_size), dtype=torch.float)
x = node_feat
print('Node features\n', node_feat, node_feat.shape)

# Creating COO format edge_indexes for the graph
src_index = torch.tensor([0,0,0,1,1,2,3,3], dtype=torch.long)
target_index = torch.tensor([1,2,3,3,0,0,0,2], dtype=torch.long)
edge_index = torch.zeros(size=(2, src_index.size()[0]), dtype=torch.int64)
edge_index[0] = src_index
edge_index[1] = target_index
print('\nedge_index\n', edge_index, edge_index.shape)

# Creating edge features
x_edge = torch.rand((edge_index.shape[1], 5), dtype=torch.float) # (num_edges, embed_size)
print('\nEdge features\n', x_edge, x_edge.shape)

Node features
 tensor([[0.4753, 0.7336, 0.3778, 0.8182, 0.1771],
        [0.0397, 0.9440, 0.5285, 0.9309, 0.2452],
        [0.2214, 0.1756, 0.7032, 0.9054, 0.5184],
        [0.2881, 0.5517, 0.6869, 0.7084, 0.4849]]) torch.Size([4, 5])

edge_index
 tensor([[0, 0, 0, 1, 1, 2, 3, 3],
        [1, 2, 3, 3, 0, 0, 0, 2]]) torch.Size([2, 8])

Edge features
 tensor([[0.5982, 0.9028, 0.9854, 0.0744, 0.2471],
        [0.8321, 0.6367, 0.5588, 0.7835, 0.9423],
        [0.5290, 0.4307, 0.8699, 0.4362, 0.2070],
        [0.6127, 0.0396, 0.9917, 0.9920, 0.9192],
        [0.8113, 0.2669, 0.9257, 0.5417, 0.9243],
        [0.0989, 0.0144, 0.2029, 0.5586, 0.4973],
        [0.6190, 0.5799, 0.8756, 0.0480, 0.0155],
        [0.7155, 0.2190, 0.7044, 0.6091, 0.3158]]) torch.Size([8, 5])


# GAT

In [3]:
from torch_geometric.nn import GATv2Conv

gat = GATv2Conv(5, 100)
res1, res2 = gat.forward(x, edge_index, return_attention_weights=True)
print(res1[0].shape, res2[0].shape)

torch.Size([100]) torch.Size([2, 12])


In [4]:
class GAT(MessagePassing):
    def __init__(
        self,
        in_channels, 
        out_channels,
        heads=1,
        concat=True,
        negative_slope=0.2,
        dropout=0.0,
        add_self_loops=True,
        bias=True,
        share_weights=False,
        **kwargs,
    ):
        super().__init__(node_dim=0, aggr='add' , **kwargs) # defines the aggregation method: `aggr='add'`
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.add_self_loops = add_self_loops
        self.share_weights = share_weights
        
        # Linear Transformation
        self.lin_l = Linear(in_channels, heads * out_channels, bias=bias,)
        
        if share_weights:
            self.lin_r = self.lin_l # use same matrix
        else:
            self.lin_r = Linear(in_channels, heads * out_channels, bias=bias)
            
        # For attention calculation
        self.att = Parameter(torch.Tensor(1, heads, out_channels))
        
        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
            
        self._alpha = None

        xavier_uniform_(self.att)
        zeros_(self.bias)        
        
    def forward(self, x, edge_index, return_attention_weights=None):
        ## N - no_of_nodes, NH - no_of heads,  H_in - input_channels, H_out - out_channels
        
        H, C = self.heads, self.out_channels
        
        x_l = None # for source nodes
        x_r = None # for target nodes
        
        x_l = self.lin_l(x).view(-1, H, C) # (N, H_in) -> (N, NH, H_Out)
        if self.share_weights:
            x_r = x_l
        else:
            x_r = self.lin_r(x).view(-1, H, C)
            
        assert x_l is not None
        assert x_r is not None
        
        if self.add_self_loops: # Adding self-loops for the graph...
            if torch.is_tensor(edge_index):
                num_nodes = x_l.size(0)
                if x_r is not None:
                    num_nodes = min(num_nodes, x_r.size(0))
                edge_index, _ = remove_self_loops(edge_index)
                edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
                
        # Start propagating info...: contruct message -> aggreate message -> update/obtain new representations
        out = self.propagate(edge_index, x=(x_l, x_r), size=None) # (N, NH, H_out)
        
        alpha = self._alpha # (#edges, NH)
        assert alpha is not None
        self._alpha = None
        
        if self.concat: # (N, NH, H_out) -> (N, NH * H_out)
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1) # (N, NH, H_out) -> (N, H_out)
        
        if self.bias is not None:
            out = out + self.bias
        
            
        # Returning attetion weights with computed hidden features    
        if isinstance(return_attention_weights, bool):
            if torch.is_tensor(edge_index):
                return out, (edge_index, alpha) # Depends on 'concat', ((2, #edges), (#edges, NH))
        else:   
            return out # Depends on 'concat': (N, NH * H_out) || (N, H_out)
    
    def message(self, x_j, x_i, index, size_i):
        # x_j has shape [#edges, NH, H_out]
        # x_i has shape [#edges, NH, H_out]
        # index: target node indexes, where data flows 'source_to_target': this is for computing softmax
        # size: size_i, size_j mean num_nodes in the graph
        
        x = x_i + x_j # adding(element-wise) source and target node features together to calculate atttention
        x = F.leaky_relu(x, self.negative_slope)
        alpha = (x * self.att).sum(dim=-1)
        alpha = softmax(alpha, index, num_nodes=size_i) # spares softmax: groups node's attention and then node-wise softmax
        self._alpha = alpha # (#edges, NH)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training) # randomly dropping attention during training
        
        return x_j * alpha.unsqueeze(-1) # (#edges, NH, H_out)
                

In [5]:
gat = GAT(embed_size, 10, heads=8, concat=False, share_weights=True)
gat.forward(x, edge_index, return_attention_weights=True)

(tensor([[ 0.1608, -0.1138, -0.3489,  0.0903, -0.0670, -0.1132, -0.0971,  0.1186,
           0.0760,  0.1189],
         [ 0.1312, -0.0939, -0.3707,  0.0681, -0.0256, -0.1286, -0.0869,  0.1697,
           0.0737,  0.1206],
         [ 0.1678, -0.1403, -0.3451,  0.0870, -0.0806, -0.0964, -0.1052,  0.0977,
           0.0813,  0.1112],
         [ 0.1531, -0.1009, -0.3622,  0.0824, -0.0441, -0.1213, -0.0989,  0.1488,
           0.0703,  0.1154]], grad_fn=<AddBackward0>),
 (tensor([[0, 0, 0, 1, 1, 2, 3, 3, 0, 1, 2, 3],
          [1, 2, 3, 3, 0, 0, 0, 2, 0, 1, 2, 3]]),
  tensor([[0.5167, 0.5093, 0.5177, 0.4917, 0.4983, 0.4948, 0.5194, 0.4951],
          [0.3298, 0.3368, 0.3535, 0.3341, 0.3303, 0.3526, 0.3467, 0.3229],
          [0.3392, 0.3388, 0.3487, 0.3305, 0.3318, 0.3375, 0.3470, 0.3269],
          [0.3178, 0.3293, 0.3293, 0.3375, 0.3312, 0.3504, 0.3237, 0.3322],
          [0.2365, 0.2473, 0.2593, 0.2583, 0.2515, 0.2706, 0.2408, 0.2488],
          [0.2568, 0.2492, 0.2154, 0.2435, 0.2473, 0

## Exp

In [6]:
class GAT(MessagePassing):
    def __init__(
        self,
        in_channels, 
        out_channels,
        heads=1,
        concat=True,
        negative_slope=0.2,
        dropout=0.0,
        add_self_loops=True,
        bias=True,
        share_weights=False,
        **kwargs,
    ):
        super().__init__(node_dim=0, aggr='add' , **kwargs) # defines the aggregation method: `aggr='add'`
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.add_self_loops = add_self_loops
        self.share_weights = share_weights
        
        # Linear Transformation
        self.lin_l = Linear(in_channels, heads * out_channels, bias=bias,)
        
        if share_weights:
            self.lin_r = self.lin_l # use same matrix
        else:
            self.lin_r = Linear(in_channels, heads * out_channels, bias=bias)
            
        # For attention calculation
        self.att = Parameter(torch.Tensor(1, heads, out_channels * 2))
        
        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
            
        self._alpha = None

        xavier_uniform_(self.att)
        zeros_(self.bias)        
        
    def forward(self, x, edge_index, return_attention_weights=None):
        ## N - no_of_nodes, NH - no_of heads,  H_in - input_channels, H_out - out_channels
        
        H, C = self.heads, self.out_channels
        
        x_src = None # for source nodes
        x_dst = None # for target nodes
        
        x_src = self.lin_l(x).view(-1, H, C) # (N, H_in) -> (N, NH, H_Out)
        if self.share_weights:
            x_dst = x_src
        else:
            x_dst = self.lin_r(x).view(-1, H, C)
            
        assert x_src is not None
        assert x_dst is not None
        
        x = (x_src, x_dst)
        a_src = x_src
        a_dst = x_dst
        a = (a_src, a_dst)
        
        if self.add_self_loops: # Adding self-loops for the graph...
            if torch.is_tensor(edge_index):
                num_nodes = x_src.size(0)
                if x_dst is not None:
                    num_nodes = min(num_nodes, x_dst.size(0))
                edge_index, _ = remove_self_loops(edge_index)
                edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
                
        # Start propagating info...: construct message -> aggregate message -> update/obtain new representations
        out = self.propagate(edge_index, x=x, a=a, size=None) # (N, NH, H_out)
        
        alpha = self._alpha # (#edges, NH)
        assert alpha is not None
        self._alpha = None
        
        if self.concat: # (N, NH, H_out) -> (N, NH * H_out)
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1) # (N, NH, H_out) -> (N, H_out)
        
        if self.bias is not None:
            out = out + self.bias
        
            
        # Returning attention weights with computed hidden features
        if isinstance(return_attention_weights, bool):
            if torch.is_tensor(edge_index):
                return out, (edge_index, alpha) # Depends on 'concat', ((2, #edges), (#edges, NH))
        else:   
            return out # Depends on 'concat': (N, NH * H_out) || (N, H_out)
    
    def message(self, x_j, x_i, a_i, a_j, index, size_i):
        # x_j has shape [#edges, NH, H_out]
        # x_i has shape [#edges, NH, H_out]
        # index: target node indexes, where data flows 'source_to_target': this is for computing softmax
        # size: size_i, size_j mean num_nodes in the graph
        # print(a_i.shape, a_j.shape)
        
        # x = x_i + x_j # adding(element-wise) source and target node features together to calculate attention
        x = torch.cat([x_i, x_j], dim=-1)
        # print(x.shape)
        # print(self.att.shape)
        x = F.leaky_relu(x, self.negative_slope)
        alpha = (x * self.att).sum(dim=-1)
        # print(alpha.shape)
        alpha = softmax(alpha, index, num_nodes=size_i) # spares softmax: groups node's attention and then node-wise softmax
        self._alpha = alpha # (#edges, NH)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training) # randomly dropping attention during training
        
        return x_j * alpha.unsqueeze(-1) # (#edges, NH, H_out)
                

In [7]:
gat = GAT(embed_size, 10, heads=8, concat=False, share_weights=True)
res = gat.forward(x, edge_index, return_attention_weights=True)
out, edge_index_returned_with_attention_weights = res[0], res[1]
edge_index_returned, attention_weights = edge_index_returned_with_attention_weights
print(out.shape)
print(attention_weights.shape)
print(edge_index_returned.shape)


torch.Size([4, 10])
torch.Size([12, 8])
torch.Size([2, 12])


# My Arch 1

In [8]:
class EIN(MessagePassing):
    def __init__(
        self,
        in_channels,
        out_channels,
        heads=1,
        concat=True,
        negative_slope=0.2,
        dropout=0.0,
        add_self_loops=True,
        edge_dim=None,
        fill_value='mean',
        bias=True,
        share_weights=False,
        **kwargs,
    ):
        super().__init__(node_dim=0, aggr='add' , **kwargs) # defines the aggregation method: `aggr='add'`

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.add_self_loops = add_self_loops
        self.share_weights = share_weights
        self.edge_dim = edge_dim
        self.fill_value = fill_value

        # Linear Transformation
        self.lin_l = Linear(in_channels, heads * out_channels, bias=bias)

        if share_weights:
            self.lin_r = self.lin_l # use same matrix
        else:
            self.lin_r = Linear(in_channels, heads * out_channels, bias=bias)

        # For attention calculation
        self.att = Parameter(torch.Tensor(1, heads, out_channels))

        # For influence mechanism
        # self.inf = Linear(edge_dim, out_channels)

        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self._alpha = None

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()
        xavier_uniform_(self.att)
        zeros_(self.bias)

    def forward(self, x, edge_index, edge_attr=None, return_attention_weights=None):
        ## N - no_of_nodes, NH - no_of heads,  H_in - input_channels, H_out - out_channels

        H, C = self.heads, self.out_channels

        x_l = None # for source nodes
        x_r = None # for target nodes

        x_l = self.lin_l(x).view(-1, H, C) # (N, H_in) -> (N, NH, H_Out)
        if self.share_weights:
            x_r = x_l
        else:
            x_r = self.lin_r(x).view(-1, H, C)

        assert x_l is not None
        assert x_r is not None

        if self.add_self_loops: # Adding self-loops for the graph...
            if torch.is_tensor(edge_index):
                if edge_attr is not None: # edge_attr is available -> add edge_attr for newly created self-loops: default `mean`
                    num_nodes = x_l.size(0)
                    if x_r is not None:
                        num_nodes = min(num_nodes, x_r.size(0))
                    edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
                    edge_index, edge_attr = add_self_loops(edge_index, edge_attr, fill_value=self.fill_value, num_nodes=num_nodes)
                else:
                    num_nodes = x_l.size(0)
                    if x_r is not None:
                        num_nodes = min(num_nodes, x_r.size(0))
                    edge_index, _ = remove_self_loops(edge_index)
                    edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)

        # Check the edge features shape: test_case
        if edge_attr is not None:
            print(f'edge_features shape: {edge_attr.shape}')
        else:
            print('No edge features!')

        # Start propagating info...: construct message -> aggregate message -> update/obtain new representations
        out = self.propagate(edge_index, x=(x_l, x_r), edge_attr=edge_attr, size=None) # (N, NH, H_out)

        alpha = self._alpha # (#edges, NH)
        assert alpha is not None
        self._alpha = None

        if self.concat: # (N, NH, H_out) -> (N, NH * H_out)
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1) # (N, NH, H_out) -> (N, H_out)

        if self.bias is not None:
            out = out + self.bias

        # Influence mechanism
        res = alpha.mean(dim=1, keepdims=True) * edge_attr
        print('inf: ', res.shape)

        # Returning attention weights with computed hidden features
        if isinstance(return_attention_weights, bool):
            if torch.is_tensor(edge_index):
                if self.concat:
                    return out, (edge_index, alpha) # Depends on 'concat', ((2, #edges), (#edges, NH))
                return out, (edge_index, alpha.mean(dim=1, keepdims=True))
        else:
            return out # Depends on 'concat': (N, NH * H_out) || (N, H_out)

    def message(self, x_j, x_i, index, size_i):
        # x_j has shape [#edges, NH, H_out]
        # x_i has shape [#edges, NH, H_out]
        # index: target node indexes, where data flows 'source_to_target': this is for computing softmax
        # size: size_i, size_j mean num_nodes in the graph

        x = x_i + x_j # adding(element-wise) source and target node features together to calculate attention
        x = F.leaky_relu(x, self.negative_slope)
        alpha = (x * self.att).sum(dim=-1) # (#edges, NH)
        alpha = softmax(alpha, index, num_nodes=size_i) # spares softmax: groups node's attention and then node-wise softmax
        self._alpha = alpha # (#edges, NH)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training) # randomly dropping attention during training

        return x_j * alpha.unsqueeze(-1) # (#edges, NH, H_out)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, heads={self.heads})')


In [9]:
ein = EIN(embed_size, 10, heads=8, concat=False,)
res = ein.forward(x, edge_index, edge_attr=x_edge,return_attention_weights=True)
out, r = res
# print(r[1].shape)
print(out.shape)
print(r[1])
print(r[0])

edge_features shape: torch.Size([12, 5])
inf:  torch.Size([12, 5])
torch.Size([4, 10])
tensor([[0.4930],
        [0.3335],
        [0.3300],
        [0.3398],
        [0.2557],
        [0.2482],
        [0.2483],
        [0.3332],
        [0.2477],
        [0.5070],
        [0.3332],
        [0.3302]], grad_fn=<MeanBackward1>)
tensor([[0, 0, 0, 1, 1, 2, 3, 3, 0, 1, 2, 3],
        [1, 2, 3, 3, 0, 0, 0, 2, 0, 1, 2, 3]])


In [10]:
x_edge.shape
r[1][:8].shape
# r[1][:8].reshape(8, 1) * x_edge
temp = r[1][:8] * x_edge
temp.shape

torch.Size([8, 5])

In [11]:
from torch_scatter import gather_csr, scatter, segment_csr


In [21]:
class EIN(MessagePassing):
    """
    A Edge featured attention based Graph Neural Network Layer for Graph Classification / Regression Tasks
    """
    def __init__(
        self,
        in_channels,
        out_channels,
        heads=1,
        negative_slope=0.2,
        dropout=0.0,
        edge_dim=None,
        train_eps = False,
        eps = 0.0,
        bias=True,
        share_weights=False,
        **kwargs,
    ):
        super().__init__(node_dim=0, aggr='add' , **kwargs) # defines the aggregation method: `aggr='add'`

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.share_weights = share_weights
        self.edge_dim = edge_dim
        self.initial_eps = eps

        # Linear Transformation
        self.lin_l = Linear(in_channels, heads * out_channels, bias=bias)

        if share_weights:
            self.lin_r = self.lin_l # use same matrix
        else:
            self.lin_r = Linear(in_channels, heads * out_channels, bias=bias)

        # For attention calculation
        self.att = Parameter(torch.Tensor(1, heads, out_channels))

        # For influence mechanism
        self.inf = Linear(edge_dim, out_channels)

        # Tunable parameter for adding self node features...
        if train_eps:
            self.eps = torch.nn.Parameter(torch.Tensor([eps]))
        else:
            self.register_buffer('eps', torch.Tensor([eps]))

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self._alpha = None # alpha weights

        self.reset_parameters()


    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()
        self.inf.reset_parameters()
        self.eps.data.fill_(self.initial_eps)
        xavier_uniform_(self.att)
        zeros_(self.bias)


    def forward(self, x, edge_index, edge_attr=None, return_attention_weights=None):
        ## N - no_of_nodes, NH - no_of heads,  H_in - input_channels, H_out - out_channels

        H, C = self.heads, self.out_channels

        x_l = None # for source nodes
        x_r = None # for target nodes

        x_l = self.lin_l(x).view(-1, H, C) # (N, H_in) -> (N, NH, H_Out)
        if self.share_weights:
            x_r = x_l
        else:
            x_r = self.lin_r(x).view(-1, H, C)

        assert x_l is not None
        assert x_r is not None

        # Check the edge features shape: test_case
        if edge_attr is not None:
            print(f'edge_features shape: {edge_attr.shape}')
        else:
            print('No edge features!')

        # Start propagating info...: construct message -> aggregate message -> update/obtain new representations
        out = self.propagate(edge_index, x=(x_l, x_r), edge_attr=edge_attr, size=None) # (N, H_out)
        # out += x_r.mean(dim=1) # add the self features

        alpha = self._alpha # (#edges, 1)
        assert alpha is not None, 'Alpha weights can not be None value!'

        if self.bias is not None:
            out = out + self.bias

        # Returning attention weights with computed hidden features
        if isinstance(return_attention_weights, bool):
                return out, alpha.mean(dim=1, keepdims=True)
        else:
            return out # (N, H_out)


    def message(self, x_j, x_i, index, size_i, edge_attr):
        # x_j has shape [#edges, NH, H_out]
        # x_i has shape [#edges, NH, H_out]
        # index: target node indexes, where data flows 'source_to_target': this is for computing softmax
        # size: size_i, size_j mean num_nodes in the graph

        x = x_i + x_j # adding(element-wise) source and target node features together to calculate attention
        x = F.leaky_relu(x, self.negative_slope)
        alpha = (x * self.att).sum(dim=-1) # (#edges, NH)
        alpha = softmax(alpha, index, num_nodes=size_i) # spares softmax: groups node's attention and then node-wise softmax
        self._alpha = alpha.mean(dim=1, keepdims=True) # (#edges, 1)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training) # randomly dropping attention during training
        node_out = (x_j * alpha.unsqueeze(-1)).mean(dim=1)


        if self.inf is not None and edge_attr is not None:
            if self.edge_dim != edge_attr.size(-1):
                raise ValueError("Node and edge feature dimensionalities do not "
                                "match. Consider setting the 'edge_dim' ""attribute")
            edge_attr = self.inf(self._alpha * edge_attr) # transformed edge features via influence mechanism
            return node_out + edge_attr  # (#edges, H_out)
        return node_out # (#edges, H_out)


    def update(self, aggr_out, x):
        aggr_out += (1 + self.eps) * x[1].mean(dim=1) # add the self features with a weighting factor
        return aggr_out # (N, H_out)


    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, heads={self.heads})')


In [24]:
ein = EIN(embed_size, 32, heads=10, edge_dim=5, train_eps=True)
res = ein.forward(x, edge_index, edge_attr=x_edge, return_attention_weights=True)
print('Returned results:')
print(res[0].shape, res[1].shape)
print('==========================================================================')

ein2 = EIN(32, 64, heads=10, edge_dim=5, train_eps=True)
res2 = ein2.forward(res[0], edge_index, edge_attr=x_edge)
print('Returned results:')
print(res2.shape)


edge_features shape: torch.Size([8, 5])
Returned results:
torch.Size([4, 32]) torch.Size([8, 1])
edge_features shape: torch.Size([8, 5])
Returned results:
torch.Size([4, 64])


### How aggregation Work
- out: source node messages, ex: [8, 32] -> [#edges, H_out]
- index: target nodes indexes
- dim: along the specified dimension, in this case dim = 0 (row-wise)

After that the result of this function will be sent to the `update` function!

    def aggregate(self, out, index):
        return scatter(out, index, dim=0, reduce='mean')