# Swin Transformer
自己制造一个 *Swin Transformer* ，当然最终的目的是想把它改造成 U-Net 的架构，我希望对它进行如下改进：
1. 去除相对位置编码：我会在进行 qkv 映射的时候使用卷积，引入卷积的归纳偏置，看看是否可以去除相对位置编码问题
2. 去除滑动窗口：因为使用了 U-Net 架构以及卷积，滑动窗口显得不是很必要

> 暂时就想到了这些，之后有什么东西想到再加上去吧，也算锻炼一下自己的复现能力

# 1. 预处理
最基本的部分，包含 embedding，patch 划分，patch 还原，窗口划分，窗口复原

## 1.1 embedding
对图像进行映射，准备转化为 Token 以便交给 Transformer 作处理
可以选择是使用线性映射还是卷积，映射完之后自动转换成 Token 的 2d 表示，以便于之后的窗口划分
1. embedding 并划分 patch

> 后面思考了一下，既然要下采样，不如还是每一个像素当成一个 Token，256*256 也不是一个太大的分辨率，直接上吧

In [22]:
import torch
import torch.nn as nn
from einops import rearrange, repeat


class patchEmbedding(nn.Module):
    """
        进行 embedding 以及 patch 划分
        Args:
            inChannels: input image channels
            embedDim: embedding dim
            methods: how to do embedding, conv or linear
        return:
            x: (B,H,W,C), H & W is patch's height & width
    """

    def __init__(self, in_channels, emb_dim, methods):
        super().__init__()
        self.methods = methods
        if methods == 'conv':
            self.proj = nn.Conv2d(in_channels, emb_dim, 3, 1, 1)
        elif methods == 'linear':
            self.proj = nn.Linear(in_channels, emb_dim)
     
    def forward(self, x):
        if self.methods == 'conv':
            x = self.proj(x)
            x = rearrange(x, 'b c (h ph) (w pw) -> b h w (ph pw c)', ph=1, pw=1)
        elif self.methods == 'linear':
            x = rearrange(x, 'b c (h ph) (w pw) -> b h w (ph pw c)', ph=1, pw=1)
            x = self.proj(x)
        return x


In [23]:
# test
pe = patchEmbedding(3, 16, 'conv')
test = torch.randn(1, 3, 256, 256)
pe(test).shape


torch.Size([1, 256, 256, 16])

## 1.2 窗口划分
将生成的 2d patch 进行窗口划分，以便于使用窗口自注意力
1. 划分窗口

In [10]:
def windowPartition(x, window_size):
    """
        Window partition, based on patch
        Args:
            windowSize: The size of the windos, based on patch
        returns:
            (B*numWindows, Wh, Ww, C), C is a Token
    """
    x = rearrange(x, 'b (h wh) (w ww) c -> (b h w) wh ww c', wh=window_size, ww=window_size)   # B*numWindows windowHeight windowWidth C
    return x


2. 划分窗口还原

In [15]:
def windowReverse(x, window_size):
    """
        Window reverse, reversing the window partition back to patch
        Args:
            imageSize: The size of the image, based on pixels
        return:
            (B, H, W, C)
    """
    B, Wh, Ww, _ = x.shape
    ratio = window_size//Wh
    x = rearrange(x, '(b h w) wh ww c -> b (h wh) (w ww) c', h=ratio, w=ratio)
    return x


In [17]:
# test
test = torch.randn(1, 256, 256, 16)
test_win = windowPartition(test, 8)
test_rev = windowReverse(test_win, 256)
test == test_rev


tensor([[[[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True]],

         [[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True]],

         [[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ...

## 2. 模型架构

这里要开始编写模型的架构细节了，从每一个小结构来构成整个模型

## 2.1 窗口自注意力

这里要编写窗口自注意力方法，swin transformer 在窗口自注意力中包含 mask 以及以及相对位置编码，这里打算不使用滑动窗口，但是会保留是否使用相对位置编码的选项

In [42]:
class windowAttention(nn.Module):
    """
        Self attention based on windows
        input: (B*num_windows, N, C)
        Args:
            patchSize: the size of the patch
            patchDim: the dim of the patch, 2d representation
            attentionDim: the projection dim of qkv

    """

    def __init__(self, window_size, dim, num_heads, pe_flag):
        super().__init__()
        self.to_qkv = nn.Conv2d(dim, dim*3, 3, 1, 1)
        self.proj = nn.Conv2d(dim, dim, 3, 1, 1)
        self.num_heads = num_heads
        self.pe_Flag = pe_flag
        self.window_size = window_size
        head_dim = dim//num_heads
        self.scale = head_dim**-0.5
        if pe_flag:
            # define a parameter table of relative position bias
            self.relative_position_bias_table = nn.Parameter(
                torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

            # get pair-wise relative position index for each token inside the window
            coords_h = torch.arange(self.window_size)
            coords_w = torch.arange(self.window_size)
            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
            relative_coords[:, :, 0] += self.window_size - 1  # shift to start from 0
            relative_coords[:, :, 1] += self.window_size - 1
            relative_coords[:, :, 0] *= 2 * self.window_size - 1
            relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
            self.register_buffer("relative_position_index", relative_position_index)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        B, N, C = x.shape  # B*num_windows, window_size**2, C
        x=rearrange(x,'b (wh ww) c -> b c wh ww',wh=self.window_size,ww=self.window_size)  # 转换为 2d 图像表示
        qkv = self.to_qkv(x)    # B*num_windows, C**3, window_size, windowsize
        qkv = rearrange(qkv, 'b (num head head_dim) wh ww  -> num b head (wh ww) head_dim',num=3, head=self.num_heads, head_dim=C//self.num_heads) # 转为 token 表示
        q, k, v = qkv[0], qkv[1], qkv[2]
        q=q*self.scale
        attn = (q @ k.transpose(-2,-1))

        if self.pe_Flag:
            relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
                self.window_size * self.window_size, self.window_size * self.window_size, -1)  # Wh*Ww,Wh*Ww,nH
            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
            attn = attn + relative_position_bias.unsqueeze(0)  # positional enbedding
        
        attn=self.softmax(attn)
        x=(attn @ v) #B*num_windows, heads, n, head_dim
        x=rearrange(x,'b head (wh ww) head_dim -> b (head head_dim) wh ww ',wh = self.window_size,ww=self.window_size)
        x=self.proj(x)  # B*num_window, window_size, window_size, dim
        x=rearrange(x,'b c wh ww -> b (wh ww) c')   # B*num_windows, n, c
        return x


In [47]:
testWA = windowAttention(2,16,4,True)  # head=4,head_dim = 4
test = torch.randn(1*4,2*2,16)  # 这里的数据符合窗口自注意力的数据，如果想要改变窗口的大小需要另行调整
testWA(test).shape

torch.Size([4, 4, 16])

## 2.2 FeedForward
这里要处理 FeedForward 结构，感觉有多种选择，一种是像传统的使用全连接层，或者是使用 CNN
但是使用 CNN 也会有问题，是要在窗口内使用 CNN 还是在整张图像上使用 CNN

> 其实我比较倾向在窗口内使用 CNN，因为既然划分了窗口，那么随着网络的加深降低图像的分辨率同样可以让不同窗口之间的信息进行交互
> 但是如果想要实现全局自注意力还是需要让网络再深一层，不然就是随着网络的加深，窗口大小同时增大


In [49]:
class FeedForward(nn.Module):
    def __init__(self, method, dim, mlp_ratio):
        """
            FeedForward Block
            Args:
                input: (B*num_winodw, N, C)
                methods: the projection methods, cnn or mlp
                dim: the dim of token
                mlp_ratio: the dim of the hidden layer
        """
        super().__init__()
        self.method = method
        if method == 'conv':
            self.layer1 = nn.Conv2d(dim, dim*mlp_ratio, 3, 1, 1)
            self.layer2 = nn.Conv2d(dim*mlp_ratio, dim, 3, 1, 1)
        else:
            self.layer1 = nn.Linear(dim, dim*mlp_ratio)
            self.layer2 = nn.Linear(dim*mlp_ratio, dim)
        self.act = nn.GELU()

    def forward(self, x):
        if self.method == 'conv':
            B, N, C = x.shape
            wh = ww = int(N**0.5)
            x = rearrange(x, 'b (wh ww) c -> b c wh ww',wh=wh,ww=ww)
            x=self.layer1(x)
            x=self.act(x)
            x=self.layer2(x)
            x = rearrange(x, 'b c wh ww -> b (wh ww) c')
        else:
            x=self.layer1(x)
            x=self.act(x)
            x=self.layer2(x)
        return x



In [52]:
test_fw=FeedForward('linear',16,4)
test_input=torch.randn(1*4,8*8,16)
test_fw(test_input).shape

torch.Size([4, 64, 16])