In [None]:
import matplotlib.pyplot as plt
import torch 
import torchvision

In [None]:
from torch import nn
from torchvision import transforms
from torchinfo import summary

In [None]:
import data_setup
import engine
from helper_functions import download_data, set_seeds, plot_loss_curves

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
#Download dataset if none already exists
#Dataset is pizza_steak_sushi
image_path = download_data(
    source="https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip", 
    destination="pizza_steak_sushi"
)
image_path


In [None]:
#Set-up test and train paths 
train_dir = image_path / "train"
test_dir = image_path / "test"

In [None]:
#Now we set our image size and transform our images before puting them through 
#the dataloaders
IMG_SIZE = 224

#Create transform pipeline
manual_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])

print(f'Image transformer created: {manual_transforms}')

In [None]:
#Vit states they used batch size of 4096 and until i switch to running on my PC where I 
# have a GPU that can handle that I will be using a batch size of 24
BATCH_SIZE = 24 if device == "cpu" else 4096

#Create data loaders
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
    train_dir=train_dir,
    test_dir=test_dir,
    transform=manual_transforms, #use previously defined transforms
    batch_size=BATCH_SIZE
)

train_dataloader, test_dataloader, class_names

In [None]:
#lets see if the dataloaders worked correctly by loading an image 
# first grab a batch of images from the train set 
image_batch, label_batch = next(iter(train_dataloader))

# get a single image from the batch 
image, label = image_batch[0], label_batch[0]

#view  the objects
image.shape, label

In [None]:
#that didn't really visualize it but we can see the size and the the associaate label tensor 
# so lets plot it w/matplotlib
plt.imshow(image.permute(1, 2, 0)) #rearrange image dimensions to suit matplotlib [color_channels, height, width] -> [height, width, color_channels]
plt.title(class_names[label])
plt.axis(False)

mmmm good looking pizza, or steak, or sushi

## now we are ready to actually replecate the paper 

In [None]:
# We start by calculating the patch embedding input and output shapes
# our training resolution is 224 x 224 (H x W) 
height = 224
width = 224 
color_channels = 3 # C 
patch_size = 16 # P - taken from column ViT-B/16 from table 5 in the ViT paper 

# calculate N (number of patches) 
number_of_patches = int((height * width) / patch_size**2)
print(f"Number of patches N w/ image height (H = {height}), width (W = {width}) and patch size (P = {patch_size}) is (N = {number_of_patches})")

In [None]:
# NOw lets replicate the input and output shapes of the patch embedding layer 
# Input: Image starts as 2D w/size (H x W x C)
embedding_layer_input_shape = (height, width, color_channels)

# Output: Image gets converted to a sequence of flattened 2D patches w/size (N x (P^2 dot C))
embedding_layer_output_shape = (number_of_patches, patch_size**2 * color_channels)
print(f"Input Shape (single 2D image): {embedding_layer_input_shape}")
print(f"Output Shape (single 2d Image flattened into patches): {embedding_layer_output_shape}")

In [None]:
# Now lets convert an image into patches
#change Image shape to be compatible w/matplotlib (color_channels, height, width) -> (height, width, color_channels)
image_permuted = image.permute(1,2,0)

#index to plot the top row of patched pixels
patch_size = 16
plt.figure(figsize=(patch_size, patch_size))
plt.imshow(image_permuted[:patch_size, :, :])
#this should show the top row of patched pixels

In [None]:
#now lets see a single patch 
plt.figure(figsize=(patch_size, patch_size))
plt.imshow(image_permuted[:patch_size, :patch_size, :])

In [None]:
#Now we can turn this into individual patches 
#Setup hyperparameters and make sure image size and patch size are compatible 
img_size = 224
# patch_size = 16 #Don't redefine instead reuse
num_patches = img_size / patch_size 
assert img_size%patch_size == 0, "Image size must be divisible by patch size"
print(f"Number of patches per row: {num_patches} \n Patch size: {patch_size} pixels x {patch_size} pixels")

#Create series of subplots 
fig, axs = plt.subplots(
    nrows=1,
    ncols=img_size//patch_size, # One column per patch
    figsize=(num_patches, num_patches),
    sharex=True,
    sharey=True
)

#Iterate through number of patches in the top row 
for i, patch in enumerate(range(0, img_size, patch_size)):
    axs[i].imshow(image_permuted[:patch_size, patch:patch+patch_size, :]) # Keep height index constant, alter width index
    axs[i].set_xlabel(i+1) #Set label for patch number
    axs[i].set_xticks([])
    axs[i].set_yticks([])

In [None]:
# now how do we expand this for the full image? While keeping the correct hyperparameters, img_size and patch_size
#img_size = 224
#patch_size = 16 
#num_patches = img_size/patch_size
assert img_size % patch_size == 0, "Image size must be divisible by patch size"
print(f"Number of patches per row: {num_patches}\
        \nNumber of patches per column: {num_patches}\
        \nTotal patches: {num_patches*num_patches}\
        \nPatch size: {patch_size} pixels x {patch_size} pixels")
    
# again create series of subplots
fig, axs = plt.subplots(
    nrows=img_size//patch_size, # Need as int not float
    ncols=img_size//patch_size,
    figsize=(num_patches, num_patches),
    sharex=True,
    sharey=True
)

#loop through height and width this time
for i, patch_height in enumerate(range(0, img_size, patch_size)): #iterate through height
    for j, patch_width in enumerate(range(0, img_size, patch_size)): #iterate through width 
        #Plot permuted image patch 
        axs[i,j].imshow(image_permuted[
            patch_height:patch_height+patch_size, # iterate height
            patch_width:patch_width+patch_size, #iterate width
            :]) # Get all color_channel
        axs[i,j].set_ylabel(i+1, rotation='horizontal', horizontalalignment='right', verticalalignment='center')
        axs[i,j].set_xlabel(j+1)
        axs[i,j].set_xticks([])
        axs[i,j].set_yticks([])
        axs[i,j].label_outer()

#Set super title for overall plot 
fig.suptitle(f"{class_names[label]} -> Patchified Bitches", fontsize=16)
plt.show()

In [None]:
# doing that manually for all of our images would take forever 
# Instead we will use the torch Conv2d() to turn our image into patches of a convolutional neural network (CNN) feature maps
# then use torch Flatten() for flattening the spatial dimensions of the feature map
# patch_size = 16 

# Create the Conv2d layer w/hyperparameters from the ViT paper 
conv2d = nn.Conv2d(
    in_channels=3, # num color channels 
    out_channels=768, # From Table 1: Hidden size D, the embedding size 
    kernel_size=patch_size, # Can also use (patch_size, patch_size) 
    stride=patch_size, 
    padding=0
)

In [None]:
# Now that we have our CNN layer lets see what we get when we pass an image through it
plt.imshow(image.permute(1,2,0)) #adjust for matplotlib
plt.title(class_names[label])
plt.axis(False)

In [None]:
#Pass image through convolutional layer
image_out_of_conv = conv2d(image.unsqueeze(0)) # add a single batch dimension (height, width, color_channels)
print(image_out_of_conv.shape)
"""Should show a torch vector / series of 768 
can be read as 
torch.Size([1, 768, 14, 14]) -> [batch_size, embedding_dim, feature_map_height, feature_map_width]

"""

In [None]:
# Lets visuallize some of these feature maps to see what they look like 
import random 
random_indexes = random.sample(range(0, 768), k=5) # pick 5 random number between 0 and embedding_size 
print(f"Showing random convolutional feature maps from indexes: {random_indexes}")

#create plot 
fig, axs = plt.subplots(nrows=1, ncols=5, figsize=(12,12))

#Plot the randomly picked feature maps 
for i, idx in enumerate(random_indexes):
    image_conv_feature_map = image_out_of_conv[:, idx, :, :] #index on the output tensor of the convolutional layer 
    axs[i].imshow(image_conv_feature_map.squeeze().detach().numpy())
    axs[i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
# these don't look like much, almost like shrunken permuted versions of the entire image maybe
# they should be representations of the major features of the image 
# these features may change over time as the neural network learns 
# What do they look like numerically? 
single_feature_map = image_out_of_conv[:, 0, :, :]
single_feature_map, single_feature_map.requires_grad

the grad_fn output of the single_feature_map and the requires_grad=True attribute 


means that PyTorch is tracking the gradients of this feature map and it will be updated 


by gradient descent during training 



In [None]:

#We've turned our image into patch embeddings but htey're still in the 2D format
# Now we need to flatten them, PyTorch should have a handy tool to do that 
#whats our current shape again?
print(f"Current tenshor shape: {image_out_of_conv.shape} -> [batch, embedding_dim, feature_map_height, feature_map_width]")

In [None]:
# We've got the 768 part (P^2 dot C) 
# Now we need the number of patches (N)
""" To do this we need to flatten the tensor
but don't want to flatten the whole tensor we only want to flatten the 
"spatial dimensions of the feature map"
in our case it's teh feature_map_height and feature_map_width 
so lets create a torch.nn.Flatten() layer to flatten only those dimensions
we can use the start_dim and end_dim parameteres to set that up 
"""
flatten = nn.Flatten(
    start_dim=2, # flatten feature_map_height (dim 2)
    end_dim=3 # flatten feature_map_width (dim 3) 
)

In [None]:
#Now lets put it all together 
"""
Step 1: take a single image 
Step 2: Put it through the convolutional layer (conv2d) to turn the image into a 2D feature map
(or patch embeddings)
Step 3: Flatten the 2D feature map into a single sequence 
"""
# Step 1 take our single image and view it 
plt.imshow(image.permute(1,2,0)) #adjusted for matplotlib
plt.title(class_names[label])
plt.axis(False)
print(f"Original image shape: {image.shape}")

# Step 1 turn it into feature maps 
image_out_of_conv = conv2d(image.unsqueeze(0)) #add batch dimension to avoid shape errors 
print(f"Image feature map shape: {image_out_of_conv.shape}")

# Step 2 Flatten the feature maps 
image_out_of_conv_flattened = flatten(image_out_of_conv)
print(f"Flattened image feature map shape: {image_out_of_conv_flattened.shape}")

In [None]:
# we are so close to teh desired shape outlined in ViT paper 
# Desired output (Flattened 2D patches): (196, 768) (N * (P^2 dot C))
# Current shape: (1, 768,  196)
# Only idfference is current shape has batch size and the dimensions are in reversed order to
# the desired output 

#Lets rearrange the dimensions using Tensor.permute()
image_out_of_conv_flattened_reshaped = image_out_of_conv_flattened.permute(0,2,1) #[batch_size, P^2•C, N] -> [batch_size, N, P^2•C]
print(f"Patch embedding sequence shape: {image_out_of_conv_flattened_reshaped.shape} -> [batch_size, num_patches, embedding_size]")


In [None]:
#sweet now we've matched the desired input & output shapes for the patch embedding layer of the ViT architecture 
#using just 2 PyTorch layers
# Lets visualize again 

#Get a single flattned feature map 
single_flattened_feature_map = image_out_of_conv_flattened_reshaped[:,:,0] # Index: (batch_size, num patches, embedding_dimension)

# Plot the flattened feature map
plt.figure(figsize=(22,22))
plt.imshow(single_flattened_feature_map.detach().numpy())
plt.title(f"Flattened feature map shape: {single_flattened_feature_map.shape}")
plt.axis(False)

In [None]:
# looks weird 
# May be due to the original transformer architecture was made to work with text 
# Vision Transformer architecture (ViT)'s goal was to use the original Transformer for images
# What does it look like in tensor form?
single_flattened_feature_map, single_flattened_feature_map.requires_grad, single_flattened_feature_map.shape

In [None]:
# Now lets combine these steps into a single PyTorch layer 
# We'll do it by sublassing nn.Module and creating a small PyTorch "model" 
# to do all the steps above 
"""
Specifically we'll:
1: Create a class called PatchEmbedding which subclasses nn.Module 
2: Initialize the class w/ parameters in_channels=3, patch_size=16 and embedding_dim=768 
3: Create a layer to turn an image into patches using nn.Conv2d 
4: Create a layer to flatten the patch feature maps into a single dimension
5: Define a forward() method to take an input  and pass it through the layers created in 3 & 4
6: Make sure the outputshape reflects the required output shape of the ViT architecutre 
"""

In [None]:
# Step 1: create a class
class PatchEmbedding(nn.Module):
    """Turns 2D input image into a 1D sequence learnable embedding vector
    
    Args: 
    in_channels (int): Number of color channels for the input images. Defaults to 3 
    patch_size (int): Size of patches to convert input image into. Defaults to 16
    embedding_dim (int): Size of embedding to turn image into. Defaults to 768 
    """
    #Step 2: initialize class w/appropriate variables 
    def __init__(self, in_channels: int=3, patch_size: int=16, embedding_dim: int=768):
        super().__init__()
        self.in_channels=in_channels
        self.patch_size=patch_size
        self.embedding_dim=embedding_dim 

        #Step 3: Create a layer to turn image into patches 
        self.patcher = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.embedding_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding=0
        )

        #Step 4: Create a layer to flatten patch feature maps into single dimension 
        self.flatten = nn.Flatten(start_dim=2, end_dim=3) #Only flatten the feature map dimensions into a single vector
    
    #Step 5: Define forward method
    def forward(self, x):
        #Create assertion to check that inputs are correct shape 
        image_resolution = x.shape[-1]
        assert image_resolution % self.patch_size == 0, f"Input image size must be divisible by patch size \n image shape: {image_resolution}, patch_size: {self.patch_size}"

        # Perform forward pass 
        x_patched = self.patcher(x)
        x_flattened = self.flatten(x_patched) 

        # Step 6: make sure output shape has the right order 
        return x_flattened.permute(0,2,1) # adjusting so the embedding is the final dimension [batch_size, P^2•C, N] -> [batch_size, N, P^2•C]


In [None]:
#Lets test it out 
set_seeds() 

# Create an instance of patch embedding layer 
patchify = PatchEmbedding(in_channels=3, patch_size=patch_size, embedding_dim=768)

# Pass an image through 
print(f"Input image shape: {image.unsqueeze(0).shape}")
patch_embedded_image = patchify(image.unsqueeze(0)) # Add an extra batch dimension on the 0th index, otherwise errors will occur
print(f"Ouptpu patch embedding shape: {patch_embedded_image.shape}")

In [None]:
# Kewl 
# We've now replicated the patch embedding for eq 1 but not the class token/position embedding 
# we'll get there eventually have patience and keep working at it
# all good things in their time 

# first lets get a summary of our PatchEmbedding layer 

In [None]:
# Create random input sizex 
rando_input_image = (1,3,224,224)
rando_input_image_error = (1,3,250,250) # will error because image size not compatible w/patch_size
summary(
    PatchEmbedding(),
    input_size=rando_input_image, # Next try with rando_input_image_error
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"]
)

In [None]:
# try with error 
summary(
    PatchEmbedding(),
    input_size=rando_input_image_error, # Next try with rando_input_image_error
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"]
)

## Creating the class token embedding

In [None]:
# now that we 've made the image patch embedding time to get to work on the class token embedding
# View the patch embedding and patch embedding shape
print(patch_embedded_image)
print(f"Patch embedding shape: {patch_embedded_image.shape} -> [batch_size, number_of_patches, embedding_dimension]")

In [None]:
"""To prepend a learnable embedding to the sequence of embedded patches we need
to create a learnable embedding in the shape of the embedding_dimension (D) then add 
it to the number_of_patches dimension
in pseudo-code 
patch_embedding = [image_patch_1, image_patch_2, image_patch_3...]
class_token = learnable_embedding
patch_embedding_with_class_token = torch.cat((class_token, patch_embedding), dim=1)
"""
#Get the batch size and embedding dimension 
batch_size = patch_embedded_image.shape[0]
embedding_dimension = patch_embedded_image.shape[-1]

#Create the class token embedding as a learnable parameter that shares the same size as 
# the embedding dimension (D) 
class_token = nn.Parameter(
    torch.ones(batch_size, 1, embedding_dimension), # [batch_size, num_tokens, embedding_dimension]
    requires_grad=True
)

# Show the first 10 examples of the class token 
print(class_token[:,:,:10])

# Pring the class_token shape 
print(f"Class token shape: {class_token.shape} -> [batch_size, num_tokens, embedding_dimension]")

In [None]:
# add the class token embedding to the front of the patch embedding 
patch_embedded_image_w_class_embedding = torch.cat((class_token, patch_embedded_image), dim=1) #concat on first dim

#Print the sequence of patch embeddings with the prepend class token embedding 
print(patch_embedded_image_w_class_embedding)
print(f"Sequence of patch embeddings with class token prepended shape: {patch_embedded_image_w_class_embedding.shape} -> [batch_size, num_patches, embedding_dimension]")