In [8]:
from svd_compression import decompose_layers
from svd_compression.naive_svd import NaiveSVDApproximator
from training_utils.trainer import Trainer

import torch
from torch import nn
from torch.utils.data import random_split, DataLoader

from torchvision.datasets import ImageFolder
from torchvision import transforms

from transformers import AutoModelForImageClassification

In [4]:
ImageNetValidationTransformation = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet mean
        std=[0.229, 0.224, 0.225],   # ImageNet std
    ),
])

In [5]:
import pickle

with open('imgnet200_datasets.pkl', 'rb') as file:
    train_dataset, val_dataset, test_dataset = pickle.load(file)

In [6]:
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=True)

In [None]:
device = 'cuda:4'
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
model.load_state_dict(torch.load('.pt'))

num_params_wo_compression = sum(p.numel() for p in model.parameters())
model = decompose_layers(
    model,
    approximator_class=NaiveSVDApproximator,
    threshold=1,
    max_rank=192
)
num_params_with_compression = sum(p.numel() for p in model.parameters())
print(f'Compression ratio: {num_params_with_compression / num_params_wo_compression}')

model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

trainer = Trainer(
    model, 
    train_loader,
    val_loader, 
    optimizer, 
    scheduler, 
    device
)

{'initial_num_of_weights': 86567656, 'post_svd_num_of_weights': 47639272, 'after / before': 0.5503126017412323}
Gotcha!
Compression ratio: 0.5503126017412323


In [28]:
train_history, val_history = trainer.run(num_epochs=3)

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [29]:
trainer.val_step(k_for_top_k=5)

  0%|          | 0/40 [00:00<?, ?it/s]

{'accuracy': 0.7447917,
 'accuracy@5': 0.8428819444444444,
 'loss': 1.0370127161343892}