# CSWin Transformer 创新点与影响总结

## 主要创新点

  ![alt text](resources/cswin_arch.png "Title")


### 1. **十字形窗口自注意力机制（Cross-Shaped Window Attention）**
- 核心思想：通过水平和垂直条带的并行处理，构建交叉形状的注意力窗口，显著扩大每个token的注意力范围，同时降低计算复杂度。
- 优势：相比传统全局注意力机制，计算效率更高；相比局部注意力机制，能缓解感受野受限的问题，提升长距离依赖建模能力。

  ![alt text](resources/cswin_attention_window.png "Title")

### 2. **局部增强位置编码（Locally Enhanced Positional Encoding, LePE）**
- 设计目的：增强局部区域的位置信息建模，弥补交叉形窗口可能忽略的局部细节。
- 实现方式：在自注意力计算后引入局部位置信息，通过卷积操作提取局部特征并融合到输出中。


  ![alt text](resources/cswin_lepe.png "Title")

### 3. **多尺度条带宽度设计**
- 理论支持：通过数学分析条带宽度对建模能力的影响，提出不同网络层动态调整条带宽度的策略，进一步平衡计算成本与性能。

---

## 模型影响与意义

### 1. **计算效率与性能的平衡**
- CSWin Transformer 在视觉任务中实现了高效的全局-局部交互，在ImageNet分类、COCO检测等任务中达到SOTA性能，同时保持较低的计算开销。

### 2. **推动Transformer架构改进**
- 为后续研究提供新思路，例如在YOLOv8等检测框架中引入CSWin模块，提升主干网络的多尺度特征提取能力。

### 3. **开源与应用价值**
- 作为CVPR 2022收录论文，其代码和方法已被广泛复现，应用于人脸修复、医学图像分析等领域，验证了其通用性。

> **总结**：CSWin Transformer 通过创新性的交叉形窗口设计和位置编码策略，解决了视觉Transformer中全局注意力计算量大与局部注意力感受野受限的矛盾，为高效视觉模型设计提供了重要参考。

In [7]:
# 自动重新加载外部module，使得修改代码之后无需重新import
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

from hdd.device.utils import get_device

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# 设置训练数据的路径
DATA_ROOT = "~/workspace/hands-dirty-on-dl/dataset"
# 设置TensorBoard的路径
TENSORBOARD_ROOT = "~/workspace/hands-dirty-on-dl/dataset"
# 设置预训练模型参数路径
TORCH_HUB_PATH = "~/workspace/hands-dirty-on-dl/pretrained_models"
torch.hub.set_dir(TORCH_HUB_PATH)
# 挑选最合适的训练设备
DEVICE = get_device(["cuda", "cpu"])
max_epochs = 150
print("Use device: ", DEVICE)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Use device:  cuda


In [2]:
from hdd.data_util.auto_augmentation import CIFAR10Policy

# 训练超参数和数据增强来自 https://github.com/omihub777/ViT-CIFAR
CIFAR_10_MEAN = [0.4914, 0.4822, 0.4465]
CIFAR_10_STD = [0.2470, 0.2435, 0.2616]
BATCH_SIZE = 128

val_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_10_MEAN, CIFAR_10_STD),
    ]
)

val_dataloader = torch.utils.data.DataLoader(
    datasets.CIFAR10(
        root=DATA_ROOT, train=False, download=True, transform=val_transform
    ),
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
)

train_transform = transforms.Compose(
    [
        transforms.RandomCrop(size=32, padding=4),
        transforms.RandomHorizontalFlip(),
        CIFAR10Policy(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_10_MEAN, CIFAR_10_STD),
    ]
)

train_dataloader = torch.utils.data.DataLoader(
    datasets.CIFAR10(
        root=DATA_ROOT, train=True, download=True, transform=train_transform
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
from hdd.train.warmup_scheduler import GradualWarmupScheduler
from hdd.models.transformer.cswin_transformer import CSWinTransformer
from hdd.train.classification_utils import naive_train_classification_model
from hdd.models.nn_utils import count_trainable_parameter


net = CSWinTransformer(
    img_size=32,
    in_chans=3,
    num_classes=10,
    embed_dim=48,
    depth=[1, 2, 8, 1],
    split_size=[1, 2, 2, 2],
    num_heads=[2, 4, 8, 8],
    mlp_ratio=4.0,
    qkv_bias=True,
    qk_scale=None,
    drop_rate=0.0,
    attn_drop_rate=0.0,
    drop_path_rate=0.0,
    norm_layer=nn.LayerNorm,
    use_chk=False,
).to(DEVICE)
print(f"#Parameter: {count_trainable_parameter(net)}")
criteria = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.Adam(
    net.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=1e-5
)

base_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, max_epochs, eta_min=1e-5
)
scheduler = GradualWarmupScheduler(
    optimizer,
    multiplier=1.0,
    total_epoch=10,
    after_scheduler=base_scheduler,
)
random_setting = naive_train_classification_model(
    net,
    criteria,
    max_epochs,
    train_dataloader,
    val_dataloader,
    DEVICE,
    optimizer,
    scheduler,
    verbose=True,
)

#Parameter: 6488218
Epoch: 1/150 Train Loss: 2.3023 Accuracy: 0.1209 Time: 16.95441  | Val Loss: 2.2906 Accuracy: 0.1328
Epoch: 2/150 Train Loss: 2.0738 Accuracy: 0.2553 Time: 17.46283  | Val Loss: 1.7979 Accuracy: 0.3928
Epoch: 3/150 Train Loss: 1.8749 Accuracy: 0.3539 Time: 16.72593  | Val Loss: 1.6264 Accuracy: 0.4791
Epoch: 4/150 Train Loss: 1.7130 Accuracy: 0.4390 Time: 16.38400  | Val Loss: 1.4757 Accuracy: 0.5578
Epoch: 5/150 Train Loss: 1.5824 Accuracy: 0.5038 Time: 16.71232  | Val Loss: 1.3418 Accuracy: 0.6171
Epoch: 6/150 Train Loss: 1.5204 Accuracy: 0.5358 Time: 16.77067  | Val Loss: 1.3044 Accuracy: 0.6341
Epoch: 7/150 Train Loss: 1.4693 Accuracy: 0.5607 Time: 16.44648  | Val Loss: 1.2820 Accuracy: 0.6494
Epoch: 8/150 Train Loss: 1.4364 Accuracy: 0.5757 Time: 16.24806  | Val Loss: 1.2818 Accuracy: 0.6394
Epoch: 9/150 Train Loss: 1.4046 Accuracy: 0.5925 Time: 16.83452  | Val Loss: 1.2033 Accuracy: 0.6844
Epoch: 10/150 Train Loss: 1.3715 Accuracy: 0.6058 Time: 16.21724  | Val