In [1]:
import torch
import torch.nn as nn
import torch.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import Dataset,DataLoader

import matplotlib.pyplot as plt

In [8]:
IN_CH = 1
IMG_SIZE = 28
NR_CLS = 10
LR = 4e-3
BATCH_SIZE = 128
EPOCHS = 25
TOKEN_DIM =64
CHANNEL_DIM = 128
DEPTH = 4

device = "cuda" if torch.cuda.is_available() else "cpu"

In [7]:
transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        [0.5 for _ in range(IN_CH)],
        [0.5 for _ in range(IN_CH)]      
    )
])

train_ds = datasets.MNIST(root='./data',train=True,download=True,transform=transforms)
test_ds = datasets.MNIST(root='./data',train=True,download=True,transform=transforms)

train_dl = DataLoader(train_ds,batch_size=BATCH_SIZE,shuffle=True)
test_dl = DataLoader(test_ds,batch_size=BATCH_SIZE,shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz



KeyboardInterrupt: 

In [9]:
class PathEmbedding(nn.Module):
    def __init__(self,in_ch,embedding_dim,path_size):
        super().__init__()
        
        self.patch = nn.Conv2d(in_ch,embedding_dim,kernel_size=path_size,stride=path_size)
        
    def forward(self,x):
        return self.patch(x)
    
class MLP(nn.Module):
    def __init__(self,dim,inter_dim,dropout=0.1):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim,inter_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(inter_dim,dim),
            nn.Dropout(dropout)
        )
    
    def forward(self,x):
        return self.mlp(x)
    

class T1(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self,x):
        return torch.permute(x,(0,2,1))
    
class T2(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self,x):
        return torch.permute(x,(0,2,3,1)).reshape(x.shape[0],-1,x.shape[1])
    

class MixerLayer(nn.Module):
    def __init__(self,embed_dim,num_patches,token_inter_dim,channel_inter_dim,dropout=0.1):
        super().__init__()
        
        self.token_mixer = nn.Sequential(
            nn.LayerNorm(embed_dim),
            T1(),
            MLP(num_patches,token_inter_dim,dropout),
            T1(),
        )
        self.channel_mixer = nn.Sequential(
            nn.LayerNorm(embed_dim),
            MLP(num_patches,channel_inter_dim,dropout),
        )
        
    def forward(self,x):
        x = x + self.token_mixer(x)
        x = x + self.channel_mixer(x)
        
        return x
    
class MLPMixer(nn.Module):
    def __init__(self,in_ch,embedding_dim,num_classes,patch_size,image_size,depth,token_interm_dim,channel_interm_dim,dropout=0.1):
        super().__init__()
        
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size'
        
        self.num_patches = (image_size//patch_size)**2
        self.pe = nn.Sequential(
            PathEmbedding(in_ch,embedding_dim,patch_size),
            T2()
        )
        
        self.mixers = nn.ModuleList([MixerLayer(embedding_dim,self.num_patches,token_interm_dim,channel_interm_dim) for _ in range(depth)])
        self.layer_norm = nn.LayerNorm(embedding_dim)
        self.classifier = nn.Sequential(
            nn.Linear(embedding_dim,num_classes),
            ### PLACE TO PUT MORE HERE
        )
        
    def forward(self,x):
        x = self.pe(x)
        for mixer in self.mixers:
            x = mixer(x)
        x = self.layer_norm(x)
        x = x.mean(dim=-1)
        
        return self.classifier(x)

In [11]:
model = MLPMixer(in_ch = IN_CH,
                 image_size=IMG_SIZE,
                 patch_size=2,
                 num_classes=NR_CLS,
                 embedding_dim=CHANNEL_DIM,
                 depth=DEPTH,
                 token_interm_dim=TOKEN_DIM,
                 channel_interm_dim=CHANNEL_DIM)
model

MLPMixer(
  (pe): Sequential(
    (0): PathEmbedding(
      (patch): Conv2d(1, 128, kernel_size=(2, 2), stride=(2, 2))
    )
    (1): T2()
  )
  (mixers): ModuleList(
    (0-3): 4 x MixerLayer(
      (token_mixer): Sequential(
        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (1): T1()
        (2): MLP(
          (mlp): Sequential(
            (0): Linear(in_features=196, out_features=64, bias=True)
            (1): GELU(approximate='none')
            (2): Dropout(p=0.1, inplace=False)
            (3): Linear(in_features=64, out_features=196, bias=True)
            (4): Dropout(p=0.1, inplace=False)
          )
        )
        (3): T1()
      )
      (channel_mixer): Sequential(
        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (1): MLP(
          (mlp): Sequential(
            (0): Linear(in_features=196, out_features=128, bias=True)
            (1): GELU(approximate='none')
            (2): Dropout(p=0.1, inplace=False)
           

In [13]:
total_params = sum([params.numel() for params in model.parameters()])
total_params

307626