# MLP vs SwiGLU

NOTE: I did not do any tuning or early stopping.

In [1]:
%%bash
pip install torchmetrics
pip install torchview



In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchmetrics import Accuracy
from tqdm.auto import tqdm, trange
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert image to tensor
])

# Define datasets
datasets_kwargs = {'download': True, 'transform': transform}
train_dataset = datasets.MNIST('data/', train=True, **datasets_kwargs)
test_dataset = datasets.MNIST('data/', train=False, **datasets_kwargs)

# Define dataloaders
dataloader_kwargs = {'batch_size': 256, 'shuffle': True}
train_loader = torch.utils.data.DataLoader(train_dataset, **dataloader_kwargs)
test_loader = torch.utils.data.DataLoader(test_dataset, **dataloader_kwargs)

In [3]:
# Specify models

### Convolutional module
class ConvModule(nn.Module):
    def __init__(self):
        super(ConvModule, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)

        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = self.pool(x)
        return x

### MLP module
class MLPModule(nn.Module):
    def __init__(self, hidden_size=256):
        super(MLPModule, self).__init__()
        self.fc1 = nn.LazyLinear(hidden_size)
        self.fc2 = nn.LazyLinear(hidden_size)
        self.output_layer = nn.LazyLinear(10)

    def forward(self, x):
        x = F.gelu(self.fc1(x))
        x = F.gelu(self.fc2(x))
        x = self.output_layer(x)
        return F.softmax(x, dim=1)

### SwiGLU module
class SwiGLULayer(nn.Module):
    def __init__(self, hidden_size, bias=True):
        super(SwiGLULayer, self).__init__()
        self.gater = nn.LazyLinear(hidden_size, bias=bias)
        self.reprojector = nn.LazyLinear(hidden_size, bias=bias)

    def forward(self, x):
        xw = self.gater(x)
        xv = self.reprojector(x)
        x_swish = xw * torch.sigmoid(xw)
        return x_swish * xv

class SwiGLUModule(nn.Module):
    def __init__(self, hidden_size=256, num_layers=2):
        super(SwiGLUModule, self).__init__()
        self.layers = nn.ModuleList([SwiGLULayer(hidden_size) for _ in range(num_layers)])
        self.output_layer = nn.LazyLinear(10)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = self.output_layer(x)
        return F.softmax(x, dim=1)


In [4]:
# Build models
class MLPBasedModel(nn.Module):
    def __init__(self):
        super(MLPBasedModel, self).__init__()
        self.conv = ConvModule()
        self.mlp = MLPModule()

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)
        x = self.mlp(x)
        return x


class SwiGLUBasedModel(nn.Module):
    def __init__(self):
        super(SwiGLUBasedModel, self).__init__()
        self.conv = ConvModule()
        self.swiglu = SwiGLUModule()

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)
        x = self.swiglu(x)
        return x

## Train Models

In [5]:
def evaluate_model(model, name):
    with torch.inference_mode():
        accuracy = Accuracy(task='multiclass', num_classes=10)
        for x_test, y_test in tqdm(test_loader, desc=f"Evaluating {name} model", total=len(test_loader)):
            # Move data to device
            x_test, y_test = x_test.to(device), y_test.to(device)

            # Forward pass
            y_hat = model.forward(x_test)
            accuracy.update(y_hat.cpu(), y_test.cpu())

        # Print accuracy
        computed_accuracy = accuracy.compute()
        print(f"Accuracy: {computed_accuracy:.2%}")
        accuracy.reset()
    return computed_accuracy

### MLP Based CNN

In [6]:
mlp_based_model = MLPBasedModel().to(device)
mlp_optimizer = torch.optim.Adam(mlp_based_model.parameters(), lr=1e-4)
mlp_criterion = nn.CrossEntropyLoss()

# Training loop
hist = []
n_epochs = 10
best_epoch = 0
best_loss = float('inf')
iterator = trange(n_epochs, desc="Training MLP-based model")
for epoch in iterator:
    loss_i = 0.0
    for i, (x_train, y_train) in tqdm(enumerate(train_loader, start=0), total=len(train_loader)):
        # Move data to device
        x_train, y_train = x_train.to(device), y_train.to(device)

        # Forward pass
        mlp_optimizer.zero_grad()
        y_hat = mlp_based_model.forward(x_train)
        loss = mlp_criterion(y_hat, y_train)

        # Backward pass
        loss.backward()
        mlp_optimizer.step()

        # Update loss
        loss_i += loss.item()
    iterator.set_postfix({'loss': loss_i, 'test_acc': evaluate_model(mlp_based_model, 'MLP-based (training)')})



Training MLP-based model:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating MLP-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 91.55%


  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating MLP-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 93.29%


  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating MLP-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 95.04%


  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating MLP-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 95.91%


  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating MLP-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 96.51%


  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating MLP-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 97.29%


  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating MLP-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 97.70%


  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating MLP-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 97.99%


  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating MLP-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 98.10%


  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating MLP-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 98.13%


In [7]:
mlp_accuracy = evaluate_model(mlp_based_model, "MLP-based")

Evaluating MLP-based model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 98.13%


### SwiGLU based CNN

In [8]:
swish_based_model = SwiGLUBasedModel().to(device)
swish_optimizer = torch.optim.Adam(swish_based_model.parameters(), lr=1e-4)
swish_criterion = nn.CrossEntropyLoss()

# Training loop
hist = []
n_epochs = 10
best_epoch = 0
best_loss = float('inf')
iterator = trange(n_epochs, desc="Training SwiGLU-based model")
for epoch in iterator:
    loss_i = 0.0
    for i, (x_train, y_train) in tqdm(enumerate(train_loader, start=0), total=len(train_loader)):
        # Move data to device
        x_train, y_train = x_train.to(device), y_train.to(device)

        # Forward pass
        swish_optimizer.zero_grad()
        y_hat = swish_based_model.forward(x_train)
        loss = swish_criterion(y_hat, y_train)

        # Backward pass
        loss.backward()
        swish_optimizer.step()

        # Update loss
        loss_i += loss.item()
    iterator.set_postfix({'loss': loss_i, 'test_acc': evaluate_model(swish_based_model, 'SwiGLU-based (training)')})


Training SwiGLU-based model:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating SwiGLU-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 83.34%


  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating SwiGLU-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 85.67%


  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating SwiGLU-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 86.09%


  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating SwiGLU-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 87.48%


  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating SwiGLU-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 95.39%


  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating SwiGLU-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 96.85%


  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating SwiGLU-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 97.20%


  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating SwiGLU-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 97.31%


  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating SwiGLU-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 97.62%


  0%|          | 0/235 [00:00<?, ?it/s]

Evaluating SwiGLU-based (training) model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 97.50%


In [9]:
swish_accuracy = evaluate_model(swish_based_model, "SwiGLU-based")

Evaluating SwiGLU-based model:   0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 97.50%


In [10]:
print(f'Complexity ratio = {count_parameters(swish_based_model) / count_parameters(mlp_based_model):.1f}:1')
print(f'    SwiGLU based model parameters = {count_parameters(swish_based_model):,}')
print(f'    MLP based model parameters    = {count_parameters(mlp_based_model):,}')
print()
print(f'SwiGLU based model Accuracy = {swish_accuracy:.2%}')
print(f'MLP based model Accuracy    = {mlp_accuracy:.2%}')

Complexity ratio = 2.0:1
    SwiGLU based model parameters = 6,576,010
    MLP based model parameters    = 3,298,698

SwiGLU based model Accuracy = 97.50%
MLP based model Accuracy    = 98.13%
