In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.optim import Adam
from torchvision.datasets.mnist import MNIST
from torch.utils.data import DataLoader
import numpy as np

### <center> Patch Embeddings </center>

In [None]:
class PatchEmbedding(nn.Module):
  """
  Split the image (for ImageNet dataset) into patches using a Conv2d layer
  """
  def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, bias=True):
    super().__init__()

    self.img_size = img_size
    self.patch_size = patch_size
    self.in_channels = in_channels
    self.embed_dim = embed_dim
    
    self.num_patches = (img_size // patch_size) ** 2

    # Using a convolution to get patches
    self.proj = nn.Conv2d(in_channels=in_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size)


  # B: batch size
  # C: num channels
  # P_col: patch col
  # P_row: patch row
  def forward(self, x):
    x = self.proj(x) # (B, C, H, W) -> (B, emed_dim, P_col, P_row)

    x = x.flatten(2) # (B, embed_dim, P_col, P_row) -> (B, embed_dim, num_patches)

    x = x.transpose(1, 2) # (B, embed_dim, num_patches) -> (B, num_patches, embed_dim)
    
    return x