Patchification Block, 
basically taking an image and splitting it down in to patches of images and then send it to input layer
after flattening it.

- Here in case of DiT, the patchification method is done on Noised Latent Images
- Latents mean the image output adter passing it through Variation Encoder after which noise is added to it
- So generally it will be a Layer of Patchifies latents + Positional Encodings which is responsible for 
keeping the information of positions of patches.
- Our initial image is of size 256 * 256 which is converted to 32 * 32 after VAE 
- Now we have 2 options of creating patch size of 2 * 2 and 4 * 4 (let's keep it variable based on training Compute is left)


In [3]:
import torch
import torch.nn as nn

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, imageSize, patchSize, inChannels, embedDimension):
        super().__init__()
        self.patchSize = patchSize
        self.inChannels = inChannels
        self.embedDimension = embedDimension
        self.imageSize = imageSize

        self.patches = imageSize//patchSize * imageSize//patchSize

        self.encode = nn.Conv2d(in_channels = inChannels, out_channels = embedDimension, kernel_size = patchSize, stride = patchSize, bias = True)
        self.decode = nn.ConvTranspose2d(in_channels=embedDimension, out_channels=inChannels, kernel_size=patchSize, stride=patchSize, bias=True)
        self.positionalEmbedding = nn.Parameter(torch.zeros(1, self.patches, embedDimension))
        nn.init.trunc_normal_(self.positionalEmbedding, std=0.02)

    def unPatchify(self, x):
        batchSize, NPatches, EmbedDim = x.shape
        patchPerDim = self.imageSize // self.patchSize
        x = x.transpose(1, 2).reshape(batchSize, EmbedDim, patchPerDim, patchPerDim)
        out = self.decode(x)
        return out


    def forward(self, latentImage):

        allPatch = self.encode(latentImage)
        # print(allPatch.shape)
        flattened = allPatch.flatten(2).transpose(1, 2)
        # print(flattened.shape, self.positionalEmbedding.shape)
        out = flattened + self.positionalEmbedding
        return out

latent = torch.randn(128, 8, 8).unsqueeze(0)
pEmbed = PatchEmbedding(imageSize = 8, patchSize = 2, inChannels = 128, embedDimension = 768)
out = pEmbed(latent)
unpatched = pEmbed.unPatchify(out)
out.shape, unpatched.shape
# (batchSize, totalPatches, embeddinDimension)

(torch.Size([1, 16, 768]), torch.Size([1, 128, 8, 8]))

In [6]:
out = out.transpose(1, 2).reshape(1, 768, 8, 8)
out.shape

torch.Size([1, 768, 8, 8])