In [2]:
import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange


In [3]:
def pair(t):
    return t if isinstance(t, tuple) else (t, t)



In [4]:

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)



In [5]:
# transformer 里面的MLP 
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        # PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
        # dim = 1024,   mlp_dim = 2048,dropout = 0.1,
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim), 
            nn.Dropout(dropout)
        ) # 1024-> 2048 -> GELU-> Dropout->1024->Dropout--->>1024
    def forward(self, x):
        return self.net(x)


In [6]:
   
class Attention(nn.Module):              
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        # dim = 1024,   depth = 6,heads = 16, dim_head=64,  mlp_dim = 2048,dropout = 0.1,
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout),
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim=-1)           # (b, n(65), dim*3) ---> 3 * (b, n, dim)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)          # q, k, v   (b, h, n, dim_head(64))

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


In [7]:

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        # dim = 1024,   depth = 6,heads = 16, dim_head=64,  mlp_dim = 2048,dropout = 0.1,
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth): # depth = 6
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), # 先norm，再算Attention
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ]))
    
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x # 加上残差
            x = ff(x) + x   # 加上残差
        return x


In [41]:
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, \
                 mlp_dim, pool='cls', channels=3, dim_head=64, dropout=0., emb_dropout=0.):
        '''
        image_size = 256,
        patch_size = 32,
        num_classes = 1000,
        dim = 1024,
        depth = 6,
        heads = 16,
        mlp_dim = 2048,
        dropout = 0.1,
        emb_dropout = 0.1
        '''
        super().__init__()
        image_height, image_width = pair(image_size) # 图片是256*256
        patch_height, patch_width = pair(patch_size) #patch is 32*32

        assert  image_height % patch_height ==0 and image_width % patch_width == 0 # 保证可以整除

        num_patches = (image_height // patch_height) * (image_width // patch_width) # patch个数=64
        print(num_patches)
        patch_dim = channels * patch_height * patch_width # patch_dim = 3*32*32 =  3,072
        assert pool in {'cls', 'mean'}

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.Linear(patch_dim, dim) # 改变patch 维度 到dim=1024
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))					# nn.Parameter()定义可学习参数
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        # dim = 1024,   depth = 6,heads = 16, dim_head=64,  mlp_dim = 2048,dropout = 0.1,
        self.pool = pool # pool='cls'
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes) # 1024--->num_classes
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)        # b c (h p1) (w p2) -> b (h w) (p1 p2 c) -> b (h w) dim
        b, n, _ = x.shape           # b表示batchSize, n表示每个块的空间分辨率就是patch的个数, _表示一个块内有多少个值
        print('n is:',n)
        print('_ is:',_)
        # repeat() 复制并增添维度
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) 
        # self.cls_token是要操作的tensor名，这一行的目的是把self.cls_token复制为batchsize的个数
        # b是batchsize = 16，后面的n，d不变， cls_token是（1,1，dim）
        # self.cls_token: (1, 1, dim) -> cls_tokens: (batchSize, 1, dim)  
        x = torch.cat((cls_tokens, x), dim=1)               # 将cls_token拼接到patch token中去       (b, 65, dim)
        print('pos_embedding[:, :(n+1)] is:\n',self.pos_embedding[:, :(n+1)] )
        print('pos_embedding[:, :(n+1)] shape is:\n',self.pos_embedding[:, :(n+1)].shape )
        x = x + self.pos_embedding[:, :(n+1)]                  # 加位置嵌入（直接加）      (b, 65, dim)
        x = self.dropout(x) # x中的一部分值会被随机变为零，这个只能在训练的时候用

        x = self.transformer(x)                                                 # (b, 65, dim)
        print("x after transformer:\n",x)
        print("x.shape after transformer:\n",x.shape)
        # self.pool == 'cls'
        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]                   # (b, dim) # 只取65个patch中的序号为0的一个也就是cls

        x = self.to_latent(x)                                                   # Identity (b, dim) # 不变，只是占位
        print('取65个patch中的序号为0的一个也就是cls:',x.shape)

        return self.mlp_head(x)                                                 #  (b, num_classes)



In [42]:
model_vit = ViT(
        image_size = 256,
        patch_size = 32,
        num_classes = 1000,
        dim = 1024,
        depth = 6,
        heads = 16,
        mlp_dim = 2048,
        dropout = 0.1,
        emb_dropout = 0.1
    )

img = torch.randn(16, 3, 256, 256) # （图片数，3通道，256*256）

preds = model_vit(img) 

print(preds.shape)  # (16, 1000) # 1000是输出类别个数1000个



64
n is: 64
_ is: 1024
pos_embedding[:, :(n+1)] is:
 tensor([[[-0.3722,  0.3396,  1.4037,  ..., -0.4461, -0.7364,  0.2433],
         [ 1.4336,  1.9588, -0.2457,  ...,  0.7117, -0.2338,  0.3758],
         [ 0.5752,  0.3203,  0.2581,  ...,  0.3693, -0.2199,  1.1205],
         ...,
         [-1.3470,  0.5339,  1.1822,  ...,  1.0060, -1.1918, -0.4994],
         [-0.8746, -0.4046, -0.1143,  ..., -0.7404, -1.5288, -1.5283],
         [ 1.2498, -0.3553,  0.9992,  ...,  0.0915,  0.0025, -1.6904]]],
       grad_fn=<SliceBackward0>)
pos_embedding[:, :(n+1)] shape is:
 torch.Size([1, 65, 1024])
x after transformer:
 tensor([[[-1.4542,  1.1076,  0.7253,  ..., -0.4778, -0.4722, -0.0113],
         [ 1.3191,  2.0528, -0.6181,  ..., -0.1488, -1.1117, -0.0330],
         [-0.3973,  0.1895,  0.4334,  ...,  0.2784, -0.6314,  2.3400],
         ...,
         [-0.8121, -0.3650,  1.3722,  ..., -0.1010, -0.9574,  0.6822],
         [-1.2663, -0.9729, -0.5192,  ..., -1.2287, -1.3402, -2.4660],
         [ 2.0645, 