# Add to the path

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

# Import Libraries

In [2]:
from Models.basic_components.multihead_attention import MultiHeadAttention as base_MHA
from Models.tensorized_components.multihead_attention import MultiHeadAttention as tensorized_MHA
from Models.basic_components.patch_embedding import PatchEmbedding as base_embedding
from Models.tensorized_components.patch_embedding import PatchEmbedding as tensorized_embedding
from Utils.Num_parameter import count_parameters

import torch
import torch.nn as nn
from torch import optim
import time

# Dummy Data


In [3]:
device = 'cuda'
batch_size = 16
dummy = torch.rand(batch_size, 3, 224, 224).to(device)
print(f'Current shape is : {dummy.shape}')

Current shape is : torch.Size([16, 3, 224, 224])


# MHA base method

1. Embedd using the base patch embedding method
2. Apply it to the MHA

In [4]:
st = time.time()
embedding = base_embedding(input_size=dummy.shape,
                           patch_size = 16,
                           embed_dim= 16*16*3,
                           bias=True,
                           device=device,
                           ignore_modes=None).to(device)
embedded_base = embedding(dummy)
elapsed = time.time() - st
print(f'output shape of patch embedding : {embedded_base.shape}\nnumber of patches = {embedded_base.shape[1]}')

print(f'This embedding has {count_parameters(embedding)} parameters')
print(f'This embedding took : {elapsed}')

output shape of patch embedding : torch.Size([16, 196, 768])
number of patches = 196
This embedding has 590592 parameters
This embedding took : 0.264082670211792


In [5]:
for key in embedding.state_dict():
    print(key)

projection.weight
projection.bias


In [6]:
st = time.time()
MHA = base_MHA(input_size=dummy.shape,
               patch_size=16,
               embed_dim= 768,
               num_heads=4,
               bias=True,
               out_embed= True,
               device=device,
               ignore_modes=None
               ).to(device)
result_mha = MHA(embedded_base)
elapsed = time.time() - st
print(f'MHA output shape is : {result_mha.shape}')
print(f'Input of MHA shape was : {embedded_base.shape}')
print(f'This MHA has {count_parameters(MHA)} parameters')
print(f'This MHA took : {elapsed}')

MHA output shape is : torch.Size([16, 196, 768])
Input of MHA shape was : torch.Size([16, 196, 768])
This MHA has 2362368 parameters
This MHA took : 0.02145075798034668


In [7]:
for key in MHA.state_dict():
    print(key)

query.weight
query.bias
key.weight
key.bias
value.weight
value.bias
fc_out.weight
fc_out.bias


In [8]:
new_classifier = nn.Sequential(
    embedding,
    MHA,
    nn.Flatten(),
    nn.Linear(150528,2)
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(new_classifier.parameters())

temp_y = torch.randint(0, 2, (embedded_base.shape[0],)).to(device)

optimizer.zero_grad()    
outputs = new_classifier(dummy)
loss = criterion(outputs, temp_y)
loss.backward()
for p in MHA.parameters():
    print(p.shape, p.device, p.grad.device)
optimizer.step()

print('second backward')
optimizer.zero_grad()    
outputs = new_classifier(dummy)
loss = criterion(outputs, temp_y)
loss.backward()
for p in MHA.parameters():
    print(p.shape, p.device, p.grad.device)
optimizer.step()

torch.Size([768, 768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768, 768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768, 768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768, 768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
second backward
torch.Size([768, 768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768, 768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768, 768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768, 768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0


# MHA tensorized method

1. Embedd using the tensorized patch embedding method
2. Apply it to the MHA

In [9]:
st = time.time()
embedding = tensorized_embedding(input_size=dummy.shape,
                                 patch_size = 16,
                                 embed_dim= (16,16,3),
                                 bias=True,
                                 device=device,
                                 ignore_modes=(0,1,2)).to(device)
embedded_tensorized = embedding(dummy)
elapsed = time.time() - st
print(f'output shape of patch embedding : {embedded_tensorized.shape}\nnumber pf patches = {embedded_tensorized.shape[1]} x {embedded_tensorized.shape[2]} = {embedded_tensorized.shape[1]*embedded_tensorized.shape[2]}')

print(f'This embedding has {count_parameters(embedding)} parameters')
print(f'This embedding took : {elapsed}')

output shape of patch embedding : torch.Size([16, 14, 14, 16, 16, 3])
number pf patches = 14 x 14 = 196
This embedding has 1289 parameters
This embedding took : 0.0006999969482421875


In [10]:
for key in embedding.state_dict():
    print(key)

tcl.b
tcl.u0
tcl.u1
tcl.u2


In [11]:
st = time.time()
MHA = tensorized_MHA(input_size=dummy.shape,
                     patch_size=16,  
                     embed_dim= (16,16,3),
                     num_heads=(2,2,1),
                     bias=True,
                     out_embed= True,
                     device=device,
                     ignore_modes=(0,1,2)).to(device)
result_mha = MHA(embedded_tensorized)
elapsed = time.time() - st
print(f'MHA output shape is : {result_mha.shape}')
print(f'Input of MHA shape was : {embedded_tensorized.shape}')
print(f'This MHA has {count_parameters(MHA)} parameters')
print(f'This MHA took : {elapsed}')

MHA output shape is : torch.Size([16, 14, 14, 16, 16, 3])
Input of MHA shape was : torch.Size([16, 14, 14, 16, 16, 3])
This MHA has 5156 parameters
This MHA took : 0.0018830299377441406


In [12]:
for key in MHA.state_dict():
    print(key)

tcl_q.b
tcl_q.u0
tcl_q.u1
tcl_q.u2
tcl_k.b
tcl_k.u0
tcl_k.u1
tcl_k.u2
tcl_v.b
tcl_v.u0
tcl_v.u1
tcl_v.u2
tcl_out.b
tcl_out.u0
tcl_out.u1
tcl_out.u2


In [13]:
new_classifier = nn.Sequential(
    embedding,
    MHA,
    nn.Flatten(),
    nn.Linear(150528,2)
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(new_classifier.parameters())

temp_y = torch.randint(0, 2, (embedded_base.shape[0],)).to(device)

optimizer.zero_grad()    
outputs = new_classifier(dummy)
loss = criterion(outputs, temp_y)
loss.backward()
for p in MHA.parameters():
    print(p.shape, p.device, p.grad.device)
optimizer.step()

print('second backward')
optimizer.zero_grad()    
outputs = new_classifier(dummy)
loss = criterion(outputs, temp_y)
loss.backward()
for p in MHA.parameters():
    print(p.shape, p.device, p.grad.device)
optimizer.step()

torch.Size([16, 16, 3]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([3, 3]) cuda:0 cuda:0
torch.Size([16, 16, 3]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([3, 3]) cuda:0 cuda:0
torch.Size([16, 16, 3]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([3, 3]) cuda:0 cuda:0
torch.Size([16, 16, 3]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([3, 3]) cuda:0 cuda:0
second backward
torch.Size([16, 16, 3]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([3, 3]) cuda:0 cuda:0
torch.Size([16, 16, 3]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([3, 3]) cuda:0 cuda:0
torch.Size([16, 16, 3]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([3, 3]) cuda:0 cuda