<a href="https://colab.research.google.com/github/SpencerFonbuena/MentorCruise/blob/main/vision_self_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as T
import torchvision.transforms as transforms
from datetime import datetime as dt
import tracemalloc



## Resources Used

> #### Articles
>> ##### An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale | [Paper](https://arxiv.org/pdf/2010.11929.pdf) | [Code](https://github.com/google-research/vision_transformer) | [Colab](https://colab.research.google.com/github/google-research/vision_transformer/blob/main/vit_jax.ipynb)
> #### Links
>> ##### Hugging Face Implementation | [Code](https://huggingface.co/transformers/v4.5.1/_modules/transformers/models/vit/modeling_vit.html)

## General Theory

> #### Questions
>> ##### Why do we have to create image patches? Why not just feed the whole image in?
>>> ##### Each token in the transformer will attend to every other token. In a typical 256x256 image used for vision tasks, that would mean 65536 tokens. With quadratic cost, it is not computationally feasible
> ##### Mathematics
>> ##### N = HW/P^2 = # of patches
>> ##### P = Heigth and Width of the patch (P, P, C)
>> ##### C = # of channels

## Mathematics Examples

In [2]:
# [N = HW/P^2]
pre_pro_img = (256,256,3) # (H, W, C)
H, W, C = pre_pro_img
P = 8 # => This is our patch size
N = H*W/(P**2) # => 1024: This is the supposed number of patches according to the number of patches.
# Assuming we take non-overlapping patches of an image, this would mean we have 256/8 x 256/8 patches. 256/8 = 32
H, W, C = (32, 32, 3) # We now have 32x32 image patches. If we understand this correctly, 32*32 should equal 1024
if N == H * W:
    print('True')
# [N = HW/P^2]

True


## Possible methods of creating the patches

In [4]:
# Reading in the image patches

#I used a local image. For uploading reasons, I've created a random tensor
'''rawimage = torchvision.io.read_image('/Users/spencerfonbuena/Desktop/Screenshot 2023-06-22 at 12.15.28 PM.png')
PIL_img = T.ToPILImage()(rawimage)
preimage = transforms.Resize(224)(PIL_img)
image = transforms.ToTensor()(preimage)
image = image.reshape(1,4,224,224)'''
image = torch.randn(1,4,224,224)

## My attempt at building one from scratch

> #### I attempted this before looking at an online implementation. I did change one thing after looking at other examples, which is outlined below
> #### I checked with the real image to make sure that it does in fact create patches of it. The last permuting and reshaping to get the feature dimension doesn't make a whole lot of sense, but I saw that the other implementation had their dimensions in a certain way and I found a way to copy after the fact :).

In [5]:
#for fun comparison of supposed speed and memory allocation vs the better implementation.
start = dt.now()
tracemalloc.start()

# Variable creation to generalize the code
B, C, H, W = image.shape
P = 16
N = int(H/P)

#I came to this solution using the toy example below
reshape = torch.cat(image.reshape(B, C, H, N, P).unbind(2), dim=3).reshape(B, C, N, N, P**2).permute(0,1,3,2,4).reshape(B,C,N*N,P**2).permute(0,2,1,3).reshape(B,N*N,C*P**2)
# frankly I'm not sure that the combining of the channels with the patch tokens even makes sense, but I saw the other example had (4,196,512) dimensions, and I decided to match it

mlp = nn.Linear(C*P**2, 512)
patch = mlp(reshape)
print(patch.shape)
# continued comparison metrics
running_secs = (dt.now() - start)
print(running_secs)
current, peak =  tracemalloc.get_traced_memory()
print(f"{current:0.2f}, {peak:0.2f}")
tracemalloc.stop()

torch.Size([1, 196, 512])
0:00:00.076021
40099.00, 69516.00


## Online Implementation

> #### This image embedding, as I expressed on our chat I believe, doesn't just take pixels and box them, but rather creates embed_dim # of feature maps, each pulled from a 16x16 patch of the image, and feeds them into the transformer. It seems like in this way you get a rich representation, without having to attend to each pixel token.

In [6]:
#This is an example from ChatGPT. Similar examples can be found at the huggingface link

# Example usage
image_size = 224
patch_size = 16
in_channels = 4
embed_dim = 512

# Generate a random image tensor (these dimensions match the size of the actual image)
batch_size = 4
image = torch.randn(batch_size, in_channels, image_size, image_size)

start = dt.now()
tracemalloc.start()

class PatchEmbedding(nn.Module):

    def __init__(self, image_size, patch_size, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.grid_size = image_size // patch_size
        self.embed_dim = embed_dim

        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # Input x has shape (batch_size, channels, height, width)

        # Reshape input to patches
        patches = self.projection(x)  # Output shape: (batch_size, embed_dim, grid_size, grid_size)
        patches = patches.flatten(2).transpose(1, 2)  # Output shape: (batch_size, grid_size^2, embed_dim)

        return patches




# Create an instance of the PatchEmbedding module
patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)

# Compute patch embeddings
patches = patch_embedding(image)
print(patches.shape)  # Output: (batch_size, grid_size^2, embed_dim)

running_secs = (dt.now() - start)
print(running_secs)
current, peak =  tracemalloc.get_traced_memory()
print(f"{current:0.2f}, {peak:0.2f}")
tracemalloc.stop()

torch.Size([4, 196, 512])
0:00:00.080698
30371.00, 40639.00


## Toy Example for Intuition

> #### toy.reshape(10,5,2)
>> #### In a 10 by 10 image with patch size = 2, there will be 25 patches. If each position in the 2d matrix is numbered, then the first patch would consist of positions (0, 1, 10, 11), the second (2, 3, 12, 13)... Paying attention to the pattern, each flattened 2x2 matrix consists of (x, x+1, x+10, (x+1)+10). By reshaping it as (10,5,2), the first dimension's (dimension with 10 values in this case) first number will be +10 to the first value of the next dimension, which allows us to begin pairing values to get to our (0,1,10,11)
> #### toy = torch.cat(toy.unbind(0), dim=-1)
>> #### Now that we have our dimensions lined up, we unbind them so that we can place the desired dimensions together, and then concatenate those dimensions.
> #### toy = toy.reshape(5,5,4)
>> #### The output of our last operation had our desired (x, x+1, x+10, (x+1)+10), but it had the dimensions 5,20, essentially lining up multiple of our desired blocks. To get each on it's own row, we reshaped it. Now, we've got our desired patch size of 4 in the last dimension.
> #### toy = toy.permute(1,0,2)
>> #### I'm assuming we need the blocks in sequential order. This permute is to put them in order, going from top left to bottom right
> #### toy = toy.reshape(25,4)
>> #### Finally, we reshape it so we have our desired 25 patches of 2x2 flattened patches.

In [7]:
toy = torch.arange(100).reshape(10,10)
# img = torch.cat(toy.reshape(10,5,2).unbind(0), dim=-1).reshape(5,5,4).permute(1,0,2).reshape(25,4) |=> On line example of the stretched out steps below.
print(toy)
toy = toy.reshape(10,5,2)
print(toy)
toy = torch.cat(toy.unbind(0), dim=-1)
print(toy)
toy = toy.reshape(5,5,4)
print(toy)
toy = toy.permute(1,0,2)
print(toy)
toy = toy.reshape(25,4)
print(toy)

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
        [40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54, 55, 56, 57, 58, 59],
        [60, 61, 62, 63, 64, 65, 66, 67, 68, 69],
        [70, 71, 72, 73, 74, 75, 76, 77, 78, 79],
        [80, 81, 82, 83, 84, 85, 86, 87, 88, 89],
        [90, 91, 92, 93, 94, 95, 96, 97, 98, 99]])
tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5],
         [ 6,  7],
         [ 8,  9]],

        [[10, 11],
         [12, 13],
         [14, 15],
         [16, 17],
         [18, 19]],

        [[20, 21],
         [22, 23],
         [24, 25],
         [26, 27],
         [28, 29]],

        [[30, 31],
         [32, 33],
         [34, 35],
         [36, 37],
         [38, 39]],

        [[40, 41],
         [42, 43],
         [44, 45],
         [46, 47],
         [48, 49]],

        [[50, 

In [8]:
toy = torch.arange(100).reshape(10,10)
# img = torch.cat(toy.reshape(10,5,2).unbind(0), dim=-1).reshape(5,5,4).permute(1,0,2).reshape(25,4) |=> On line example of the stretched out steps below.
print(toy)
toy = toy.reshape(10,5,2)
print(toy)
toy = torch.cat(toy.unbind(0), dim=-1)
print(toy)
toy = toy.reshape(5,5,4)
'''print(toy)
toy = toy.permute(1,0,2)
print(toy)
toy = toy.reshape(25,4)
print(toy)'''

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
        [40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54, 55, 56, 57, 58, 59],
        [60, 61, 62, 63, 64, 65, 66, 67, 68, 69],
        [70, 71, 72, 73, 74, 75, 76, 77, 78, 79],
        [80, 81, 82, 83, 84, 85, 86, 87, 88, 89],
        [90, 91, 92, 93, 94, 95, 96, 97, 98, 99]])
tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5],
         [ 6,  7],
         [ 8,  9]],

        [[10, 11],
         [12, 13],
         [14, 15],
         [16, 17],
         [18, 19]],

        [[20, 21],
         [22, 23],
         [24, 25],
         [26, 27],
         [28, 29]],

        [[30, 31],
         [32, 33],
         [34, 35],
         [36, 37],
         [38, 39]],

        [[40, 41],
         [42, 43],
         [44, 45],
         [46, 47],
         [48, 49]],

        [[50, 

'print(toy)\ntoy = toy.permute(1,0,2)\nprint(toy)\ntoy = toy.reshape(25,4)\nprint(toy)'

## Visual of Patch Embedding Process

![vit_figure.png](https://raw.githubusercontent.com/google-research/vision_transformer/main/vit_figure.png)