# Swin Transformer
自己制造一个 *Swin Transformer* ，当然最终的目的是想把它改造成 U-Net 的架构，我希望对它进行如下改进：
1. 去除相对位置编码：我会在进行 qkv 映射的时候使用卷积，引入卷积的归纳偏置，看看是否可以去除相对位置编码问题
2. 去除滑动窗口：因为使用了 U-Net 架构以及卷积，滑动窗口显得不是很必要

> 暂时就想到了这些，之后有什么东西想到再加上去吧，也算锻炼一下自己的复现能力
> 但是这个还叫 Swin 吗？叫做 window Transformer 好了，没有滑动，只有窗口自注意力

# 1. 预处理
最基本的部分，包含 embedding，patch 划分，patch 还原，窗口划分，窗口复原

## 1.1 embedding
对图像进行映射，准备转化为 Token 以便交给 Transformer 作处理
可以选择是使用线性映射还是卷积，映射完之后自动转换成 Token 的 2d 表示，以便于之后的窗口划分
1. embedding 并划分 patch

> 后面思考了一下，既然要下采样，不如还是每一个像素当成一个 Token，256*256 也不是一个太大的分辨率，直接上吧

In [1]:
import torch
import torch.nn as nn
from einops import rearrange, repeat


class patchEmbedding(nn.Module):
    """
        进行 embedding 以及 patch 划分
        Args:
            inChannels: input image channels
            embedDim: embedding dim
            methods: how to do embedding, conv or linear
        return:
            x: (B,H,W,C), H & W is patch's height & width
    """

    def __init__(self, in_channels, emb_dim, methods):
        super().__init__()
        self.methods = methods
        if methods == 'conv':
            self.proj = nn.Conv2d(in_channels, emb_dim, 3, 1, 1)
        elif methods == 'linear':
            self.proj = nn.Linear(in_channels, emb_dim)
     
    def forward(self, x):
        if self.methods == 'conv':
            x = self.proj(x)
            x = rearrange(x, 'b c (h ph) (w pw) -> b h w (ph pw c)', ph=1, pw=1)
        elif self.methods == 'linear':
            x = rearrange(x, 'b c (h ph) (w pw) -> b h w (ph pw c)', ph=1, pw=1)
            x = self.proj(x)
        return x


In [23]:
# test
pe = patchEmbedding(3, 16, 'conv')
test = torch.randn(1, 3, 256, 256)
pe(test).shape


torch.Size([1, 256, 256, 16])

## 1.2 窗口划分
将生成的 2d patch 进行窗口划分，以便于使用窗口自注意力
1. 划分窗口

In [2]:
def windowPartition(x, window_size):
    """
        Window partition, based on patch
        Args:
            windowSize: The size of the windos, based on patch
        returns:
            (B*numWindows, Wh, Ww, C), C is a Token
    """
    x = rearrange(x, 'b (h wh) (w ww) c -> (b h w) wh ww c', wh=window_size, ww=window_size)   # B*numWindows windowHeight windowWidth C
    return x


2. 划分窗口还原

In [56]:
def windowReverse(x, resolution):
    """
        Window reverse, reversing the window partition back to patch
        Args:
            imageSize: The size of the image, based on pixels
        return:
            (B, H, W, C)
    """
    B, Wh, Ww, _ = x.shape
    ratio = resolution//Wh
    x = rearrange(x, '(b h w) wh ww c -> b (h wh) (w ww) c', h=ratio, w=ratio)
    return x


In [17]:
# test
test = torch.randn(1, 256, 256, 16)
test_win = windowPartition(test, 8)
test_rev = windowReverse(test_win, 256)
test == test_rev


tensor([[[[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True]],

         [[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True]],

         [[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ...

## 2. 模型架构

这里要开始编写模型的架构细节了，从每一个小结构来构成整个模型

## 2.1 窗口自注意力

这里要编写窗口自注意力方法，swin transformer 在窗口自注意力中包含 mask 以及以及相对位置编码，这里打算不使用滑动窗口，但是会保留是否使用相对位置编码的选项

In [42]:
class windowAttention(nn.Module):
    """
        Self attention based on windows
        input: (B*num_windows, N, C)
        Args:
            patchSize: the size of the patch
            patchDim: the dim of the patch, 2d representation
            attentionDim: the projection dim of qkv

    """

    def __init__(self, window_size, dim, num_heads, pe_flag):
        super().__init__()
        self.to_qkv = nn.Conv2d(dim, dim*3, 3, 1, 1)
        self.proj = nn.Conv2d(dim, dim, 3, 1, 1)
        self.num_heads = num_heads
        self.pe_Flag = pe_flag
        self.window_size = window_size
        head_dim = dim//num_heads
        self.scale = head_dim**-0.5
        if pe_flag:
            # define a parameter table of relative position bias
            self.relative_position_bias_table = nn.Parameter(
                torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

            # get pair-wise relative position index for each token inside the window
            coords_h = torch.arange(self.window_size)
            coords_w = torch.arange(self.window_size)
            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
            relative_coords[:, :, 0] += self.window_size - 1  # shift to start from 0
            relative_coords[:, :, 1] += self.window_size - 1
            relative_coords[:, :, 0] *= 2 * self.window_size - 1
            relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
            self.register_buffer("relative_position_index", relative_position_index)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        B, N, C = x.shape  # B*num_windows, window_size**2, C
        x=rearrange(x,'b (wh ww) c -> b c wh ww',wh=self.window_size,ww=self.window_size)  # 转换为 2d 图像表示
        qkv = self.to_qkv(x)    # B*num_windows, C**3, window_size, windowsize
        qkv = rearrange(qkv, 'b (num head head_dim) wh ww  -> num b head (wh ww) head_dim',num=3, head=self.num_heads, head_dim=C//self.num_heads) # 转为 token 表示
        q, k, v = qkv[0], qkv[1], qkv[2]
        q=q*self.scale
        attn = (q @ k.transpose(-2,-1))

        if self.pe_Flag:
            relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
                self.window_size * self.window_size, self.window_size * self.window_size, -1)  # Wh*Ww,Wh*Ww,nH
            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
            attn = attn + relative_position_bias.unsqueeze(0)  # positional enbedding
        
        attn=self.softmax(attn)
        x=(attn @ v) #B*num_windows, heads, n, head_dim
        x=rearrange(x,'b head (wh ww) head_dim -> b (head head_dim) wh ww ',wh = self.window_size,ww=self.window_size)
        x=self.proj(x)  # B*num_window, window_size, window_size, dim
        x=rearrange(x,'b c wh ww -> b (wh ww) c')   # B*num_windows, n, c
        return x


In [47]:
testWA = windowAttention(2,16,4,True)  # head=4,head_dim = 4
test = torch.randn(1*4,2*2,16)  # 这里的数据符合窗口自注意力的数据，如果想要改变窗口的大小需要另行调整
testWA(test).shape

torch.Size([4, 4, 16])

#### 2.1.1 不同的映射方式

自注意力映射矩阵可以有不同的方式，上面的方式选择的是直接使用标准卷积核做映射，可能会导致参数过多，这里做出一些修改

1. 深度可分离卷积

In [4]:
class DSconv(nn.Module):
    """
        深度可分离卷积
    """
    def __init__(self,in_channels, out_channels, kernel_size,padding,stride,dropout_ratio):
        super().__init__()
        self.dwconv=nn.Conv2d(in_channels=in_channels,out_channels=in_channels,kernel_size=kernel_size,padding=padding,stride=stride,groups=in_channels)
        self.pwconv=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1,stride=1)
        self.drop=nn.Dropout2d(dropout_ratio)
    def forward(self,x):
        x=self.dwconv(x)
        x=self.pwconv(x)
        x=self.drop(x)
        return x

### 2.1.2 降低 KV 数量

增大卷积步长来减少 KV 的数量，从而缓解自注意力的运算压力

> 这里直接选择深度可分离卷积作为映射方式

In [18]:
class convWindowAttention(nn.Module):
    """
        Self attention based on windows
        input: (B*num_windows, N, C)
        Args:
            patchSize: the size of the patch
            patchDim: the dim of the patch, 2d representation
            attentionDim: the projection dim of qkv

    """

    def __init__(self, window_size, dim, num_heads, pe_flag,drop_ratio):
        super().__init__()
        self.to_kv = DSconv(dim, 2*dim, 4, 1, 2)    # 分辨率减小一半
        self.to_q = DSconv(dim, dim, 3, 1, 1)  # 标准卷积
        self.proj = DSconv(dim, dim, 3, 1, 1)
        self.attn_drop = nn.Dropout(drop_ratio)
        self.num_heads = num_heads
        self.pe_Flag = pe_flag
        self.window_size = window_size
        head_dim = dim//num_heads
        self.scale = head_dim**-0.5
        if pe_flag:
            # define a parameter table of relative position bias
            self.relative_position_bias_table = nn.Parameter(
                torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

            # get pair-wise relative position index for each token inside the window
            coords_h = torch.arange(self.window_size)
            coords_w = torch.arange(self.window_size)
            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
            relative_coords[:, :, 0] += self.window_size - 1  # shift to start from 0
            relative_coords[:, :, 1] += self.window_size - 1
            relative_coords[:, :, 0] *= 2 * self.window_size - 1
            relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
            self.register_buffer("relative_position_index", relative_position_index)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        B, N, C = x.shape  # B*num_windows, window_size**2, C
        x = rearrange(x, 'b (wh ww) c -> b c wh ww', wh=self.window_size, ww=self.window_size)  # 转换为 2d 图像表示
        kv = self.to_kv(x)  # B*num_windows, C**3, window_size, windowsize
        q = self.to_q(x)  # B*num_windows, C, window_size, windowsize
        kv = rearrange(kv, 'b (num head head_dim) wh ww  -> num b head (wh ww) head_dim',
                        num=2, head=self.num_heads, head_dim=C//self.num_heads)  # 转为 token 表示
        q =rearrange(q,'b (head head_dim) h w -> b head (h w) head_dim',head=self.num_heads)
        k,v=kv[0],kv[1]
        q = q*self.scale
        attn = (q @ k.transpose(-2, -1))

        if self.pe_Flag:
            relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
                self.window_size * self.window_size, self.window_size * self.window_size, -1)  # Wh*Ww,Wh*Ww,nH
            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
            attn = attn + relative_position_bias.unsqueeze(0)  # positional enbedding

        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        x = (attn @ v)  # B*num_windows, heads, n, head_dim
        x = rearrange(x, 'b head (wh ww) head_dim -> b (head head_dim) wh ww ', wh=self.window_size, ww=self.window_size)
        x = self.proj(x)  # B*num_window, window_size, window_size, dim
        x = rearrange(x, 'b c wh ww -> b (wh ww) c')   # B*num_windows, n, c
        return x


In [15]:
testWA = convWindowAttention(8,16,4,False)  # head=4,head_dim = 4
test = torch.randn(1*4,8*8,16)  # 这里的数据符合窗口自注意力的数据，如果想要改变窗口的大小需要另行调整
testWA(test)

tensor([[[-0.2560, -0.0939, -0.0194,  ...,  0.2518,  0.3050, -0.3785],
         [-0.2602, -0.1133, -0.0187,  ...,  0.2388,  0.3101, -0.3518],
         [-0.2618, -0.1155, -0.0182,  ...,  0.2365,  0.3128, -0.3530],
         ...,
         [-0.3560, -0.1044,  0.0375,  ...,  0.2084,  0.3182, -0.3383],
         [-0.3559, -0.1077,  0.0367,  ...,  0.2072,  0.3186, -0.3393],
         [-0.2261, -0.1360,  0.0473,  ...,  0.2014,  0.3205, -0.3506]],

        [[-0.2129, -0.0876, -0.0165,  ...,  0.2283,  0.2701, -0.3152],
         [-0.1913, -0.1270,  0.0009,  ...,  0.2174,  0.3126, -0.3255],
         [-0.1909, -0.1273,  0.0016,  ...,  0.2143,  0.3127, -0.3229],
         ...,
         [-0.2453, -0.1447,  0.0735,  ...,  0.1842,  0.3495, -0.3479],
         [-0.2495, -0.1425,  0.0751,  ...,  0.1821,  0.3498, -0.3451],
         [-0.1570, -0.1626,  0.0571,  ...,  0.1950,  0.3388, -0.3490]],

        [[-0.2493, -0.0651, -0.0466,  ...,  0.2278,  0.2948, -0.3268],
         [-0.2275, -0.0954, -0.0064,  ...,  0

## 2.2 FeedForward
这里要处理 FeedForward 结构，感觉有多种选择，一种是像传统的使用全连接层，或者是使用 CNN
但是使用 CNN 也会有问题，是要在窗口内使用 CNN 还是在整张图像上使用 CNN

> 其实我比较倾向在窗口内使用 CNN，因为既然划分了窗口，那么随着网络的加深降低图像的分辨率同样可以让不同窗口之间的信息进行交互
> 但是如果想要实现全局自注意力还是需要让网络再深一层，不然就是随着网络的加深，窗口大小同时增大


In [17]:
class FeedForward(nn.Module):
    def __init__(self, method, dim, mlp_ratio,dropout_ratio):
        """
            FeedForward Block
            Args:
                input: (B*num_winodw, N, C)
                methods: the projection methods, cnn or mlp
                dim: the dim of token
                mlp_ratio: the dim of the hidden layer
                dropout_ratio: the ratio of the dropout
        """
        super().__init__()
        self.method = method
        if method == 'conv':
            self.layer1 = nn.Conv2d(dim, dim*mlp_ratio, 3, 1, 1)
            self.layer2 = nn.Conv2d(dim*mlp_ratio, dim, 3, 1, 1)
            self.drop=nn.Dropout2d(dropout_ratio)
        else:
            self.layer1 = nn.Linear(dim, dim*mlp_ratio)
            self.layer2 = nn.Linear(dim*mlp_ratio, dim)
            self.drop=nn.Dropout(dropout_ratio)
        self.act = nn.GELU()
        

    def forward(self, x):
        if self.method == 'conv':
            B, N, C = x.shape
            wh = ww = int(N**0.5)
            x = rearrange(x, 'b (wh ww) c -> b c wh ww',wh=wh,ww=ww)
            x=self.layer1(x)
            x=self.act(x)
            x=self.drop(x)
            x=self.layer2(x)
            x=self.drop(x)
            x = rearrange(x, 'b c wh ww -> b (wh ww) c')
        else:
            x=self.layer1(x)
            x=self.act(x)
            x=self.drop(x)
            x=self.layer2(x)
            x=self.drop(x)
        return x



In [52]:
test_fw=FeedForward('linear',16,4)
test_input=torch.randn(1*4,8*8,16)
test_fw(test_input).shape

torch.Size([4, 64, 16])

## 2.3 Basic Block

组成一个基础架构，包含自注意力和 FeedForward 
对其稍加修改，将 Dropout 添加进去了，据说可以在数据量小的情况下，方式过拟合

> 但是有一个问题，我希望在训练的开始阶段使用 Dropout，之后就不使用 Dropout 了，用直接定义对象的方式似乎不是很合适

In [21]:
class basicBlock(nn.Module):
    """
        一个基础的 Transformer Block，包含自注意力和 FeedForward
        Args:
            input: (B*num_windows, N, C)
            window_size: the size of the window
            dim: Token dim
            num_heads: numbers of the heads
            pe_flag: use positional encoding or not
            mlp_methods: FeedForward methods, Linear or conv
            mlp_ratio: the ratio of the feedforward block
            drop_ratio: the ratio of the dropout layer

    """

    def __init__(self, window_size, dim, num_heads, pe_flag, mlp_method, mlp_ratio, mlp_drop_ratio, attn_drop_ratio):
        super().__init__()
        self.attn = convWindowAttention(window_size=window_size, dim=dim, num_heads=num_heads, pe_flag=pe_flag, drop_ratio=attn_drop_ratio)
        self.feedForward = FeedForward(method=mlp_method, dim=dim, mlp_ratio=mlp_ratio, dropout_ratio=mlp_drop_ratio)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        shortcut = x
        x = self.norm1(x)   # preNorm
        x = self.attn(x)
        attnout = x+shortcut
        mlpout = self.feedForward(attnout)
        mlpout = self.norm2(mlpout)
        out = mlpout + attnout
        return out


In [34]:
testBasicBlock = basicBlock(8,16,2,False,'conv',4,0.1,1)
testInput = torch.randn(2,8*8,16)
testBasicBlock(testInput)


tensor([[[-2.6063,  0.7361, -0.2138,  ...,  0.3986, -1.8798, -1.1483],
         [-1.1505, -0.1727, -1.4844,  ...,  0.1473,  0.0247,  1.5883],
         [ 0.9689,  0.7363,  2.8224,  ..., -0.2719, -2.3062,  0.6836],
         ...,
         [-0.8456, -0.0271,  1.9537,  ...,  1.5990, -1.1831,  0.9799],
         [ 1.9941,  1.2018, -1.1034,  ..., -2.2401, -0.4410, -3.7872],
         [ 0.9032,  0.8688, -0.6262,  ..., -0.4189, -0.4473,  0.3869]],

        [[-2.4947, -0.1487, -2.2262,  ...,  0.1756,  0.6031,  0.5275],
         [-0.1865,  0.9018, -1.7052,  ...,  0.7480,  0.3009, -1.6146],
         [-0.2341, -0.6768,  0.7450,  ...,  0.2567, -1.1757, -1.1714],
         ...,
         [-0.2481, -0.1114, -0.1961,  ..., -4.6342,  0.2114, -0.0136],
         [-0.7939, -3.1093,  1.7647,  ...,  1.5772, -0.9247,  0.5921],
         [ 1.4138, -3.4308, -0.9724,  ..., -1.9892,  0.1002,  0.4153]]],
       grad_fn=<AddBackward0>)

## 2.4 Basic Layer

将上面的块组成 Uformer 中的一层

In [35]:
class basicLayer(nn.Module):
    """
        一个基础的 Uformer Layer, 是一层 U-Net 的结构
        Args:
            input: (B*num_windows, N, C)
            window_size: the size of the window
            dim: Token dim
            num_heads: numbers of the heads
            pe_flag: use positional encoding or not
            mlp_methods: FeedForward methods, Linear or conv
            mlp_ratio: the ratio of the feedforward block
            drop_ratio: the ratio of the dropout layer
            depth: the depth of this layer

    """

    def __init__(self, window_size, dim, num_heads, pe_flag, mlp_method, mlp_ratio, mlp_drop_ratio, attn_drop_ratio, depth):
        super().__init__()
        self.modlist = nn.ModuleList([])
        for i in range(depth):
            self.modlist.append(basicBlock(window_size=window_size, dim=dim, num_heads=num_heads, pe_flag=pe_flag,
                                mlp_method=mlp_method, mlp_ratio=mlp_ratio, mlp_drop_ratio=mlp_drop_ratio, attn_drop_ratio=attn_drop_ratio))

    def forward(self, x):
        for m in self.modlist:
            x = m(x)
        return x


In [44]:
testBasicLayer= basicLayer(8,16,2,False,'conv',4,0.1,0.1,2)
testInput = torch.randn(2,8*8,16)
testBasicLayer(testInput)


tensor([[[-0.7942, -1.8997, -1.0571,  ...,  1.2608, -0.5754, -2.4363],
         [ 0.1233, -1.5293,  1.0393,  ...,  1.8218,  2.8645, -2.0358],
         [-2.7585,  0.5277, -0.8782,  ...,  0.8342,  1.9935,  2.6708],
         ...,
         [ 3.9766,  2.9710, -0.6239,  ..., -1.3399,  3.1986,  3.6853],
         [-3.2670,  0.7657, -1.2656,  ..., -0.9717, -0.7699,  0.6301],
         [ 1.0382,  2.4210,  1.6876,  ..., -0.1289,  0.3380, -0.9510]],

        [[ 1.1781, -2.3276, -4.5955,  ..., -1.2207, -4.9892, -1.4254],
         [ 3.0435, -1.1702, -1.1032,  ...,  0.4988,  0.8493,  0.5271],
         [-0.9873,  0.2970,  0.6085,  ...,  0.6181, -1.7288,  2.8000],
         ...,
         [-0.5273, -2.5459,  1.5629,  ..., -0.7672, -0.0823,  0.3997],
         [-2.0775,  0.8554, -1.2231,  ..., -2.7632,  2.4376, -0.9827],
         [ 1.8451,  1.3500,  0.3397,  ..., -0.9480,  2.7525, -1.2971]]],
       grad_fn=<AddBackward0>)

## 2.5 上下采样

使用卷积进行上下采样，以组成 U-Net

> 这个上下采样不知道有没有什么骚操作可以做呢~
> 刚刚去看了 Uformer 的代码，没有发现有什么可以做的骚操作

### 2.5.1 下采样

In [69]:
class dowmSample(nn.Module):
    """
        下采样层，输入请将窗口转换为图像表示
        Args:
            input: Token, after window partition
            
    """
    def __init__(self,dim,window_size,resolution):
        super().__init__()
        self.window_size = window_size
        self.resolution=resolution
        self.conv=nn.Conv2d(dim,dim*2,4,2,1)
    def forward(self,x):
        B,N,C=x.shape
        Wh=Ww=int(N**0.5)
        x=rearrange(x,'b (Wh Ww) c -> b Wh Ww c',Wh=Wh,Ww=Ww)
        x=windowReverse(x,resolution=self.resolution).permute(0,3,1,2).contiguous()   # (B,C,H,W)
        x=self.conv(x).permute(0,2,3,1).contiguous()    # (B,H',W',2C), 此时分辨率已减小一半
        x=windowPartition(x,window_size=self.window_size)   # (B*num_window, Wh, Ww, C)
        x=rearrange(x,'b h w c -> b (h w) c')
        return x
        

In [70]:
testDownSample = dowmSample(16,8,16)    # 注意这个分辨率要写好，不要写歪了
test=torch.randn(1*4,8*8,16)
testDownSample(test).shape

torch.Size([1, 64, 32])

### 2.5.1 上采样

In [86]:
class upSample(nn.Module):
    """
        下采样层，输入请将窗口转换为图像表示
    """
    def __init__(self,dim,window_size,resolution):
        super().__init__()
        self.window_size = window_size
        self.resolution=resolution
        self.conv=nn.ConvTranspose2d(dim,dim//2,2,2)
    def forward(self,x):
        B,N,C=x.shape
        Wh=Ww=int(N**0.5)
        x=rearrange(x,'b (Wh Ww) c -> b Wh Ww c',Wh=Wh,Ww=Ww)
        x=windowReverse(x,resolution=self.resolution).permute(0,3,1,2).contiguous()   # (B,C,H,W)
        x=self.conv(x).permute(0,2,3,1).contiguous()    # (B,H',W',2C), 此时分辨率已减小一半
        x=windowPartition(x,window_size=self.window_size)   # (B*num_window, Wh, Ww, C)
        x=rearrange(x,'b h w c -> b (h w) c')
        return x
        

In [87]:
testDownSample = dowmSample(16,8,16)    # 注意这个分辨率要写好，不要写歪了
test=torch.randn(1*4,8*8,16)
down=testDownSample(test)
testUpSample=upSample(32,8,8)
testUpSample(down).shape

torch.Size([4, 64, 16])