# Swin Transformer 总结

## 1. 历史背景和解决的问题
Swin Transformer 是在视觉 Transformer（ViT）的基础上改进而来的。早期 ViT 直接将图像划分为固定大小的块并全局应用自注意力机制，虽然在分类任务中表现优异，但存在两个核心问题：
- **计算复杂度高**：全局自注意力的计算量与图像尺寸成平方关系，难以处理高分辨率图像。
- **缺乏层次结构**：ViT 的单一特征图无法像 CNN 一样通过多尺度卷积提取层级化特征，限制了其对多尺度目标的感知能力。

Swin Transformer 通过引入 **滑动窗口机制** 和 **层次化特征图设计**，解决了上述问题，同时保留了 Transformer 的全局建模优势。

---

## 2. 模型的创新性和影响

### 创新点
- **滑动窗口注意力（Shifted Window Attention）**  
  将图像划分为不重叠的局部窗口（如 7×7），仅在窗口内计算自注意力，大幅降低计算量（复杂度从 $ O(n^2) $ 降为 $ O(n) $）。通过 **交替移位窗口**（Shifting Operation）实现跨窗口交互，既保持局部性又增强全局感知。

  ![alt text](resources/swin_transformer_sliding_window.png "Title")

- **层次化特征图（Hierarchical Feature Maps）**  
  通过逐步合并图像块（如 2×2 到 1 个块）实现多尺度特征提取，输出不同层级的特征图，支持目标检测、分割等需要多尺度的任务。

- **相对位置编码（Relative Position Bias）**  
  在窗口内引入可学习的相对位置偏置，弥补局部窗口丢失的绝对位置信息，提升模型对空间关系的建模能力。

### 影响与意义
- **通用视觉骨干网络**  
  Swin Transformer 成为首个在分类、检测、分割等视觉任务中全面超越 CNN 的通用 backbone，推动了 Transformer 在 CV 领域的普及。
  
- **性能与效率平衡**  
  在 ImageNet 等数据集上取得 SOTA 同时，其线性复杂度设计使模型能高效处理高分辨率图像，被广泛应用于超分辨率、视频分析等领域。

- **启发后续研究**  
  其“局部-全局”混合建模思想影响了后续模型（如 ConvNeXt、Focal Transformer），并促进了 CNN 与 Transformer 的融合研究。

  ![alt text](resources/swin_transformer.png "Title")


In [1]:
# 自动重新加载外部module，使得修改代码之后无需重新import
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

from hdd.device.utils import get_device

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# 设置训练数据的路径
DATA_ROOT = "~/workspace/hands-dirty-on-dl/dataset"
# 设置TensorBoard的路径
TENSORBOARD_ROOT = "~/workspace/hands-dirty-on-dl/dataset"
# 设置预训练模型参数路径
TORCH_HUB_PATH = "~/workspace/hands-dirty-on-dl/pretrained_models"
torch.hub.set_dir(TORCH_HUB_PATH)
# 挑选最合适的训练设备
DEVICE = get_device(["cuda", "cpu"])
max_epochs = 300
print("Use device: ", DEVICE)

Use device:  cuda


In [2]:
from hdd.data_util.auto_augmentation import CIFAR10Policy

# 训练超参数和数据增强来自 https://github.com/omihub777/ViT-CIFAR
CIFAR_10_MEAN = [0.4914, 0.4822, 0.4465]
CIFAR_10_STD = [0.2470, 0.2435, 0.2616]
BATCH_SIZE = 128

val_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_10_MEAN, CIFAR_10_STD),
    ]
)

val_dataloader = torch.utils.data.DataLoader(
    datasets.CIFAR10(
        root=DATA_ROOT, train=False, download=True, transform=val_transform
    ),
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
)

train_transform = transforms.Compose(
    [
        transforms.RandomCrop(size=32, padding=4),
        transforms.RandomHorizontalFlip(),
        CIFAR10Policy(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_10_MEAN, CIFAR_10_STD),
    ]
)

train_dataloader = torch.utils.data.DataLoader(
    datasets.CIFAR10(
        root=DATA_ROOT, train=True, download=True, transform=train_transform
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
from hdd.train.warmup_scheduler import GradualWarmupScheduler
from hdd.models.transformer.swin_transformer import SwinTransformer
from hdd.train.classification_utils import naive_train_classification_model
from hdd.models.nn_utils import count_trainable_parameter


net = SwinTransformer(
    img_size=32,
    patch_size=2,
    in_chans=3,
    num_classes=10,
    embed_dim=54,
    depths=[6, 6, 6],
    num_heads=[6, 6, 6],
    dropout=0.0,
    window_size=2,
    mlp_ratio=4,
).to(DEVICE)
print(f"#Parameter: {count_trainable_parameter(net)}")
criteria = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.Adam(
    net.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=1e-5
)

base_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, max_epochs, eta_min=1e-5
)
scheduler = GradualWarmupScheduler(
    optimizer,
    multiplier=1.0,
    total_epoch=10,
    after_scheduler=base_scheduler,
)
random_setting = naive_train_classification_model(
    net,
    criteria,
    max_epochs,
    train_dataloader,
    val_dataloader,
    DEVICE,
    optimizer,
    scheduler,
    verbose=True,
)

#Parameter: 4560796
Epoch: 1/300 Train Loss: 2.3166 Accuracy: 0.1005 Time: 14.07742  | Val Loss: 2.3194 Accuracy: 0.0991
Epoch: 2/300 Train Loss: 2.1354 Accuracy: 0.2207 Time: 14.26677  | Val Loss: 1.8890 Accuracy: 0.3450
Epoch: 3/300 Train Loss: 1.9731 Accuracy: 0.3015 Time: 14.30375  | Val Loss: 1.7422 Accuracy: 0.4166
Epoch: 4/300 Train Loss: 1.8482 Accuracy: 0.3719 Time: 14.00262  | Val Loss: 1.6566 Accuracy: 0.4588
Epoch: 5/300 Train Loss: 1.7520 Accuracy: 0.4175 Time: 14.29534  | Val Loss: 1.5586 Accuracy: 0.5030
Epoch: 6/300 Train Loss: 1.6723 Accuracy: 0.4590 Time: 13.70866  | Val Loss: 1.4725 Accuracy: 0.5522
Epoch: 7/300 Train Loss: 1.6178 Accuracy: 0.4876 Time: 14.03620  | Val Loss: 1.4252 Accuracy: 0.5828
Epoch: 8/300 Train Loss: 1.5684 Accuracy: 0.5099 Time: 13.70768  | Val Loss: 1.3886 Accuracy: 0.5904
Epoch: 9/300 Train Loss: 1.5381 Accuracy: 0.5255 Time: 14.26017  | Val Loss: 1.3929 Accuracy: 0.6039
Epoch: 10/300 Train Loss: 1.5166 Accuracy: 0.5395 Time: 14.16402  | Val