In [1]:
import torch 
import torch.nn as nn

In [2]:
class Patcher(nn.Module):
    
    def __init__(self, num_patches, dim_in=3, channels=512, patch_size=16,):
        super().__init__()
       
        self.projection = nn.Conv2d(3, channels, kernel_size=patch_size, stride=(patch_size, patch_size))
        
    def forward(self, x):
        
        return self.projection(x).flatten(2).transpose(1, 2)

In [3]:
class MLP(nn.Module):
    
    def __init__(self, dim, hidden_dim, dropout = 0.1):
        super().__init__()
        
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        
        return self.mlp(x)

In [4]:
class PoolingLayer(nn.Module):
    
    def __init__(self, n_classes, channels):
                
        super().__init__()

        self.projection = nn.Linear(channels, n_classes)
        
    def forward(self, x):
        
        return self.projection(x.mean(dim=1))

In [5]:
class MixerLayer(nn.Module):

    def __init__(self, channels, num_patches, dropout=0.1):
        
        super().__init__()
        
        self.norm1 = nn.LayerNorm(channels)
        self.norm2 = nn.LayerNorm(channels)
        
        self.mlp1 = MLP(num_patches, num_patches, dropout)
        self.mlp2 = MLP(channels, channels, dropout)
        
    def forward(self, x):
        
        x = x + self.mlp1(self.norm1(x).transpose(2,1)).transpose(2,1)
        x = x + self.mlp2(self.norm2(x))
        
        return x
        
        

In [6]:
class Mixer(nn.Module):
    
    def __init__(self, n, channels, num_patches, n_classes, dropout=0.1):
        
        super().__init__()
        
        self.blocks = nn.Sequential(*[
            MixerLayer(channels, num_patches, dropout=0.1)
            for _ in range(n)])
        
        self.pooling = PoolingLayer(n_classes, channels)
        
        self.patcher = Patcher(num_patches)
        
    def forward(self, x):
        z = self.patcher(x)
        z = self.blocks(z)
        logist = self.pooling(z)
        
        return logist

In [7]:
b = 128
h = 224
w = 224
c_in = 3

x = torch.rand(b, c_in, h, w)

patch_size = 16
channels = 512
dim_in = 3
n_classes = 10
n = 6 
num_patches = int(h/patch_size) ** 2

In [8]:
patcher = Patcher(num_patches)
layer1 = MixerLayer(channels=channels, num_patches=num_patches)
layer2 = MixerLayer(channels=channels, num_patches=num_patches)
pooling = PoolingLayer(n_classes, channels)

mixer = Mixer(n, channels, num_patches, n_classes, dropout=0.1)

In [9]:
z = patcher(x)
z = layer1(z)
z = layer2(z)
logist  = pooling(z)
logist.shape

torch.Size([128, 10])

In [10]:
logist = mixer(x)
logist.shape

torch.Size([128, 10])