In [16]:
import einops
import torch
from torch import nn


embedding_dim=17
embedding_dim_V=21
sequence_length=13

Q= torch.rand([sequence_length,embedding_dim])
K= torch.rand([sequence_length,embedding_dim])
V= torch.rand([sequence_length,embedding_dim_V])

edge_index=torch.randint(0,sequence_length,(2,sequence_length*4))

senders,receivers=edge_index



In [17]:
senders[receivers==4]

tensor([11,  8,  9, 10,  4,  3,  7, 12])

In [63]:
import torch

def normalize_strength(strength,receivers,n_nodes,heads):
    """
    lets say we have a directed graph with N nodes and M edges.
    To represent each one i have 3 M-dimentional vectors which are cal `senders`, `receivers`
    and `strength`:
    The i-th element of the `senders` vector represents a node  that is directed towards the
    i-th element of the `receivers` vector. The strength of this connection is represented by
    the i-th element of the `strength` vector.

    This function normalizes the strength of each connection by dividing it by the sum of the
    strengths of all the connections that are directed towards the same node.

    Args:
        receivers (torch.Tensor): A vector of length M, where M is the number of edges in the
        strength (torch.Tensor): strength of each connection, (M,h) where M is the number of edges
            head is the number of heads
        N (int): number of nodes
        heads (int): number of heads

    Returns:
        torch.Tensor: strenght vector normalized by the sum of the strengths of all the
            connections that are directed towards the same node.
    """
    assert strength.dim()==2, "strength must be a 2-dimentional tensor (M,h) where head is the number of heads"
    assert type(n_nodes)==type(heads)==int, "n_nodes and heads must be integers"

    strengths_sum = torch.zeros([n_nodes,heads],device=strength.device)
    strengths_sum.index_add_(0, receivers, strength)

    return strength / strengths_sum[receivers]

# Example usage
sequence_length=3
heads=2

edge_index=torch.randint(0,sequence_length,(2,sequence_length*2))
senders,receivers=edge_index
strength = torch.rand([senders.shape[0],heads])


normalized_strength = normalize_strength(strength, receivers, sequence_length, heads)
print(senders)
print(receivers)
print(normalized_strength)

tensor([1, 0, 1, 2, 0, 2])
tensor([2, 0, 2, 0, 2, 1])
tensor([[0.2201, 0.0155],
        [0.5055, 0.6268],
        [0.2510, 0.8635],
        [0.4945, 0.3732],
        [0.5289, 0.1210],
        [1.0000, 1.0000]])


In [58]:
import einops

def attention_message(K,Q,V,receivers,senders):
    #Q: (N, h, dQ)
    #K: (N, h, dK)
    #V: (N, h, dV)
    #receivers: (M,)
    #senders: (M,)
    assert K.dim()==Q.dim()==V.dim()==3, "K,Q,V must be 3-dimentional tensors"
    assert K.shape[0]==Q.shape[0]==V.shape[0], "K,Q,V must have the same first dimension"
    assert K.shape[1]==Q.shape[1]==V.shape[1], "K,Q,V must have the same second dimension"
    assert K.shape[2]==Q.shape[2], "K,Q must have the same third dimension"

    assert receivers.dim()==senders.dim()==1, "receivers and senders must be 1-dimentional tensors"
    assert receivers.shape[0]==senders.shape[0], "receivers and senders must have the same length"

    N,h,d=K.shape    
    att=(Q[receivers]*K[senders]).sum(dim=-1) #TODO: add multi-head attention
    
    att=att*3/att.max()

    att = torch.exp(att)
    att = normalize_strength(att, receivers, N, h)

    att = einops.einsum(att,V[senders],' ... , ... c -> ... c')

    out=torch.zeros_like(V,device=V.device)

    return out.index_add_(0,receivers,att)

In [62]:
embedding_dim=17
embedding_dim_V=21
sequence_length=13
n_edges=133
heads=3

Q= torch.rand([sequence_length,heads,embedding_dim])
K= torch.rand([sequence_length,heads,embedding_dim])
V= torch.rand([sequence_length,heads,embedding_dim_V])

edge_index=torch.randint(0,sequence_length,(2,n_edges))

senders,receivers=edge_index

att=attention_message(K,Q,V,receivers,senders)

In [19]:
phi(K,Q,V,receivers,senders)

TypeError: stack(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor

In [45]:
senders = torch.tensor([0, 1, 2, 2])
receivers = torch.tensor([1, 2, 1, 0])
strength = torch.tensor([[0.5, 0.2, 0.3, 0.7],[0.3, 0.2, 0.3, 0.7]]).t()

nodes=torch.zeros([3,2])
print(nodes[receivers].shape)
#nodes[:,receivers]+=strength
s
nodes=nodes.index_add(0,receivers,strength)

strength=strength/nodes[receivers]
nodes

torch.Size([4, 2])


tensor([[0.7000, 0.7000],
        [0.8000, 0.6000],
        [0.2000, 0.2000]])

In [30]:
x=torch.tensor([1,2,3,4,5,6,7,8,9,10])
x[:2]


tensor([1, 2])

In [11]:
senders = torch.tensor([0, 1, 2, 2])
receivers = torch.tensor([1, 2, 1, 0])
strength = torch.tensor([0.5, 0.2, 0.3, 0.7])

nodes=torch.zeros(3)

nodes[receivers]+=strength
nodes

tensor([0.7000, 0.3000, 0.2000])

In [79]:
import torch


n_nodes=74
n_edges=563
heads=6
receivers=torch.randint(0,n_nodes,(n_edges,))
strength = torch.rand([n_edges,heads])

strength=normalize_strength(strength,receivers,n_nodes,heads)

nodes=torch.zeros([n_nodes,heads])
nodes=nodes.index_add(0,receivers,strength)
torch.allclose(nodes,torch.ones_like(nodes))


RuntimeError: index_add_(): self (Double) and source (Float) must have the same scalar type