In [1]:
import sys
sys.path.append('/home/ckohler/workspace/bdai/projects/_experimental/fail/')

import torch
import numpy as np

import escnn
from escnn import nn as enn
from escnn import gspaces
from escnn import group

from fail.model.so2_transformer import SO2MultiheadAttention, SO2EncoderBlock, SO2TransformerEncoder, SO2Transformer

In [2]:
G = group.so2_group()
gspace = gspaces.no_base_space(G)

In [3]:
t = G.bl_regular_representation(L=5)
id = 11

In [4]:
model_dim = 8
hidden_dim = 8
out_dim = 1
in_type = enn.FieldType(gspace, [t] * model_dim)
num_heads = 4
m = SO2Transformer(
    in_type=in_type,
    model_dim=model_dim,
    out_dim=out_dim,
    num_heads=num_heads,
    num_layers=4, 
    dropout=0.1,
    in_dropout=0.1
)

In [5]:
x = torch.randn(1, 20, id*model_dim)
print(x.shape)

torch.Size([1, 20, 88])


In [6]:
y = m(x)
y.shape

torch.Size([20, 11])

In [7]:
torch.linalg.norm(y.tensor, dim=1)

tensor([26.1094, 10.7588, 16.8432,  9.8071, 22.7605, 30.3195, 32.3515, 23.1880,
         8.7990, 18.1947, 17.2120, 14.9155, 17.8798, 26.0388, 37.0092, 21.9139,
        27.1757, 26.9467, 18.7503, 29.0477],
       grad_fn=<LinalgVectorNormBackward0>)

In [8]:
np.set_printoptions(linewidth=10000, precision=4, suppress=True)

m.eval()
B = 10
L = 20
x = torch.randn(B, L, id*model_dim)

with torch.no_grad():
    y = m(x)
    print("Outputs' magnitudes")
    print(torch.linalg.norm(y.tensor, dim=1).numpy().reshape(-1)[:10])
    print('##########################################################################################')
    print("Errors' magnitudes")
    for r in range(8):
        # sample a random rotation
        g = G.sample()
        
        x_transformed = (g @ m.in_type(x.view(B*L, -1))).tensor.view(B,L,-1)
        x_transformed = x_transformed

        y_transformed = m(x_transformed)
        
        # verify that f(g@x) = g@f(x)=g@y
        print(torch.linalg.norm(y_transformed.tensor - (g @ y).tensor, dim=1).numpy().reshape(-1)[:10])
        

print('##########################################################################################')
print()

Outputs' magnitudes
[ 8.2461  9.698   4.8709  9.202  10.1814 10.4692  6.2604  9.8397  6.2164  9.4652]
##########################################################################################
Errors' magnitudes
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
##########################################################################################



In [9]:
in_type

[SO(2): {regular_5 (x8)}(88)]