# Pytorch Geometric Workflow...

- `__check_input__(**kwargs)`: check **SparseTensor** or not
- `__collect__(**kwargs)`: Contruct the message of **node i**, for every node in the graph
- `message(**kwargs)`: construct the **node i**'s message
- `aggregate(**kwargs)`: message aggregation: max, min, mean, add
- `update(**kwargs)`: update the features with aggregated message + current one

In [1]:
import torch
import torch.nn.functional as F
from torch.nn import Linear, Parameter
from torch.nn.init import xavier_uniform_, zeros_
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, remove_self_loops, degree, softmax

In [2]:
num_nodes = 4 # number of nodes in the graph
embed_size = 5 # initial node features

# Creating node features
node_feat = torch.rand((num_nodes, embed_size), dtype=torch.float)
x = node_feat
print('Node features\n', node_feat, node_feat.shape)

# Creating COO format edge_indexes for the graph
src_index = torch.tensor([0,0,0,1,1,2,3,3], dtype=torch.long)
target_index = torch.tensor([1,2,3,3,0,0,0,2], dtype=torch.long)
edge_index = torch.zeros(size=(2, src_index.size()[0]), dtype=torch.int64)
edge_index[0] = src_index
edge_index[1] = target_index
print('\nedge_index\n', edge_index, edge_index.shape)

# Creating edge features
x_edge = torch.rand((edge_index.shape[1], 5), dtype=torch.float) # (num_edges, embed_size)
print('\nEdge features\n', x_edge, x_edge.shape)

Node features
 tensor([[0.5345, 0.0603, 0.5653, 0.4170, 0.9076],
        [0.2998, 0.2692, 0.2079, 0.0464, 0.2558],
        [0.3123, 0.9619, 0.8403, 0.2705, 0.5712],
        [0.2150, 0.3433, 0.1995, 0.6572, 0.4067]]) torch.Size([4, 5])

edge_index
 tensor([[0, 0, 0, 1, 1, 2, 3, 3],
        [1, 2, 3, 3, 0, 0, 0, 2]]) torch.Size([2, 8])

Edge features
 tensor([[0.1068, 0.6574, 0.3379, 0.0812, 0.0539],
        [0.5815, 0.1062, 0.6312, 0.4412, 0.5107],
        [0.6821, 0.3295, 0.1306, 0.7818, 0.6742],
        [0.1610, 0.1011, 0.7122, 0.3853, 0.9222],
        [0.9187, 0.0547, 0.3643, 0.3242, 0.5840],
        [0.9987, 0.3162, 0.3334, 0.5726, 0.5883],
        [0.6838, 0.8965, 0.4611, 0.0078, 0.7463],
        [0.8724, 0.6355, 0.2915, 0.5650, 0.4002]]) torch.Size([8, 5])


# GAT

In [3]:
from torch_geometric.nn import GATv2Conv

gat = GATv2Conv(5, 100)
res1, res2 = gat.forward(x, edge_index, return_attention_weights=True)
print(res1[0].shape, res2[0].shape)

torch.Size([100]) torch.Size([2, 12])


In [4]:
class GAT(MessagePassing):
    def __init__(
        self,
        in_channels, 
        out_channels,
        heads=1,
        concat=True,
        negative_slope=0.2,
        dropout=0.0,
        add_self_loops=True,
        bias=True,
        share_weights=False,
        **kwargs,
    ):
        super().__init__(node_dim=0, aggr='add' , **kwargs) # defines the aggregation method: `aggr='add'`
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.add_self_loops = add_self_loops
        self.share_weights = share_weights
        
        # Linear Transformation
        self.lin_l = Linear(in_channels, heads * out_channels, bias=bias,)
        
        if share_weights:
            self.lin_r = self.lin_l # use same matrix
        else:
            self.lin_r = Linear(in_channels, heads * out_channels, bias=bias)
            
        # For attention calculation
        self.att = Parameter(torch.Tensor(1, heads, out_channels))
        
        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
            
        self._alpha = None

        xavier_uniform_(self.att)
        zeros_(self.bias)        
        
    def forward(self, x, edge_index, return_attention_weights=None):
        ## N - no_of_nodes, NH - no_of heads,  H_in - input_channels, H_out - out_channels
        
        H, C = self.heads, self.out_channels
        
        x_l = None # for source nodes
        x_r = None # for target nodes
        
        x_l = self.lin_l(x).view(-1, H, C) # (N, H_in) -> (N, NH, H_Out)
        if self.share_weights:
            x_r = x_l
        else:
            x_r = self.lin_r(x).view(-1, H, C)
            
        assert x_l is not None
        assert x_r is not None
        
        if self.add_self_loops: # Adding self-loops for the graph...
            if torch.is_tensor(edge_index):
                num_nodes = x_l.size(0)
                if x_r is not None:
                    num_nodes = min(num_nodes, x_r.size(0))
                edge_index, _ = remove_self_loops(edge_index)
                edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
                
        # Start propagating info...: contruct message -> aggreate message -> update/obtain new representations
        out = self.propagate(edge_index, x=(x_l, x_r), size=None) # (N, NH, H_out)
        
        alpha = self._alpha # (#edges, NH)
        assert alpha is not None
        self._alpha = None
        
        if self.concat: # (N, NH, H_out) -> (N, NH * H_out)
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1) # (N, NH, H_out) -> (N, H_out)
        
        if self.bias is not None:
            out = out + self.bias
        
            
        # Returning attetion weights with computed hidden features    
        if isinstance(return_attention_weights, bool):
            if torch.is_tensor(edge_index):
                return out, (edge_index, alpha) # Depends on 'concat', ((2, #edges), (#edges, NH))
        else:   
            return out # Depends on 'concat': (N, NH * H_out) || (N, H_out)
    
    def message(self, x_j, x_i, index, size_i):
        # x_j has shape [#edges, NH, H_out]
        # x_i has shape [#edges, NH, H_out]
        # index: target node indexes, where data flows 'source_to_target': this is for computing softmax
        # size: size_i, size_j mean num_nodes in the graph
        
        x = x_i + x_j # adding(element-wise) source and target node features together to calculate atttention
        x = F.leaky_relu(x, self.negative_slope)
        alpha = (x * self.att).sum(dim=-1)
        alpha = softmax(alpha, index, num_nodes=size_i) # spares softmax: groups node's attention and then node-wise softmax
        self._alpha = alpha # (#edges, NH)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training) # randomly dropping attention during training
        
        return x_j * alpha.unsqueeze(-1) # (#edges, NH, H_out)
                

In [5]:
gat = GAT(embed_size, 10, heads=8, concat=False, share_weights=True)
gat.forward(x, edge_index, return_attention_weights=True)

(tensor([[ 0.0209,  0.0724, -0.1171, -0.0533,  0.0359,  0.0240, -0.0938, -0.0585,
          -0.1673, -0.0335],
         [ 0.0226,  0.0430, -0.1237, -0.0295,  0.0446,  0.0336, -0.0909, -0.0249,
          -0.1550, -0.0169],
         [ 0.0206,  0.0907, -0.1357, -0.0814,  0.0266,  0.0107, -0.1131, -0.0627,
          -0.1599, -0.0364],
         [ 0.0341,  0.0587, -0.0838, -0.0190,  0.0475,  0.0290, -0.0979, -0.0074,
          -0.1709, -0.0145]], grad_fn=<AddBackward0>),
 (tensor([[0, 0, 0, 1, 1, 2, 3, 3, 0, 1, 2, 3],
          [1, 2, 3, 3, 0, 0, 0, 2, 0, 1, 2, 3]]),
  tensor([[0.4209, 0.4899, 0.4841, 0.5050, 0.5091, 0.4941, 0.5088, 0.5234],
          [0.3030, 0.3273, 0.3238, 0.3514, 0.3509, 0.3355, 0.3368, 0.3343],
          [0.2909, 0.3182, 0.3288, 0.3447, 0.3433, 0.3317, 0.3300, 0.3444],
          [0.3809, 0.3313, 0.3401, 0.3379, 0.3270, 0.3488, 0.3223, 0.3196],
          [0.2857, 0.2542, 0.2541, 0.2562, 0.2487, 0.2557, 0.2443, 0.2304],
          [0.2351, 0.2329, 0.2431, 0.2419, 0.2374, 0

## Exp

In [6]:
class GAT(MessagePassing):
    def __init__(
        self,
        in_channels, 
        out_channels,
        heads=1,
        concat=True,
        negative_slope=0.2,
        dropout=0.0,
        add_self_loops=True,
        bias=True,
        share_weights=False,
        **kwargs,
    ):
        super().__init__(node_dim=0, aggr='add' , **kwargs) # defines the aggregation method: `aggr='add'`
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.add_self_loops = add_self_loops
        self.share_weights = share_weights
        
        # Linear Transformation
        self.lin_l = Linear(in_channels, heads * out_channels, bias=bias,)
        
        if share_weights:
            self.lin_r = self.lin_l # use same matrix
        else:
            self.lin_r = Linear(in_channels, heads * out_channels, bias=bias)
            
        # For attention calculation
        self.att = Parameter(torch.Tensor(1, heads, out_channels * 2))
        
        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
            
        self._alpha = None

        xavier_uniform_(self.att)
        zeros_(self.bias)        
        
    def forward(self, x, edge_index, return_attention_weights=None):
        ## N - no_of_nodes, NH - no_of heads,  H_in - input_channels, H_out - out_channels
        
        H, C = self.heads, self.out_channels
        
        x_src = None # for source nodes
        x_dst = None # for target nodes
        
        x_src = self.lin_l(x).view(-1, H, C) # (N, H_in) -> (N, NH, H_Out)
        if self.share_weights:
            x_dst = x_src
        else:
            x_dst = self.lin_r(x).view(-1, H, C)
            
        assert x_src is not None
        assert x_dst is not None
        
        x = (x_src, x_dst)
        a_src = x_src
        a_dst = x_dst
        a = (a_src, a_dst)
        
        if self.add_self_loops: # Adding self-loops for the graph...
            if torch.is_tensor(edge_index):
                num_nodes = x_src.size(0)
                if x_dst is not None:
                    num_nodes = min(num_nodes, x_dst.size(0))
                edge_index, _ = remove_self_loops(edge_index)
                edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
                
        # Start propagating info...: construct message -> aggregate message -> update/obtain new representations
        out = self.propagate(edge_index, x=x, a=a, size=None) # (N, NH, H_out)
        
        alpha = self._alpha # (#edges, NH)
        assert alpha is not None
        self._alpha = None
        
        if self.concat: # (N, NH, H_out) -> (N, NH * H_out)
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1) # (N, NH, H_out) -> (N, H_out)
        
        if self.bias is not None:
            out = out + self.bias
        
            
        # Returning attention weights with computed hidden features
        if isinstance(return_attention_weights, bool):
            if torch.is_tensor(edge_index):
                return out, (edge_index, alpha) # Depends on 'concat', ((2, #edges), (#edges, NH))
        else:   
            return out # Depends on 'concat': (N, NH * H_out) || (N, H_out)
    
    def message(self, x_j, x_i, a_i, a_j, index, size_i):
        # x_j has shape [#edges, NH, H_out]
        # x_i has shape [#edges, NH, H_out]
        # index: target node indexes, where data flows 'source_to_target': this is for computing softmax
        # size: size_i, size_j mean num_nodes in the graph
        # print(a_i.shape, a_j.shape)
        
        # x = x_i + x_j # adding(element-wise) source and target node features together to calculate attention
        x = torch.cat([x_i, x_j], dim=-1)
        # print(x.shape)
        # print(self.att.shape)
        x = F.leaky_relu(x, self.negative_slope)
        alpha = (x * self.att).sum(dim=-1)
        # print(alpha.shape)
        alpha = softmax(alpha, index, num_nodes=size_i) # spares softmax: groups node's attention and then node-wise softmax
        self._alpha = alpha # (#edges, NH)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training) # randomly dropping attention during training
        
        return x_j * alpha.unsqueeze(-1) # (#edges, NH, H_out)
                

In [7]:
gat = GAT(embed_size, 10, heads=8, concat=False, share_weights=True)
res = gat.forward(x, edge_index, return_attention_weights=True)
out, edge_index_returned_with_attention_weights = res[0], res[1]
edge_index_returned, attention_weights = edge_index_returned_with_attention_weights
print(out.shape)
print(attention_weights.shape)
print(edge_index_returned.shape)


torch.Size([4, 10])
torch.Size([12, 8])
torch.Size([2, 12])


# My Arch 1

In [11]:
class EIN(MessagePassing):
    def __init__(
        self,
        in_channels,
        out_channels,
        heads=1,
        concat=True,
        negative_slope=0.2,
        dropout=0.0,
        add_self_loops=True,
        edge_dim=None,
        fill_value='mean',
        bias=True,
        share_weights=False,
        **kwargs,
    ):
        super().__init__(node_dim=0, aggr='add' , **kwargs) # defines the aggregation method: `aggr='add'`

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.add_self_loops = add_self_loops
        self.share_weights = share_weights
        self.edge_dim = edge_dim
        self.fill_value = fill_value

        # Linear Transformation
        self.lin_l = Linear(in_channels, heads * out_channels, bias=bias)

        if share_weights:
            self.lin_r = self.lin_l # use same matrix
        else:
            self.lin_r = Linear(in_channels, heads * out_channels, bias=bias)

        # For attention calculation
        self.att = Parameter(torch.Tensor(1, heads, out_channels))

        # For influence mechanism
        # self.inf = Linear(edge_dim, out_channels)

        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self._alpha = None

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()
        xavier_uniform_(self.att)
        zeros_(self.bias)

    def forward(self, x, edge_index, edge_attr=None, return_attention_weights=None):
        ## N - no_of_nodes, NH - no_of heads,  H_in - input_channels, H_out - out_channels

        H, C = self.heads, self.out_channels

        x_l = None # for source nodes
        x_r = None # for target nodes

        x_l = self.lin_l(x).view(-1, H, C) # (N, H_in) -> (N, NH, H_Out)
        if self.share_weights:
            x_r = x_l
        else:
            x_r = self.lin_r(x).view(-1, H, C)

        assert x_l is not None
        assert x_r is not None

        if self.add_self_loops: # Adding self-loops for the graph...
            if torch.is_tensor(edge_index):
                if edge_attr is not None: # edge_attr is available -> add edge_attr for newly created self-loops: default `mean`
                    num_nodes = x_l.size(0)
                    if x_r is not None:
                        num_nodes = min(num_nodes, x_r.size(0))
                    edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
                    edge_index, edge_attr = add_self_loops(edge_index, edge_attr, fill_value=self.fill_value, num_nodes=num_nodes)
                else:
                    num_nodes = x_l.size(0)
                    if x_r is not None:
                        num_nodes = min(num_nodes, x_r.size(0))
                    edge_index, _ = remove_self_loops(edge_index)
                    edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)

        # Check the edge features shape: test_case
        if edge_attr is not None:
            print(f'edge_features shape: {edge_attr.shape}')
        else:
            print('No edge features!')

        # Start propagating info...: construct message -> aggregate message -> update/obtain new representations
        out = self.propagate(edge_index, x=(x_l, x_r), edge_attr=edge_attr, size=None) # (N, NH, H_out)

        alpha = self._alpha # (#edges, NH)
        assert alpha is not None
        self._alpha = None

        if self.concat: # (N, NH, H_out) -> (N, NH * H_out)
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1) # (N, NH, H_out) -> (N, H_out)

        if self.bias is not None:
            out = out + self.bias

        # Influence mechanism
        res = alpha.mean(dim=1, keepdims=True) * edge_attr
        print('inf: ', res.shape)

        # Returning attention weights with computed hidden features
        if isinstance(return_attention_weights, bool):
            if torch.is_tensor(edge_index):
                if self.concat:
                    return out, (edge_index, alpha) # Depends on 'concat', ((2, #edges), (#edges, NH))
                return out, (edge_index, alpha.mean(dim=1, keepdims=True))
        else:
            return out # Depends on 'concat': (N, NH * H_out) || (N, H_out)

    def message(self, x_j, x_i, index, size_i):
        # x_j has shape [#edges, NH, H_out]
        # x_i has shape [#edges, NH, H_out]
        # index: target node indexes, where data flows 'source_to_target': this is for computing softmax
        # size: size_i, size_j mean num_nodes in the graph

        x = x_i + x_j # adding(element-wise) source and target node features together to calculate attention
        x = F.leaky_relu(x, self.negative_slope)
        alpha = (x * self.att).sum(dim=-1) # (#edges, NH)
        alpha = softmax(alpha, index, num_nodes=size_i) # spares softmax: groups node's attention and then node-wise softmax
        self._alpha = alpha # (#edges, NH)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training) # randomly dropping attention during training

        return x_j * alpha.unsqueeze(-1) # (#edges, NH, H_out)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, heads={self.heads})')


In [12]:
ein = EIN(embed_size, 10, heads=8, concat=False,)
res = ein.forward(x, edge_index, edge_attr=x_edge,return_attention_weights=True)
out, r = res
# print(r[1].shape)
print(out.shape)
print(r[1])
print(r[0])

edge_features shape: torch.Size([12, 5])
inf:  torch.Size([12, 5])
torch.Size([4, 10])
tensor([[0.5031],
        [0.3298],
        [0.3339],
        [0.3296],
        [0.2443],
        [0.2584],
        [0.2505],
        [0.3327],
        [0.2469],
        [0.4969],
        [0.3375],
        [0.3365]], grad_fn=<MeanBackward1>)
tensor([[0, 0, 0, 1, 1, 2, 3, 3, 0, 1, 2, 3],
        [1, 2, 3, 3, 0, 0, 0, 2, 0, 1, 2, 3]])


In [13]:
x_edge.shape
r[1][:8].shape
# r[1][:8].reshape(8, 1) * x_edge
temp = r[1][:8] * x_edge
temp.shape

torch.Size([8, 5])

In [154]:
from torch_scatter import gather_csr, scatter, segment_csr


In [177]:
class EIN(MessagePassing):
    def __init__(
        self,
        in_channels,
        out_channels,
        heads=1,
        negative_slope=0.2,
        dropout=0.0,
        edge_dim=None,
        train_eps = False,
        eps = 0.0,
        bias=True,
        share_weights=False,
        **kwargs,
    ):
        super().__init__(node_dim=0, aggr='add' , **kwargs) # defines the aggregation method: `aggr='add'`

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.share_weights = share_weights
        self.edge_dim = edge_dim
        self.initial_eps = eps

        # Linear Transformation
        self.lin_l = Linear(in_channels, heads * out_channels, bias=bias)

        if share_weights:
            self.lin_r = self.lin_l # use same matrix
        else:
            self.lin_r = Linear(in_channels, heads * out_channels, bias=bias)

        # For attention calculation
        self.att = Parameter(torch.Tensor(1, heads, out_channels))

        # For influence mechanism
        self.inf = Linear(edge_dim, out_channels)

        # Tunable parameter for adding self node features...
        if train_eps:
            self.eps = torch.nn.Parameter(torch.Tensor([eps]))
        else:
            self.register_buffer('eps', torch.Tensor([eps]))

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self._alpha = None # alpha weights

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()
        self.inf.reset_parameters()
        self.eps.data.fill_(self.initial_eps)
        xavier_uniform_(self.att)
        zeros_(self.bias)

    def forward(self, x, edge_index, edge_attr=None, return_attention_weights=None):
        ## N - no_of_nodes, NH - no_of heads,  H_in - input_channels, H_out - out_channels

        H, C = self.heads, self.out_channels

        x_l = None # for source nodes
        x_r = None # for target nodes

        x_l = self.lin_l(x).view(-1, H, C) # (N, H_in) -> (N, NH, H_Out)
        if self.share_weights:
            x_r = x_l
        else:
            x_r = self.lin_r(x).view(-1, H, C)

        assert x_l is not None
        assert x_r is not None

        # Check the edge features shape: test_case
        if edge_attr is not None:
            print(f'edge_features shape: {edge_attr.shape}')
        else:
            print('No edge features!')

        # Start propagating info...: construct message -> aggregate message -> update/obtain new representations
        out = self.propagate(edge_index, x=(x_l, x_r), edge_attr=edge_attr, size=None) # (N, H_out)
        # out += x_r.mean(dim=1) # add the self features

        alpha = self._alpha # (#edges, 1)
        assert alpha is not None

        if self.bias is not None:
            out = out + self.bias

        # Returning attention weights with computed hidden features
        if isinstance(return_attention_weights, bool):
            if torch.is_tensor(edge_index):
                return out, alpha.mean(dim=1, keepdims=True)
        else:
            return out # (N, H_out)

    def message(self, x_j, x_i, index, size_i, edge_attr):
        # x_j has shape [#edges, NH, H_out]
        # x_i has shape [#edges, NH, H_out]
        # index: target node indexes, where data flows 'source_to_target': this is for computing softmax
        # size: size_i, size_j mean num_nodes in the graph

        x = x_i + x_j # adding(element-wise) source and target node features together to calculate attention
        x = F.leaky_relu(x, self.negative_slope)
        alpha = (x * self.att).sum(dim=-1) # (#edges, NH)
        alpha = softmax(alpha, index, num_nodes=size_i) # spares softmax: groups node's attention and then node-wise softmax
        self._alpha = alpha.mean(dim=1, keepdims=True) # (#edges, 1)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training) # randomly dropping attention during training
        node_out = (x_j * alpha.unsqueeze(-1)).mean(dim=1)

        tmp = scatter(node_out, index, dim=0 ,reduce='mean')
        print('tmp', tmp.shape)
        print(index, node_out.shape)
        if self.inf is not None:
            edge_attr = self.inf(self._alpha * edge_attr)
            return node_out + edge_attr  # (#edges, H_out)
        return node_out

    def update(self, aggr_out, x):
        print(aggr_out.shape)
        aggr_out += x[1].mean(dim=1) # add the self features
        return aggr_out


    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, heads={self.heads})')


In [178]:
ein = EIN(embed_size, 32, heads=10, edge_dim=5, train_eps=False)
res = ein.forward(x, edge_index, edge_attr=x_edge,return_attention_weights=True)
print('Returned results:')
print('Output dim:', res[0].shape)
print('Alpha weights: ', res[1].shape)

edge_features shape: torch.Size([8, 5])
tmp torch.Size([4, 32])
tensor([1, 2, 3, 3, 0, 0, 0, 2]) torch.Size([8, 32])
torch.Size([4, 32])
Returned results:
Output dim: torch.Size([4, 32])
Alpha weights:  torch.Size([8, 1])
