## **Swin Transformer Pytorch**

- 깃허브 : [주소](https://github.com/berniwal/swin-transformer-pytorch)

In [2]:
# 기본 패키지 설치
# einops : tensor 재구성 유틸리티
# timm : vision model 모음 패키
!pip install einops timm

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->timm)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->timm)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->timm)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch->tim

In [3]:
# github 저장소 클론
!git clone https://github.com/berniwal/swin-transformer-pytorch.git
%cd swin-transformer-pytorch

Cloning into 'swin-transformer-pytorch'...
remote: Enumerating objects: 80, done.[K
remote: Counting objects: 100% (80/80), done.[K
remote: Compressing objects: 100% (55/55), done.[K
remote: Total 80 (delta 39), reused 62 (delta 21), pack-reused 0 (from 0)[K
Receiving objects: 100% (80/80), 194.66 KiB | 4.75 MiB/s, done.
Resolving deltas: 100% (39/39), done.
/content/swin-transformer-pytorch


In [None]:
# swin_transformer.py
import torch
from torch import nn, einsum
import numpy as np
from einops import rearrange, repeat


## Cyclic Shift : shifted 윈도우 구현을 위한 cyclic roll - 패치를 반칸씩 옮길 때 사용
class CyclicShift(nn.Module):
    def __init__(self, displacement):
        super().__init__()
        self.displacement = displacement # 이동할 픽셀 수 (절반 window_size)

    def forward(self, x):
        # (B, H, W, C) 형태의 텐서를 H/W 방향으로 cyclic shift
        return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))



# Residual / PreNorm / FeedForward: Transformer 기본 블록 구성 요소
## Residual - skip connection
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x  # 잔차 연결

## PreNorm - LayerNorm → 함수 순서로 stability 확보
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

## FeedForward - MLP (dim → hidden_dim → dim)
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            # 두 개의 선형 레이어와 GELU 활성화
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim)
        )

    def forward(self, x):
        return self.net(x)


# Mask 생성 및 상대 위치 인덱스 계산 유틸리티
## create_mask 함수 - shifted window 시 경계 간 attention 차단(윈도우 내에서만 계)
def create_mask(window_size, displacement, upper_lower, left_right):
    # (M^2, M^2) 형태의 mask 텐서. -inf 할당으로 attention 차단
    mask = torch.zeros(window_size ** 2, window_size ** 2)

    if upper_lower:
        # 상하 블록끼리 attention 못 하게 마스킹
        mask[-displacement * window_size:, :-displacement * window_size] = float('-inf')
        mask[:-displacement * window_size, -displacement * window_size:] = float('-inf')

    if left_right:
        # 좌우 블록끼리 attention 못 하게 마스킹
        mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
        mask[:, -displacement:, :, :-displacement] = float('-inf')
        mask[:, :-displacement, :, -displacement:] = float('-inf')
        mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')

    return mask

## get_relative_distances 함수 - 상대적 좌표 차 계산 (pos bias 용)
def get_relative_distances(window_size):
    # M^2 x 2 좌표 인덱스 생성 → (M^2, M^2, 2) shape의 거리 텐서 반환
    indices = torch.tensor(np.array([ [x, y] for x in range(window_size) for y in range(window_size)]))
    distances = indices[None, :, :] - indices[:, None, :]
    return distances




## WindowAttention - local window self-attention + shifted window 구현
class WindowAttention(nn.Module):
    def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):
        super().__init__()
        inner_dim = head_dim * heads

        self.heads = heads
        self.scale = head_dim ** -0.5
        self.window_size = window_size
        self.relative_pos_embedding = relative_pos_embedding
        self.shifted = shifted

        # shifted window의 경우 cyclic shift와 boundary mask 준비
        if self.shifted:
            displacement = window_size // 2
            self.cyclic_shift = CyclicShift(-displacement)
            self.cyclic_back_shift = CyclicShift(displacement)
            self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                             upper_lower=True, left_right=False), requires_grad=False)
            self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                            upper_lower=False, left_right=True), requires_grad=False)

            # Q, K, V projection
            self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

            # 위치 바이어스
            if self.relative_pos_embedding:
                # (2M-1, 2M-1) 크기의 learnable bias
                self.relative_indices = get_relative_distances(window_size) + window_size - 1
                self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))
            else:
                # (M^2, M^2) 크기 absolute pos embedding
                self.pos_embedding = nn.Parameter( torch.randn(window_size **2, window_size **2))

            self.to_out = nn.Linear(inner_dim, dim)

    def forward(self, x):
        # x: (B, H, W, C)
        if self.shifted:
            x = self.cyclic_shift(x)  # cyclic shift 적용

        b, n_h, n_w, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim=-1)  # Q, K, V 분할

        # 윈도우 단위로 reshape
        nw_h = n_h // self.window_size
        nw_w = n_w // self.window_size

        q, k, v = map(
            lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
                                h=h, w_h=self.window_size, w_w=self.window_size), qkv)

        # scaled dot-product attention
        dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale


        # 위치 바이어스 더하기
        if self.relative_pos_embedding:
            dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
        else:
            dots += self.pos_embedding

        # shifted window mask 적용 (윈도우 간 attention 차단)
        if self.shifted:
            dots[:, :, -nw_w:] += self.upper_lower_mask
            dots[:, :, nw_w - 1::nw_w] += self.left_right_mask

        attn = dots.softmax(dim=-1)

        # attention 결과와 V 결합
        out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
        out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
                        h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
        out = self.to_out(out)

        if self.shifted:
            out = self.cyclic_back_shift(out)  # 원위치 복구
        return out

In [None]:
## SwinBlock - 하나의 Transformer block
#    - PreNorm → WindowAttention (regular or shifted) → Residual
#    - PreNorm → FeedForward → Residual
class SwinBlock(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding):
        super().__init__()
        self.attention_block = Residual(PreNorm(dim, WindowAttention(dim=dim,
                                                                     heads=heads,
                                                                     head_dim=head_dim,
                                                                     shifted = shifted,
                                                                     window_size=window_size,
                                                                     relative_pos_embedding=relative_pos_embedding)))
        self.mlp_block = Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))

    def forward(self, x):
        x = self.attention_block(x)
        x = self.mlp_block(x)
        return x

## PatchMergeing - CNN의 stride=2와 유사한 다운샘플링
#    - 2×2 패치를 한 토큰으로 병합 → 채널 수 4배 → Linear projection
#    - 해상도는 1/2, 토큰 수는 1/4로 줄어듦.
class PatchMerging(nn.Module):
    def __init__(self, in_channels, out_channels, downscaling_factor):
        super().__init__()
        self.downscaling_factor = downscaling_factor
        self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0)
        self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)

    def forward(self, x):
        # x: (B, C, H, W)
        b, c, h, w = x.shape
        x = self.patch_merge(x)
        new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
        # Unfold → (B, C*fs^2, new_h*new_w) → reshape
        x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)
        x = self.linear(x)  # 채널 수 projection
        return x

## StageModule - Swin의 한 스테이지(계층)
#    - PatchMerging → (Regular SwinBlock → Shifted SwinBlock) × (layers/2) 반복
class StageModule(nn.Module):
    def __init__(self, in_channels, hidden_dimension, layers, downscaling_factor, num_heads, head_dim, window_size,
                 relative_pos_embedding):
        super().__init__()
        assert layers % 2 == 0, 'Stage layers need to be divisible by 2 for regular and shifted blocks!'

        # 다운샘플링
        self.patch_partition = PatchMerging(in_channels=in_channels, out_channels=hidden_dimension,
                                            downscaling_factor=downscaling_factor)

        # regular + shifted 블록 쌍을 반복
        self.layers = nn.ModuleList([])
        for _ in range(layers // 2):
            self.layers.append(nn.ModuleList([
                SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
                SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=True, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
            ]))

    def forward(self, x):
        x = self.patch_partition(x)
        for regular_block, shifted_block in self.layers:
            x = regular_block(x)
            x = shifted_block(x)
        # (B, H', W', C) → (B, C, H', W')
        return x.permute(0, 3, 1, 2)

## SwinTransformer - 전체 모델 정의
#    - 4개의 StageModule → global average pooling → MLP classification head
class SwinTransformer(nn.Module):
    def __init__(self, *, hidden_dim, layers, heads, channels, num_classes, head_dim, window_size,
                 downscaling_factors, relative_pos_embedding=True):
        super().__init__()

        # 각 Stage
        self.stage1 = StageModule(in_channels=channels, hidden_dimension=hidden_dim, layers=layers[0],
                                  downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage2 = StageModule(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, layers=layers[1],
                                  downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage3 = StageModule(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, layers=layers[2],
                                  downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage4 = StageModule(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, layers=layers[3],
                                  downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)

        # 분류 헤드: LayerNorm → Linear
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim * 8),  # 층 정규화
            nn.Linear(hidden_dim * 8, num_classes)
        )

    def forward(self, img):
        x = self.stage1(img)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = x.mean(dim=[2, 3])  # Global Average Pooling
        return self.mlp_head(x)


## swin_t, swin_s, swin_b, swin_l - Tiny/Small/Base/Large 모델 생성
def swin_t(hidden_dim=96, layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)


def swin_s(hidden_dim=96, layers=(2, 2, 18, 2), heads=(3, 6, 12, 24), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)


def swin_b(hidden_dim=128, layers=(2, 2, 18, 2), heads=(4, 8, 16, 32), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)


def swin_l(hidden_dim=192, layers=(2, 2, 18, 2), heads=(6, 12, 24, 48), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)

- **swin_transformer.py**

1. `CyclicShift`

    - `torch.roll`을 이용해 feature map을 cyclic하게 shift

    - Shifted Window Attention 구현 시 효율적인 윈도우 경계 처리용

2. `create_mask, get_relative_distances`
    - Shifted Window에서 윈도우 경계 간 정보 누수 방지용 마스킹 생성

    - 상대 위치 바이어스를 위한 거리 인덱스 생성

3. `WindowAttention`
    - 윈도우 단위로 self-attention 계산 (Shifted 여부 포함)

    - `relative_pos_embedding` 옵션으로 상대 위치 바이어스 적용 가능

    - 연산량: 선형에 가까운 수준으로 최적화됨

4. `SwinBlock`
    - Swin의 기본 블록:

    (1) Window Attention → residual → LayerNorm

    (2) FeedForward MLP → residual → LayerNorm

    - shifted 플래그로 Shifted Window 여부를 제어

5. `PatchMerging`
    - CNN의 stride=2와 유사하게 feature map 다운샘플링

    - 2×2 패치를 병합 → 채널 수 4배 → Linear로 원하는 차원으로 축소

    - 해상도 ↓, 채널 수 ↑

6. `StageModule`
    - 하나의 Swin Stage를 구성 (ex: Stage1~4)

    - 내부에 Patch Merging → N개의 SwinBlock 쌍 (regular, shifted)

    예: layers=6이면 → SwinBlock 6개 = 3쌍 (regular → shifted) 반복

7. `SwinTransformer`
    - 전체 모델 구조를 통합

    - 입력: 이미지

    - 구성: Stage1~4 + Global Average Pooling + Linear Head (Classification)

    - `mean(dim=[2, 3])`으로 GAP 수행

8. `swin_t, swin_s, swin_b, swin_l` 함수
    - Swin의 Tiny, Small, Base, Large 모델 생성기

    - 각각 hidden_dim, layers, heads 값만 다름
    (논문에 정의된 구조 그대로 반영)


In [4]:
! pip install swin-transformer-pytorch

Collecting swin-transformer-pytorch
  Downloading swin_transformer_pytorch-0.4.1-py3-none-any.whl.metadata (5.2 kB)
Downloading swin_transformer_pytorch-0.4.1-py3-none-any.whl (11 kB)
Installing collected packages: swin-transformer-pytorch
Successfully installed swin-transformer-pytorch-0.4.1


In [5]:
# 모델 불러오기(example.py)

import torch
from swin_transformer_pytorch import SwinTransformer

net = SwinTransformer(
    hidden_dim=96,
    layers=(2, 2, 6, 2),
    heads=(3, 6, 12, 24),
    channels=3,
    num_classes=10,         # CIFAR-10에 맞게 클래스 수 변경
    head_dim=32,
    window_size=7,
    downscaling_factors=(4, 2, 2, 2),
    relative_pos_embedding=True
)

dummy_x = torch.randn(1, 3, 224, 224)
logits = net(dummy_x)
print(net)
print(logits)

SwinTransformer(
  (stage1): StageModule(
    (patch_partition): PatchMerging(
      (patch_merge): Unfold(kernel_size=4, dilation=1, padding=0, stride=4)
      (linear): Linear(in_features=48, out_features=96, bias=True)
    )
    (layers): ModuleList(
      (0): ModuleList(
        (0): SwinBlock(
          (attention_block): Residual(
            (fn): PreNorm(
              (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
              (fn): WindowAttention(
                (to_qkv): Linear(in_features=96, out_features=288, bias=False)
                (to_out): Linear(in_features=96, out_features=96, bias=True)
              )
            )
          )
          (mlp_block): Residual(
            (fn): PreNorm(
              (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
              (fn): FeedForward(
                (net): Sequential(
                  (0): Linear(in_features=96, out_features=384, bias=True)
                  (1): GELU(approximate='none

#### **Parameters**
- `hidden_dim` : int -> 논문에서 C로 언급된 아키텍처에 사용하려는 hidden dimension

- `layers` : 4-tuple of ints(짝수) -> 각 단계에서 적용할 레이어 수. 항상 일반 swin block과 shifted swin block을 함께 적용하기 때문에 모든 int는 2로 나눌 수 있어야 함.

- `heads` : 4-tuple of ints -> 각 단계에서 적용할 헤드 수.

- `channels` : int -> 입력 채널 수.

- `num_classes` : int -> 출력에 포함되어야 할 Num 클래스.

- `head_dim` : int -> 각 헤드가 가져야 할 차원.

- `window_size` : int -> 어떤 윈도우 크기를 사용할 것인지. 각 다룬 스케일링 후에도 이미지 크기가 여전히 윈도우 크기로 나눌 수 있는지 확인.

- `downscaling_factors` : int -> 4-tuple of ints -> 각 단계에서 사용할 다운스케일링 요소. 이미지 크기가 다운스케일링 요소보다 충분히 큰지 확인.

- `relative_pos_embedding` : bool -> 학습 가능한 상대 위치 임베딩$(2M-1)$x$(2M-1)$ 또는 전체 위치 임베딩 $(M^2$x$ M^2)$ 사용할지 여부.


#### 모델 학습 및 테스트

In [6]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

# CIFAR10 데이터셋 불러오기
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Swin은 224x224 입력
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

100%|██████████| 170M/170M [00:02<00:00, 62.9MB/s]


In [8]:
# Swin Transformer 모델 생성 (클래스 수 10)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = net.to(device)

# 손실 함수 & 옵티마이저
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

# 학습 루프
epochs = 5
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in tqdm(trainloader, desc=f"Epoch {epoch+1}/{epochs}"):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"[Epoch {epoch+1}] Loss: {running_loss/len(trainloader):.4f}")

Epoch 1/5: 100%|██████████| 782/782 [10:32<00:00,  1.24it/s]


[Epoch 1] Loss: 1.7773


Epoch 2/5: 100%|██████████| 782/782 [10:46<00:00,  1.21it/s]


[Epoch 2] Loss: 1.3389


Epoch 3/5: 100%|██████████| 782/782 [10:40<00:00,  1.22it/s]


[Epoch 3] Loss: 1.1288


Epoch 4/5: 100%|██████████| 782/782 [10:44<00:00,  1.21it/s]


[Epoch 4] Loss: 0.9777


Epoch 5/5: 100%|██████████| 782/782 [10:39<00:00,  1.22it/s]

[Epoch 5] Loss: 0.8552





In [9]:
# 테스트 정확도 측정
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in testloader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")

Test Accuracy: 67.03%
