In [85]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def img2embed(image, patch_size, stride, weight):
    patches = F.unfold(image, patch_size, stride=stride)
    patches = patches.transpose(-1, -2)
    embedding = patches @ weight
    return embedding
def img2embed_conv(image, kernel, stride):
    conv_output = F.conv2d(image, kernel, stride=stride)
    b, c, h, w = conv_output.shape
    embedding = conv_output.reshape(b, c , h * w).transpose(-1, -2)
    return embedding

bc, i_ch, i_h, i_w = 1, 3, 8, 8
patch_size = 4
stride = 4
output_dim = 8
max_num_token = 16
image = torch.randn(bc, i_ch, i_h, i_w)
weight = torch.randn(patch_size * patch_size * i_ch, output_dim)
num_classes = 10
label = torch.randint(0, num_classes, (bc, ))

embedding1 = img2embed(image, patch_size, stride, weight)
print("embedding1", embedding1.shape)

kernel = weight.transpose(-1, -2).reshape(output_dim, i_ch, patch_size, patch_size)
embedding2 = img2embed_conv(image, kernel, stride)
print("embedding2", embedding2.shape)


# step 2: CLS token embedding
cls_token_embedding = torch.randn(bc, 1, output_dim, requires_grad=True)
token_embedding = torch.cat([cls_token_embedding, embedding2], dim = 1)
print("token_embedding", token_embedding.shape)

# step3: Positional embedding
positional_embedding_table = torch.randn(max_num_token, output_dim, requires_grad=True)
seq_len = token_embedding.shape[1]
positional_embedding = positional_embedding_table[:seq_len].repeat([token_embedding.shape[0], 1, 1])
token_embedding += positional_embedding
print("token_embedding", token_embedding.shape)

# step4 feed embedding to Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(d_model=token_embedding.shape[2], nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
output = transformer_encoder(token_embedding)
print("output", out.shape)

# step5 extract CLS token and classify
cls_token_output = output[:, 0, :]
print("cls_token_output", cls_token_output.shape)
linear_layer = nn.Linear(cls_token_output.shape[1], num_classes)
logits = linear_layer(cls_token_output)
loss_f = nn.CrossEntropyLoss()
loss = loss_f(logits, label)
print("loss", loss)

embedding1 torch.Size([1, 4, 8])
embedding2 torch.Size([1, 4, 8])
token_embedding torch.Size([1, 5, 8])
token_embedding torch.Size([1, 5, 8])
output torch.Size([1, 5, 8])
cls_token_output torch.Size([1, 8])
loss tensor(2.9624, grad_fn=<NllLossBackward>)


In [54]:
torch.__version__

'1.7.0'