In [None]:
import math
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:
class MaskDownSampler(nn.Module):
    """_summary_

    Args:
        nn (_type_): _description_
    """
    def __init__(
        self,
        embedding_dim: int=256,  # 最后的编码维度，也是输出通道数
        kernel_size: int=4,
        stride: int=4,
        padding: int=0,
        total_stride: int=16,
        activation: nn.Module=nn.ReLU,
    ) -> None:
        super().__init__()
        num_layers = int(math.log2(total_stride) // math.log2(stride))
        assert stride ** num_layers == total_stride
        self.encoder = nn.Sequential()
        mask_in_chans, mask_out_chans = 1, 1
        for _ in range(num_layers):
            mask_out_chans = mask_in_chans * (stride**2)
            self.encoder.append(
                nn.Conv2d(
                    mask_in_chans,
                    mask_out_chans,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                )
            )
            self.encoder.append(LayerNorm2d(mask_out_chans))
            self.encoder.append(activation())
            mask_in_chans = mask_out_chans

        self.encoder.append(nn.Conv2d(mask_out_chans, embedding_dim, kernel_size=1))

    def forward(self, x):
        return self.encoder(x)
        
        

In [None]:
class CXBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        kernel_size: int=7,
        padding: int=3,
        drop_path: int=0.0,
        layer_scale_init_value=1e-6,
        use_dwconv=True,
    ):
        super().__init__()
        self.dwconv = nn.Conv2d(
            dim,
            dim,
            kernel_size=kernel_size,
            padding=padding,
            groups=dim if use_dwconv else 1,
        )
        self.norm = LayerNorm2d(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(
            dim, 4 * dim
        )
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(
            4 * dim, dim
        )
        self.gamma = (
            nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
            if layer_scale_init_value > 0
            else None
        )  # 可学习参数
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = self.norm(x)
        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)

        x = input + self.drop_path(x)
        return x
