In [5]:
from einops import rearrange, reduce, repeat
import numpy as np

In [7]:
x = np.random.rand(8, 10, 16)

In [11]:
y = rearrange(x, "b (p t) d -> b p t d", p=2)

In [12]:
y.shape

(8, 2, 5, 16)

In [14]:
z = np.random.rand(2, 3)

In [22]:
w = repeat(z, "... -> x ...", x = 3)

In [23]:
w.shape

(3, 2, 3)

In [24]:
w

array([[[0.49657109, 0.52323389, 0.47404402],
        [0.01994771, 0.56012146, 0.3251534 ]],

       [[0.49657109, 0.52323389, 0.47404402],
        [0.01994771, 0.56012146, 0.3251534 ]],

       [[0.49657109, 0.52323389, 0.47404402],
        [0.01994771, 0.56012146, 0.3251534 ]]])

In [66]:
q = np.random.rand(2, 4, 5, 1, 3)
k = np.random.rand(2, 4, 6, 1, 3)
q_ = rearrange(q, "b p t h d -> (b p) t h d")
k_ = rearrange(k, "b p T h d -> (b p) T h d")


In [47]:
q_.shape

(8, 5, 1, 3)

In [42]:
(q[0, 0].squeeze() @ k[0, 0].squeeze().T)

array([[0.75008027, 1.09026072, 0.41217021, 0.25257082, 1.01247137,
        0.86615209],
       [0.52323619, 0.97465198, 0.41852974, 0.43844271, 0.81660024,
        0.61017446],
       [0.82302213, 1.21990252, 0.41481094, 0.23510208, 1.15738577,
        1.05603723],
       [0.58979729, 1.10420923, 0.35060775, 0.33069759, 1.00565737,
        0.94000435],
       [0.71960346, 1.31198019, 0.69463044, 0.75470739, 1.01825805,
        0.56231248]])

In [75]:
%timeit a1 = np.einsum("bpthd, bpThd -> bphtT", q, k)

9.46 µs ± 92.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [60]:
a1.shape

(2, 4, 1, 5, 6)

In [76]:
%timeit a2 = np.einsum("bthd, bThd -> bhtT", q_, k_)

9.61 µs ± 216 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [69]:
a3 = rearrange(a2, "(b p) h t T-> b p h t T", p=4)

In [62]:
a3.shape

(2, 4, 1, 5, 6)

In [70]:
a1 == a3

array([[[[[ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True]]],


        [[[ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True]]],


        [[[ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True]]],


        [[[ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True],
          [ True

In [8]:
import haiku as hk
import numpy as np
import jax.numpy as jnp
from typing import Optional
from haiku._src.typing import Initializer

In [12]:
class MultiChannelLinear(hk.Module):
    """
    N parallel affine transformations.
    """

    def __init__(
        self,
        num_channels: int,
        output_features: int,
        init_scale: float = 1.0,
        with_bias: bool = True,
        b_init: Optional[Initializer] = None,
        name: Optional[str] = None,
    ):
        super().__init__(name=name)
        self._num_channels = num_channels
        self._output_features = output_features
        self._init_scale = init_scale
        self._with_bias = with_bias
        self._b_init = b_init or jnp.zeros
        self._name = name

    def __call__(self, inputs):
        """
        Args:
            inputs: [batch, num_pathways, index_dim, num_channels]
        """
        num_channels = self._num_channels
        input_features = self._input_features = inputs.shape[-1]
        output_features = self._output_features
        dtype = inputs.dtype

        init_scale = self._init_scale
        if init_scale is None:
            stddev = 1.0 / np.sqrt(self._input_features)
            w_init = hk.initializers.TruncatedNormal(stddev=stddev)
        else:
            w_init = hk.initializers.VarianceScaling(init_scale)
        w = hk.get_parameter(
            "w", [num_channels, input_features, output_features], dtype, init=w_init
        )

        out = jnp.matmul(inputs, w)

        if self._with_bias:
            b = hk.get_parameter(
                "b", [num_channels, 1, output_features], dtype, init=self._b_init
            )
            b = jnp.broadcast_to(b, out.shape)
        out = out + b

        return out

In [13]:
def multi_channel_linear(x):
    module = MultiChannelLinear(num_channels=4, output_features=16)
    return module(x)

In [14]:
forward = hk.transform(multi_channel_linear)

In [11]:
import jax
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (2, 4, 5, 8))



In [15]:
params = forward.init(rng=key, x=dummy_x)

  lax._check_user_dtype_supported(dtype, "zeros")


In [44]:
x = jnp.ones((2, 4, 5, 8))
y = forward.apply(params=params, x=x, rng=key)

In [45]:
y.shape

(2, 4, 5, 16)

In [47]:
y[0, 1]

DeviceArray([[-0.18306366, -1.1214032 , -0.20981392,  0.02430223,
              -0.758953  ,  0.19539857,  0.14526668, -0.4697502 ,
              -0.37162304, -0.420932  ,  0.14927003,  0.21616359,
               0.12027161,  0.455296  , -0.6907001 , -0.53500366],
             [-0.18306366, -1.1214032 , -0.20981392,  0.02430223,
              -0.758953  ,  0.19539857,  0.14526668, -0.4697502 ,
              -0.37162304, -0.420932  ,  0.14927003,  0.21616359,
               0.12027161,  0.455296  , -0.6907001 , -0.53500366],
             [-0.18306366, -1.1214032 , -0.20981392,  0.02430223,
              -0.758953  ,  0.19539857,  0.14526668, -0.4697502 ,
              -0.37162304, -0.420932  ,  0.14927003,  0.21616359,
               0.12027161,  0.455296  , -0.6907001 , -0.53500366],
             [-0.18306366, -1.1214032 , -0.20981392,  0.02430223,
              -0.758953  ,  0.19539857,  0.14526668, -0.4697502 ,
              -0.37162304, -0.420932  ,  0.14927003,  0.21616359,
       

In [48]:
def regular_linear(x):
    module = hk.Linear(16)
    return module(x)

In [49]:
forward2 = hk.transform(regular_linear)

In [50]:
params2 = forward2.init(rng=key, x=dummy_x)

  lax._check_user_dtype_supported(dtype, "zeros")


In [51]:
y2 = forward2.apply(params=params2, x=x, rng=key)

In [52]:
y2.shape

(2, 4, 5, 16)

In [56]:
y2[0, 3]

DeviceArray([[ 0.23766805,  0.09935351,  0.6060016 ,  1.5810465 ,
              -1.6784463 ,  2.7004874 ,  0.6521029 , -1.4530311 ,
              -0.5496504 , -0.12809707, -0.01297408, -1.8025329 ,
              -0.6704819 , -0.3872682 , -0.90747887, -0.15702218],
             [ 0.23766805,  0.09935351,  0.6060016 ,  1.5810465 ,
              -1.6784463 ,  2.7004874 ,  0.6521029 , -1.4530311 ,
              -0.5496504 , -0.12809707, -0.01297408, -1.8025329 ,
              -0.6704819 , -0.3872682 , -0.90747887, -0.15702218],
             [ 0.23766805,  0.09935351,  0.6060016 ,  1.5810465 ,
              -1.6784463 ,  2.7004874 ,  0.6521029 , -1.4530311 ,
              -0.5496504 , -0.12809707, -0.01297408, -1.8025329 ,
              -0.6704819 , -0.3872682 , -0.90747887, -0.15702218],
             [ 0.23766805,  0.09935351,  0.6060016 ,  1.5810465 ,
              -1.6784463 ,  2.7004874 ,  0.6521029 , -1.4530311 ,
              -0.5496504 , -0.12809707, -0.01297408, -1.8025329 ,
       

array([[[[0.8303461 , 0.72315708, 0.82299629, ..., 0.56444649,
          0.96323568, 0.73136931],
         [0.03498549, 0.06814834, 0.56369268, ..., 0.96849281,
          0.12923999, 0.58238242],
         [0.81883653, 0.02746022, 0.17267796, ..., 0.37578728,
          0.73969488, 0.66038231],
         [0.60496612, 0.77983048, 0.26959339, ..., 0.76893138,
          0.75972568, 0.65070706],
         [0.72225114, 0.65261821, 0.82927554, ..., 0.62158078,
          0.12911551, 0.6679843 ]],

        [[0.96223099, 0.73722385, 0.41100629, ..., 0.21013132,
          0.54172832, 0.44045069],
         [0.19218966, 0.34728459, 0.50591513, ..., 0.69986711,
          0.02934971, 0.912753  ],
         [0.91654037, 0.82316081, 0.28468506, ..., 0.92637637,
          0.36311632, 0.46934494],
         [0.09603709, 0.67139424, 0.58297449, ..., 0.79178829,
          0.07896961, 0.74882819],
         [0.01713771, 0.04372111, 0.57681133, ..., 0.94427153,
          0.56619509, 0.14182395]],

        [[0.8954