In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from typing import Optional

### Drop path (Stochastic Depth)

In [2]:
def drop_path_f(x, drop_prob: float = 0., training: bool = False):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output

In [9]:
x = torch.rand((5,2,2))
drop_path_f(x, drop_prob=0.3, training=True)

tensor([[[0.4225, 0.6266],
         [0.0390, 0.0111]],

        [[0.0000, 0.0000],
         [0.0000, 0.0000]],

        [[0.5573, 1.3916],
         [0.5250, 0.1318]],

        [[1.1811, 1.2200],
         [0.6533, 1.2629]],

        [[0.0000, 0.0000],
         [0.0000, 0.0000]]])

### Window Partition

In [13]:
def window_partition(x, window_size: int):
    """
    将feature map按照window_size划分成一个个没有重叠的window,用于之后的W-MSA,SW-MSA
    Args:
        x: (B, H, W, C)
        window_size (int): window size(M)
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    # permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
    # view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size: int, H: int, W: int):
    """
    将一个个window还原成一个feature map
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size(M)
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    # view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    # permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
    # view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

In [19]:
x = torch.rand(1, 56, 56, 48)
# batchsize * (56/7) * (56/7)
y = window_partition(x, window_size=7)
y.shape

torch.Size([64, 7, 7, 48])

In [20]:
z = window_reverse(y, 7, 56, 56)
z.shape

torch.Size([1, 56, 56, 48])

In [24]:
# x and z are the same
torch.sum(z != x)

tensor(0)

### Patch Embedding + Linear Projection

In [63]:
class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = (patch_size, patch_size)
        self.patch_size = patch_size
        self.in_chans = in_c
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        _, _, H, W = x.shape
        # padding
        # 如果输入图片的H，W不是patch_size的整数倍，需要进行padding
        pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
        if pad_input:
            # to pad the last 3 dimensions,
            # (W_left, W_right, H_top,H_bottom, C_front, C_back)
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
                          0, self.patch_size[0] - H % self.patch_size[0],
                          0, 0))
        # 下采样patch_size倍
        x = self.proj(x)
        _, _, H, W = x.shape
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W

In [64]:
img = torch.rand(2, 3, 224, 224)
patch_emb = PatchEmbed()
x1 = patch_emb(img)[0]
# 3136 = (224/4) * (224/4)
x1.shape

torch.Size([2, 3136, 96])

### Patch Merging

In [65]:
class PatchMerging(nn.Module):
    r""" Patch Merging Layer.
    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        x = x.view(B, H, W, C)
        # padding
        # 如果输入feature map的H，W不是2的整数倍，需要进行padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            # to pad the last 3 dimensions, starting from the last dimension and moving forward.
            # (C_front, C_back, W_left, W_right, H_top, H_bottom)
            # 注意这里的Tensor通道是[B, H, W, C]，所以会和官方文档有些不同
              x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
        x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
        x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
        x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]
        x = self.norm(x)
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]
        return x

In [84]:
patch_mer = PatchMerging(96)
# 784 = 28 * 28
patch_mer(x1, 56, 56).shape

torch.Size([2, 784, 192])

### Relative Position Bias

In [99]:
# Assume 2 * 2 window size
window_size = (2, 2)
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])

In [100]:
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))
coords

tensor([[[0, 0],
         [1, 1]],

        [[0, 1],
         [0, 1]]])

In [101]:
coords_flatten = torch.flatten(coords, 1)
coords_flatten  # [2, Mh*Mw]  绝对位置索引

tensor([[0, 0, 1, 1],
        [0, 1, 0, 1]])

In [102]:
coords_flatten[0]  # 第i行

tensor([0, 0, 1, 1])

In [103]:
coords_flatten[1]  #第j列

tensor([0, 1, 0, 1])

In [105]:
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords

tensor([[[ 0,  0, -1, -1],
         [ 0,  0, -1, -1],
         [ 1,  1,  0,  0],
         [ 1,  1,  0,  0]],

        [[ 0, -1,  0, -1],
         [ 1,  0,  1,  0],
         [ 0, -1,  0, -1],
         [ 1,  0,  1,  0]]])

In [106]:
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords

tensor([[[ 0,  0],
         [ 0, -1],
         [-1,  0],
         [-1, -1]],

        [[ 0,  1],
         [ 0,  0],
         [-1,  1],
         [-1,  0]],

        [[ 1,  0],
         [ 1, -1],
         [ 0,  0],
         [ 0, -1]],

        [[ 1,  1],
         [ 1,  0],
         [ 0,  1],
         [ 0,  0]]])

In [108]:
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
relative_position_index

tensor([[4, 3, 1, 0],
        [5, 4, 2, 1],
        [7, 6, 4, 3],
        [8, 7, 5, 4]])

In [110]:
num_heads = 3
relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
relative_position_bias_table.shape

torch.Size([9, 3])

In [116]:
relative_position_bias_table[relative_position_index.view(-1)].shape

torch.Size([16, 3])

In [117]:
relative_position_bias_table[relative_position_index.view(-1)].view(
            window_size[0] * window_size[1], window_size[0] * window_size[1], -1).shape

torch.Size([4, 4, 3])

In [None]:
def create_mask_simplified(self, H, W):
    # H, W是window_size的整数倍
    window_size = 3
    shift_size = 1
    img_mask = torch.zeros((1, H, W, 1)) 
    h_slices = (slice(0, -window_size),
                slice(-window_size, -shift_size),
                slice(-shift_size, None))
    w_slices = (slice(0, -window_size),
                slice(-window_size, -shift_size),
                slice(-shift_size, None))
    cnt = 0
    for h in h_slices:
        for w in w_slices:
            img_mask[:, h, w, :] = cnt
            cnt += 1
    print(img_mask)

    mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
    mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
    # [nW, Mh*Mw, Mh*Mw]
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
    return attn_mask

In [136]:
window_size = 3
shift_size = 1
H = 9
W = 9
img_mask = torch.zeros((1, H, W, 1)) 
h_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
w_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
cnt = 0
for h in h_slices:
    for w in w_slices:
        img_mask[:, h, w, :] = cnt
        cnt += 1
mask_windows = window_partition(img_mask, window_size)
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
attn_mask

tensor([[[   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.],
         [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.],
         [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.],
         [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.],
         [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.],
         [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.],
         [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.],
         [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.],
         [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.]],

        [[   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.],
         [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.],
         [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.],
         [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.],
         [   0.,    0.,    0.,    0.