# Shifted Windows Transformer

https://amaarora.github.io/posts/2022-07-04-swintransformerv1.html

In [1]:
from transformers import AutoImageProcessor, SwinForImageClassification
import torch
from torch import nn
from datasets import load_dataset

In [3]:
dataset = load_dataset("huggingface/cats-image", trust_remote_code=True)
image = dataset["test"]["image"][0]

image_processor = AutoImageProcessor.from_pretrained(
    "microsoft/swin-tiny-patch4-window7-224"
)
model = SwinForImageClassification.from_pretrained(
    "microsoft/swin-tiny-patch4-window7-224"
)

inputs = image_processor(image, return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

# model predicts one of the 1000 ImageNet classes
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])

Using the latest cached version of the module from /home/yangyansheng/.cache/huggingface/modules/datasets_modules/datasets/huggingface--cats-image/68fbc793fb10cd165e490867f5d61fa366086ea40c73e549a020103dcb4f597e (last modified on Mon Oct 14 11:03:54 2024) since it couldn't be found locally at huggingface/cats-image, or remotely on the Hugging Face Hub.


tabby, tabby cat


In [7]:
inputs.pixel_values.shape

torch.Size([1, 3, 224, 224])

In [10]:
patch_size = 4
window_size = 7
embed_dim = 96

# SwinPatchEmbedding

In [25]:
_, num_channels, height, weight = inputs.pixel_values.shape
patch_projection = nn.Conv2d(num_channels, embed_dim, patch_size, patch_size)

# [1, 3, 224, 224] -> [1, 96, 56, 56]
embeddings = patch_projection(inputs.pixel_values)
print(embeddings.shape)

torch.Size([1, 96, 56, 56])


In [26]:
_, _, height, width = embeddings.shape

In [27]:
embeddings = embeddings.flatten(2)
print(embeddings.shape)
embeddings = embeddings.transpose(1, 2)
print(embeddings.shape)

torch.Size([1, 96, 3136])
torch.Size([1, 3136, 96])


# Window SelfAttention

In [36]:
hidden_states = embeddings.view(-1, height, width, embed_dim)
hidden_states.shape

torch.Size([1, 56, 56, 96])

In [39]:
hidden_states_windows = hidden_states.view(
    -1, height // window_size, window_size, width // window_size, window_size, embed_dim
)
print(hidden_states_windows.shape)
hidden_states_windows = hidden_states_windows.permute(0, 1, 3, 2, 4, 5)
print(hidden_states_windows.shape)

# 将窗口个数的维度合并到 batchsize 中，将每个7x7 的窗口合成一个 序列长度维度
hidden_states_windows = hidden_states_windows.contiguous().view(
    -1, window_size * window_size, embed_dim
)
print(hidden_states_windows.shape)

torch.Size([1, 8, 7, 8, 7, 96])
torch.Size([1, 8, 8, 7, 7, 96])
torch.Size([64, 49, 96])


# Shift Windows SelfAttention

In [40]:
hidden_states = embeddings.view(-1, height, width, embed_dim)
hidden_states.shape

torch.Size([1, 56, 56, 96])

In [42]:
t = torch.arange(25).view((5, 5))
t

tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24]])

In [46]:
torch.roll(t, shifts=(2, 2), dims=(0, 1))

tensor([[18, 19, 15, 16, 17],
        [23, 24, 20, 21, 22],
        [ 3,  4,  0,  1,  2],
        [ 8,  9,  5,  6,  7],
        [13, 14, 10, 11, 12]])

# Patch Mergeings