# Add to the path

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

# Import Libraries

In [2]:
from Models.vit_base import VisionTransformer as vit_base
from Models.vit_tensorized import VisionTransformer as vit_tensorized
from Utils.Num_parameter import count_parameters

import torch
import torch.nn as nn
from torch import optim
import time

# Dummy Data

In [3]:
device = 'cuda'
batch_size = 16
dummy = torch.rand(batch_size, 3, 224, 224).to(device)
print(f'Current shape is : {dummy.shape}')

Current shape is : torch.Size([16, 3, 224, 224])


# VIT base method

In [4]:
vit = vit_base(input_size=dummy.shape,
                patch_size=16,
                num_classes=1000,
                embed_dim=16*16*3,
                num_heads=12,
                num_layers=12,
                mlp_dim=1024,
                dropout=0.1,
                bias=True,
                out_embed=True,
                device=device,
                ignore_modes=None,
                Tensorized_mlp=False).to(device)

st = time.time()
output = vit(dummy)
elapsed = time.time() - st
print(f'output shape of vit : {output.shape}')

print(f'This vit has {count_parameters(vit)} parameters')
print(f'This vit took : {elapsed}')

output shape of vit : torch.Size([16, 1000])
This vit has 48793576 parameters
This vit took : 0.29781436920166016


In [5]:
for key in vit.state_dict():
    print(key)

pos_embedding
cls_token
patch_embedding.projection.weight
patch_embedding.projection.bias
transformer.0.norm1.weight
transformer.0.norm1.bias
transformer.0.norm2.weight
transformer.0.norm2.bias
transformer.0.attention.query.weight
transformer.0.attention.query.bias
transformer.0.attention.key.weight
transformer.0.attention.key.bias
transformer.0.attention.value.weight
transformer.0.attention.value.bias
transformer.0.attention.fc_out.weight
transformer.0.attention.fc_out.bias
transformer.0.mlp.0.weight
transformer.0.mlp.0.bias
transformer.0.mlp.2.weight
transformer.0.mlp.2.bias
transformer.1.norm1.weight
transformer.1.norm1.bias
transformer.1.norm2.weight
transformer.1.norm2.bias
transformer.1.attention.query.weight
transformer.1.attention.query.bias
transformer.1.attention.key.weight
transformer.1.attention.key.bias
transformer.1.attention.value.weight
transformer.1.attention.value.bias
transformer.1.attention.fc_out.weight
transformer.1.attention.fc_out.bias
transformer.1.mlp.0.weight

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit.parameters())

temp_y = torch.randint(0, 1000, (dummy.shape[0],)).to(device)

optimizer.zero_grad()    
outputs = vit(dummy)
loss = criterion(outputs, temp_y)
loss.backward()
for p in vit.parameters():
    print(p.shape, p.device, p.grad.device)
optimizer.step()

print('second backward')
optimizer.zero_grad()    
outputs = vit(dummy)
loss = criterion(outputs, temp_y)
loss.backward()
for p in vit.parameters():
    print(p.shape, p.device, p.grad.device)
optimizer.step()

torch.Size([1, 196, 768]) cuda:0 cuda:0
torch.Size([1, 1, 768]) cuda:0 cuda:0
torch.Size([768, 768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768, 768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768, 768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768, 768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768, 768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([1024, 768]) cuda:0 cuda:0
torch.Size([1024]) cuda:0 cuda:0
torch.Size([768, 1024]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768, 768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768, 768]) cuda:0 cuda:0
torch.Size([768]) cuda:0 cuda:0
torch.Size([768, 768]) cuda:0 cuda:0
torch

# VIT tensorized method

In [7]:
vit = vit_tensorized(input_size=dummy.shape,
                patch_size=16,
                num_classes=1000,
                embed_dim=(16,16,3),
                num_heads=(2,2,3),
                num_layers=12,
                mlp_dim=(16,16,4),
                dropout=0.1,
                bias=True,
                out_embed=True,
                device=device,
                ignore_modes=(0,1,2),
                Tensorized_mlp=True).to(device)

st = time.time()
output = vit(dummy)
elapsed = time.time() - st
print(f'output shape of vit : {output.shape}')

print(f'This vit has {count_parameters(vit)} parameters')
print(f'This vit took : {elapsed}')

output shape of vit : torch.Size([16, 1000])
This vit has 11499942 parameters
This vit took : 0.019927263259887695


In [8]:
for key in vit.state_dict():
    print(key)

pos_embedding
cls_token
patch_embedding.tcl.b
patch_embedding.tcl.u0
patch_embedding.tcl.u1
patch_embedding.tcl.u2
transformer.0.norm1.weight
transformer.0.norm1.bias
transformer.0.norm2.weight
transformer.0.norm2.bias
transformer.0.attention.tcl_q.b
transformer.0.attention.tcl_q.u0
transformer.0.attention.tcl_q.u1
transformer.0.attention.tcl_q.u2
transformer.0.attention.tcl_k.b
transformer.0.attention.tcl_k.u0
transformer.0.attention.tcl_k.u1
transformer.0.attention.tcl_k.u2
transformer.0.attention.tcl_v.b
transformer.0.attention.tcl_v.u0
transformer.0.attention.tcl_v.u1
transformer.0.attention.tcl_v.u2
transformer.0.attention.tcl_out.b
transformer.0.attention.tcl_out.u0
transformer.0.attention.tcl_out.u1
transformer.0.attention.tcl_out.u2
transformer.0.mlp.0.b
transformer.0.mlp.0.u0
transformer.0.mlp.0.u1
transformer.0.mlp.0.u2
transformer.0.mlp.2.b
transformer.0.mlp.2.core
transformer.0.mlp.2.u0
transformer.0.mlp.2.u1
transformer.0.mlp.2.u2
transformer.0.mlp.2.u3
transformer.0.mlp.2

In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit.parameters())

temp_y = torch.randint(0, 1000, (dummy.shape[0],)).to(device)

optimizer.zero_grad()    
outputs = vit(dummy)
loss = criterion(outputs, temp_y)
loss.backward()
for p in vit.parameters():
    print(p.shape, p.device, p.grad.device)
optimizer.step()

print('second backward')
optimizer.zero_grad()    
outputs = vit(dummy)
loss = criterion(outputs, temp_y)
loss.backward()
for p in vit.parameters():
    print(p.shape, p.device, p.grad.device)
optimizer.step()

torch.Size([1, 14, 14, 16, 16, 3]) cuda:0 cuda:0
torch.Size([1, 1, 1, 16, 16, 3]) cuda:0 cuda:0
torch.Size([16, 16, 3]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([3, 3]) cuda:0 cuda:0
torch.Size([16, 16, 3]) cuda:0 cuda:0
torch.Size([16, 16, 3]) cuda:0 cuda:0
torch.Size([16, 16, 3]) cuda:0 cuda:0
torch.Size([16, 16, 3]) cuda:0 cuda:0
torch.Size([16, 16, 3]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([3, 3]) cuda:0 cuda:0
torch.Size([16, 16, 3]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([3, 3]) cuda:0 cuda:0
torch.Size([16, 16, 3]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([3, 3]) cuda:0 cuda:0
torch.Size([16, 16, 3]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([16, 16]) cuda:0 cuda:0
torch.Size([3, 3]) cuda:0 cuda:0
torch.Size([16, 16, 4]) cuda:0 cuda:0
torch.Siz