# eSCN Block Simplified

In [1]:
from typing import Any
from einops import rearrange
from nshtrainer.ll.typecheck import Float, Tensor, tassert
import torch


Linear: Any = ...
MLP: Any = ...


def _m_primary(x: Float[Tensor, "E l_sq C"]) -> list[Float[Tensor, "E l_per_m C"]]: ...
def _l_primary(x: list[Float[Tensor, "E l_per_m C"]]) -> Float[Tensor, "E l_sq C"]: ...
def _rotate(x: Float[Tensor, "E l_sq C"]) -> Float[Tensor, "E l_sq C"]: ...
def _rotate_inv(x: Float[Tensor, "E l_sq C"]) -> Float[Tensor, "E l_sq C"]: ...


def escn_conv_m0(
    x: Float[Tensor, "E l_per_m C"],
    x_e_invariant: Float[Tensor, "E C"],
) -> Float[Tensor, "E l_per_m C"]:
    C = x.shape[-1]
    x = rearrange(x, "E l_per_m C -> E (l_per_m C)")

    x = Linear(bias=False)(x)
    tassert(Float[Tensor, "E H"], x)

    x = MLP()(x_e_invariant) * x
    tassert(Float[Tensor, "E H"], x)

    x = Linear(bias=False)(x)
    tassert(Float[Tensor, "E l_times_C"], x)

    x = rearrange(x, "E (l_per_m C) -> E l_per_m C", C=C)

    return x


def escn_conv_mgt0(
    x: Float[Tensor, "E sign l_per_m C"],
    x_e_invariant: Float[Tensor, "E C"],
    m: int,
) -> Float[Tensor, "E sign l_per_m C"]:
    C = x.shape[-1]

    x = rearrange(x, "E sign l_per_m C -> E sign (l_per_m C)")

    x_r = Linear(bias=False)(x)
    tassert(Float[Tensor, "E 2 H"], x_r)

    x_r = MLP()(x_e_invariant) * x_r
    tassert(Float[Tensor, "E 2 H"], x_r)

    x_r = Linear(bias=False)(x_r)
    tassert(Float[Tensor, "E 2 l_times_C"], x_r)

    x_i = Linear(bias=False)(x)
    tassert(Float[Tensor, "E 2 H"], x_i)

    x_i = MLP()(x_e_invariant) * x_i
    tassert(Float[Tensor, "E 2 H"], x_i)

    x_i = Linear(bias=False)(x_i)
    tassert(Float[Tensor, "E 2 l_times_C"], x_i)

    x_m_r = x_r[:, 0] - x_i[:, 1]
    x_m_i = x_r[:, 1] + x_i[:, 0]
    tassert(Float[Tensor, "E l_times_C"], (x_m_r, x_m_i))

    x = torch.stack([x_m_r, x_m_i], dim=1)
    tassert(Float[Tensor, "E 2 l_times_C"], x)

    x = rearrange(x, "E sign (l_per_m C) -> E sign l_per_m C", C=C)
    return x


def escn_conv(
    x: Float[Tensor, "E l_sq C"],
    x_e_invariant: Float[Tensor, "E C"],
):
    x = _rotate(x)

    # Reorder coeffs based on `m` and get each `abs(m)`-th set of coefficients
    x_per_m: list[Float[Tensor, "E l_per_m C"]] = _m_primary(x)

    x_m0, *x_mgt0 = x_per_m

    x_m0_new = escn_conv_m0(x_m0, x_e_invariant)
    x_mgt0_new = []
    for x_mgt0_m in x_mgt0:
        # Rearrange to get negative and positive `m` coefficients
        x_mgt0_m = rearrange(x_mgt0_m, "E (sign l_per_m) C -> E sign l_per_m C", sign=2)
        x_mgt0_m = escn_conv_mgt0(x_mgt0_m, x_e_invariant)
        x_mgt0_m = rearrange(x_mgt0_m, "E sign l_per_m C -> E (sign l_per_m) C")
        x_mgt0_new.append(x_mgt0_m)

    x = _l_primary([x_m0_new, *x_mgt0_new])

    x = _rotate_inv(x)
    return x

TensorBoard/TensorBoardX not found. Disabling TensorBoardLogger. Please install TensorBoard with `pip install tensorboard` or TensorBoardX with `pip install tensorboardx` to enable TensorBoard logging.
