In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# <font color = 'red'>1、如何基于图片生成patch_embbeding</font>
## 方法一
- 基于pytorch的unfold API来讲图片进行分块,也就是模仿卷积的思路，设置kernel_size=stride=patch_size,得到分块后的图片
- 得到的图片格式为[bs,num_patch,patch_depth]
- 将上述张量与形状为[patch_depth,model_dim_C]的张量做线性映射，即可得到[bs,num_patch,model_dim_C]的patch_embbeding

In [7]:
def image2emb_naive(image,patch_size,weight):
    patch = F.unfold(image,kernel_size=(patch_size,patch_size),stride = (patch_size,patch_size)).transpose(-1,-2)
    patch_embbeding = patch @ weight
    return patch_embbeding


image = torch.randn(1,3,8,8)
weight = torch.randn(4*4*3,8)

print(image2emb_naive(image,4,weight).shape)


torch.Size([1, 4, 8])


## 方法二
- 卷积形式
- patch_depth相当于patch_size\*patch_size\*input_channel
- model_dim_C相当于二维卷积的输出channel



In [8]:
def image2emb_convolution(image,kernel,stride):
    conv_output = F.conv2d(image, kernel, stride=stride)
    bs,oc,oh,ow = conv_output.shape
    patch_embbeding = conv_output.reshape((bs,oc,oh*ow)).transpose(-1,-2)
    return patch_embbeding

# <font color = 'red'>2、构建MHSA并计算其复杂度</font>
- 1.基于输入x[bs,L,C]进行三个映射分别得到q、k、v三个矩阵
    - 则每个矩阵$q$、$W$、$v$ 的计算复杂度为$LC^2$,一共是$3LC^2$
    
- 2.计算attention时候的复杂度
    - 1.$q@k^T$的复杂度为$L^2C$
    - 2.继续乘以v的复杂度为$L^2C$
    - 3.最后做一个线性映射的复杂度为$L2^C$
    
- 3.虽有MHSA的时间复杂度为
    - $4LC^2+2L^2C$
   
- 4.可以看到传统的MHSA有着与L的平方的关系,具有很高的时间复杂度

In [23]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self,model_dim,num_head):
        super(MultiHeadSelfAttention,self).__init__()
        self.num_head = num_head

        self.proj_linear_layer = nn.Linear(model_dim, 3*model_dim)
        self.final_linear_layer = nn.Linear(model_dim, model_dim)

    def forward(self, input, additive_mask=None):
        bs,seq_len,model_dim = input.shape
        num_head = self.num_head
        head_dim = model_dim // num_head
        proj_output = self.proj_linear_layer(input)
        q, k, v = proj_output.chunk(3,dim=-1) # shape均为[bs,seq_len,model_dim_C]
        # q:[bs,seq_len,model_dim]
        q = q.reshape(bs, seq_len, num_head, head_dim).transpose(1, 2).reshape(bs * num_head, seq_len, head_dim)
        k = k.reshape(bs, seq_len, num_head, head_dim).transpose(1, 2).reshape(bs * num_head, seq_len, head_dim)
        v = v.reshape(bs, seq_len, num_head, head_dim).transpose(1, 2).reshape(bs * num_head, seq_len, head_dim)
        if additive_mask is None:
            attention_prob = F.softmax(torch.bmm(q,k.transpose(-1,-2))/math.sqrt(head_dim),dim=-1)
        else:
            additive_mask = additive_mask.tile(num_head,1,1)
            attention_prob = F.softmax(torch.bmm(q,k.transpose(-1,-2))/math.sqrt(head_dim)+additive_mask,dim=-1)
        output = torch.bmm(attention_prob, v) # [bs*num_head,seq_len,head_dim]
        output = output.reshape(bs, num_head, seq_len, head_dim).transpose(1,2)
        output = output.reshape(bs, seq_len, model_dim)
        output = self.final_linear_layer(output)
        return attention_prob, output

# <font color = 'red'>3、构建window MHSA 并计算其复杂度</font>

- 将patch后的图片进一步分成一个个更大的window
    - 1.需要将3D的patch embedding转换为图片的格式
    - 2.使用unfold函数将patch划分为window
    
- 在每个window内部计算MHSA
    - window数目可以跟batch_size进行同一对待,因为在window之间没有交互计算
    - 关于计算WMHSA的时间复杂度
        - 假设窗的边长为W,那么计算每个窗的复杂度为$4W^2C^2+2W^4C$
        - 一共有窗的个数为$L/W^2$
        - 因此总的复杂度为二者相乘$4LC^2+2LW^2C$
    
    - 此处不需要mask
    - 将计算结果转换成带window的4D tenser
    
 
- 复杂度对比
    - `MHSA`: $4LC^2+2L^2C$
    - `W-MHSA`: $4LC^2+2LW^2C$

In [24]:
def windows_multi_head_self_attention(patch_embbeding,mhsa,window_size=4,num_head=2):
    num_patch_in_window = window_size*window_size
    bs, num_patch, patch_depth = patch_embbeding.shape
    image_height = image_width = int(math.sqrt(num_patch))

    patch_embbeding = patch_embbeding.transpose(-1, -2)
    patch = patch_embbeding.reshape(bs,patch_depth, image_height, image_width)
    window = F.unfold(patch,kernel_size=(window_size,window_size),stride=(window_size,window_size),).transpose(-1,-2)


    bs,num_windows,patch_depth_times_num_patch_in_window = window.shape
    window = window.reshape(bs*num_windows,patch_depth,num_patch_in_window).transpose(-1,-2)

    attn, output = mhsa(window)# [bs*num_window,num_patch_in_window,PATCH_depth]

    output = output.reshape(bs,num_windows,num_patch_in_window,patch_depth)
    return output

# <font color = 'red'>4、构建shifted Window MHSA 及其MASK</font>
- 将上一步的 W-MHSA转换为图片的形式即:[bs,num_windows,num_patch_in_window,patch_depth]转换为[bs,ic,image_h,image_w]
- 假设已经做了新的window划分,这一步叫做shift-window
- 为了保持window数目不变从而有高效的计算,需要将图片的patch往左和往上各自滑动半个窗口大小的步长,保持patch所属window类别不变
- 将图片patch还原成window的数据格式
- 由于shift_window之后,每个window虽然形状罪证,但是部分window存在不属于同一个窗口的patch,所以要生成mask
- 如何生成mask
    - 首先构建一个shift-window的patch所属的window类别矩阵
    - 对该矩阵进行同样的王座和网上各自滑动半个窗口大小的步长的操作
    - 通过unfold操作可得到[bs,num_window,num_patch_in_window]形状的类别矩阵
    - 对该矩阵进行扩维:[bs,num_window,num_patch_in_window,1]
    - 该矩阵与自身的转置作差,得到同类关系矩阵，（为0的位置patch的关系属于同类）
    - 对同类矩阵中的非0的位置用负无穷进行填充,对于零的位置用0去填充,这样就构建好了MHS所需要的Mask
    - 这个mask矩阵的形状为[bs,num_window,num_patch_in_window,num_patch_in_window]
- 将window转换为三维的形式:[bs * num_window,num_patch_in _window，patch_size]
- 将三维格式的特征连通mask一起送入MHSA中计算得到注意力输出
- 将注意力输出换转为图片patch形式:[bs,num_window,num_patch_in_window,patch_size]
- 为了恢复位置,需要将图片的patch往右和往左滑动半个window大小
- 至此,SW-MHSA计算完成



In [25]:
# 构建SW-MHSA附属函数1:将window形式转换为image形式:目的在于将image进行shift
def window2image(msa_output):
    bs, num_windows, num_patch_in_windows, patch_depth = msa_output.shape
    window_size = int(math.sqrt(num_patch_in_windows))
    image_height = image_width = int(math.sqrt(num_windows)) * window_size
    msa_output = msa_output.reshape(bs,int(math.sqrt(num_windows)),int(math.sqrt(num_windows)),window_size,window_size,patch_depth)
    msa_output = msa_output.transpose(2,3)
    image = msa_output.reshape(bs,image_height*image_width,patch_depth)
    image = image.transpose(-2,-1)
    image = image.reshape(bs,patch_depth,image_height,image_width)
    return image



In [26]:
# 构建SW-MHSA附属函数2:进行window-shift
def shift_window(w_msa_output,window_size,shift_size,generate_mask=False):
    bs, num_window, num_patch_in_window,patch_depth = w_msa_output.shape
    w_msa_output = window2image(w_msa_output)
    bs,patch_depth,image_height,image_weight = w_msa_output.shape
    rolled_w_msa_output = torch.roll(w_msa_output,shifts=(shift_size,shift_size),dims=(2,3))
    shifted_w_msa_output = rolled_w_msa_output.reshape(bs,patch_depth,int(math.sqrt(num_window)),window_size,int(math.sqrt(num_window)),window_size)
    shifted_w_msa_output = shifted_w_msa_output.transpose(3,4)
    shifted_w_msa_output = shifted_w_msa_output.reshape(bs,patch_depth,num_window*num_patch_in_window)
    shifted_w_msa_output = shifted_w_msa_output.transpose(-1,-2)
    shifted_window = shifted_w_msa_output.reshape(bs,num_window,num_patch_in_window,patch_depth)

    if generate_mask:
        additive_mask = build_mask_for_shifted_wmsa(bs,image_height,image_weight,window_size)
    else:
        additive_mask = None

    return shifted_window,additive_mask


In [27]:
# 构建SW-MHSA附属函数3:构建mask矩阵
def build_mask_for_shifted_wmsa(batch_size,image_height,image_width,window_size):
    index_matrix = torch.zeros(image_height,image_width)
    for i in range(image_height):
        for j in range(image_width):
            row_times = (i+window_size//2) // window_size
            col_times = (j+window_size//2) // window_size
            index_matrix[i,j] = row_times * (image_height // window_size) + col_times + 1

    roll_index_matrix = torch.roll(index_matrix,shifts=(-window_size//2,-window_size//2),dims=(0,1))
    roll_index_matrix = roll_index_matrix.unsqueeze(0).unsqueeze(0) #[bs,ch,h,w]

    c = F.unfold(roll_index_matrix,kernel_size=(window_size,window_size),stride=(window_size,window_size)).transpose(-1,-2)

    c = c.tile(batch_size,1,1)

    bs,num_window,num_patch_in_window = c.shape

    c1 = c.unsqueeze(-1)
    c2 = (c1-c1.transpose(-1,-2))== 0
    valid_matrix = c2.to(torch.float32)

    unlimit_min = -1e-9
    additive_mask = (1-valid_matrix) * unlimit_min

    additive_mask = additive_mask.reshape(bs * num_window,num_patch_in_window,num_patch_in_window)
    print(additive_mask.shape)
    return additive_mask

In [28]:
def shift_window_multi_head_self_attention(w_msa_output,mhsa,window_size=4,num_head=2):
    bs,num_window,num_patch_in_window,patch_depth = w_msa_output.shape
    shifted_w_msa_output ,additive_mask = shift_window(w_msa_output,window_size,
                                                       shift_size=-window_size//2,generate_mask=True)

    shifted_w_msa_output = shifted_w_msa_output.reshape(bs*num_window,num_patch_in_window,patch_depth)

    _, output = mhsa(shifted_w_msa_output,additive_mask)
    output = output.reshape(bs, num_window,num_patch_in_window,patch_depth)

    output,_ = shift_window(output,window_size,shift_size=window_size//2,generate_mask=False)
    return output

# <font color = 'red'>5、如何构建PatchMerging</font>
- 将window转换为patch
- 利用unfold操作,按照merge_size\*merge_size大小得到新的patch,[bs,num_patch_new,merge_size*merge_size*patch_depth_old]
- 增加全连接层对patch_depth进行映射,一般将维度成0.5倍,输出的patch_embbe的形状为[bs,num_patch,patch_depth]
- 举例说明:以merge_size=2为例,num_patch变为原来的1/4,patch_depth变为原来的2倍

In [29]:
class PatchMerging(nn.Module):
    def  __init__(self, model_dim, merge_size, output_depth_scale = 0.5):
        super(PatchMerging,self).__init__()
        self.merge_size = merge_size
        self.proj_layer = nn.Linear(
            model_dim*merge_size*merge_size,
            int(model_dim*merge_size*merge_size*output_depth_scale))

    def forward(self,input):
        bs,num_window,num_patch_in_window,patch_depth = input.shape
        #window_size = int(math.sqrt(num_patch_in_window))
        input = window2image(input)  # [bs,patch_depth,image_h,image_w]
        merged_window = F.unfold(input,kernel_size=(self.merge_size,self.merge_size),
                                 stride=(self.merge_size,self.merge_size)).transpose(-1,-2)

        merged_window = self.proj_layer(merged_window)

        return merged_window

# <font color = 'red'>6、构建SwinTransformerBlock</font>
- 1、每个block包含有LayerNorm,WMHSA、MLP、SMHSA
- 2、输入的是patch_embbeding[bs,num_patch,patch_depth]
- 3、其中每个MLP与所有的Transformer族一致，都是包含了两个Linear,第一个Layer将model_dim_C映射到4\*model_dimC,第二个layer将4\*model_dim_C映射会model_dim_C维度
- 4、输出的维度为[bs,num_patch,num_patch_in_window,patch_depth]
- 5、注意残差链接的时候输入与输出的维度需要保持一致


In [30]:
class SwinTransformerBlock(nn.Module):
    def __init__(self,model_dim,window_size,num_head):
        super(SwinTransformerBlock,self).__init__()
        self.layer_norm1 = nn.LayerNorm(model_dim)
        self.layer_norm2 = nn.LayerNorm(model_dim)
        self.layer_norm3 = nn.LayerNorm(model_dim)
        self.layer_norm4 = nn.LayerNorm(model_dim)

        self.wsma_mlp1 = nn.Linear(model_dim,4*model_dim)
        self.wsma_mlp2 = nn.Linear(4*model_dim,model_dim)
        
        self.swsma_mlp1 = nn.Linear(model_dim, 4 * model_dim)
        self.swsma_mlp2 = nn.Linear(4 * model_dim, model_dim)

        self.mhsa1 = MultiHeadSelfAttention(model_dim,num_head)
        self.mhsa2 = MultiHeadSelfAttention(model_dim,num_head)

    def forward(self, input):
        bs,num_patch,patch_depth = input.shape
        
        # WMHSA
        input1 = self.layer_norm1(input)
        w_msa_output = windows_multi_head_self_attention(input1,self.mhsa1,window_size=4,num_head=2)
        bs,num_window, num_patch_in_window, patch_depth = w_msa_output.shape
        w_msa_output = input + w_msa_output.reshape(bs,num_patch,patch_depth)
        output1 = self.wsma_mlp2(self.wsma_mlp1(self.layer_norm2(w_msa_output)))
        output1 += w_msa_output

        # SWMHSA
        input2 = self.layer_norm3(output1)
        input2 = input2.reshape(bs,num_window,num_patch_in_window,patch_depth)
        sw_msa_output = shift_window_multi_head_self_attention(input2, self.mhsa2,window_size=4,num_head=2)
        sw_msa_output = output1+sw_msa_output.reshape(bs,num_patch,patch_depth)
        output2 = self.swsma_mlp2(self.swsma_mlp1(self.layer_norm4(sw_msa_output)))
        output2 += sw_msa_output

        output2 = output2.reshape(bs,num_window,num_patch_in_window,patch_depth)
        return output2

# <font color = 'red'>7、构建SwinTransformerModel</font>
- 输入是图片[bs,ic,image_h,image_w]
- 首先对图片进行分块得到patch embedding
- 根据论文进入四个stage
- 对最后一个输出转换为patch embedding的形式[bs,num_patch,patch_depth]
- 对patch embbeding进行时间维度的平均池化操作,并映射得到分类的logits,分类完毕



In [31]:
class SwinTransformerModel(nn.Module):
    def __init__(self,input_image_channel = 3,patch_size = 4,model_dim_C = 8,num_classes = 10,window_size = 4,num_head=2,merge_size = 2):
        super(SwinTransformerModel,self).__init__()

        patch_depth = patch_size*patch_size*input_image_channel
        self.patch_size = patch_size
        self.model_dim_C = model_dim_C
        self.num_classes = num_classes

        self.patch_embbeding_weight = nn.Parameter(torch.randn(patch_depth,model_dim_C))

        self.block1 = SwinTransformerBlock(model_dim_C,window_size,num_head)
        self.block2 = SwinTransformerBlock(model_dim_C*2, window_size, num_head)
        self.block3 = SwinTransformerBlock(model_dim_C*4, window_size, num_head)
        self.block4 = SwinTransformerBlock(model_dim_C*8, window_size, num_head)

        self.patch_merging1 = PatchMerging(model_dim_C,merge_size)
        self.patch_merging2 = PatchMerging(model_dim_C*2, merge_size)
        self.patch_merging3 = PatchMerging(model_dim_C*4, merge_size)

        self.final_linear = nn.Linear(model_dim_C*8,num_classes)


    def forward(self,image):
        patch_embbeding_naive = image2emb_naive(image,self.patch_size,self.patch_embbeding_weight)

        # stage1
        patch_embbeding = patch_embbeding_naive
        print(patch_embbeding.shape)
        sw_msa_output = self.block1(patch_embbeding)
        print("stage1_output:",sw_msa_output.shape)

        # stage2
        merged_patch1 = self.patch_merging1(sw_msa_output)
        sw_msa_output_1 = self.block2(merged_patch1)
        print("stage2_output:", sw_msa_output_1.shape)

        # stage3
        merged_patch2 = self.patch_merging2(sw_msa_output_1)
        sw_msa_output_2 = self.block3(merged_patch2)
        print("stage3_output:", sw_msa_output_2.shape)

        # stage4
        merged_patch3 = self.patch_merging3(sw_msa_output_2)
        sw_msa_output_3 = self.block4(merged_patch3)
        print("stage4_output:", sw_msa_output_3.shape)

        bs,num_window,num_patch_in_window,patch_depth = sw_msa_output_3.shape
        sw_msa_output_3 = sw_msa_output_3.reshape(bs,-1,patch_depth)
        pool_output = torch.mean(sw_msa_output_3,dim=1)
        logits = self.final_linear(pool_output)

        print("logits",logits.shape)

# <font color = 'red'>8、测试主函数</font>

In [32]:
if __name__ == '__main__':
    bs,ic,image_h,image_w = 4,3,256,256
    patch_size = 4
    model_dim_C = 8
    num_classes = 10
    window_size = 4
    num_head = 2
    merge_size = 2

    patch_depth = patch_size*patch_size*ic
    image = torch.randn(bs,ic,image_h,image_w)
    model = SwinTransformerModel(ic,patch_size,model_dim_C,num_classes,window_size,num_head,merge_size)

    model(image)

torch.Size([4, 4096, 8])
torch.Size([1024, 16, 16])
stage1_output: torch.Size([4, 256, 16, 8])
torch.Size([256, 16, 16])
stage2_output: torch.Size([4, 64, 16, 16])
torch.Size([64, 16, 16])
stage3_output: torch.Size([4, 16, 16, 32])
torch.Size([16, 16, 16])
stage4_output: torch.Size([4, 4, 16, 64])
logits torch.Size([4, 10])
