In [1]:
import torch
import torch.nn as nn

In [18]:
class AddBroadcastPosEmbed(nn.Module):
    def __init__(self, shape, embd_dim, dim=-1):
        super().__init__()
        assert dim in [-1, 1] # only first or last dim supported
        self.shape = shape
        self.n_dim = n_dim = len(shape)
        self.embd_dim = embd_dim
        self.dim = dim

        assert embd_dim % n_dim == 0, f"{embd_dim} % {n_dim} != 0"
        self.emb = nn.ParameterDict({
             f'd_{i}': nn.Parameter(torch.randn(shape[i], embd_dim // n_dim) * 0.01
                                    if dim == -1 else
                                    torch.randn(embd_dim // n_dim, shape[i]) * 0.01)
             for i in range(n_dim)
        })

    def forward(self, x, decode_step=None, decode_idx=None):
        embs = []
        for i in range(self.n_dim):
            e = self.emb[f'd_{i}']
            print("embedding: ", i, e.shape)
            if self.dim == -1:
                # (1, 1, ..., 1, self.shape[i], 1, ..., -1)
                e = e.view(1, *((1,) * i), self.shape[i], *((1,) * (self.n_dim - i - 1)), -1)
                print("expand: ", e.shape)
                e = e.expand(1, *self.shape, -1)
                print("expand: ", e.shape)
            else:
                e = e.view(1, -1, *((1,) * i), self.shape[i], *((1,) * (self.n_dim - i - 1)))
                e = e.expand(1, -1, *self.shape)
            embs.append(e)

        embs = torch.cat(embs, dim=self.dim)
        print(embs.shape)
        if decode_step is not None:
            embs = tensor_slice(embs, [0, *decode_idx, 0],
                                [x.shape[0], *(1,) * self.n_dim, x.shape[-1]])

        return x + embs

In [19]:
b = 2
t = 8
h = w = 32
c = 64*3
embed_dim = c
x = torch.randn(b, t, h, w, c)
shape = (t, h, w)

In [20]:
m = AddBroadcastPosEmbed(shape, embed_dim)
m(x).shape

embedding:  0 torch.Size([8, 64])
expand:  torch.Size([1, 8, 1, 1, 64])
expand:  torch.Size([1, 8, 32, 32, 64])
embedding:  1 torch.Size([32, 64])
expand:  torch.Size([1, 1, 32, 1, 64])
expand:  torch.Size([1, 8, 32, 32, 64])
embedding:  2 torch.Size([32, 64])
expand:  torch.Size([1, 1, 1, 32, 64])
expand:  torch.Size([1, 8, 32, 32, 64])
torch.Size([1, 8, 32, 32, 192])


torch.Size([2, 8, 32, 32, 192])