In [1]:
import torch
import torch.nn as nn

In [6]:
x = torch.rand((2,3,224,224))
x.shape

torch.Size([2, 3, 224, 224])

In [None]:
224,224,3


In [12]:
x = torch.rand((2,3,224,224))
print('x shape:',x.shape)
img_size = 224
patch_size = 16
in_c = 3
grid_size = img_size // patch_size  # 14
num_patches = grid_size * grid_size # 196
embed_dim = 768

proj = nn.Conv2d(in_c,embed_dim,kernel_size=patch_size,stride=patch_size)
x = proj(x)
print('shape',x.shape)
x_flatten = x.flatten(start_dim=2)
print('flatten shape:',x_flatten.shape)
x_transpose = x_flatten.transpose(1,2)
print('transpose shape:',x_transpose.shape)

x shape: torch.Size([2, 3, 224, 224])
shape torch.Size([2, 768, 14, 14])
flatten shape: torch.Size([2, 768, 196])
transpose shape: torch.Size([2, 196, 768])


In [15]:
cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
print('class token shape:',cls_token.shape)
cls_token = cls_token.expand(x.shape[0],-1,-1)
print('class token shape:',cls_token.shape)

class token shape: torch.Size([1, 1, 768])
class token shape: torch.Size([2, 1, 768])


In [17]:
x = torch.cat((cls_token, x_transpose), dim=1)
print('cat class token shape:',x.shape)

cat class token shape: torch.Size([2, 197, 768])


In [18]:
pos_drop = nn.Dropout(p=0.1)
pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
x = pos_drop(x+pos_embed)
print('pos embedding shape',x.shape)   

pos embedding shape torch.Size([2, 197, 768])


In [32]:
dim = embed_dim  # 768
num_heads = 8
head_dim = dim // num_heads  # 96
scale = head_dim ** -.5
qkv = nn.Linear(dim,dim*3,bias=False)
proj = nn.Linear(dim, dim)

In [34]:
B,N,C = x.shape
x_qkv = qkv(x)
print('qkv:',x_qkv.shape)
x_qkv = x_qkv.reshape(B,N,3,num_heads,C//num_heads)
print('qkv reshape shape:',x_qkv.shape)
x_qkv = x_qkv.permute(2,0,3,1,4) # qkv, batch, head, token, embed
print(x_qkv.shape)

qkv: torch.Size([2, 197, 2304])
qkv reshape shape: torch.Size([2, 197, 3, 8, 96])
torch.Size([3, 2, 8, 197, 96])


In [38]:
x_qkv = qkv(x).reshape(B,N,3,num_heads,C//num_heads).permute(2,0,3,1,4)
q,k,v = x_qkv[0], x_qkv[1], x_qkv[2]
print('qkv shape:',q.shape,k.shape,v.shape)

qkv shape: torch.Size([2, 8, 197, 96]) torch.Size([2, 8, 197, 96]) torch.Size([2, 8, 197, 96])


In [41]:
attn = q @ k.transpose(-2,-1) * scale
print('attn shape:',attn.shape)
attn = attn.softmax(dim=-1)
print('attn shape:',attn.shape)

attn shape: torch.Size([2, 8, 197, 197])
attn shape: torch.Size([2, 8, 197, 197])


In [43]:
x = (attn @ v)
print(x.shape)
x = x.transpose(2,1)
print(x.shape)
x = x.reshape(B,N,C)
print(x.shape)

torch.Size([2, 8, 197, 96])
torch.Size([2, 197, 8, 96])
torch.Size([2, 197, 768])


In [44]:
x = proj(x)
print(x.shape)

torch.Size([2, 197, 768])


In [5]:
x = torch.randn((2,197,768))
B,N,C = x.shape
dim = C
qkv = nn.Linear(dim,dim*3,bias=False)
x_qkv = qkv(x)
x_qkv.shape

torch.Size([2, 197, 2304])

In [10]:
### droppath

x = torch.rand((2,3,4))
drop_prob = 0.2
keep_porb = 1 - drop_prob  # 0.8
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # (2,1,1)
random_tensor = keep_porb + torch.rand(shape)
random_tensor.floor_()
output = x.div(keep_porb) * random_tensor
x,output

(tensor([[[0.6650, 0.1043, 0.4476, 0.2958],
          [0.6112, 0.5947, 0.1497, 0.4887],
          [0.5964, 0.5270, 0.5247, 0.1507]],
 
         [[0.3382, 0.4961, 0.9240, 0.1225],
          [0.8432, 0.3558, 0.6870, 0.9859],
          [0.1601, 0.4467, 0.8009, 0.7571]]]),
 tensor([[[0.8312, 0.1304, 0.5595, 0.3698],
          [0.7640, 0.7434, 0.1871, 0.6108],
          [0.7454, 0.6588, 0.6559, 0.1884]],
 
         [[0.4228, 0.6201, 1.1550, 0.1531],
          [1.0540, 0.4448, 0.8588, 1.2323],
          [0.2001, 0.5583, 1.0011, 0.9463]]]))