### vision transformer
学习

In [1]:
from functools import partial
from collections import OrderedDict

import torch
import torch.nn as nn


In [2]:
class PathEmbed(nn.Module):
    """2d images to path Embedding 
    """
    def __init__(self, image_size = 224, path_size = 16, in_c = 3, embed_dim = 768, norm_layer = None):
        super().__init__()
        image_size=(image_size, image_size)
        path_size = (path_size, path_size)
        self.image_size = image_size
        self.patch_size = path_size
        self.grid_size = (image_size[0] // path_size[0], image_size[1] // path_size[1])
        self.num_patches = self.grid_size[0]*self.grid_size[1]


        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=path_size, stride=path_size) # 按照默认参数 输出维度是 768 * 14 * 14
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
    
    def forward(self, x):
        B,C,H,W = x.shape
        assert H == self.image_size[0] and W == self.image_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
        
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = self.proj(x).flatten(2).transpose(1,2)
        x = self.norm(x)

        return x
 

In [3]:
class Attention(nn.Module):
    def __init__(self,
                 dim, # 输入token的dim
                 num_heads=False,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.
                 ):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias) # 将x--3倍维度扩张-->q k v
        self.atten_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim,dim)
        self.proj_drop = nn.Dropout(proj_drop_ratio)

    def forward(self, x):
        B,N,C = x.shape # 1，2，6
        # self.qkv(x)->1,2,9
        # permute 维度置换 3(代表q、k、v),batch 1,head 3, 序列2, 嵌入2
        qkv = self.qkv(x).reshape(B,N,3,self.num_heads,
                                  C//self.num_heads).permute(2,0,3,1,4) 
        q,k,v = qkv[0],qkv[1],qkv[2] 

        attn = (q@k.tanspose(-2,-1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.atten_drop(attn)


        x = (attn @ v).transpose(1,2).reshape(B,N,C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

        


In [4]:
def drop_path(x, drop_prob:float=0., training:bool=False):
    """
    参考连接: https://blog.csdn.net/qq_43135204/article/details/127912029
    
    :param x: 
    :param drop_prob: 
    :param training: 
    :return: 
    """
    if drop_prob == 0 or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # x的第一维度,x的维度数 


    # 随机均匀的在(0,1]上生成shape形状的tensor每个点加上keep_prob
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    # print(random_tensor)
    random_tensor.floor_()# 向下取整操作
    # print(random_tensor)
    
    #将x中的每个元素除以keep_prob（放缩操作） 然后乘以random_tensor 这里巧妙的利用的广播机制
    # 意义在于 rt经过向下取整后只有0/1可以忽略为0的位置，这里放缩理解不是很清晰应该是加大存在元素的影响
    output = x.div(keep_prob) * random_tensor 

    return output

In [5]:
class DropPath(nn.Module):
    """
    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    每个样本的下降路径（随机深度）（当应用于残余块的主路径时）。"""
    
    def __init__(self,drop_prob = None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        
    def forward(self, x):
        return drop_path(x,self.drop_prob, self.training)
    

In [6]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None,act_layer=nn.GELU,drop=0):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


In [7]:
class Block(nn.Module):
    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=4,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 act_layer=nn.GELU,
                 norm_layer = nn.LayerNorm
                 ):
        super(Block, self).__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim,num_heads=num_heads,qkv_bias=qkv_bias,qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_path_ratio
                              )
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio >0. else nn.Identity()   
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp=MLP(in_features=dim,
                     hidden_features=mlp_hidden_dim,
                     act_layer=act_layer,
                     drop=drop_ratio)
        
    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        pass

In [1]:
class VisionTransformer(nn.Module):
   def __init__(self, image_size=224, path_size=16, in_c=3, num_classes=1000,
                embed_dim=768, depth=12, num_heads=12,mlp_ratio=4.0,qkv_bias=True,
                qk_scale=None, representation=None, distilled=False,drop_ratio=0.,
                attn_drop_ration=0., drop_path_ratio=0.5,embed_layer=PathEmbed,norm_layer=None,
                act_layer=None
                ):
       """_summary_

       Args:
           image_size (int, optional): 图片输入的size.
           path_size (int, optional): 图片每个path的size.
           in_c (int, optional): 图片输入的通道数.
           number_classes (int, optional):分类数
           embed_dim (int, optional): 嵌入维度
           depth (int, optional): transformer 深度 ？
           num_heads (int, optional): 注意力头数
           mlp_ratio (float, optional): mlp嵌入跟词嵌入的维度比例
           qkv_bias (bool, optional): 是否需要qkv权重的偏置项
           qk_scale (_type_, optional): _description_. Defaults to None.
           representation (_type_, optional): _description_. Defaults to None.
           distilled (bool, optional): _description_. Defaults to False.
           drop_ratio (_type_, optional): drop指标.
           attn_drop_ration (_type_, optional): _description_. Defaults to 0..
           drop_path_ratio (float, optional): _description_. Defaults to 0.5.
           embed_layer (_type_, optional): _description_. Defaults to PathEmbed.
           norm_layer (_type_, optional): _description_. Defaults to None.
           act_layer (_type_, optional): _description_. Defaults to None.
       """
       super(VisionTransformer, self).__init__()
       self.num_classes = num_classes
       self.num_features = self.embed_dim=embed_dim
        
       self.num_tokens = 2 if distilled else 1
       norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
       act_layer = act_layer or nn.GELU

       self.patch_embed = embed_layer(image_size=image_size, path_size=path_size, in_c=in_c, embed_dim=embed_dim)
       num_patches = self.patch_embed.num_patches

       self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
       self.dist_token = nn.Parameter(torch.zeros(1,1,embed_dim)) if distilled else None
       self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+self.num_tokens, embed_dim))
       self.pos_drop = nn.Dropout(p=drop_ratio)

       dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule 随机深度衰减规则
       self.blocks = nn.Sequential(*[
           Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, )
       ])


       
       

NameError: name 'nn' is not defined

In [None]:
*