In [1]:
import torch 
from transformers import ViTImageProcessor, ViTForImageClassification

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
processor = ViTImageProcessor.from_pretrained('./vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('./vit-base-patch16-224')

In [3]:
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')

86,567,656 total parameters.


In [4]:
compressed_keys = []

for name, param in model.named_parameters():
    # if 'encoder' in name and 'weight' in name and 'norm' not in name:
    if ('query' in name) and 'weight' in name:
        compressed_keys.append(name)

print(compressed_keys[:5])

['vit.encoder.layer.0.attention.attention.query.weight', 'vit.encoder.layer.1.attention.attention.query.weight', 'vit.encoder.layer.2.attention.attention.query.weight', 'vit.encoder.layer.3.attention.attention.query.weight', 'vit.encoder.layer.4.attention.attention.query.weight']


In [5]:
compressed_params = 0
for name in compressed_keys:
    q_weight = model.state_dict()[name]
    k_weight = model.state_dict()[name.replace('query', 'key')]
    compressed_params += q_weight.numel() + k_weight.numel()
    qk_weight = q_weight.T @ k_weight
    u, s, v = torch.svd_lowrank(qk_weight, q=100)
    # new_q_weight_t = u @ torch.diag(s) @ v.T
    new_q_weight_t = qk_weight.T 
    new_k_weight = torch.eye(k_weight.size(0), k_weight.size(1), device=k_weight.device)
    model.state_dict()[name].copy_(new_q_weight_t.T)
    model.state_dict()[name.replace('query', 'key')].copy_(new_k_weight)
    compressed_params -= u.numel() + s.numel() + v.numel()

print(f'{compressed_params:,} compressed parameters.')

12,311,376 compressed parameters.


In [6]:
print(f'model parameters: {total_params-compressed_params:,}, ratio: {(total_params-compressed_params)/total_params:.2f}')

model parameters: 74,256,280, ratio: 0.86


In [7]:
model_name = 'vit-base-patch16-224-svd-QK16'
model.save_pretrained(model_name)
processor.save_pretrained(model_name)

['vit-base-patch16-224-svd-QK16/preprocessor_config.json']