### vision transformer
学习

In [1]:
from functools import partial
from collections import OrderedDict

import torch
import torch.nn as nn


In [2]:
class PathEmbed(nn.Module):
    """2d images to path Embedding 
    """
    def __init__(self, image_size = 224, path_size = 16, in_c = 3, embed_dim = 768, norm_layer = None):
        super().__init__()
        image_size=(image_size, image_size)
        path_size = (path_size, path_size)
        self.image_size = image_size
        self.patch_size = path_size
        self.grid_size = (image_size[0] // path_size[0], image_size[1] // path_size[1])
        self.num_patched = self.grid_size[0]*self.grid_size[1]


        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=path_size, stride=path_size) # 按照默认参数 输出维度是 768 * 14 * 14
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
    
    def forward(self, x):
        B,C,H,W = x.shape
        assert H == self.image_size[0] and W == self.image_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
        
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = self.proj(x).flatten(2).transpose(1,2)
        x = self.norm(x)

        return x
 

In [None]:
class VisionTransformer(nn.Module):
   def __init__(self, image_size=224, path_size=16, in_c=3, num_classes=1000,
                embed_dim=768, depth=12, num_heads=12,mlp_ratio=4.0,qkv_bias=True,
                qk_scale=None, representation=None, distilled=False,drop_ratio=0.,
                attn_drop_ration=0., drop_path_ratio=0.5,embed_layer=PathEmbed,norm_layer=None,
                act_layer=None
                ):
       """_summary_

       Args:
           image_size (int, optional): 图片输入的size.
           path_size (int, optional): 图片每个path的size.
           in_c (int, optional): 图片输入的通道数.
           number_classes (int, optional):分类数
           embed_dim (int, optional): 嵌入维度
           depth (int, optional): transformer 深度 ？
           num_heads (int, optional): 注意力头数
           mlp_ratio (float, optional): mlp嵌入跟词嵌入的维度比例
           qkv_bias (bool, optional): 是否需要qkv权重的偏置项
           qk_scale (_type_, optional): _description_. Defaults to None.
           representation (_type_, optional): _description_. Defaults to None.
           distilled (bool, optional): _description_. Defaults to False.
           drop_ratio (_type_, optional): drop指标.
           attn_drop_ration (_type_, optional): _description_. Defaults to 0..
           drop_path_ratio (float, optional): _description_. Defaults to 0.5.
           embed_layer (_type_, optional): _description_. Defaults to PathEmbed.
           norm_layer (_type_, optional): _description_. Defaults to None.
           act_layer (_type_, optional): _description_. Defaults to None.
       """
       super(VisionTransformer, self).__init__()
       self.num_classes = num_classes
       self.num_features = self.embed_dim=embed_dim

       self.num_tokens = 2 if distilled else 1
       norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)