<a href="https://colab.research.google.com/github/GuiXu40/deeplearning0/blob/main/Basic_code/Vision_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [9]:
# step1. image patch embedding
# 使用两种方式对image进行处理，一种是直接通过分割图片的方式。一种是通过卷积操作
#  1. 直接分割图片
def image2emb_naive(image, patch_size, weight):
  patch = nn.functional.unfold(image, kernel_size=patch_size, stride=patch_size).transpose(-1, -2)
  print(patch.shape)
  patch_embedding = patch @ weight
  return patch_embedding

#  2. 使用卷积的方式直接输出
def image2emb_conv(image, kernel, stride):
  conv_output = F.conv2d(image, kernel, stride=stride)
  bs, oc, oh, ow = conv_output.shape
  patch_embedding = conv_output.reshape((bs, oc, oh * ow)).transpose(-1, -2)
  return patch_embedding

bs, ic, image_h, image_w = 1, 3, 8, 8
patch_size = 4
model_dim = 8
patch_depth = patch_size * patch_size * ic
image = torch.randn(bs, ic, image_h, image_w)
weight = torch.randn(patch_depth, model_dim)

patch_embedding_naive = image2emb_naive(image, patch_size, weight)
print(patch_embedding_naive)

kernel = weight.transpose(0, 1).reshape(-1, ic, patch_size, patch_size)
patch_embedding_conv = image2emb_conv(image, kernel, patch_size)

print(patch_embedding_conv)

torch.Size([1, 4, 48])
tensor([[[  2.9560,   3.5471,   1.9480,  -6.8817,  -4.3895,  -1.1609,   8.9942,
           -1.1096],
         [  1.8633,  -0.6170,   6.6276,   2.8051,   2.7494,   8.0640,   7.7082,
          -11.2505],
         [  3.1121, -10.1775,   7.9735,  -6.0336,  -0.3792,  -5.2678, -12.5721,
            1.4477],
         [ -3.5535,  10.2715,   4.5857,  -2.0812,  -4.3538,   1.0491,   0.4134,
           -2.9734]]])
tensor([[[  2.9560,   3.5471,   1.9480,  -6.8817,  -4.3895,  -1.1609,   8.9942,
           -1.1096],
         [  1.8633,  -0.6170,   6.6276,   2.8051,   2.7494,   8.0640,   7.7082,
          -11.2505],
         [  3.1121, -10.1775,   7.9736,  -6.0336,  -0.3792,  -5.2678, -12.5721,
            1.4477],
         [ -3.5535,  10.2715,   4.5857,  -2.0812,  -4.3538,   1.0491,   0.4134,
           -2.9734]]])


In [11]:
# step2. prepend CLS token embedding
cls_token_embedding = torch.randn(bs, 1, model_dim, requires_grad = True)
token_embedding = torch.cat([cls_token_embedding, patch_embedding_conv], dim = 1)
print(token_embedding.shape)
print(cls_token_embedding.shape)
print(patch_embedding_conv.shape)


torch.Size([1, 5, 8])
torch.Size([1, 1, 8])
torch.Size([1, 4, 8])


In [15]:
max_num_token = 16
# step3. add position embedding
position_embedding_table = torch.randn(max_num_token, model_dim, requires_grad=True)
seq_len = token_embedding.shape[1]
print(position_embedding_table[:seq_len])
position_embedding = torch.tile(position_embedding_table[:seq_len], [token_embedding.shape[0], 1, 1])
token_embedding += position_embedding
print(position_embedding.shape)

tensor([[-1.6283, -0.8537, -0.3089, -0.1871, -2.3427, -2.3121,  0.6822, -0.9336],
        [-1.2068,  0.2422,  1.2009,  0.3887,  0.5662, -0.0687,  0.7246,  1.3458],
        [ 0.0571, -1.1050,  0.5187, -0.5592, -1.0866, -0.1405,  1.2848, -1.2199],
        [ 2.2285, -0.9752,  0.3891, -0.1614,  0.2204,  0.8927,  1.5179, -0.8595],
        [-0.7251,  0.0619, -1.2961, -1.8820,  0.0202,  0.0031, -0.1315,  0.5817]],
       grad_fn=<SliceBackward0>)
torch.Size([1, 5, 8])


In [16]:
# step4. pass embedding to Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(d_model = model_dim, nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)

encoder_output = transformer_encoder(token_embedding)

In [20]:
# step5. do classification
num_classes=10
label = torch.randint(10, (bs,))

cls_token_output = encoder_output[:, 0, :]
linear_layer = nn.Linear(model_dim, num_classes)
logits = linear_layer(cls_token_output)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, label)
print(loss)

tensor(2.3791, grad_fn=<NllLossBackward0>)
