In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

bs, sl, expansion, dim = 2, 1, 4, 3
X = torch.randn(bs, sl, dim)

H_pre = torch.randn(expansion)
X_pre = torch.einsum('bled,e->bd', X.unsqueeze(1).repeat(1, 1, expansion, 1), H_pre)

tensor([[ 1.3193,  0.2545,  0.4482],
        [-0.1236,  0.4518,  0.5052]])

In [29]:
H = torch.randn(bs, sl, dim).unsqueeze(-2).repeat(1, 1, expansion, 1)
H

tensor([[[[ 1.0768, -1.2149,  0.4028],
          [ 1.0768, -1.2149,  0.4028],
          [ 1.0768, -1.2149,  0.4028],
          [ 1.0768, -1.2149,  0.4028]]],


        [[[-0.2116,  0.9096, -0.2262],
          [-0.2116,  0.9096, -0.2262],
          [-0.2116,  0.9096, -0.2262],
          [-0.2116,  0.9096, -0.2262]]]])

In [31]:
class HC(nn.Module):
    def __init__(self, n, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.n = n
        self.alpha = nn.Parameter(torch.zeros(n))
        self.beta = nn.Parameter(torch.ones(n))
        self.interaction = nn.Parameter(torch.eye(n, n))

    def forward(self, H, layer):
        res = torch.einsum('bled,ee->bled', H, self.interaction)
        H = torch.einsum('bled,e->bld', H, self.alpha)
        H = layer(H)
        H = torch.einsum('bld,e->bled', H, self.beta)
        return H + res

hc = HC(expansion)
hc(H, nn.Identity())

tensor([[[[ 1.0768, -1.2149,  0.4028],
          [ 1.0768, -1.2149,  0.4028],
          [ 1.0768, -1.2149,  0.4028],
          [ 1.0768, -1.2149,  0.4028]]],


        [[[-0.2116,  0.9096, -0.2262],
          [-0.2116,  0.9096, -0.2262],
          [-0.2116,  0.9096, -0.2262],
          [-0.2116,  0.9096, -0.2262]]]], grad_fn=<AddBackward0>)

In [38]:
from torch.nn import LayerNorm

class Sinkhorn(nn.Module):
    def __init__(self, n_iters: int = 20, eps: float = 1e-8, temperature: float = 1.0):
        super().__init__()
        self.n_iters = n_iters
        self.eps = eps
        self.temperature = temperature

    def forward(self, H: torch.Tensor) -> torch.Tensor:
        # H: (..., n_q, n_c) - arbitrary batch dimensions
        K = (H * self.temperature).exp()
        *batch_dims, n_q, n_c = K.shape

        # Initialize scaling vectors with batch dimensions
        u = torch.full((*batch_dims, n_q), 1.0 / n_q, device=K.device, dtype=K.dtype)
        v = torch.full((*batch_dims, n_c), 1.0 / n_c, device=K.device, dtype=K.dtype)

        # Sinkhornâ€“Knopp iterations using batched matrix operations
        for _ in range(self.n_iters):
            u = 1.0 / (torch.matmul(K, v.unsqueeze(-1)).squeeze(-1) + self.eps)
            v = 1.0 / (torch.matmul(K.transpose(-2, -1), u.unsqueeze(-1)).squeeze(-1) + self.eps)

        # Final doubly-stochastic matrix: diag(u) @ K @ diag(v)
        return (u.unsqueeze(-1) * K) * v.unsqueeze(-2)

class HyperConnection(nn.Module):
    def __init__(self, dim, rate, layer_id, dynamic, device=None):
        super(HyperConnection, self).__init__()

        self.rate = rate
        self.layer_id = layer_id
        self.dynamic = dynamic

        self.static_beta = nn.Parameter(torch.ones((rate,), device=device))

        init_alpha0 = torch.zeros((rate, 1), device=device)
        init_alpha0[layer_id % rate, 0] = 1.
        self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye((rate), device=device)], dim=1))

        if self.dynamic:
            self.dynamic_alpha_fn = nn.Parameter(torch.zeros((dim, rate+1), device=device))
            self.dynamic_alpha_scale = nn.Parameter(torch.ones(1, device=device) * 0.01)
            self.dynamic_beta_fn = nn.Parameter(torch.zeros((dim, ), device=device))
            self.dynamic_beta_scale = nn.Parameter(torch.ones(1, device=device) * 0.01)
            self.layer_norm = LayerNorm(dim)

        self.ln = nn.LayerNorm(dim)
        self.sinkhorn = Sinkhorn()

    def forward(self, H, layer):
        mix_h, beta = self.width_connection(H)
        h = layer(self.ln(F.sigmoid(mix_h[..., 0, :])))
        return self.depth_connection(mix_h, h, beta)

    def width_connection(self, h):
        # get alpha and beta
        if self.dynamic:
            norm_h = self.layer_norm(h)

        if self.dynamic:
            wc_weight = norm_h @ self.dynamic_alpha_fn
            dynamic_alpha = wc_weight * self.dynamic_alpha_scale
            alpha = dynamic_alpha + self.static_alpha[None, None, ...]
        else:
            alpha = self.static_alpha[None, None, ...]

        if self.dynamic:
            dc_weight = norm_h @ self.dynamic_beta_fn
            dynamic_beta = dc_weight * self.dynamic_beta_scale
            beta = dynamic_beta + self.static_beta[None, None, ...]
        else:
            beta = self.static_beta[None, None, ...]

        # width connection
        mix_h = alpha.transpose(-1, -2) @ h

        return mix_h, beta

    def depth_connection(self, mix_h, h_o, beta):
        print(mix_h[..., 1:, :].shape)
        h = 2 * F.sigmoid(torch.einsum("blh,bln->blnh", h_o, beta)) + self.sinkhorn(mix_h[..., 1:, :])

        return h

hc = HyperConnection(dim, expansion, 0, dynamic=False, device='cpu')
hc(H, nn.Identity())

torch.Size([2, 1, 4, 3])


tensor([[[[1.7186, 0.6586, 1.4199],
          [1.7186, 0.6586, 1.4199],
          [1.7186, 0.6586, 1.4199],
          [1.7186, 0.6586, 1.4199]]],


        [[[0.9169, 1.8587, 0.9043],
          [0.9169, 1.8587, 0.9043],
          [0.9169, 1.8587, 0.9043],
          [0.9169, 1.8587, 0.9043]]]], grad_fn=<AddBackward0>)