In [3]:
%load_ext autoreload

%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from utils import rand_rotation_tensor, test_close

# Hyperparameters sweeped from VN-Transformer paper
feature_dimension = {32, 64, 128, 256, 512, 1024}
number_of_attention_heads = {4, 8, 16, 32, 64, 128}
hidden_layer_dimension_in_encoder_VN_MLP = {32, 64, 128, 256, 512}
learning_rate = 10^-3
learning_rate_schedule = "linear_decay"
optimizer = "Adam"
epochs = 4000
epsilon_VN_With_Bias = {0, 10^-6}


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Rotation-equivariant Attention

Consider two tensors $Q \in \mathbb{R}^{M \times C \times 3}$ and $K \in \mathbb{R}^{N \times C \times 3}$ which can be thought of as sets of $M$ and $N$ tokens respectively, each a $C \times 3$ matrix. Using the Frobenius  Inner Product, we can define an attention matrix $A(Q, K) \in \mathbb{R}^{M \times N}$ between the two sets as follows:

$$
A(Q, K)^{(m)} \triangleq \operatorname{softmax}\left(\frac{1}{\sqrt{3 C}}\left[\left\langle Q^{(m)}, K^{(n)}\right\rangle_F\right]_{n=1}^N\right)
$$

Following [Vaswani et al](https://arxiv.org/abs/1706.03762), we divide the inner products by $\sqrt{3C}$ since $Q^{(m)}$, $K^{(n)} \in \mathbb{R}^{C \times 3}$

Therefore:

$$
A(QR, KR) = A(Q,K) \ \ \forall R \in SO(3)
$$

e.g. rotation invariant for rotations of both Q and K.

In [4]:
m, n, c = 8, 5, 7
Q = torch.randn(m, c, 3)
K = torch.randn(n, c, 3)
R = rand_rotation_tensor() # to test rotation-invariance

def A(Q, K):
  C = Q.shape[1]
  scores = torch.sum(torch.einsum("mck,nck->mnck", Q, K), dim=[2,3]) / torch.sqrt(torch.tensor(3 * C))
  scores = F.softmax(scores, dim=-1)
  return scores

# how close are two equal matrices?
res, eps = test_close(A(Q, K), A(Q, K))
print(f'A(Q, K) = A(Q, K) is {res} with eps {eps:.0e}')

# A(QR, KR) = A(Q,K)
res, eps = test_close(A(Q @ R, K @ R), A(Q, K))
print(f'A(QR, KR) = A(Q, K) is {res} with eps {eps:.0e}')


A(Q, K) = A(Q, K) is True with eps 1e-15
A(QR, KR) = A(Q, K) is True with eps 1e-06


Finally, we define the operation VN_Attention:

$$
\mathrm{VN}\_\operatorname{Attention}(Q, K, Z)^{(m)} \triangleq \sum_{n=1}^N A(Q, K)^{(m, n)} Z^{(n)}
$$

Where $\mathrm{VN}\_\operatorname{Attention} : \mathbb{R}^{M \times C \times 3} \times \mathbb{R}^{N \times C \times 3} \times \mathbb{R}^{N \times C' \times 3} \rightarrow \mathbb{R}^{M \times C' \times 3} $ 

In [154]:
def VN_Attention(Q, K, Z):
    """
      Q size M, C,  3 
      K size N, C,  3
      Z size N, C', 3

      rotation-equivariant with respect to rotations of all inputs
    """
    C = Q.shape[1]

    # frombenius inner product between Q and K
    # to produce rotation invariant attention scores
    scores = torch.sum(torch.einsum("mck,nck->mnck", Q, K), dim=[2,3]) / torch.sqrt(torch.tensor(3 * C))
    scores = F.softmax(scores, dim=-1)
    
    # apply attention scores to Z
    # this is rotation-equivariant
    output = torch.einsum('mn,nck->mck', scores, Z)

    return output

m, n, c, c_prime = 8, 5, 7, 12
Q = torch.rand(m, c, 3)
K = torch.rand(n, c, 3)
Z = torch.rand(n, c_prime, 3)
R = rand_rotation_tensor() # to test rotation-invariance

# VN_Attention(QR, KR, ZR) != VN_Attention(Q, K, Z)
res, _ = test_close(VN_Attention(Q @ R, K @ R, Z @ R), VN_Attention(Q, K, Z))
print(f'VN_Attention(QR, KR, ZR) != VN_Attention(Q, K, Z) is {not res}')

# VN_Attention(QR, KR, ZR) = VN_Attention(Q, K, Z) R
res, eps = test_close(VN_Attention(Q @ R, K @ R, Z @ R), VN_Attention(Q, K, Z) @ R)
print(f'VN_Attention(QR, KR, ZR) = VN_Attention(Q, K, Z) R is {res} with eps {eps:.0e}')

VN_Attention(QR, KR, ZR) != VN_Attention(Q, K, Z) is True
VN_Attention(QR, KR, ZR) = VN_Attention(Q, K, Z) R is True with eps 1e-06


This is extendable to multi-head attention with $H$ heads, $\mathrm{VN}\_\operatorname{MultiHeadAttention} : \mathbb{R}^{M \times C \times 3} \times \mathbb{R}^{N \times C \times 3} \times \mathbb{R}^{N \times C' \times 3} \rightarrow \mathbb{R}^{M \times C' \times 3} $ 

$$
\mathrm{VN}\_\operatorname{MultiHeadAttention}(Q, K, Z) \triangleq W^{O} \left [ \mathrm{VN}\_\operatorname{Attention}(W^{Q}_{h}Q, W^{K}_{h}K, W^{Z}_{h}Z) \right ]^{H}_{h = 1}
$$

Where $W^{Q}_{h}, W^{K}_{h} \in \mathbb{R}^{P \times C}, \ W^{Z}_{h} \in \mathbb{R}^{C' \times HP}$ Where $H$ and $P$ are set such that $HP = C'$

In [155]:
class VN_AttentionHead(nn.Module):
  def __init__(self, p, c, c_prime):
    super(VN_AttentionHead, self).__init__()
    self.W_q = torch.rand(p, c)
    self.W_k = torch.rand(p, c)
    self.W_z = torch.rand(p, c_prime)

  def forward(self, q, k, z):

    q = torch.einsum('pc,mck->mck', self.W_q, q)
    k = torch.einsum('pc,nck->nck', self.W_k, k)
    z = torch.einsum('pc,nck->nck', self.W_z, z)

    return VN_Attention(q, k, z)

class VN_MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, p, c, c_prime):
        super(VN_MultiHeadAttention, self).__init__()
        self.W_o = torch.rand(c_prime, num_heads * p)
        
        self.heads = nn.ModuleList(
            [VN_AttentionHead(p, c, c_prime) for _ in range(num_heads)]
        )

    def forward(self, q, k, z):
        return torch.einsum('cp,mhk->mpk', self.W_o, torch.cat([head(q, k, z) for head in self.heads], dim=1))

In [159]:
m, n, c, c_prime = 8, 5, 7, 12
p = 2

Q = torch.rand(m, c, 3)
K = torch.rand(n, c, 3)
Z = torch.rand(n, c_prime, 3)

R = rand_rotation_tensor() # to test rotation-invariance

VN_Att_H = VN_AttentionHead(p, c, c_prime)

# VN_Att_H(QR, KR, ZR) != VN_Att_H(Q, K, Z)
res, _ = test_close(VN_Att_H(Q @ R, K @ R, Z @ R), VN_Att_H(Q, K, Z))
print(f'VN_Att_H(QR, KR, ZR) != VN_Att_H(Q, K, Z) is {not res}')

# VN_Att_H(QR, KR, ZR) = VN_Att_H(Q, K, Z) R
res, eps = test_close(VN_Att_H(Q @ R, K @ R, Z @ R), VN_Att_H(Q, K, Z) @ R)
print(f'VN_Att_H(QR, KR, ZR) = VN_Att_H(Q, K, Z) R is {res} with eps {eps:.0e}')

VN_Att_H(QR, KR, ZR) != VN_Att_H(Q, K, Z) is True
VN_Att_H(QR, KR, ZR) = VN_Att_H(Q, K, Z) R is True with eps 1e-06


In [171]:
m, n, c, c_prime = 8, 5, 7, 12
p = 2
h = c_prime // p

Q = torch.rand(m, c, 3)
K = torch.rand(n, c, 3)
Z = torch.rand(n, c_prime, 3)

R = rand_rotation_tensor() # to test rotation-invariance

VN_MHA = VN_MultiHeadAttention(h, p, c, c_prime)

# # VN_MHA(QR, KR, ZR) != VN_MHA(Q, K, Z)
res, _ = test_close(VN_MHA(Q @ R, K @ R, Z @ R), VN_MHA(Q, K, Z))
print(f'VN_MHA(QR, KR, ZR) != VN_MHA(Q, K, Z) is {not res}')

# # VN_MHA(QR, KR, ZR) = VN_MHA(Q, K, Z) R
res, eps = test_close(VN_MHA(Q @ R, K @ R, Z @ R), VN_MHA(Q, K, Z) @ R)
print(f'VN_MHA(QR, KR, ZR) = VN_MHA(Q, K, Z) R is {res} with eps {eps:.0e}') # epsilon is decreasing so not sure if this is correct

VN_MHA(QR, KR, ZR) != VN_MHA(Q, K, Z) is True
VN_MHA(QR, KR, ZR) = VN_MHA(Q, K, Z) R is True with eps 1e-04


### Rotation-Equivariant Layer Normalisation


$\mathrm{VN}\_\operatorname{LayerNorm}(V^{(n)}) \triangleq \left [ \frac{V^{(n, c)}}{\left \| V^{(n, c)} \right \|_{2}} \right ]_{c=1}^{C} \odot LayerNorm(\left [ \left \| V^{(n, c)} \right \|_{2} \right ]_{c=1}^{C})$ 𝟙$_{1 \times 3}$ 


In [7]:
class VNLayerNorm(nn.Module):
    def __init__(self, in_channels):
        super(VNLayerNorm, self).__init__()
        self.layer_norm = nn.LayerNorm(in_channels)
    
    def forward(self, x):
        '''
        x: tensor of shape (B, C, 3)
        '''
        row_wise_norm = torch.norm(x, dim=-1)
        # row wise division of x by row_wise_norm
        x = torch.einsum('bck,bc->bck', x, 1 / row_wise_norm)
        ln = self.layer_norm(row_wise_norm)
        # row wise multiplication of x by ln
        x = torch.einsum('bck,bc->bck', x, ln)
        return x

n, c = 8, 12
R = rand_rotation_tensor() # to test rotation-invariance
V = torch.rand(n, c, 3)
layernorm = VNLayerNorm(c)

# layernorm(VR) != layernorm(V)
res, _ = test_close(layernorm(V @ R), layernorm(V))
print(f'layernorm(VR) != layernorm(V) is {not res}')

# layernorm(VR) = layernorm(V) R
res, eps = test_close(layernorm(V @ R), layernorm(V) @ R)
print(f'layernorm(VR) = layernorm(V) @ R is {res} with eps {eps:.0e}')

layernorm(VR) != layernorm(V) is True
layernorm(VR) = layernorm(V) @ R is True with eps 1e-06


### Rotation-Invariant Classification Model

In [None]:
from vnn.vn_models import VN_DGCNN as VN_MLP

class VN_Transformer_Classifier(nn.Module):
    def __init__(self, vn_mlp_hidden_dim, feature_dimension, attention_heads):
        super(VN_Transformer_Classifier, self).__init__()
        self.vn_mlp = VN_MLP(c_dim=feature_dimension, dim=3, hidden_dim=vn_mlp_hidden_dim, k=20, meta_output='invariant_latent')
        self.attention = VN_MultiHeadAttention(attention_heads, feature_dimension // attention_heads, c, feature_dimension)
        self.layer_norm = VNLayerNorm(c_prime)
        self.W = torch.rand(num_classes, c_prime)
