In [2]:
import torch 
import torch.nn as nn

In [162]:
class Patcher(nn.Module):
    
    def __init__(self, num_patches, dim_in=3, channels=512, patch_size=16,):
        super().__init__()
       
        self.projection = nn.Conv2d(3, channels, kernel_size=patch_size, stride=(patch_size, patch_size))
        
    def forward(self, x):
        
        return self.projection(x).flatten(2).transpose(1, 2)

In [63]:
class MLP(nn.Module):
    
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        
        return self.mlp(x)

In [144]:
class PoolingLayer(nn.Module):
    
    def __init__(self, n_classes, channels):
                
        super().__init__()

        self.projection = nn.Linear(channels, n_classes)
        
    def forward(self, x):
        
        return self.projection(x.mean(dim=1))

In [153]:
class MixerLayer(nn.Module):

    def __init__(self, channels, num_patches):
        
        super().__init__()
        
        self.norm1 = nn.LayerNorm(channels)
        self.norm2 = nn.LayerNorm(channels)
        
        self.mlp1 = MLP(num_patches, num_patches)
        self.mlp2 = MLP(channels, channels)
        
    def forward(self, x):
        
        x = x + mlp1(norm1(x).transpose(2,1)).transpose(2,1)
        x = x + self.mlp2(self.norm2(x))
        
        return x
        
        

In [154]:
class Mixer(nn.Module):
    
    def __init__(self, n, channels, num_patches, n_classes):
        
        super().__init__()
        
        self.blocks = nn.Sequential(*[
            MixerLayer(channels, num_patches)
            for _ in range(n)])
        
        self.pooling = PoolingLayer(n_classes, channels)
        
        self.patcher = Patcher(num_patches)
        
    def forward(self, x):
        x = self.patcher(x)
        x = self.blocks(x)
        out = self.pooling(x)
        
        return x

In [155]:
b = 128
h = 224
w = 224
c_in = 3

x = torch.rand(b, c_in, h, w)

patch_size = 16
channels = 512
dim_in = 3
n_classes = 10
n = 6 
num_patches = int(h/patch_size) ** 2

In [156]:
patcher = Patcher(num_patches)
layer = MixerLayer(channels=channels, num_patches=num_patches)
pooling = PoolingLayer(n_classes, channels)

mixer = Mixer(n, channels, num_patches, n_classes)

In [158]:
z = patcher(x)
z = layer(z)
out  = pooling(z)

In [159]:
out1 = mixer(x)