In [None]:
# shape (B,C,H,W)
# (B,num_vectors,vector_dim)

In [46]:
import torch
from torch import nn

In [55]:
def patchify(images, n_patches):
    n, c, h, w = images.shape

    assert h == w, "Patchify method is implemented for square images only"

    patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)
    patch_size = h // n_patches

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches

In [56]:
class MyViT(nn.Module):
  def __init__(self, chw=(3,28, 28), n_patches=7):
    # Super constructor
    super(MyViT, self).__init__()

    # Attributes
    self.chw = chw # (C, H, W)
    self.n_patches = n_patches

    assert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
    assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"

  def forward(self, images):
    patches = patchify(images, self.n_patches)
    return patches

In [58]:
if __name__ == '__main__':
  # Current model
  model = MyViT(
    chw=(3, 28, 28),
    n_patches=7
  )

  x = torch.randn(7, 3, 28,28) # Dummy images
  print(model(x).shape) # torch.Size([7, 49, 16])

torch.Size([7, 49, 48])


In [None]:
# from torch import nn
# from einops.layers.torch import Rearrange
# import torch

# class PatchEmbedding(nn.Module):
#     def __init__(self, in_channels=3, patch_size=8, emb_size=128):
#         self.patch_size = patch_size
#         super().__init__()
#         self.projection = nn.Sequential(
#             Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
#             nn.Linear(patch_size * patch_size * in_channels, emb_size)
#         )

#     def forward(self, x):
#         x = self.projection(x)
#         return x

# sample_datapoint = torch.randn(2, 3, 144, 144)  # Example random image tensor
# datapoint_shape = sample_datapoint.shape
# print("Initial shape: ", datapoint_shape)

# patch_embedding = PatchEmbedding()

# embedding = patch_embedding(sample_datapoint)
# #print("Patches shape: ", embedding.shape)

# batch_size = embedding.shape[0]
# patch_size = patch_embedding.patch_size

# num_vectors = (datapoint_shape[2] // patch_size) * (datapoint_shape[3] // patch_size)

# vector_dim = embedding.shape[2]
# reshaped_embedding = embedding.view(batch_size, num_vectors, vector_dim)

# print("Required_shape: ", reshaped_embedding.shape)
