In [1]:
import torch
from torch.nn import Linear, Parameter, Module
from torch_scatter.scatter import scatter_sum
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv, GATConv
from torch.nn import Linear
import numpy as np

In [35]:
class MultiHeadGraphAttention(Module):
    
    def __init__(self, in_features, out_features, num_heads=1, *args, **kwargs) -> None:
        super(MultiHeadGraphAttention, self).__init__(*args, **kwargs)
        
        self.in_features = in_features
        self.out_features = out_features
        self.num_heads = num_heads
        
        self.head_in_features = in_features//num_heads
        self.head_out_features = out_features//num_heads
        
        self.query_w, self.key_w, self.agg_value_w, self.update_value_w = self.create_weights()
        
        
    def forward(self, x, edge_index):
        if edge_index.shape[0] != self.num_heads:
            raise RuntimeError("The first dimension of edge_index should be equal to the number of heads.")
        x1_h = x.view(self.num_heads, x.shape[0], self.in_features//self.num_heads)
        
        batch_indices = torch.arange(self.num_heads).unsqueeze(1).unsqueeze(2).repeat(1, edge_index.shape[1], edge_index.shape[2]) 
        source_target_features = x1_h[batch_indices, edge_index]
        
        query = torch.sum(torch.einsum('ijkl,ilm->ijkm', source_target_features, self.query_w), dim=1)
        key = torch.sum(torch.einsum('ijkl,ilm->ijkm', source_target_features, self.key_w), dim=1)
        a = torch.einsum('ijk,ijk->ij', query, key)
        attention = torch.softmax(a, dim=1) / torch.sqrt(torch.tensor(self.head_out_features))
        
        value_update = torch.einsum('ikl,ilm->ikm', x1_h, self.update_value_w)
        value_aggregate = torch.sum(torch.einsum('ijkl,ilm->ijkm', source_target_features, self.agg_value_w), dim=1)
        value_aggregate = torch.einsum('ij, ijk->ijk', attention, value_aggregate)
        agg_sum = torch.zeros_like(value_update, device=value_update.device)
        scatter_sum(value_aggregate, index=edge_index.permute(0, 2, 1)[:, :, :1], dim=1, out=agg_sum)
        scatter_sum(value_aggregate, index=edge_index.permute(0, 2, 1)[:, :, 1:], dim=1, out=agg_sum)
        out = value_update + agg_sum
        print(out.shape, x.shape[0])
        return out.permute(1, 0, 2).reshape(x.shape[0], -1)
        
    def create_weights(self):
        query_w = torch.randn((self.num_heads, self.head_in_features, self.head_out_features))
        key_w = torch.randn((self.num_heads, self.head_in_features, self.head_out_features))
        agg_value_w = torch.randn((self.num_heads, self.head_in_features, self.head_out_features))
        update_value_w = torch.randn((self.num_heads, self.head_in_features, self.head_out_features))
        return Parameter(query_w), Parameter(key_w), Parameter(agg_value_w), Parameter(update_value_w)
        
    

In [36]:
num_heads = 4
in_features = 64
out_features = 128
MHGAT = MultiHeadGraphAttention(in_features, out_features, num_heads)

In [5]:
torch.ones((4, 251, 32)).permute(1, 0, 2).reshape(251,-1).shape

torch.Size([251, 128])

In [32]:
edge_index = torch.randint(0, 250, (100, ), dtype=torch.long)
random_values = torch.randn((100,))
base_tensor = torch.zeros((250, ))
print(base_tensor[:10])
scatter_sum(random_values, edge_index, out=base_tensor)
print(base_tensor[:10])
scatter_sum(random_values, edge_index, out=base_tensor)
print(base_tensor[:10])

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
tensor([ 0.0000, -0.3529, -0.7422,  1.8477,  0.0000, -0.0184,  0.0000, -0.9262,
         0.0000,  0.0000])
tensor([ 0.0000, -0.7059, -1.4844,  3.6954,  0.0000, -0.0367,  0.0000, -1.8524,
         0.0000,  0.0000])


In [37]:
x1 = torch.randn((251, in_features))
edge_index = torch.randint(0, 220, (num_heads, 2, 2600), dtype=torch.long)
out_x = MHGAT(x1, edge_index)

torch.Size([4, 251, 32]) 251


In [6]:



x1_h = x1.view(num_heads, x1.shape[0], in_features//num_heads)
# x1_h = x1_h.permute( 0, 1,2)

In [7]:
batch_indices = torch.arange(num_heads).unsqueeze(1).unsqueeze(2).repeat(1, edge_index.shape[1], edge_index.shape[2]) 
source_target_features = x1_h[batch_indices, edge_index]
source_target_features[0] = source_target_features[0]*0 + 1 
source_target_features[1] = source_target_features[1]*0 + 2 
source_target_features[2] = source_target_features[2]*0 + 3 
source_target_features[3] = source_target_features[3]*0 + 4
source_target_features.shape

torch.Size([4, 2, 2600, 16])

In [8]:
weights1 = torch.randn((num_heads, in_features//num_heads, out_features//num_heads))
weights2 = torch.randn((num_heads, in_features//num_heads, out_features//num_heads))
weights1[0] = weights1[0]*0 + 1 
weights1[1] = weights1[1]*0 + 2 
weights1[2] = weights1[2]*0 + 3 
weights1[3] = weights1[3]*0 + 4
weights1.shape

torch.Size([4, 16, 32])

In [9]:
print(x1_h.shape)
print(source_target_features.shape)

torch.Size([4, 250, 16])
torch.Size([4, 2, 2600, 16])


In [10]:
query = torch.sum(torch.einsum('ijkl,ilm->ijkm', source_target_features, weights1), dim=1)
key = torch.sum(torch.einsum('ijkl,ilm->ijkm', source_target_features, weights2), dim=1)
a = torch.einsum('ijk,ijk->ij', query, key)
attention = torch.softmax(a, dim=1) / torch.sqrt(torch.tensor(out_features//num_heads))
attention.shape

torch.Size([4, 2600])

In [32]:
value_aggregate = torch.sum(torch.einsum('ijkl,ilm->ijkm', source_target_features, weights2), dim=1)
value_aggregate = torch.einsum('ij, ijk->ijk', attention, value_aggregate)
value_aggregate = scatter_sum(value_aggregate, index=edge_index.permute(0, 2, 1)[:, :, :1], dim=1)+ scatter_sum(value_aggregate, index=edge_index.permute(0, 2, 1)[:, :, 1:], dim=1)
value_update = torch.einsum('ikl,ilm->ikm', x1_h, weights2)
value_update.shape, value_aggregate.shape

(torch.Size([4, 250, 32]), torch.Size([4, 250, 32]))

In [31]:
edge_index[:, :1].shape

torch.Size([4, 1, 2600])

In [25]:
edge_index.shape

torch.Size([4, 2, 2600])

In [29]:
()

torch.Size([4, 250, 32])

In [12]:
value_aggregate.shape

torch.Size([4, 250, 32])

In [206]:
attention = torch.ones((4, 2600))
attention[1]*=torch.linspace(1, 2, 2600)
attention[2]*=3
attention[3]*=4
value_aggregate = torch.ones((4, 2600, 32))

In [214]:
value_aggregate[1,2]

tensor([1.0008, 1.0008, 1.0008, 1.0008, 1.0008, 1.0008, 1.0008, 1.0008, 1.0008,
        1.0008, 1.0008, 1.0008, 1.0008, 1.0008, 1.0008, 1.0008, 1.0008, 1.0008,
        1.0008, 1.0008, 1.0008, 1.0008, 1.0008, 1.0008, 1.0008, 1.0008, 1.0008,
        1.0008, 1.0008, 1.0008, 1.0008, 1.0008])

In [168]:
value_update.shape

torch.Size([4, 250, 32])

In [169]:
value_aggregate.shape

torch.Size([4, 2600, 32])

In [183]:
x1_h.shape, edge_index[:, 0].unsqueeze(2).shape

(torch.Size([4, 250, 16]), torch.Size([4, 2600, 1]))

In [189]:

# x1_h.scatter_reduce(1, edge_index[:, 0].unsqueeze(2), value_update, reduce='sum')

torch.Size([4, 250, 32])

In [171]:
edge_index.shape

torch.Size([4, 2, 2600])

In [163]:
source_target_features.shape

torch.Size([4, 2, 2600, 16])

In [138]:
value_aggregate.shape

torch.Size([4, 26, 32])

In [None]:
torch.einsum('ijkl,ijkl->ijkm', source_target_features, weights2)

In [115]:
query.shape

torch.Size([4, 26, 32])

In [None]:
torch.index_select(x1)

In [17]:
messages = torch.einsum('ijkl,ilm->ijkm', source_target_features, weights1)
messages = F.leaky_relu_(messages[0] + messages[1])
messages = torch.exp(torch.einsum('', w_a, messages))
alpha = messages/torch.sum(messages, dim=(1,2))

x = torch.

torch.Size([26, 4, 4, 16])

In [19]:
sample_layer = GATv2Conv(4, 3, edge_dim=2, heads=2, add_self_loops=False)
x = torch.ones((6, 4))
edge_index = torch.tensor([[0, 0, 1, 1, 2],[2, 3, 4, 5, 5]], dtype=torch.long).unsqueeze(1)
edge_index = torch.cat([edge_index, edge_index], dim=1)
print(edge_index.shape)
edge_attr = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5],[0.15, 0.25, 0.35, 0.45, 0.55]]).T
sample_layer(x, edge_index, edge_attr=edge_attr)

torch.Size([2, 2, 5])
1: torch.Size([5, 2]), torch.Size([2, 2, 5])
2: torch.Size([5, 2]), torch.Size([2, 2, 5])


ValueError: Expected 'edge_index' to be two-dimensional (got 3 dimensions)

In [17]:
edge_index

tensor([[  0,   0,  80,  80, 160],
        [160, 240, 320, 400, 400]])