<a href="https://colab.research.google.com/github/Nahom32/ViT/blob/main/ViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms

In [2]:
import math

transform = transforms.Compose([transforms.ToTensor()])

train_data = datasets.CIFAR10(root='./data/cifar-10', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='./data/cifar-10', train=False, download=True, transform=transform)

100%|██████████| 170M/170M [00:03<00:00, 47.4MB/s]


In [3]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size = 32, patch_size = 4, in_chans = 3, embed_dim = 768):
        super().__init__()
        self.img_size   = img_size
        self.patch_size = patch_size    # P
        self.in_chans   = in_chans      # C
        self.embed_dim  = embed_dim     # D

        self.num_patches = (img_size // patch_size) ** 2        # N = H*W/P^2
        self.flatten_dim = patch_size * patch_size * in_chans   # P^2*C

        self.proj = nn.Linear(self.flatten_dim, embed_dim) # (P^2*C,D)

        self.position_embed = nn.Parameter(torch.zeros(1, 1 + self.num_patches, embed_dim))
        self.class_embed    = nn.Parameter(torch.zeros(1, 1, embed_dim))

    def forward(self, x):
        B, C, H, W = x.shape

        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        x = x.reshape(1, -1, self.patch_size, self.patch_size)
        x = x.permute(0, 2, 1, 3).reshape(B, self.num_patches, -1)

        x = self.proj(x)

        cls_emb = self.class_embed.expand(B, -1, -1)
        x = torch.cat((cls_emb, x), dim = 1)

        x = x + self.position_embed
        return x


In [4]:
patch_embed = PatchEmbedding()

embeddings = patch_embed(torch.stack([train_data[i][0] for i in range(10)]))
print(embeddings.shape)
print(embeddings)


torch.Size([10, 65, 768])
tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-5.2078e-01,  1.0504e-02,  5.4104e-01,  ...,  1.8488e-01,
          -3.0784e-01,  1.0644e-01],
         [-4.0490e-01,  4.1700e-03,  4.2748e-01,  ...,  1.1319e-01,
          -2.0800e-01, -2.2263e-02],
         ...,
         [-3.0173e-01,  7.5302e-02,  3.5745e-01,  ...,  1.8374e-01,
          -1.4220e-01,  1.4374e-01],
         [-3.3998e-01, -3.5417e-02,  3.2596e-01,  ...,  1.0464e-01,
          -6.7990e-02, -7.4361e-02],
         [-1.7090e-01, -1.3885e-02,  2.8166e-01,  ...,  1.0261e-01,
          -6.6580e-02, -3.1313e-02]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-6.6811e-01, -1.0550e-01,  5.6385e-01,  ..., -4.7664e-03,
          -2.9877e-01, -7.1115e-02],
         [-3.9925e-01, -3.2990e-03,  3.9107e-01,  ...,  1.3110e-01,
          -9.6710e-02, -8.2381e-03],
         ...,