# Shifted Windows Transformer

* https://amaarora.github.io/posts/2022-07-04-swintransformerv1.html
* https://www.zhihu.com/question/521494294/answer/3178312617

In [1]:
from transformers import AutoImageProcessor, SwinForImageClassification
import torch
from torch import nn
from datasets import load_dataset

In [2]:
dataset = load_dataset("huggingface/cats-image", trust_remote_code=True)
image = dataset["test"]["image"][0]

image_processor = AutoImageProcessor.from_pretrained(
    "microsoft/swin-tiny-patch4-window7-224"
)
model = SwinForImageClassification.from_pretrained(
    "microsoft/swin-tiny-patch4-window7-224"
)

inputs = image_processor(image, return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

# model predicts one of the 1000 ImageNet classes
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])

Using the latest cached version of the module from /home/yangyansheng/.cache/huggingface/modules/datasets_modules/datasets/huggingface--cats-image/68fbc793fb10cd165e490867f5d61fa366086ea40c73e549a020103dcb4f597e (last modified on Mon Oct 14 11:03:54 2024) since it couldn't be found locally at huggingface/cats-image, or remotely on the Hugging Face Hub.


tabby, tabby cat


In [7]:
# get pair-wise relative position index for each token inside the window
win_w = 5
win_h = 5
coords = torch.stack(
    torch.meshgrid((torch.arange(win_h), torch.arange(win_w)), indexing="ij")
)  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww

In [8]:
# 代表了所有 patches 在 grid 中的垂直方向的坐标
coords[0]

tensor([[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3],
        [4, 4, 4, 4, 4]])

In [9]:
# 代表了所有 patches 在 grid 中的水平方向的坐标
coords[1]

tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4]])

In [12]:
relative_coords = (
    coords_flatten[:, :, None] - coords_flatten[:, None, :]
)  # 2, Wh*Ww, Wh*Ww

# relative_coords[0][i][j] 代表的是被 flatten 成一维后，第 i 个 Patch 和第 j 个 Patch 之间的 y 坐标的距离
# relative_coords[0][i][j] 代表的是被 flatten 成一维后，第 i 个 Patch 和第 j 个 Patch 之间的 x 坐标的距离
relative_coords[0][0]

tensor([ 0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3, -3,
        -3, -3, -4, -4, -4, -4, -4])

In [13]:
# relative_coords[i][j] 代表的第 i 个 patch 和第 j 个 patch 之间的 x 和 y 坐标的距离
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += win_h - 1  # shift to start from 0
relative_coords[:, :, 1] += win_w - 1
relative_coords[:, :, 0] *= (
    2 * win_w - 1
)  # 2 * win_w - 1 是 x 方向上两个 Patch之间距离的最大值

# dist = y * (2 * win_w - 1) + x
# relative_coords[i][j] 就代表第 i 个 patch 和第 j 个 path之间的相对距离
relative_coords = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww

In [21]:
relative_coords

tensor([[40, 39, 38, 37, 36, 31, 30, 29, 28, 27, 22, 21, 20, 19, 18, 13, 12, 11,
         10,  9,  4,  3,  2,  1,  0],
        [41, 40, 39, 38, 37, 32, 31, 30, 29, 28, 23, 22, 21, 20, 19, 14, 13, 12,
         11, 10,  5,  4,  3,  2,  1],
        [42, 41, 40, 39, 38, 33, 32, 31, 30, 29, 24, 23, 22, 21, 20, 15, 14, 13,
         12, 11,  6,  5,  4,  3,  2],
        [43, 42, 41, 40, 39, 34, 33, 32, 31, 30, 25, 24, 23, 22, 21, 16, 15, 14,
         13, 12,  7,  6,  5,  4,  3],
        [44, 43, 42, 41, 40, 35, 34, 33, 32, 31, 26, 25, 24, 23, 22, 17, 16, 15,
         14, 13,  8,  7,  6,  5,  4],
        [49, 48, 47, 46, 45, 40, 39, 38, 37, 36, 31, 30, 29, 28, 27, 22, 21, 20,
         19, 18, 13, 12, 11, 10,  9],
        [50, 49, 48, 47, 46, 41, 40, 39, 38, 37, 32, 31, 30, 29, 28, 23, 22, 21,
         20, 19, 14, 13, 12, 11, 10],
        [51, 50, 49, 48, 47, 42, 41, 40, 39, 38, 33, 32, 31, 30, 29, 24, 23, 22,
         21, 20, 15, 14, 13, 12, 11],
        [52, 51, 50, 49, 48, 43, 42, 41, 40, 39,

In [None]:
x = torch.arange(0, win_h * win_w)

real_coor = x[:, None] - x[None, :]
real_coor += win_h * win_w - 1
real_coor

tensor([[24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10,  9,  8,  7,
          6,  5,  4,  3,  2,  1,  0],
        [25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10,  9,  8,
          7,  6,  5,  4,  3,  2,  1],
        [26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10,  9,
          8,  7,  6,  5,  4,  3,  2],
        [27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10,
          9,  8,  7,  6,  5,  4,  3],
        [28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11,
         10,  9,  8,  7,  6,  5,  4],
        [29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12,
         11, 10,  9,  8,  7,  6,  5],
        [30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13,
         12, 11, 10,  9,  8,  7,  6],
        [31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14,
         13, 12, 11, 10,  9,  8,  7],
        [32, 31, 30, 29, 28, 27, 26, 25, 24, 23,

In [6]:
inputs.pixel_values.shape

torch.Size([1, 3, 224, 224])

In [10]:
patch_size = 4
window_size = 7
embed_dim = 96

## SwinPatchEmbedding

In [25]:
_, num_channels, height, weight = inputs.pixel_values.shape
patch_projection = nn.Conv2d(num_channels, embed_dim, patch_size, patch_size)

# [1, 3, 224, 224] -> [1, 96, 56, 56]
embeddings = patch_projection(inputs.pixel_values)
print(embeddings.shape)

torch.Size([1, 96, 56, 56])


In [26]:
_, _, height, width = embeddings.shape

In [27]:
embeddings = embeddings.flatten(2)
print(embeddings.shape)
embeddings = embeddings.transpose(1, 2)
print(embeddings.shape)

torch.Size([1, 96, 3136])
torch.Size([1, 3136, 96])


## Window SelfAttention

In [36]:
hidden_states = embeddings.view(-1, height, width, embed_dim)
hidden_states.shape

torch.Size([1, 56, 56, 96])

In [39]:
hidden_states_windows = hidden_states.view(
    -1, height // window_size, window_size, width // window_size, window_size, embed_dim
)
print(hidden_states_windows.shape)
hidden_states_windows = hidden_states_windows.permute(0, 1, 3, 2, 4, 5)
print(hidden_states_windows.shape)

# 将窗口个数的维度合并到 batchsize 中，将每个7x7 的窗口合成一个 序列长度维度
hidden_states_windows = hidden_states_windows.contiguous().view(
    -1, window_size * window_size, embed_dim
)
print(hidden_states_windows.shape)

torch.Size([1, 8, 7, 8, 7, 96])
torch.Size([1, 8, 8, 7, 7, 96])
torch.Size([64, 49, 96])


## Shift Windows SelfAttention

In [40]:
hidden_states = embeddings.view(-1, height, width, embed_dim)
hidden_states.shape

torch.Size([1, 56, 56, 96])

In [42]:
t = torch.arange(25).view((5, 5))
t

tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24]])

In [46]:
torch.roll(t, shifts=(2, 2), dims=(0, 1))

tensor([[18, 19, 15, 16, 17],
        [23, 24, 20, 21, 22],
        [ 3,  4,  0,  1,  2],
        [ 8,  9,  5,  6,  7],
        [13, 14, 10, 11, 12]])

## Patch Mergeings