In [188]:
from abc import abstractmethod
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
from torch.nn.functional import silu
from e3nn import nn, o3
from e3nn.util.jit import compile_mode
import torch
# Set the default floating-point type to float64
torch.set_default_dtype(torch.float64)

class GroupavgReadoutBlock(torch.nn.Module):

    def __init__(self, irreps_in: o3.Irreps,
                 gate: Optional[Callable],
                 irrep_out: o3.Irreps=o3.Irreps("0e"),
                ):
        super().__init__()
        self.irreps_in = irreps_in
        self.non_linearity = gate
        input_size = irreps_in.dim
        output_size = irrep_out.dim
        hidden_size = 128
        self.MLP = torch.nn.Sequential(
            torch.nn.Linear(input_size, hidden_size),
            torch.nn.BatchNorm1d(hidden_size),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_size, output_size)
        )
        self.register_buffer("SO3_grid_1_0", 
            o3.quaternion_to_matrix(torch.load("/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_1_0.pt").to(torch.get_default_dtype())))
        self.register_buffer("SO3_grid_1_1", 
            o3.quaternion_to_matrix(torch.load("/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_1_1.pt").to(torch.get_default_dtype())))
        self.register_buffer("SO3_grid_1_2", 
            o3.quaternion_to_matrix(torch.load("/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_1_2.pt").to(torch.get_default_dtype())))
        self.register_buffer("SO3_grid_2_0", 
            o3.quaternion_to_matrix(torch.load("/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_2_0.pt").to(torch.get_default_dtype())))
        self.register_buffer("SO3_grid_2_1", 
            o3.quaternion_to_matrix(torch.load("/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_2_1.pt").to(torch.get_default_dtype())))
        self.register_buffer("SO3_grid_2_2", 
            o3.quaternion_to_matrix(torch.load("/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_2_2.pt").to(torch.get_default_dtype())))


    def forward(self, x: torch.Tensor, heads: Optional[torch.Tensor] = None):
        rand_D = o3.rand_matrix(device=x.device)
        gs = self.SO3_grid_1_2 @ rand_D       # [72, 3, 3]
        Ds = self.irreps_in.D_from_matrix(gs) # [72, D, D]

        xs = torch.einsum("nd,rjd->nrj", x, Ds) # [n_graphs, D], [72, D, D] -> [n_graphs, 72, D]
        print(xs.shape)
        outs = self.MLP(xs.view(-1, xs.size(-1)))                    # [n_graph, 72, 1]
        out = torch.mean(outs.view(*xs.shape[:-1], -1), dim=1, keepdim=False)
        return out

In [189]:
irreps_in = o3.Irreps("3x0e+1x1o+1x2e")
n_graph = 32
readout = GroupavgReadoutBlock(irreps_in=irreps_in, gate=torch.nn.SiLU)

In [205]:
x = irreps_in.randn(n_graph, -1)

out = readout(x)

rot_x = x @ irreps_in.D_from_matrix(o3.rand_matrix())

rot_out = readout(rot_x)
# print(x - rot_x)
print((rot_out - out).abs().mean())

torch.Size([32, 4608, 11])
torch.Size([32, 4608, 11])
tensor(0.0003, grad_fn=<MeanBackward0>)
