## 手写swin-transformer

### 1、如何基于图片生成patch embedding?
方法一
- 基于pytorch unfold的API来将图片进行分块，也就是模仿卷积的思路，设置kernel_size=stride=patch_size, 得到分块后的图片
- 得到格式为[bs, num_patch, patch_depth]的张量
- 将张量与形状为[patch_depth, model_dim_C]的权重矩阵进行乘法操作，即可得到形状为[bs, num_patch, model_dim_C]的patch embedding

方法二
- patch_depth是等于input_channel * patch_size * patch_size
- model_dim_C相当于二维卷积的输出通道数目
- 将形状为[patch_depth, model_dim_C]的权重矩阵转换为[model_dim_C, input_channel, patch_size, patch_size]的卷积核
- 调用PyTorch的conv2d API得到卷积的输出张量，形状为[bs, output_channel, height, width]
- 转换为[bs, num_patch, model_dim_C]的格式，即为patch embedding



In [16]:

import torch
import torch.nn as nn
import torch.nn.functional as F

import math

#难点1 patch embedding
def image2emb_naive(image, patch_size, weight):
    """直观方法实现patch embedding"""
    # 注意unfold的输入只针对4-D向量，所以images shape：bs*channel*h*w
    patch_image = F.unfold(image, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size)).transpose(-1, -2)  # bc*num_patch*patch_depth(相当于一个patch的深度)
    patch_embedding = patch_image @ weight   # bc*num_patch*model_dim_C
    return patch_embedding

def image2emb_conv2(image, kernel, stride):
    """基于二维卷积的patch embedding, embedding的维度就是卷积的输出通道数"""
    out = F.conv2d(image, kernel, stride=stride)
    bs, output_channel, height, width = out.shape
    patch_embedding = out.reshape(bs, output_channel, height*width).transpose(-1, -2)
    return patch_embedding
    
# 验证, 得到两者计算结果一致(这样构造虽然shape是对的，但是好像结果有点问题)
patch_size = 4
model_dim_C = 100
images = torch.randn(2,3, 16, 16)
weight = torch.randn(patch_size*patch_size*3, model_dim_C)
kernel = weight.reshape(model_dim_C, 3, patch_size, patch_size)
print(image2emb_naive(images, patch_size, weight).shape)
print(image2emb_conv2(images, kernel, stride=patch_size).shape)    

torch.Size([2, 16, 100])
torch.Size([2, 16, 100])


### 2、如何构建MHSA（多头自注意力）并计算其复杂度？
- 基于输入x经过三个映射分别得到q，k，v
  - 此步复杂度为$3LC^2$，其中L为序列长度，C为特征大小
- 将q，k，v拆分成多头的形式，注意这里的多头各自计算时互不影响，所以可以与bs维度进行统一的看待
- 计算$qk^T$，并考虑可能的掩码，即让无效的两两位置之间的能量为负无穷，掩码是在shift window MHSA中会需要，而在window MHSA中暂不需要
  - 此步复杂度为$L^2C$
- 计算概率值与v的乘积
  - 此步复杂度为$LC^2$
- 总体复杂度为$4LC^2+2L^2C$

In [28]:

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, model_dim, num_head) -> None:
        super().__init__()
        self.num_head = num_head

        self.proj_linear_layer = nn.Linear(model_dim, 3*model_dim)
        self.final_linear_layer = nn.Linear(model_dim, model_dim)

    def forward(self, input, additive_mask=None):
        bs, seqlen, model_dim = input.shape
        num_head = self.num_head
        head_dim = model_dim // num_head

        proj_out = self.proj_linear_layer(input)
        q, k, v = proj_out.chunk(3, dim=-1)   # [bs, seqlen, model_dim] 
        # 在给定维度(轴)上将输入张量进行分块儿

        # 接下来把q，k, v拆成多头的形式
        q = q.reshape(bs, seqlen, num_head, head_dim).transpose(1, 2)  #[bs, num_head, seqlen, head_dim]
        q = q.reshape(bs*num_head, seqlen, head_dim)

        k = k.reshape(bs, seqlen, num_head, head_dim).transpose(1, 2)  #[bs, num_head, seqlen, head_dim]
        k = k.reshape(bs*num_head, seqlen, head_dim)

        v = v.reshape(bs, seqlen, num_head, head_dim).transpose(1, 2)  #[bs, num_head, seqlen, head_dim]
        v = v.reshape(bs*num_head, seqlen, head_dim)

        if additive_mask is None:
            attn_prob = F.softmax(torch.bmm(q, k.transpose(-2, -1))/math.sqrt(head_dim), dim=-1)
        else:
            additive_mask = additive_mask.tile((num_head, 1, 1))
            attn_prob = F.softmax(torch.bmm(q, k.transpose(-2, -1))/math.sqrt(head_dim)+additive_mask, dim=-1)

        output = torch.bmm(attn_prob, v)    # [bs*num_head, seqlen, head_dim]
        output = output.reshape(bs, num_head, seqlen, head_dim).transpose(1,2)
        output = output.reshape(bs, seqlen, model_dim)

        output = self.final_linear_layer(output)
        return attn_prob, output

### 3、如何构建Window MHSA并计算其复杂度? (window_size固定)
- patch组成的图片进一步划分成一个个更大的window
  - 首先需要将三维的patch embedding转换成图片格式
  - 使用unfold来讲patch划分成window
- 在每个window内部计算MHSA
  - window数目其实可以跟batchsize进行统一对待，因为window与window之间没有交互计算
  - 关于计算复杂度
    - 假设窗的边长为W，那么计算每个窗的总体复杂度是$4W^2C^2+2W^4C$
    - 假设patch的总数目为L，那么窗的数目为$L/W^2$
    - 因此，W-HMSA的总体复杂度为$4LC^2+2LW^2C$
  - 此处不需要mask
  - 将计算结果转换成带window的四维张量格式
- 复杂度对比
  - MHSA ：$4LC^2+2L^2C$
  - W-HMSA：$4LC^2+2LW^2C$

In [29]:
def window_multi_head_self_attention(patch_embedding, mhsa, window_size=4, num_head=2):
    num_patch_in_window = window_size * window_size
    bs, num_patch, patch_depth = patch_embedding.shape
    image_height = image_width = int(math.sqrt(num_patch))

    patch_embedding = patch_embedding.transpose(-1, -2)
    patch = patch_embedding.reshape(bs, patch_depth, image_height, image_width)
    window = F.unfold(patch, kernel_size=(window_size, window_size), stride=(window_size, window_size)).transpose(-1,-2)
    
    bs, num_window, patch_depth_times_num_patch_in_window = window.shape
    window = window.reshape(bs*num_window, patch_depth, num_patch_in_window).transpose(-1,-2)

    attn_probs, output = mhsa(window)

    output = output.reshape(bs, num_window, num_patch_in_window, patch_depth)
    return output