In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

In [3]:
class ViT(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim):
        super(ViT, self).__init__()
        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2

        self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
        self.positional_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=dim, nhead=8), num_layers=6)
        self.classification_head = nn.Linear(dim, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        x = rearrange(x, 'b c h w -> b (h w) c')
        print('rearrange', x.shape)
        x = torch.cat((self.positional_embedding, x), dim=1)
        x = self.transformer(x)
        x = x[:, 0, :]  # take only the first token
        x = self.classification_head(x)
        return x

# Example usage
image_size = 224
patch_size = 16
num_classes = 10
dim = 512

model = ViT(image_size, patch_size, num_classes, dim)
input = torch.randn(1, 3, image_size, image_size)
output = model(input)
print(output.shape)



rearrange torch.Size([1, 196, 512])
torch.Size([1, 10])


In [None]:
# https://blog.csdn.net/MengYa_Dream/article/details/126579405

import torch
from vit_pytorch import ViT
 
v = ViT(
    image_size = 256,     
    patch_size = 32,      
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)
 
img = torch.randn(1, 3, 256, 256)
 
preds = v(img) # (1, 1000)