In [1]:
import torch
from einops import rearrange
import numpy as np

In [4]:
B = 64
H = 128
W = 128
C = 3
p = 16


In [5]:
# x will be of shape (B, C, H, W)
def patchify(x: torch.tensor, p: int):

    _, _, H, W = x.shape

    assert H % p == 0, "Invalid patch_size"
    assert W % p == 0, "Invalid patch_size"

    y = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)

    return y

# x will be of shape (B, H/P * W/P, p * p * c)
def depatchify(y: torch.tensor, H:int, W:int, p: int):

    _, num_tokens, _ = y.shape  
    h1 = H//p
    w1 = W//p

    assert num_tokens == (h1 * w1) , "Invalid dimensions"

    x = rearrange(y, "b (h1 w1) (p1 p2 c) -> b c (h1 p1) (w1 p2)", h1 = h1, w1 = w1, p1 = p, p2 = p)

    return x

x = torch.randn((B, C, H, W))
y = patchify(x, p)
x_rec = depatchify(y, H, W, p)

print(torch.allclose(x, x_rec))




True


In [None]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer


class Encoder(nn.Module):
    def __init__(self, num_layers: int, d_model: int, nhead: int, dim_feedforward: int):
        super().__init__()

        encoder_layer = TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            batch_first=True
        )

        self.encoder = TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x: torch.Tensor):  # (B, num_tokens, d_model)
        return self.encoder(x)


class PositionalEncoding(nn.Module):
    def __init__(self, num_tokens: int, d_model: int):
        super().__init__()
        self.pos_emb = nn.Parameter(torch.randn(1, num_tokens, d_model))

    def forward(self, x: torch.Tensor):
        return x + self.pos_emb
    

class Transformer(nn.Module):
    def __init__(self, num_layers, num_tokens, d_model, dim_feedforward, nhead):
        super().__init__()
        self.pos_emb = PositionalEncoding(num_tokens, d_model)
        self.encoder = Encoder(num_layers, d_model, nhead, dim_feedforward)

    def forward(self, x: torch.Tensor):
        x = self.pos_emb(x)
        x = self.encoder(x)
        return x
