In [4]:
import setup
import torch
from torchvision import transforms
from data import create_data_loaders


setup.set_seed()
device: torch.device = setup.get_device()
CPU_COUNT: int = setup.get_cpu_count()
MNIST_STATS: dict = {"std": [0.3081], "mean": [0.1307]}
BATCH_SIZE: int = 32

train_dir: str = "static/MNIST/train"
test_dir: str = "static/MNIST/test"
height: int = 224
width: int = 224
color_channels: int = 3
patch_size: int = 16
num_patches: int = (height // patch_size) * (width // patch_size)

In [5]:
train_loader, test_loader, classes = create_data_loaders(
    train_dir=train_dir,
    test_dir=test_dir,
    transform=transforms.Compose(
        [
            transforms.Resize((width, height)),
            transforms.ToTensor(),
            transforms.Normalize(**MNIST_STATS),
        ]
    ),
    num_workers=CPU_COUNT,
)

In [None]:
image, label = next(iter(train_loader))

In [11]:
from torch import nn

out_channels: int = color_channels * patch_size**2
# Create the Conv2d layer from the ViT paper. The idea is to
# treat the series as an image of patches that are mapped to the transformer input
# via a learnable embedding. This is achievable with a convolution folllowed by a suitable projection as well
conv2d = nn.Conv2d(
    in_channels=color_channels,
    out_channels=out_channels,  # Hidden size D i.e., embedding size
    kernel_size=patch_size,
    stride=patch_size,
    padding=0,
)

# We flatten the spatial dimensions
flatten = nn.Flatten(start_dim=2, end_dim=3)

In [18]:
image = torch.rand(color_channels, height, width)

image = image.unsqueeze(0)
print(image.shape)

image = flatten(conv2d(image)).permute(0, 2, 1)
print(image.shape)

torch.Size([1, 3, 224, 224])
torch.Size([1, 196, 768])
