In [2]:
import torch
torch.set_default_dtype(torch.float64)
from torch import nn

linear = nn.Linear(16, 16)


In [3]:
batch_size = 128
num_nodes = 16
dim = 8
graph_features = torch.randn(batch_size, num_nodes, dim)
perm = torch.randperm(num_nodes)
graph_features_permuted = graph_features[:, perm, :]

In [4]:
graph_features_expanded = graph_features.unsqueeze(1).expand(-1, num_nodes, -1, -1)
graph_features_permuted_expanded = graph_features_permuted.unsqueeze(1).expand(-1, num_nodes, -1, -1)

In [5]:
I = torch.eye(num_nodes, dtype=bool).unsqueeze(-1).expand(-1, -1, dim)

graph_features_masked = graph_features_expanded * ~I
node_features = graph_features.unsqueeze(2).expand(-1, -1, num_nodes, -1)

graph_features_permuted_masked = graph_features_permuted_expanded * ~I
node_features_permuted = graph_features_permuted.unsqueeze(2).expand(-1, -1, num_nodes, -1)

In [6]:
input = torch.cat([node_features, graph_features_masked], dim=-1)
input_permuted = torch.cat([node_features_permuted, graph_features_permuted_masked], dim=-1)

In [7]:
output = linear(input)
output_permuted = linear(input_permuted)

output = output.sum(dim=2)
output_permuted = output_permuted.sum(dim=2)

In [8]:
(output[:, perm, :] - output_permuted).abs().max()

tensor(1.0658e-14, grad_fn=<MaxBackward1>)

In [9]:
output_permuted.shape

torch.Size([128, 16, 16])

In [10]:
class Attention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        q_norm=False,
        kv_norm=False,
        dropout: float = 0.0,
        add_bias_kv: bool = False,
        norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        self.q_norm = norm_layer(dim) if q_norm else nn.Identity()
        self.kv_norm = norm_layer(dim) if kv_norm else nn.Identity()

        self.attn = nn.MultiheadAttention(
            dim, num_heads, dropout=dropout, add_bias_kv=add_bias_kv, batch_first=True
        )

    def forward(self, h, c=None):
        if c is None:
            c = h
        else:
            c = torch.cat((h, c), dim=1)

        q = self.q_norm(h)
        k = self.kv_norm(c)
        v = self.kv_norm(c)

        return self.attn(q, k, v)[0]

In [13]:
attn = Attention(8, 4, q_norm=True, kv_norm=True)

output = attn(graph_features)
output_permuted = attn(graph_features_permuted)

(output[:, perm, :] - output_permuted).abs().max()



tensor(1.6653e-16, grad_fn=<MaxBackward1>)