In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import FashionMNIST

import random
import copy
import tqdm
import numpy as np
import matplotlib.pyplot as plt

def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

In [2]:
class Args():
    device = 'cpu'
    # device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')

    input_size = 784
    output_size = 10
    hidden_dim = 1280
    num_layers = 32
    tokens_dim = 630
    channels_dim = 5120
    patch_size = 14

    batch_size = 128
    lr = 1e-2
    weight_decay = 1e-5
    max_epochs = 100

In [3]:
train_data = FashionMNIST(root='./data', train=True, download=False)
test_data = FashionMNIST(root='./data', train=False, download=False)

X_train = train_data.data.unsqueeze(1).float().flatten(start_dim=1)/255
X_test = test_data.data.unsqueeze(1).float().flatten(start_dim=1)/255

y_train = train_data.targets.long()
y_test = test_data.targets.long()

train_data = torch.utils.data.TensorDataset(X_train, y_train)
test_data = torch.utils.data.TensorDataset(X_test, y_test)

In [12]:
class MixerBlock(nn.Module):

    def __init__(self, input_size, hidden_dim, tokens_dim, channels_dim, args):
        super().__init__()
        self.args = args

        self.layer_norm1 = nn.LayerNorm([hidden_dim, input_size])
        self.linear1in = nn.Linear(hidden_dim, tokens_dim)
        self.linear1out = nn.Linear(tokens_dim, hidden_dim)

        self.layer_norm2 = nn.LayerNorm([hidden_dim, input_size])
        self.linear2in = nn.Linear(input_size, channels_dim)
        self.linear2out = nn.Linear(channels_dim, input_size)

        self.gelu = nn.GELU()

    def forward(self, x):
        y = self.layer_norm1(x)
        y = y.permute(0, 2, 1)
        y = self.linear1in(y)
        y = self.gelu(y)
        y = self.linear1out(y)
        y = y.permute(0, 2, 1)
        x = x + y
        
        y = self.layer_norm2(x)
        y = self.linear2in(y)
        y = self.gelu(y)
        y = self.linear2out(y)
        x = x + y

        return x

class MixerMLP(nn.Module):

    def __init__(self, input_size, output_size, patch_size, hidden_dim, num_layers, tokens_dim, channels_dim, args):
        super().__init__()
        self.args = args
    
        self.projector = nn.Conv2d(1, hidden_dim, kernel_size=patch_size, stride=patch_size)
        self.blocks = nn.ModuleList([MixerBlock(input_size//(patch_size*patch_size), hidden_dim, tokens_dim, channels_dim, args) for _ in range(num_layers)])
        self.layer_norm = nn.LayerNorm([hidden_dim, input_size//(patch_size*patch_size)])
        self.pool = nn.AvgPool1d(kernel_size=input_size//(patch_size*patch_size))
        self.head = nn.Linear(hidden_dim, output_size)

    def forward(self, x):
        x = self.projector(x)
        x = x.flatten(start_dim=2)
        for block in self.blocks:
            x = block(x)
        x = self.layer_norm(x)
        x = self.pool(x).squeeze(2)
        x = self.head(x)

        return x


In [13]:
set_seed(0)
args = Args()

block = MixerBlock(args.input_size, args.hidden_dim, args.tokens_dim, args.channels_dim, args)
model = MixerMLP(args.input_size, args.output_size, args.patch_size, args.hidden_dim, args.num_layers, args.tokens_dim, args.channels_dim, args)

x = torch.rand(7, 1, 28, 28)
y = model(x)

torch.Size([7, 1280, 4])
torch.Size([7, 1280])
