In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

######################### Step1: Convert image to embedding vector sequence #########################
def image2emb_naive(image, patch_size, weight):
    # image shape: batch_size * channel * h * w
    patch = F.unfold(image, kernel_size=patch_size, stride=patch_size).transpose(-2, -1)
    # print("patch = ", patch)
    # print("The shape of patch is: ", patch.shape)

    patch_embedding = patch @ weight
    return patch_embedding

def image2emb_conv(image, kernel, stride):
    # 可以将image to embedding看做是一个二维卷积，卷积的结果是batch_size*通道数*高度*宽度(因为这里是要模仿NLP，将图片变成一个序列)
    conv_output = F.conv2d(image, kernel, stride=stride) # batch_size = output_channel * output_height * output_weight
    bs, oc, oh, ow = conv_output.shape
    patch_embedding = conv_output.reshape((bs, oc, oh*ow)).transpose(-2, -1)

    return patch_embedding

# test code for image2emb
bs, ic, image_h, image_w = 1, 3, 8, 8 # bs: batch_size; ic: input_channel
patch_size = 4
model_dim = 8
patch_depth = patch_size * patch_size * ic
image = torch.randn(bs, ic, image_h, image_w)
# print("image : ", image)
# print("The shape of image is: ", image.shape)

weight = torch.randn(patch_depth, model_dim) # model_dim是输出通道数目，patch_depth是卷积核的面积乘以输入通道数
# print("weight = ", weight)
# print("The shape of weight is: ", weight.shape) # torch.Size([1, 4, 48])  '1': batch_size; '4': patch数目，序列长度。一个8*8的图片，如果以4*4为一块的一共有4块(4就是图片分块过后块的数目); '48':一个patch所包含的像素点的数目。batch_size*batch_size*input_channel(4*4*3)

patch_embedding_naive = image2emb_naive(image, patch_size, weight) # 分块方法得到的patch embedding
# print("patch_embedding_naive = ", patch_embedding_naive, '\n')
# print("The shape of patch_embedding_naive is: ", patch_embedding_naive.shape) # torch.Size([1, 4, 8]) 一个图片被分成了4块，每一块用一个长度为8的向量来表示这个块

kernel = weight.transpose(0, 1).reshape((-1, ic, patch_size, patch_size))  # output_channel * input_channel * kernel_height * kernel_width
patch_embedding_conv = image2emb_conv(image, kernel, patch_size) # 二维卷积方法得到的patch embedding
# print("patch_embedding_naive: ", '\n',patch_embedding_naive)
# print("The shape of patch_embedding_naive is: ", patch_embedding_naive.shape, '\n')
# print("patch_embedding_conv: ", '\n',patch_embedding_conv)
# print("The shape of patch_embedding_conv is: ", patch_embedding_conv.shape)

######################### Step2: CLS token embedding (like BERT) #########################
cls_token_embedding = torch.randn(batch_size, 1, model_dim, requires_grad=True)
token_embedding = torch.cat([cls_token_embedding, patch_embedding_conv], dim=1)






patch_embedding_naive:  
 tensor([[[-2.8071,  5.8215, -5.6974,  8.9031, 14.3631,  6.9120,  0.8625,
           8.4833],
         [-8.9508, -0.5123,  0.2092, -2.2515,  2.1214, -7.5573,  7.4714,
           4.2689],
         [-0.4929, -2.6356, -0.8143, -5.2192,  6.1698, -2.7434, -4.6646,
           9.7776],
         [-6.6631, -3.6958, -9.9150, -4.3306,  5.2153,  9.9827, -0.1758,
           9.5918]]])
The shape of patch_embedding_naive is:  torch.Size([1, 4, 8]) 

patch_embedding_conv:  
 tensor([[[-2.8071,  5.8215, -5.6974,  8.9031, 14.3631,  6.9120,  0.8625,
           8.4833],
         [-8.9508, -0.5123,  0.2092, -2.2515,  2.1214, -7.5573,  7.4714,
           4.2689],
         [-0.4929, -2.6356, -0.8143, -5.2192,  6.1698, -2.7433, -4.6646,
           9.7776],
         [-6.6631, -3.6958, -9.9150, -4.3306,  5.2153,  9.9827, -0.1758,
           9.5918]]])
The shape of patch_embedding_conv is:  torch.Size([1, 4, 8])


In [8]:
# torch.nn.functional.unfold()测试

import torch
import torch.nn.functional as F

x = torch.arange(0, 1*3*15*15).float()
print("Before view: ", '\n', x)
print("The shape of x is: ", x.shape, '\n')

x = x.view(1, 3, 15, 15)
print("After view: ", '\n',x)
print("The shape of x is: ", x.shape, '\n')

x1 = F.unfold(x, kernel_size=3, dilation=1, stride=1)
print("x1 = ", x1)
print("The shape of x1 is: ", x1.shape)

Before view:  
 tensor([  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,  11.,
         12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,  22.,  23.,
         24.,  25.,  26.,  27.,  28.,  29.,  30.,  31.,  32.,  33.,  34.,  35.,
         36.,  37.,  38.,  39.,  40.,  41.,  42.,  43.,  44.,  45.,  46.,  47.,
         48.,  49.,  50.,  51.,  52.,  53.,  54.,  55.,  56.,  57.,  58.,  59.,
         60.,  61.,  62.,  63.,  64.,  65.,  66.,  67.,  68.,  69.,  70.,  71.,
         72.,  73.,  74.,  75.,  76.,  77.,  78.,  79.,  80.,  81.,  82.,  83.,
         84.,  85.,  86.,  87.,  88.,  89.,  90.,  91.,  92.,  93.,  94.,  95.,
         96.,  97.,  98.,  99., 100., 101., 102., 103., 104., 105., 106., 107.,
        108., 109., 110., 111., 112., 113., 114., 115., 116., 117., 118., 119.,
        120., 121., 122., 123., 124., 125., 126., 127., 128., 129., 130., 131.,
        132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142., 143.,
        144., 145., 146.