# Add to the path

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

# Import Libraries

In [2]:
from Models.basic_components.patch_embedding import PatchEmbedding as base_embedding
from Models.tensorized_components.patch_embedding import PatchEmbedding as tensorized_embedding
from Utils.Num_parameter import count_parameters

import torch
import torch.nn as nn
from torch import optim
import time

# Dummy Data

In [3]:
device = 'cuda'
batch_size = 16
dummy = torch.rand(batch_size, 3, 224, 224).to(device)
print(f'Current shape is : {dummy.shape}')

Current shape is : torch.Size([16, 3, 224, 224])


# Patch Embedding Base method

In [4]:
st = time.time()
embedding = base_embedding(input_size=dummy.shape,
                           patch_size = 16,
                           embed_dim= 16*16*3,
                           bias=True,
                           device=device,
                           ignore_modes=None).to(device)
embedded = embedding(dummy)
elapsed = time.time() - st
print(f'output shape of patch embedding : {embedded.shape}\nnumber of patches = {embedded.shape[1]}')

print(f'This embedding has {count_parameters(embedding)} parameters')
print(f'This embedding took : {elapsed}')

output shape of patch embedding : torch.Size([16, 196, 768])
number of patches = 196
This embedding has 590592 parameters
This embedding took : 0.26371049880981445


In [5]:
for key in embedding.state_dict():
    print(key)

projection.weight
projection.bias


In [6]:
new_classifier = nn.Sequential(
    embedding,
    nn.Flatten(),
    nn.Linear(150528,2)
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(new_classifier.parameters())

temp_y = torch.randint(0, 2, (embedded.shape[0],)).to(device)

optimizer.zero_grad()    
outputs = new_classifier(dummy)
loss = criterion(outputs, temp_y)
loss.backward()
for p in embedding.parameters():
    print(p.shape, p.device, p.grad.device)
optimizer.step()

print('second backward')
optimizer.zero_grad()    
outputs = new_classifier(dummy)
loss = criterion(outputs, temp_y)
loss.backward()
for p in embedding.parameters():
    print(p.shape, p.device, p.grad.device)
optimizer.step()

torch.Size([768, 768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
second backward
torch.Size([768, 768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0


# Patch Embedding Tensorized method

In [7]:
st = time.time()
embedding = tensorized_embedding(input_size=dummy.shape,
                                 patch_size = 16,
                                 embed_dim= (16,16,3),
                                 bias=True,
                                 device=device,
                                 ignore_modes=(0,1,2)).to(device)
embedded = embedding(dummy)
elapsed = time.time() - st
print(f'output shape of patch embedding : {embedded.shape}\nnumber pf patches = {embedded.shape[1]} x {embedded.shape[2]} = {embedded.shape[1]*embedded.shape[2]}')

print(f'This embedding has {count_parameters(embedding)} parameters')
print(f'This embedding took : {elapsed}')

output shape of patch embedding : torch.Size([16, 14, 14, 16, 16, 3])
number pf patches = 14 x 14 = 196
This embedding has 1289 parameters
This embedding took : 0.006803750991821289


In [8]:
for key in embedding.state_dict():
    print(key)

tcl.b
tcl.u0
tcl.u1
tcl.u2


In [9]:
new_classifier = nn.Sequential(
    embedding,
    nn.Flatten(),
    nn.Linear(150528,2)
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(new_classifier.parameters())

temp_y = torch.randint(0, 2, (embedded.shape[0],)).to(device)

optimizer.zero_grad()    
outputs = new_classifier(dummy)
loss = criterion(outputs, temp_y)
loss.backward()
for p in embedding.parameters():
    print(p.shape, p.device, p.grad.device)
optimizer.step()

print('second backward')
optimizer.zero_grad()    
outputs = new_classifier(dummy)
loss = criterion(outputs, temp_y)
loss.backward()
for p in embedding.parameters():
    print(p.shape, p.device, p.grad.device)
optimizer.step()

torch.Size([16, 16, 3]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([3, 3]) cuda:0 cuda:0
second backward
torch.Size([16, 16, 3]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([3, 3]) cuda:0 cuda:0
