In [16]:
import einops
import torch
from torch import nn


embedding_dim=17
embedding_dim_V=21
sequence_length=13

Q= torch.rand([sequence_length,embedding_dim])
K= torch.rand([sequence_length,embedding_dim])
V= torch.rand([sequence_length,embedding_dim_V])

edge_index=torch.randint(0,sequence_length,(2,sequence_length*4))

senders,receivers=edge_index



In [17]:
senders[receivers==4]

tensor([11,  8,  9, 10,  4,  3,  7, 12])

In [41]:
import torch

def normalize_strength(strength,receivers,shape):
    """
    lets say we have a directed graph with N nodes and M edges.
    To represent each one i have 3 M-dimentional vectors which are cal `senders`, `receivers`
    and `strength`:
    The i-th element of the `senders` vector represents a node  that is directed towards the
    i-th element of the `receivers` vector. The strength of this connection is represented by
    the i-th element of the `strength` vector.

    This function normalizes the strength of each connection by dividing it by the sum of the
    strengths of all the connections that are directed towards the same node.

    Args:
        receivers (torch.Tensor): A vector of length M, where M is the number of edges in the
        strength (torch.Tensor): strength of each connection, must be a 2-dimentional tensor (h,M) where
            head is the number of heads
        shape (torch.Tensor): shape of the output tensor, must be a 2-dimentional tensor (h,N) where
            head is the number of heads and N is the number of nodes 

    Returns:
        torch.Tensor: strenght vector normalized by the sum of the strengths of all the
            connections that are directed towards the same node.
    """
    assert strength.dim()==2, "strength must be a 2-dimentional tensor (h,M) where head is the number of heads"
    assert len(shape)==2, "shape must be a 2-dimentional tensor (h,N) where head is the number of heads and N is the number of nodes"

    strengths_sum = torch.zeros(shape)
    strengths_sum.index_add_(1, receivers, strength)

    return strength / strengths_sum[:,receivers]

# Example usage
sequence_length=3

edge_index=torch.randint(0,sequence_length,(2,sequence_length*2))
senders,receivers=edge_index
strength = torch.rand([4,senders.shape[0]])


normalized_strength = normalize_strength(strength, receivers, [4,sequence_length])

print(normalized_strength)

tensor([[0.4614, 1.0000, 0.5386, 0.2418, 0.0557, 0.7025],
        [0.8225, 1.0000, 0.1775, 0.0960, 0.6821, 0.2219],
        [0.6421, 1.0000, 0.3579, 0.3533, 0.3917, 0.2550],
        [0.1475, 1.0000, 0.8525, 0.0511, 0.2874, 0.6615]])


In [73]:
def attention_message(K,Q,V,receivers,senders):
    #Q: (N, h, dQ)
    #K: (N, h, dK)
    #V: (N, h, dV)
    #receivers: (M,)
    #senders: (M,)
    assert K.dim()==Q.dim()==V.dim()==3, "K,Q,V must be 3-dimentional tensors"
    assert K.shape[0]==Q.shape[0]==V.shape[0], "K,Q,V must have the same first dimension"
    assert K.shape[1]==Q.shape[1]==V.shape[1], "K,Q,V must have the same second dimension"
    assert K.shape[2]==Q.shape[2], "K,Q must have the same third dimension"

    assert receivers.dim()==senders.dim()==1, "receivers and senders must be 1-dimentional tensors"
    assert receivers.shape[0]==senders.shape[0], "receivers and senders must have the same length"

    N,h,d=K.shape    
    att=(Q[receivers]*K[senders]).sum(dim=-1) #TODO: add multi-head attention
    #att: (M,h)
    att = torch.exp(att)
    att = normalize_strength(att, receivers,[N,h])

    att = einops.einsum(att,V[senders],' ... , ... c -> ... c')

    out=torch.zeros_like(V)

    return out.index_add_(-2,receivers,att)

In [19]:
phi(K,Q,V,receivers,senders)

TypeError: stack(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor

In [25]:
senders = torch.tensor([0, 1, 2, 2])
receivers = torch.tensor([1, 2, 1, 0])
strength = torch.tensor([[0.5, 0.2, 0.3, 0.7],[0.3, 0.2, 0.3, 0.7]])

nodes=torch.zeros([2,3])
print(nodes[:,receivers].shape)
#nodes[:,receivers]+=strength

nodes=nodes.index_add(1,receivers,strength)

strength=strength/nodes[:,receivers]
strength

torch.Size([2, 4])


tensor([[0.6250, 1.0000, 0.3750, 1.0000],
        [0.5000, 1.0000, 0.5000, 1.0000]])

In [30]:
x=torch.tensor([1,2,3,4,5,6,7,8,9,10])
x[:2]


tensor([1, 2])

In [11]:
senders = torch.tensor([0, 1, 2, 2])
receivers = torch.tensor([1, 2, 1, 0])
strength = torch.tensor([0.5, 0.2, 0.3, 0.7])

nodes=torch.zeros(3)

nodes[receivers]+=strength
nodes

tensor([0.7000, 0.3000, 0.2000])

In [25]:
nodes

tensor([0., 0.])