In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# 创建一个 4x4 的图像，批次大小为1，通道数为1
x = torch.tensor([[[[ 0.,  1.,  2.,  3.],
                   [ 4.,  5.,  6.,  7.],
                   [ 8.,  9., 10., 11.],
                   [12., 13., 14., 15.]]]])
print("输入图像形状:", x.shape) # torch.Size([1, 1, 4, 4])

unfold = nn.Unfold(kernel_size=2, stride=2, padding=0)
print(unfold(x).reshape(4,4).T)
print(unfold(x).shape)

输入图像形状: torch.Size([1, 1, 4, 4])
tensor([[ 0.,  1.,  4.,  5.],
        [ 2.,  3.,  6.,  7.],
        [ 8.,  9., 12., 13.],
        [10., 11., 14., 15.]])
torch.Size([1, 4, 4])


In [11]:
x = F.pad(x, (0, 2, 0, 1, 0, 0))
print(x)
print(x.shape)

tensor([[[[ 0.,  1.,  2.,  3.,  0.,  0.],
          [ 4.,  5.,  6.,  7.,  0.,  0.],
          [ 8.,  9., 10., 11.,  0.,  0.],
          [12., 13., 14., 15.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.]]]])
torch.Size([1, 1, 5, 6])


In [20]:
y = torch.randn((1,2,4,4))
y = F.pad(y, (0, 2, 0, 1,0,0))
print(y)
print(y.shape)
print(y.transpose(1,2).shape)

tensor([[[[ 0.3838,  2.3252,  1.6461,  0.6313,  0.0000,  0.0000],
          [-0.6535, -1.1300,  0.4256, -0.9707,  0.0000,  0.0000],
          [-0.3385,  1.9245, -0.5994, -0.8047,  0.0000,  0.0000],
          [ 0.0867,  0.1324,  0.2788,  0.2740,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

         [[-0.2247, -1.4819, -0.8830, -0.5641,  0.0000,  0.0000],
          [-0.7296, -1.5403, -1.3028, -1.5787,  0.0000,  0.0000],
          [-0.2095,  2.0121, -0.3904, -0.2283,  0.0000,  0.0000],
          [ 1.4176, -0.7287,  1.1338, -0.1613,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]]])
torch.Size([1, 2, 5, 6])
torch.Size([1, 5, 2, 6])


In [25]:
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])
x = torch.roll(x, (-1, -1), dims=(0,1))
print(x)
print(x.shape)

tensor([[5, 6, 4],
        [8, 9, 7],
        [2, 3, 1]])
torch.Size([3, 3])


In [18]:
drop_path_rate = 0.1
depths = (2, 2, 6, 2)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
print(dpr)
for i_layer in range(len(depths)):
    print(sum(depths[:i_layer]))
    drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])]
    print(drop_path)

[0.0, 0.00909090880304575, 0.0181818176060915, 0.027272727340459824, 0.036363635212183, 0.045454543083906174, 0.054545458406209946, 0.06363636255264282, 0.0727272778749466, 0.08181818574666977, 0.09090909361839294, 0.10000000149011612]
0
[0.0, 0.00909090880304575]
2
[0.0181818176060915, 0.027272727340459824]
4
[0.036363635212183, 0.045454543083906174, 0.054545458406209946, 0.06363636255264282, 0.0727272778749466, 0.08181818574666977]
10
[0.09090909361839294, 0.10000000149011612]


In [36]:
x = torch.tensor([0, 1])  # 假设是x坐标
y = torch.tensor([0, 1])     # 假设是y坐标

# 默认 indexing='ij'
xx_ij, yy_ij = torch.meshgrid(x, y, indexing='ij')
a = torch.stack((xx_ij, yy_ij))
print(a)
a = torch.flatten(a, 1)
print(a)
a[:, :, None] - a[:, None, :]

tensor([[[0, 0],
         [1, 1]],

        [[0, 1],
         [0, 1]]])
tensor([[0, 0, 1, 1],
        [0, 1, 0, 1]])


tensor([[[ 0,  0, -1, -1],
         [ 0,  0, -1, -1],
         [ 1,  1,  0,  0],
         [ 1,  1,  0,  0]],

        [[ 0, -1,  0, -1],
         [ 1,  0,  1,  0],
         [ 0, -1,  0, -1],
         [ 1,  0,  1,  0]]])

In [37]:
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])
print(x[0])

tensor([1, 2, 3])


In [None]:
class CyclicShift(nn.Module):
    def __init__(self, displacement):
        super().__init__()
        self.displacement = displacement

    def forward(self, x):
        return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))

In [None]:
def get_relative_distances(window_size):
    indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
    distances = indices[None, :, :] - indices[:, None, :]
    return distances

In [None]:
class WindowAttention(nn.Module):
    def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):
        super().__init__()
        inner_dim = head_dim * heads
        self.heads = heads
        self.scale = head_dim ** -0.5
        self.window_size = window_size
        self.relative_pos_embedding = relative_pos_embedding # (13, 13)
        self.shifted = shifted

        if self.shifted:
            displacement = window_size // 2
            self.cyclic_shift = CyclicShift(-displacement)
            self.cyclic_back_shift = CyclicShift(displacement)
            self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement, upper_lower=True, left_right=False), requires_grad=False) # (49, 49)
            self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,pper_lower=False, left_right=True), requires_grad=False) # (49, 49)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        if self.relative_pos_embedding:
            self.relative_indices = get_relative_distances(window_size) + window_size - 1
            self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))
        else:
            self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2))

        self.to_out = nn.Linear(inner_dim, dim)

    def forward(self, x):
        if self.shifted:
            x = self.cyclic_shift(x)

        b, n_h, n_w, _, h = *x.shape, self.heads # [1, 56, 56, _, 3]
        qkv = self.to_qkv(x).chunk(3, dim=-1) # [(1,56,56,96), (1,56,56,96), (1,56,56,96)]
        nw_h = n_h // self.window_size # 8
        nw_w = n_w // self.window_size # 8
        # 分成 h/M * w/M 个窗口
        q, k, v = map( lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d', h=h, w_h=self.window_size, w_w=self.window_size), qkv)
        # q, k, v : (1, 3, 64, 49, 32)
        # 按窗口个数的self-attention
        dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale # (1,3,64,49,49)

        if self.relative_pos_embedding:
            dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
        else:
            dots += self.pos_embedding

        if self.shifted:
            dots[:, :, -nw_w:] += self.upper_lower_mask
            dots[:, :, nw_w - 1::nw_w] += self.left_right_mask

        attn = dots.softmax(dim=-1) # (1,3,64,49,49)
        out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
        out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)', h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w) # (1, 56, 56, 96) # 窗口合并
        out = self.to_out(out)
        if self.shifted:
            out = self.cyclic_back_shift(out)
        return out

In [None]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class SwinBlock(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding):
        super().__init__()
        self.attention_block = Residual(PreNorm(dim, WindowAttention(dim=dim, heads=heads, head_dim=head_dim, shifted=shifted, window_size=window_size, relative_pos_embedding=relative_pos_embedding)))
        self.mlp_block = Residual(PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim)))

    def forward(self, x):
        x = self.attention_block(x)
        x = self.mlp_block(x)
        return x