### 모델 구현

#### MLP

In [None]:
class MLP(nn.Module):
    def __init__(self, patches_dim, n_hidden_layer, drop_p):
        super().__init__()

        self.mlp = nn.Sequential(
            nn.Linear(patches_dim, n_hidden_layer),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(n_hidden_layer, patches_dim))

    def forward(self, x):
        result = self.mlp(x)
        return result

#### PatchMerging

In [None]:
class PatchMerging(nn.Module):
    def __init__(self, patches_dim):
        super().__init__()

        self.patches_dim = patches_dim
        self.norm = nn.LayerNorm(4*patches_dim, eps=1e-5)
        self.compression = nn.Linear(4*patches_dim, 2*patches_dim, bias=False)

    def forward(self, x):
        image_h, image_w, _ = x.shape[1:]
        padding_x = func.pad(x, (0, 0, 0, image_w % 2, 0, image_h % 2))

        x0 = padding_x[..., 0::2, 0::2, :]
        x1 = padding_x[..., 1::2, 0::2, :]
        x2 = padding_x[..., 0::2, 1::2, :]
        x3 = padding_x[..., 1::2, 1::2, :]
        new_patch = torch.cat([x0, x1, x2, x3], -1)

        new_patch = self.norm(new_patch)
        result = self.compression(new_patch)

        return result

#### W-MSA, SW-MSA

In [None]:
class ShiftedWindowAttention(nn.Module):
    def __init__(self, patches_dim, window_size, shift_size, n_heads, drop_p, device=DEVICE):
        super().__init__()

        self.device = device if device is not None else torch.device("cpu")
        self.window_size = window_size
        self.shift_size = shift_size
        self.n_heads = n_heads
        self.root_dk = torch.sqrt(torch.tensor(patches_dim / n_heads, dtype=torch.float32, device=self.device))

        self.q_Linear = nn.Linear(patches_dim, patches_dim)
        self.k_Linear = nn.Linear(patches_dim, patches_dim)
        self.v_Linear = nn.Linear(patches_dim, patches_dim)
        self.last_Linear = nn.Linear(patches_dim, patches_dim)
        self.linear_layers = [self.q_Linear, self.k_Linear, self.v_Linear, self.last_Linear]

        self.get_relative_position_bias()

    def get_relative_position_bias(self):
        B_hat = nn.Parameter(
            torch.zeros(
                self.n_heads,
                (2 * self.window_size[0] - 1),
                (2 * self.window_size[1] - 1),
                device=self.device))
        init.trunc_normal_(B_hat, std=0.02)

        absolute_position_h = torch.arange(self.window_size[0], device=self.device)
        absolute_position_w = torch.arange(self.window_size[1], device=self.device)
        absolute_coordinate_h, absolute_coordinate_w = torch.meshgrid(absolute_position_h, absolute_position_w,indexing='ij')

        relative_coords_h = absolute_coordinate_h.reshape(1,-1) - absolute_coordinate_h.reshape(-1,1)
        relative_coords_w = absolute_coordinate_w.reshape(1,-1) - absolute_coordinate_w.reshape(-1,1)
        relative_index_h = relative_coords_h + self.window_size[0] - 1
        relative_index_w = relative_coords_w + self.window_size[1] - 1

        self.B = B_hat[:, relative_index_h, relative_index_w].unsqueeze(0).unsqueeze(0).detach()

    def forward(self, x):

        window_h, window_w = self.window_size
        _, H, W, _ = x.shape

        pad_right = (window_w - W % window_w) % window_w
        pad_under = (window_h - H % window_h) % window_h
        padding_x = func.pad(x, (0, 0, 0, pad_right, 0, pad_under))
        _, padding_h, padding_w, _ = padding_x.shape

        shift_size = self.shift_size.copy()
        if padding_h <= window_h:
            shift_size[0] = 0
        if padding_w <= window_w:
            shift_size[1] = 0

        if sum(shift_size) > 0:
            shift_x = torch.roll(padding_x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
            ready_for_attention = rearrange(shift_x, 'b (n_window_h window_h) (n_window_w window_w) d -> b (n_window_h n_window_w) (window_h window_w) d',
                                        window_h=window_h, window_w=window_w)
        else:
            ready_for_attention = rearrange(padding_x, 'b (n_window_h window_h) (n_window_w window_w) d -> b (n_window_h n_window_w) (window_h window_w) d',
                                        window_h=window_h, window_w=window_w)

        Q, K, V = [linear(ready_for_attention) for linear in self.linear_layers[:3]]
        Q, K, V = [rearrange(tensor, 'b nw p (h d) -> b nw h p d', h=self.n_heads) for tensor in [Q, K, V]]

        qkt_dk = torch.matmul(Q, K.transpose(-2,-1)) / self.root_dk
        ready_for_masking = qkt_dk + self.B

        if sum(shift_size) > 0:
            window_group_number = ready_for_attention.new_zeros(padding_h, padding_w)
            slice_h = ((0, -window_h), (-window_h, -shift_size[0]), (-shift_size[0], None))
            slice_w = ((0, -window_w), (-window_w, -shift_size[1]), (-shift_size[1], None))
            count = 0
            for h in slice_h:
                for w in slice_w:
                    window_group_number[h[0] : h[1], w[0] : w[1]] = count
                    count += 1
            window_group_number = rearrange(window_group_number, '(n_window_h window_h) (n_window_w window_w) -> (n_window_h n_window_w) (window_h window_w)',
                                            window_h=window_h, window_w=window_w)
            mask = window_group_number.unsqueeze(2) - window_group_number.unsqueeze(1)
            mask[mask != 0] = -1e10
            done_masking = ready_for_masking + mask.unsqueeze(1).unsqueeze(0)
        else:
            done_masking = ready_for_masking

        done_masking = func.softmax(done_masking, dim=-1)
        attention_result = torch.matmul(done_masking, V)
        concat_out = rearrange(attention_result, 'b nw h p d -> b nw p (h d)')
        concat_out = self.last_Linear(concat_out)
        windows_merged = rearrange(concat_out, 'b (n_window_h n_window_w) (window_h window_w) d -> b (n_window_h window_h) (n_window_w window_w) d',
                      n_window_h=padding_h//window_h, window_h=window_h)

        if sum(shift_size) > 0:
            shift_restored = torch.roll(windows_merged, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
        else:
            shift_restored = windows_merged

        result = shift_restored[:, :H, :W, :]

        return result

#### Transformer

In [None]:
class SwinTransformerBlock(nn.Module):
    def __init__(self, patches_dim, n_heads, window_size, shift_size, hidden_ratio, drop_p, stochastic_depth_p):
        super().__init__()

        self.first_norm = nn.LayerNorm(patches_dim, eps=1e-5)
        self.sw_msa = ShiftedWindowAttention(patches_dim, window_size, shift_size, n_heads, drop_p=drop_p)
        self.second_norm = nn.LayerNorm(patches_dim, eps=1e-5)
        self.mlp = MLP(patches_dim, int(patches_dim * hidden_ratio), drop_p=drop_p)
        self.dropout = nn.Dropout(drop_p)
        self.stochastic_depth = StochasticDepth(stochastic_depth_p, "row")

        for linear in self.mlp.modules():
            if isinstance(linear, nn.Linear):
                init.xavier_uniform_(linear.weight)
                if linear.bias is not None:
                    init.normal_(linear.bias, std=1e-6)

    def forward(self, x):
        norm_out = self.first_norm(x)
        msa_out = self.sw_msa(norm_out)
        msa_out = self.dropout(msa_out)
        msa_out = self.stochastic_depth(msa_out)
        msa_result_with_skip = x + msa_out

        norm_out = self.second_norm(msa_result_with_skip)
        mlp_out = self.mlp(norm_out)
        mlp_out = self.dropout(mlp_out)
        mlp_out = self.stochastic_depth(mlp_out)
        result_with_skip = msa_result_with_skip + mlp_out

        return result_with_skip

#### SwinTransformer

In [None]:
class SwinTransformer(nn.Module):
    def __init__(self, patch_size, embedding_dim, n_transformer, n_heads, window_size, hidden_ratio, drop_p, stochastic_depth_p, n_classes):
        super().__init__()

        self.n_classes = n_classes
        n_all_blocks = sum(n_transformer)
        block_index = 0

        layers = []
        layers += [nn.Sequential(nn.Conv2d(3, embedding_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1])),
                                 Permute([0, 2, 3, 1]),
                                 nn.LayerNorm(embedding_dim, eps=1e-5))]

        for stage_num in range(len(n_transformer)):
            stage = []
            stage_patches_dim = embedding_dim * 2**stage_num

            for transformer_block_num in range(n_transformer[stage_num]):
                stochastic_depth_prob = stochastic_depth_p * block_index / (n_all_blocks - 1)

                stage += [SwinTransformerBlock(stage_patches_dim,
                                               n_heads[stage_num],
                                               window_size=window_size,
                                               shift_size=[0 if transformer_block_num % 2 == 0 else w // 2 for w in window_size],
                                               hidden_ratio=hidden_ratio,
                                               drop_p = drop_p,
                                               stochastic_depth_p=stochastic_depth_prob)]
                block_index += 1

            layers += [nn.Sequential(*stage)]

            if stage_num < (len(n_transformer) - 1):
                layers += [PatchMerging(stage_patches_dim)]

        self.all_layers = nn.Sequential(*layers)
        self.norm = nn.LayerNorm(stage_patches_dim, eps=1e-5)
        self.GAP = nn.Sequential(Permute([0, 3, 1, 2]),
                                 nn.AdaptiveAvgPool2d((1,1)))
        self.head = nn.Linear(stage_patches_dim, n_classes)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    init.zeros_(m.bias)

    def forward(self, x):
        patches = self.all_layers(x)
        patches = self.norm(patches)
        patches = self.GAP(patches)
        patches = torch.flatten(patches, start_dim=1)
        model_result = self.head(patches)

        return model_result