In [44]:
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
from einops import rearrange
from functools import reduce
from operator import mul
def count_learnable_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [45]:
class PatchEmbedding(nn.Module):
    def __init__(self, patch_size, in_channels, embed_dim):
        super().__init__()
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim
        self.projection = nn.Linear(in_channels * patch_size**3, embed_dim, bias=False)

    def forward(self, x):
        batch_size, channels, x_dim, y_dim, z_dim = x.shape

        patches_x = x.unfold(2, self.patch_size, self.patch_size)
        patches_y = patches_x.unfold(3, self.patch_size, self.patch_size)
        patches_z = patches_y.unfold(4, self.patch_size, self.patch_size)

        patches = patches_z.permute(0, 2, 3, 4, 1, 5, 6, 7).contiguous()
        patches = patches.view(batch_size, -1, self.patch_size**3 * channels)

        embeddings = self.projection(patches)

        return embeddings


class LinearPatchEmbedding(nn.Module):
    def __init__(self, patch_size, in_channels, embed_dim): 

        super().__init__()
 
        self.in_channels = in_channels
        self.embed_dim = embed_dim

        axes_len = {f"p{i+1}": p for i, p in enumerate(patch_size)}
        self.embedder = nn.Sequential(
            Rearrange("b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)", **axes_len),
            nn.Linear(in_channels*reduce(mul, patch_size), embed_dim), 
        )
    def forward(self, x):
        return self.embedder(x)
    
class LinearPatchDeEmbedding(nn.Module):
    def __init__(self, patch_size, embed_dim, out_channels, img_size):

        super().__init__()

        self.embed_dim = embed_dim
        self.out_channels = out_channels

        axes_len = {f"p{i+1}": p for i, p in enumerate(patch_size)}
        h, w, d = [i//p for i, p in zip(img_size, patch_size)]

        self.unembedder = nn.Sequential(
            nn.Linear(embed_dim, out_channels*reduce(mul, patch_size)),
            Rearrange("b (h w d) (p1 p2 p3 c) -> b c (h p1) (w p2) (d p3)", **axes_len, h=h, w=w, d=d),
        )

    def forward(self, x):
        return self.unembedder(x)

In [46]:
example = torch.rand(size=(1, 4, 128, 128, 128))

In [47]:
embedder = PatchEmbedding(patch_size=8, in_channels=32, embed_dim=512)
mine = LinearPatchEmbedding(patch_size=(4, 4, 4), in_channels=4, embed_dim=256)
de = LinearPatchDeEmbedding(patch_size=(4, 4, 4), out_channels=4, embed_dim=256)

In [48]:
embedded = mine(example)
embedded.shape

torch.Size([1, 32768, 256])

In [49]:
out = de(embedded)
out.shape

EinopsError:  Error while processing rearrange-reduction pattern "b (h w d) (p1 p2 p3 c) -> b c (h p1) (w p2) (d p3)".
 Input tensor shape: torch.Size([1, 32768, 256]). Additional info: {'p1': 4, 'p2': 4, 'p3': 4, 'h': 0, 'w': 8192, 'd': 64}.
 Shape mismatch, 32768 != 0

In [32]:
32768//4

8192

In [34]:
8192//(128//4)

256

In [35]:
128//4

32

In [37]:
32*32*32

32768

In [52]:
import numpy as np
np.array((128, 128, 128)) //4

array([32, 32, 32])