# MLP-Mixer 创新点与争议总结

## 主要创新点

### 1. **完全摒弃卷积与注意力机制**
- 首次提出仅依赖MLP构建视觉模型的架构，无需传统CNN的卷积操作或Transformer的自注意力机制。
- 核心思想：通过两种MLP层实现**通道混合（Channel Mixing）**和**Token混合（Token Mixing）**，分别处理特征通道间的关系和空间位置间的交互。

### 2. **高效的混合操作设计**
- **Token Mixing MLP**：在图像块（patch）序列上进行跨空间位置的信息交互，模拟卷积或注意力的空间建模能力。可以认为一个Global Convolution.
- **Channel Mixing MLP**：对每个空间位置的特征通道进行非线性变换，实现类似通道注意力的效果。可以认为是1x1卷积.
- 结合残差连接、层归一化和Dropout，增强模型稳定性和表达能力。

### 3. **简化模型复杂度**
- 仅依赖矩阵乘法等基础操作，降低硬件适配难度，理论上更易部署到轻量化场景。
- 在ImageNet上达到SOTA性能，验证了纯MLP架构的可行性。

---

  ![alt text](resources/mlp_mixer.png "Title")

## 争议与局限性

### 1. **静态特征处理的局限性**
- MLP的固定权重模式导致其对动态特征建模能力较弱，相比注意力机制的动态权重分配，可能限制复杂场景的适应性。
- 缺乏对长距离依赖关系的灵活建模能力（如Transformer的全局注意力）。

### 2. **训练难度与资源消耗**
- 模型参数量和计算量较大，训练过程需要大量数据和算力支持，小规模实验容易失败。
- 对超参数敏感，调参难度较高，被质疑“非赌徒心态难以复现”。

### 3. **泛化能力存疑**
- 在ImageNet等大型数据集表现优异，但在小数据或复杂任务（如目标检测、分割）中性能下降明显，通用性受限。
- 部分观点认为其成功依赖于对训练数据分布的过拟合，而非真正突破网络结构瓶颈。

---

## 总结与意义
MLP-Mixer通过纯MLP架构挑战了CNN和Transformer的视觉模型范式，证明了基础操作（矩阵乘法）的组合潜力。尽管存在争议，但它为模型设计提供了新思路——例如后续研究结合MLP与注意力机制的混合架构。

In [1]:
# 自动重新加载外部module，使得修改代码之后无需重新import
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

from hdd.device.utils import get_device

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# 设置训练数据的路径
DATA_ROOT = "~/workspace/hands-dirty-on-dl/dataset"
# 设置TensorBoard的路径
TENSORBOARD_ROOT = "~/workspace/hands-dirty-on-dl/dataset"
# 设置预训练模型参数路径
TORCH_HUB_PATH = "~/workspace/hands-dirty-on-dl/pretrained_models"
torch.hub.set_dir(TORCH_HUB_PATH)
# 挑选最合适的训练设备
DEVICE = get_device(["cuda", "cpu"])
print("Use device: ", DEVICE)

Use device:  cuda


In [2]:
from hdd.data_util.auto_augmentation import CIFAR10Policy

# 训练超参数和数据增强来自 https://github.com/omihub777/ViT-CIFAR
CIFAR_10_MEAN = [0.4914, 0.4822, 0.4465]
CIFAR_10_STD = [0.2470, 0.2435, 0.2616]
BATCH_SIZE = 256

val_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_10_MEAN, CIFAR_10_STD),
    ]
)

val_dataloader = torch.utils.data.DataLoader(
    datasets.CIFAR10(
        root=DATA_ROOT, train=False, download=True, transform=val_transform
    ),
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
)

train_transform = transforms.Compose(
    [
        transforms.RandomCrop(size=32, padding=4),
        transforms.RandomHorizontalFlip(),
        CIFAR10Policy(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_10_MEAN, CIFAR_10_STD),
    ]
)

train_dataloader = torch.utils.data.DataLoader(
    datasets.CIFAR10(
        root=DATA_ROOT, train=True, download=True, transform=train_transform
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
from hdd.train.warmup_scheduler import GradualWarmupScheduler
from hdd.models.cnn.mlp_mixer import MLPMixer
from hdd.train.classification_utils import naive_train_classification_model
from hdd.models.nn_utils import count_trainable_parameter

max_epochs = 800
net = MLPMixer(
    image_size=32,
    channels=3,
    patch_size=4,
    dim=128,
    depth=8,
    num_classes=10,
).to(DEVICE)
print(f"#Parameter: {count_trainable_parameter(net)}")
criteria = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.Adam(
    net.parameters(),
    lr=1e-3,
    betas=(0.9, 0.99),
)

base_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, max_epochs, eta_min=1e-6
)
scheduler = GradualWarmupScheduler(
    optimizer,
    multiplier=1.0,
    total_epoch=5,
    after_scheduler=base_scheduler,
)
patch_4 = naive_train_classification_model(
    net,
    criteria,
    max_epochs,
    train_dataloader,
    val_dataloader,
    DEVICE,
    optimizer,
    scheduler,
    verbose=True,
)

#Parameter: 1099146
Epoch: 1/800 Train Loss: 2.3554 Accuracy: 0.1019 Time: 2.20599  | Val Loss: 2.3622 Accuracy: 0.1092
Epoch: 2/800 Train Loss: 2.1206 Accuracy: 0.2278 Time: 1.98097  | Val Loss: 1.8754 Accuracy: 0.3405
Epoch: 3/800 Train Loss: 1.9304 Accuracy: 0.3281 Time: 1.87739  | Val Loss: 1.6786 Accuracy: 0.4550
Epoch: 4/800 Train Loss: 1.7973 Accuracy: 0.3991 Time: 1.97301  | Val Loss: 1.5359 Accuracy: 0.5199
Epoch: 5/800 Train Loss: 1.7024 Accuracy: 0.4444 Time: 1.95160  | Val Loss: 1.4563 Accuracy: 0.5715
Epoch: 6/800 Train Loss: 1.6303 Accuracy: 0.4820 Time: 1.85779  | Val Loss: 1.4071 Accuracy: 0.5863
Epoch: 7/800 Train Loss: 1.5478 Accuracy: 0.5205 Time: 1.91336  | Val Loss: 1.2997 Accuracy: 0.6432
Epoch: 8/800 Train Loss: 1.4955 Accuracy: 0.5473 Time: 1.95487  | Val Loss: 1.2688 Accuracy: 0.6539
Epoch: 9/800 Train Loss: 1.4409 Accuracy: 0.5756 Time: 1.88831  | Val Loss: 1.2391 Accuracy: 0.6705
Epoch: 10/800 Train Loss: 1.4052 Accuracy: 0.5885 Time: 1.92552  | Val Loss: 1.1