In [None]:
import matplotlib.pyplot as plt
import torch 
import torchvision

In [None]:
from torch import nn
from torchvision import transforms
from torchinfo import summary

In [None]:
import data_setup
import engine
from helper_functions import download_data, set_seeds, plot_loss_curves

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
#Download dataset if none already exists
#Dataset is pizza_steak_sushi
image_path = download_data(
    source="https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip", 
    destination="pizza_steak_sushi"
)
image_path


In [None]:
#Set-up test and train paths 
train_dir = image_path / "train"
test_dir = image_path / "test"

In [None]:
#Now we set our image size and transform our images before puting them through 
#the dataloaders
IMG_SIZE = 224

#Create transform pipeline
manual_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])

print(f'Image transformer created: {manual_transforms}')

In [None]:
#Vit states they used batch size of 4096 and until i switch to running on my PC where I 
# have a GPU that can handle that I will be using a batch size of 24
BATCH_SIZE = 24 if device == "cpu" else 4096

#Create data loaders
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
    train_dir=train_dir,
    test_dir=test_dir,
    transform=manual_transforms, #use previously defined transforms
    batch_size=BATCH_SIZE
)

train_dataloader, test_dataloader, class_names

In [None]:
#lets see if the dataloaders worked correctly by loading an image 
# first grab a batch of images from the train set 
image_batch, label_batch = next(iter(train_dataloader))

# get a single image from the batch 
image, label = image_batch[0], label_batch[0]

#view  the objects
image.shape, label

In [None]:
#that didn't really visualize it but we can see the size and the the associaate label tensor 
# so lets plot it w/matplotlib
plt.imshow(image.permute(1, 2, 0)) #rearrange image dimensions to suit matplotlib [color_channels, height, width] -> [height, width, color_channels]
plt.title(class_names[label])
plt.axis(False)

mmmm good looking pizza, or steak, or sushi

## now we are ready to actually replecate the paper 

In [None]:
# We start by calculating the patch embedding input and output shapes
# our training resolution is 224 x 224 (H x W) 
height = 224
width = 224 
color_channels = 3 # C 
patch_size = 16 # P - taken from column ViT-B/16 from table 5 in the ViT paper 

# calculate N (number of patches) 
number_of_patches = int((height * width) / patch_size**2)
print(f"Number of patches N w/ image height (H = {height}), width (W = {width}) and patch size (P = {patch_size}) is (N = {number_of_patches})")

In [None]:
# NOw lets replicate the input and output shapes of the patch embedding layer 
# Input: Image starts as 2D w/size (H x W x C)
embedding_layer_input_shape = (height, width, color_channels)

# Output: Image gets converted to a sequence of flattened 2D patches w/size (N x (P^2 dot C))
embedding_layer_output_shape = (number_of_patches, patch_size**2 * color_channels)
print(f"Input Shape (single 2D image): {embedding_layer_input_shape}")
print(f"Output Shape (single 2d Image flattened into patches): {embedding_layer_output_shape}")

In [None]:
# Now lets convert an image into patches
#change Image shape to be compatible w/matplotlib (color_channels, height, width) -> (height, width, color_channels)
image_permuted = image.permute(1,2,0)

#index to plot the top row of patched pixels
patch_size = 16
plt.figure(figsize=(patch_size, patch_size))
plt.imshow(image_permuted[:patch_size, :, :])
#this should show the top row of patched pixels

In [None]:
#now lets see a single patch 
plt.figure(figsize=(patch_size, patch_size))
plt.imshow(image_permuted[:patch_size, :patch_size, :])

In [None]:
#Now we can turn this into individual patches 
#Setup hyperparameters and make sure image size and patch size are compatible 
img_size = 224
# patch_size = 16 #Don't redefine instead reuse
num_patches = img_size / patch_size 
assert img_size%patch_size == 0, "Image size must be divisible by patch size"
print(f"Number of patches per row: {num_patches} \n Patch size: {patch_size} pixels x {patch_size} pixels")

#Create series of subplots 
fig, axs = plt.subplots(
    nrows=1,
    ncols=img_size//patch_size, # One column per patch
    figsize=(num_patches, num_patches),
    sharex=True,
    sharey=True
)

#Iterate through number of patches in the top row 
for i, patch in enumerate(range(0, img_size, patch_size)):
    axs[i].imshow(image_permuted[:patch_size, patch:patch+patch_size, :]) # Keep height index constant, alter width index
    axs[i].set_xlabel(i+1) #Set label for patch number
    axs[i].set_xticks([])
    axs[i].set_yticks([])

In [None]:
# now how do we expand this for the full image? While keeping the correct hyperparameters, img_size and patch_size
#img_size = 224
#patch_size = 16 
#num_patches = img_size/patch_size
assert img_size % patch_size == 0, "Image size must be divisible by patch size"
print(f"Number of patches per row: {num_patches}\
        \nNumber of patches per column: {num_patches}\
        \nTotal patches: {num_patches*num_patches}\
        \nPatch size: {patch_size} pixels x {patch_size} pixels")
    
# again create series of subplots
fig, axs = plt.subplots(
    nrows=img_size//patch_size, # Need as int not float
    ncols=img_size//patch_size,
    figsize=(num_patches, num_patches),
    sharex=True,
    sharey=True
)

#loop through height and width this time
for i, patch_height in enumerate(range(0, img_size, patch_size)): #iterate through height
    for j, patch_width in enumerate(range(0, img_size, patch_size)): #iterate through width 
        #Plot permuted image patch 
        axs[i,j].imshow(image_permuted[
            patch_height:patch_height+patch_size, # iterate height
            patch_width:patch_width+patch_size, #iterate width
            :]) # Get all color_channel
        axs[i,j].set_ylabel(i+1, rotation='horizontal', horizontalalignment='right', verticalalignment='center')
        axs[i,j].set_xlabel(j+1)
        axs[i,j].set_xticks([])
        axs[i,j].set_yticks([])
        axs[i,j].label_outer()

#Set super title for overall plot 
fig.suptitle(f"{class_names[label]} -> Patchified Bitches", fontsize=16)
plt.show()