In [49]:
import torch
import torchvision

from pathlib import Path

In [50]:
# Device agnostic code
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [51]:
# Hyperparameters
BATCH_SIZE = 32

# Patches
num_patches = (224 / 16) ** 2
patch_size = (16, 16)

# Patches to Embeddings
embed_dims = 768

# Data

In [52]:
# Data paths
data_path = Path('./data/desserts')
train_path = data_path / 'train'
test_path = data_path / 'test'

In [61]:
# Visualizing images
from easyimages import EasyImageList

Li = EasyImageList.from_multilevel_folder(train_path)
Li.html(sample = 5, size = 150)

Drawing cannoli


Drawing donuts


Drawing pancakes


Drawing tiramisu


Drawing waffles


In [54]:
# Dataset transforms
train_transform = torchvision.transforms.Compose(
    [torchvision.transforms.Resize(size = (224, 224)),
     torchvision.transforms.TrivialAugmentWide(num_magnitude_bins=31),
     torchvision.transforms.ToTensor()]
)

test_transform = torchvision.transforms.Compose(
    [torchvision.transforms.Resize(size = (224, 224)),
     torchvision.transforms.ToTensor()]
)

# Datasets
train_dataset = torchvision.datasets.ImageFolder(root=train_path,
                                                 transform=train_transform,
                                                 target_transform=None)

test_dataset = torchvision.datasets.ImageFolder(root=test_path,
                                                transform=test_transform,
                                                target_transform=None)

In [55]:
# Dataloaders
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=BATCH_SIZE,
                                               shuffle=True,
                                               pin_memory=True)

test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=BATCH_SIZE,
                                              shuffle=False,
                                              pin_memory=True)

# Encoder Input

In [60]:
batch_X, batch_y = next(iter(train_dataloader))
batch_X.shape
print(f"Image shape: {batch_X.shape} -> (batch_dim, color_channels, image_height, image_width)")

Image shape: torch.Size([32, 3, 224, 224]) -> (batch_dim, color_channels, image_height, image_width)


In [None]:
embed_processing = torch.nn.Sequential(
    torch.nn.Conv2d(in_channels=3,                  # Input -> (32, 3, 224, 224)
                    out_channels=embed_dims,
                    kernel_size=patch_size,
                    stride=patch_size),             # Output -> (32, 768, 14, 14)
    torch.nn.Flatten(start_dim=2, end_dim=3),       # Output -> (32, 768, 196)
)

In [None]:
embeddings = embed_processing(batch_X)
print(f"Image shape: {embeddings.shape} -> (batch_dim, embedding_dim, num_patches)")

Image shape: torch.Size([32, 768, 196]) -> (batch_dim, color_channels, image_height, image_width)
