In [None]:
import torch
import torchvision

from pathlib import Path

In [None]:
# Device agnostic code
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Hyperparameters
BATCH_SIZE = 32

# Patches
PATCH_SIZE = (16, 16)
NUM_PATCHES = int((224 / 16) ** 2)

# Patches to Embeddings
EMBED_DIMS = 768

# Data

In [None]:
# Data paths
data_path = Path('./data/desserts')
train_path = data_path / 'train'
test_path = data_path / 'test'

In [None]:
# Visualizing images
from easyimages import EasyImageList

Li = EasyImageList.from_multilevel_folder(train_path)
Li.html(sample = 5, size = 150)

In [None]:
# Dataset transforms
train_transform = torchvision.transforms.Compose(
    [torchvision.transforms.Resize(size = (224, 224)),
     torchvision.transforms.TrivialAugmentWide(num_magnitude_bins=31),
     torchvision.transforms.ToTensor()]
)

test_transform = torchvision.transforms.Compose(
    [torchvision.transforms.Resize(size = (224, 224)),
     torchvision.transforms.ToTensor()]
)

# Datasets
train_dataset = torchvision.datasets.ImageFolder(root=train_path,
                                                 transform=train_transform,
                                                 target_transform=None)

test_dataset = torchvision.datasets.ImageFolder(root=test_path,
                                                transform=test_transform,
                                                target_transform=None)

In [None]:
# Class labels
class_labels = train_dataset.classes
print(class_labels)

In [None]:
# Dataloaders
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=BATCH_SIZE,
                                               shuffle=True,
                                               pin_memory=True)

test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=BATCH_SIZE,
                                              shuffle=False,
                                              pin_memory=True)

# Encoder Input (Patch Embeddings)

In [None]:
batch_X, batch_y = next(iter(train_dataloader))

print(f"Image shape: {batch_X.shape} -> (batch_dim, color_channels, image_height, image_width)")

In [None]:
embed_patch_processing = torch.nn.Sequential(
    torch.nn.Conv2d(in_channels=3,                  # Input -> (32, 3, 224, 224)
                    out_channels=EMBED_DIMS,
                    kernel_size=PATCH_SIZE,
                    stride=PATCH_SIZE),             # Output -> (32, 768, 14, 14)
    torch.nn.Flatten(start_dim=2, end_dim=3),       # Output -> (32, 768, 196)
)

In [None]:
patch_embeddings = embed_patch_processing(batch_X)
print(f"Output shape: {patch_embeddings.shape} -> (batch_dim, embedding_dims, num_patches)")

In [None]:
# Rearrange the dimensions for better readability
    # Change to -> (batch_dim, num_patches, embedding_dims)
    #######################################################
    # Number of datapoints in each batch -> batch_dim
    # Number of patches in each datapoint -> num_patches
    # Number of dimensions in each patch -> embedding_dims
    #######################################################
patch_embeddings = patch_embeddings.permute(0, 2, 1)
print(f"Rearranged output shape: {patch_embeddings.shape} -> (batch_dim, num_patches, embedding_dims)")

# Encoder Input (Prepend Class Embeddings)

In [None]:
# Class embedding (Learnable embedding)
class_embedding = torch.nn.Parameter(torch.randn(size=(1, EMBED_DIMS)),
                                     requires_grad=True)

# Expanding same rand numbers across all data in a batch
class_embedding = class_embedding.expand(BATCH_SIZE, -1, -1)

print(f"Shape of class embedding to be prepended: {class_embedding.shape} -> (batch_dim, num_patches, embedding_dims)")

In [None]:
# Adding the class embedding
embeddings = torch.cat([class_embedding, patch_embeddings], dim=1)

print(f"Shape of the embeddings with patch and class embeddings together: {embeddings.shape} -> (batch_dim, num_embeddings, embedding_dims)")

# Encoder Input (Positional Embeddings)

In [None]:
# Positional embeddings
positional_embeddings = torch.nn.Parameter(torch.randn(size=(NUM_PATCHES + 1, EMBED_DIMS)),
                                                       requires_grad=True)

# Expanding same rand numbers across all data in a batch
positional_embeddings = positional_embeddings.expand(size=(BATCH_SIZE, -1, -1))

print(f"Shape of positional embeddings to be added: {positional_embeddings.shape} -> (batch_dim, num_patches, embedding_dims)")

In [None]:
# Adding the postional embeddings
embeddings += positional_embeddings