In [None]:
import torch
import torch.nn as nn

PRINT_IF_FLAG = True


def PRINT_IF(x: torch.Tensor, name=""):
    if PRINT_IF_FLAG:
        print(
            "Tensor {} shape {}. mean:{:.2f} range:{:.2f}~{:.2f}".format(
                name, x.shape, x.mean().item(), x.min().item(), x.max().item()
            )
        )


# reshape image x(H,W,C) -> x(N,P*P*C)
# cls_token in start
class PatchEmbed(nn.Module):
    def __init__(self, in_channels=3, embed_dims=768, kernel_size=16, stride=16):
        super().__init__()
        self.projection = nn.Conv2d(in_channels, embed_dims, kernel_size, stride)

    def forward(self, x: torch.Tensor):
        # B,C,H,W
        x = self.projection(x)
        PRINT_IF(x)
        # B,C,H//16,W//16
        x = x.flatten(2)
        PRINT_IF(x)
        # B,C,H//16*W//16
        x = x.transpose(1, 2)
        PRINT_IF(x)
        return x


class FFN(nn.Module):
    """Implements feed-forward networks (FFNs) with identity connection."""

    def __init__(
        self,
        embed_dims,
        feedforward_channels,
        num_fcs,
        ffn_drop=0.0,
        add_identity=True,
    ):
        super().__init__()
        assert num_fcs >= 2
        self.embed_dims = embed_dims
        self.feedforward_channels = feedforward_channels
        self.num_fcs = num_fcs
        self.add_identity = add_identity
        in_channels = embed_dims
        layers = []
        for _ in range(num_fcs - 1):
            layers.extend(
                [
                    nn.Linear(in_channels, feedforward_channels),
                    nn.GELU(),
                    nn.Dropout(ffn_drop),
                ]
            )
            in_channels = feedforward_channels
        layers.extend(
            [
                nn.Linear(in_channels, feedforward_channels),
                nn.Dropout(ffn_drop),
            ]
        )
        self.layers = nn.Sequential(*layers)

    def forward(self, x, identity=None):
        #2,196,768
        out = self.layers(x)
        if identity is None:
            identity = x
        if self.add_identity:
            return identity + out
        else:
            return out


class TransformerEncoderLayer(nn.Module):
    def __init__(
        self, embed_dims, num_heads, feedforward_channels, drop_rate=0.0, qkv_bias=True
    ):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dims, num_heads, drop_rate, qkv_bias)
        self.ffn = FFN(embed_dims,feedforward_channels,2,drop_rate,add_identity=True)
        self.norm1=nn.LayerNorm(embed_dims)
        self.norm2=nn.LayerNorm(embed_dims)

    def forward(self, x):
        x=self.norm1(x)
        x,_=self.attn(x,x,x)
        PRINT_IF(x)
        x=self.ffn(self.norm2(x))
        PRINT_IF(x)
        return x


class Vit(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        pass


In [None]:
# test
x=torch.ones([2,3,224,224])
patch_embed=PatchEmbed(3,768)
encoder=TransformerEncoderLayer(768,8,768)
a=patch_embed(x)
b=encoder(a)

In [None]:
cls_token=torch.zeros([1,1,768])
cls_token.expand(2,-1,-1).shape

In [4]:
import torch.nn as nn
layer=nn.Linear(4,4)
