In [1]:
#using_PyTorch
import os
import copy 
import math
import typing
import cv2
import numpy as np
import matplotlib.pyplot as plt

In [2]:
import torch
import torch.nn as nn

In [3]:
NoneFloat=typing.Union[None, float]

In [4]:
#Attention_in_General

class Attention(nn.Module):
    def __init__(self, dim:int, chan:int, num_heads:int=1, qkv_bias:bool=False, qk_scale:NoneFloat=None):
        super().__init__()
        self.num_heads=num_heads
        self.chan=chan
        self.head_dim=self.chan//self.num_heads
        self.scale=qk_scale or self.head_dim**-0.5
        assert self.chan % self.num_heads==0
        
        #Define_layers
        self.qkv=nn.Linear(dim, chan*3, bias=qkv_bias)
        self.proj=nn.Linear(chan, chan)
        
    def forward(self, x):
        B, N, C=x.shape #(Dim:(Batch, num_tokens, token_len))
        qkv=self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v=qkv[0], qkv[1], qkv[2]
        
        #calculate_attention
        attn=(q*self.scale) @ k.transpose(-2, -1)
        attn=attn.softmax(dim=-1)
        
        x=(attn @ v).transpose(1, 2).reshape(B, N, self.chan)
        x=self.proj(x)
        v=v.transpose(1, 2).reshape(B, N, self.chan)
        x=v+x
        return x

In [5]:
#Single_headed_Attention
#Define_an_input
token_len=7*7
channels=64
num_tokens=100
batch=13
x=torch.rand(batch, num_tokens, token_len)
B, N, C=x.shape
print('Input Dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken size:', x.shape[2])

A=Attention(dim=token_len, chan=channels, num_heads=1, qkv_bias=False, qk_scale=None)
A.eval();

Input Dimensions are
	batchsize: 13 
	number of tokens: 100 
	token size: 49


In [6]:
qkv=A.qkv(x).reshape(B, N, 3, A.num_heads, A.head_dim).permute(2, 0, 3, 1, 4)
q, k, v=qkv[0], qkv[1], qkv[2]
print('Dimensions for queries are\n\tbatchsize:', q.shape[0], '\n\attention heads:', q.shape[1], '\n\number of tokens:', q.shape[2], '\n\number of tokens:', q.shape[3])
print('See that the dimensions for queries, keys & values are same:')
print('\tShape of Q:', q.shape, '\n\tshape of k:', k.shape, '\n\tshape of v:', v.shape)

Dimensions for queries are
	batchsize: 13 
ttention heads: 1 

umber of tokens: 100 

umber of tokens: 64
See that the dimensions for queries, keys & values are same:
	Shape of Q: torch.Size([13, 1, 100, 64]) 
	shape of k: torch.Size([13, 1, 100, 64]) 
	shape of v: torch.Size([13, 1, 100, 64])


In [7]:
attn=(q*A.scale) @ k.transpose(-2, -1)
print('Dimensions of attention are\n\tbatchsize:', attn.shape[0], '\n\tattention heads:', attn.shape[1], '\n\tnumber of tokens:', attn.shape[2], '\n\tnumber of tokens', attn.shape[3])

Dimensions of attention are
	batchsize: 13 
	attention heads: 1 
	number of tokens: 100 
	number of tokens 100


In [8]:
#calculate_softmax_of_A_which_does_not_change_its_shape
attn=attn.softmax(dim=-1)
print('Dimensions for atten are\n\tbatchsize:', attn.shape[0], '\n\tattention heads:', attn.shape[1], '\n\tnumber of tokens:', attn.shape[2], '\n\tnumber of tokens:', attn.shape[3])

Dimensions for atten are
	batchsize: 13 
	attention heads: 1 
	number of tokens: 100 
	number of tokens: 100


In [9]:
x=attn @ v
print('Dimensions of x are\n\tbatchsize:', x.shape[0], '\n\tattention heads:', x.shape[1], '\n\tnumber of tokens', x.shape[2], '\n\tlength of tokens', x.shape[3])

Dimensions of x are
	batchsize: 13 
	attention heads: 1 
	number of tokens 100 
	length of tokens 64


In [10]:
#output_x_is_reshaped
x=x.transpose(1, 2).reshape(B, N, A.chan)
print('Dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])

Dimensions for x are
	batchsize: 13 
	number of tokens: 100 
	length of tokens: 64


In [11]:
x=A.proj(x)
print('Dimensions of x are\n\tbatch size:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])

Dimensions of x are
	batch size: 13 
	number of tokens: 100 
	length of tokens: 64


In [12]:
orig_shape=(batch, num_tokens, token_len)
curr_shape=(x.shape[0], x.shape[1], x.shape[2])
v=v.transpose(1, 2).reshape(B, N, A.chan)
v_shape=(v.shape[0], v.shape[1], v.shape[2])
print('Original shape of input x:',orig_shape)
print('Current shape of x:', curr_shape)
print('Shape of V:', v_shape)
x=v+x
print('After Skip connections, dimensions of x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])

Original shape of input x: (13, 100, 49)
Current shape of x: (13, 100, 64)
Shape of V: (13, 100, 64)
After Skip connections, dimensions of x are
	batchsize: 13 
	number of tokens: 100 
	length of tokens: 64


In [13]:
#Multi-headed Self Attention
#use_4_attention_heads
#define_input
token_len=7*7
channels=64
num_tokens=100
batch=13
num_heads=4
x=torch.rand(batch, num_tokens, token_len)
B, N, C=x.shape
print('Input Dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken size:', x.shape[2])
#Define_the-Module
MSA=Attention(dim=token_len, chan=channels, num_heads=num_heads, qkv_bias=False, qk_scale=None)
MSA.eval();

Input Dimensions are
	batchsize: 13 
	number of tokens: 100 
	token size: 49


In [14]:
qkv=MSA.qkv(x).reshape(B, N, 3, MSA.num_heads, MSA.head_dim).permute(2, 0, 3, 1, 4)
q, k, v=qkv[0], qkv[1], qkv[2]
print('Head Dimension=chan/num_heads=', MSA.chan, '/', MSA.num_heads, '=', MSA.head_dim)
print('Dimensions for Queries are\n\tbatchsize:', q.shape[0], '\n\tattention heads:', q.shape[1], '\n\tnumber of tokens:', q.shape[2], '\n\tnew length of tokens:', q.shape[3])
print('See that the dimensions for queries, keys and values are all the same:')
print('\tshape of Q:', q.shape, '\n\tshape of K:', k.shape, '\n\tshape of V:', v.shape)

Head Dimension=chan/num_heads= 64 / 4 = 16
Dimensions for Queries are
	batchsize: 13 
	attention heads: 4 
	number of tokens: 100 
	new length of tokens: 16
See that the dimensions for queries, keys and values are all the same:
	shape of Q: torch.Size([13, 4, 100, 16]) 
	shape of K: torch.Size([13, 4, 100, 16]) 
	shape of V: torch.Size([13, 4, 100, 16])


In [16]:
#Query-Key_multiplication
attn=(q*MSA.scale)@k.transpose(-2, -1)
print('Dimension of attention are\n\tbatchsize:', attn.shape[0], '\n\tattention heads:', attn.shape[1], '\n\tnumber of tokens:', attn.shape[2], '\n\tnumber of tokens:', attn.shape[3])

Dimension of attention are
	batchsize: 13 
	attention heads: 4 
	number of tokens: 100 
	number of tokens: 100


In [17]:
attn=attn.softmax(dim=-1)
x=attn@v
print('Dimension of x are\n\tbatchsize:', x.shape[0], '\n\tattention heads:', x.shape[1], '\n\tnumber of tokens:', x.shape[2], '\n\tlength of tokens:', x.shape[3])

Dimension of x are
	batchsize: 13 
	attention heads: 4 
	number of tokens: 100 
	length of tokens: 16


In [18]:
#concatenate_all_xi's
x=x.transpose(1, 2).reshape(B, N, MSA.chan)
print('DImensions of x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])

DImensions of x are
	batchsize: 13 
	number of tokens: 100 
	length of tokens: 64


In [19]:
x=MSA.proj(x)
print('Dimension of x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tnumber of tokens:', x.shape[2])
orig_shape=(batch, num_tokens, token_len)
curr_shape=(x.shape[0], x.shape[1], x.shape[2])
v=v.transpose(1, 2).reshape(B, N, A.chan)
v_shape=(v.shape[0], v.shape[1], v.shape[2])
print('Original Shape of input x:', orig_shape)
print('Current Shape of x:', curr_shape)
print('Shape of V:', v_shape)
x=v+x
print('After Skip connections, dimensions of x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])

Dimension of x are
	batchsize: 13 
	number of tokens: 100 
	number of tokens: 64
Original Shape of input x: (13, 100, 49)
Current Shape of x: (13, 100, 64)
Shape of V: (13, 100, 64)
After Skip connections, dimensions of x are
	batchsize: 13 
	number of tokens: 100 
	length of tokens: 64
