# Add to the path

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

# Import Libraries

In [2]:
from Models.basic_components.multihead_attention import MultiHeadAttention as base_MHA
from Models.tensorized_components.multihead_attention import MultiHeadAttention as tensorized_MHA
from Models.basic_components.patch_embedding import PatchEmbedding as base_embedding
from Models.tensorized_components.patch_embedding import PatchEmbedding as tensorized_embedding
from Utils.Num_parameter import count_parameters

import torch
import torch.nn as nn
import time

# Dummy Data


In [3]:
device = 'cuda'
batch_size = 16
dummy = torch.rand(batch_size, 3, 224, 224).to(device)
print(f'Current shape is : {dummy.shape}')

Current shape is : torch.Size([16, 3, 224, 224])


# MHA base method

1. Embedd using the base patch embedding method
2. Apply it to the MHA

In [4]:
st = time.time()
embedding = base_embedding(patch_size = 16,
                           in_channels= 3 ,
                           embed_dim= 16*16*3).to(device)
embedded_base = embedding(dummy)
elapsed = time.time() - st
print(f'output shape of patch embedding : {embedded_base.shape}\nnumber of patches = {embedded_base.shape[1]}')

print(f'This embedding has {count_parameters(embedding)} parameters')
print(f'This embedding took : {elapsed}')

output shape of patch embedding : torch.Size([16, 196, 768])
number of patches = 196
This embedding has 590592 parameters
This embedding took : 0.25943589210510254


In [5]:
for key in embedding.state_dict():
    print(key)

projection.weight
projection.bias


In [6]:
st = time.time()
MHA = base_MHA(embed_dim= 768,
               num_heads=4,
               out_embed= True).to(device)
result_mha = MHA(embedded_base)
elapsed = time.time() - st
print(f'MHA output shape is : {result_mha.shape}')
print(f'Input of MHA shape was : {embedded_base.shape}')
print(f'This MHA has {count_parameters(MHA)} parameters')
print(f'This MHA took : {elapsed}')

MHA output shape is : torch.Size([16, 196, 768])
Input of MHA shape was : torch.Size([16, 196, 768])
This MHA has 2362368 parameters
This MHA took : 0.02088165283203125


In [7]:
for key in MHA.state_dict():
    print(key)

query.weight
query.bias
key.weight
key.bias
value.weight
value.bias
fc_out.weight
fc_out.bias


In [8]:
temp_loss = nn.CrossEntropyLoss()
temp_y = torch.randint(0, 192, (result_mha.shape[0], result_mha.shape[2])).to(device)
l = temp_loss(result_mha, temp_y)
l.backward()
# If no error, Then is fine

# MHA tensorized method

1. Embedd using the tensorized patch embedding method
2. Apply it to the MHA

In [13]:
st = time.time()
embedding = tensorized_embedding(patch_size = 16,
                                 in_channels= 3 ,
                                 embed_dim= (16,16,3)).to(device)
embedded_tensorized = embedding(dummy)
elapsed = time.time() - st
print(f'output shape of patch embedding : {embedded_tensorized.shape}\nnumber pf patches = {embedded_tensorized.shape[1]} x {embedded_tensorized.shape[2]} = {embedded_tensorized.shape[1]*embedded_tensorized.shape[2]}')

print(f'This embedding has {count_parameters(embedding)} parameters')
print(f'This embedding took : {elapsed}')

output shape of patch embedding : torch.Size([16, 14, 14, 16, 16, 3])
number pf patches = 14 x 14 = 196
This embedding has 521 parameters
This embedding took : 0.007994651794433594


In [14]:
for key in embedding.state_dict():
    print(key)

w_h
w_w
w_c


In [15]:
st = time.time()
MHA = tensorized_MHA(embed_dim= (16,16,3),
                     num_heads=(2,2,1),
                     out_embed= True).to(device)
result_mha = MHA(embedded_tensorized)
elapsed = time.time() - st
print(f'MHA output shape is : {result_mha.shape}')
print(f'Input of MHA shape was : {embedded_tensorized.shape}')
print(f'This MHA has {count_parameters(MHA)} parameters')
print(f'This MHA took : {elapsed}')

torch.Size([16, 14, 14, 16, 16, 3])
MHA output shape is : torch.Size([16, 14, 14, 16, 16, 3])
Input of MHA shape was : torch.Size([16, 14, 14, 16, 16, 3])
This MHA has 2084 parameters
This MHA took : 0.12992429733276367


In [16]:
for key in MHA.state_dict():
    print(key)

w_e1_q
w_e2_q
w_e3_q
w_e1_k
w_e2_k
w_e3_k
w_e1_v
w_e2_v
w_e3_v
w_e1_out
w_e2_out
w_e3_out


In [18]:
temp_loss = nn.CrossEntropyLoss()
temp_y = torch.randint(0, 14, (result_mha.shape[0], result_mha.shape[2], result_mha.shape[3], result_mha.shape[4], result_mha.shape[5])).to(device)
l = temp_loss(result_mha, temp_y)
l.backward()
# If no error, Then is fine