In [19]:
import os
import sys

import torch

from pathlib import Path

super_directory = os.path.abspath('..')
sys.path.append(super_directory)

from data_setup import get_dataloaders

In [20]:
# Device agnostic code
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [21]:
# Hyperparameters
BATCH_SIZE = 32

# Patches
PATCH_SIZE = (7, 7)
NUM_PATCHES = int((28 / 7) ** 2)

# Patches to Embeddings
EMBED_DIMS = 48

# Data

In [22]:
# Data paths
data_path = Path('../data')

In [23]:
# Getting the dataloaders
train_dataloader, test_dataloader, class_labels = get_dataloaders(batch_size=BATCH_SIZE)

In [24]:
# Get X and y from the first batch
batch_X, batch_y = next(iter(train_dataloader))

# Encoder Input (Patch Embeddings)

In [25]:
print(f"Image shape: {batch_X.shape} -> (batch_dim, color_channels, image_height, image_width)")

Image shape: torch.Size([32, 1, 28, 28]) -> (batch_dim, color_channels, image_height, image_width)


In [26]:
# Sequential block to process the flattened patches
embed_patch_processing = torch.nn.Sequential(
    torch.nn.Conv2d(in_channels=1,                  # Input -> (32, 1, 28, 28)
                    out_channels=EMBED_DIMS,
                    kernel_size=PATCH_SIZE,
                    stride=PATCH_SIZE),             # Output -> (32, 48, 4, 4)
    torch.nn.Flatten(start_dim=2, end_dim=3),       # Output -> (32, 48, 16)
)

In [27]:
# Verifying the output of flattened patches
patch_embeddings = embed_patch_processing(batch_X)
print(f"Output shape: {patch_embeddings.shape} -> (batch_dim, embedding_dims, num_patches)")

Output shape: torch.Size([32, 48, 16]) -> (batch_dim, embedding_dims, num_patches)


In [28]:
# Rearrange the dimensions for better readability
    # Change to -> (batch_dim, num_patches, embedding_dims)
    #######################################################
    # Number of datapoints in each batch -> batch_dim
    # Number of patches in each datapoint -> num_patches
    # Number of dimensions in each patch -> embedding_dims
    #######################################################
patch_embeddings = patch_embeddings.permute(0, 2, 1)
print(f"Rearranged output shape: {patch_embeddings.shape} -> (batch_dim, num_patches, embedding_dims)")

Rearranged output shape: torch.Size([32, 16, 48]) -> (batch_dim, num_patches, embedding_dims)


# Encoder Input (Prepend Class Embeddings)

In [29]:
# Class embedding (Learnable embedding)
class_embedding = torch.nn.Parameter(torch.randn(size=(1, EMBED_DIMS)),
                                     requires_grad=True)

# Expanding same rand numbers across all data in a batch
class_embedding = class_embedding.expand(BATCH_SIZE, -1, -1)

print(f"Shape of class embedding to be prepended: {class_embedding.shape} -> (batch_dim, num_patches, embedding_dims)")

Shape of class embedding to be prepended: torch.Size([32, 1, 48]) -> (batch_dim, num_patches, embedding_dims)


In [30]:
# Adding the class embedding
embeddings = torch.cat([class_embedding, patch_embeddings], dim=1)

print(f"Shape of the embeddings with patch and class embeddings together: {embeddings.shape} -> (batch_dim, num_embeddings, embedding_dims)")

Shape of the embeddings with patch and class embeddings together: torch.Size([32, 17, 48]) -> (batch_dim, num_embeddings, embedding_dims)


# Encoder Input (Positional Embeddings)

In [31]:
# Positional embeddings
positional_embeddings = torch.nn.Parameter(torch.randn(size=(NUM_PATCHES + 1, EMBED_DIMS)),
                                                       requires_grad=True)

# Expanding same rand numbers across all data in a batch
positional_embeddings = positional_embeddings.expand(size=(BATCH_SIZE, -1, -1))

print(f"Shape of positional embeddings to be added: {positional_embeddings.shape} -> (batch_dim, num_patches, embedding_dims)")

Shape of positional embeddings to be added: torch.Size([32, 17, 48]) -> (batch_dim, num_patches, embedding_dims)


In [32]:
# Adding the postional embeddings
embeddings += positional_embeddings