# Import the necessary libraries

In [1]:
import torch
from torch.utils.data import DataLoader, random_split
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os
import numpy as np

# Load the MNIST dataset

In [2]:
# Make torch deterministic
_ = torch.manual_seed(0)

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

mnist_train_valset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

dataset_size = len(mnist_train_valset)

train_size = int(0.9 * dataset_size)
val_size = dataset_size - train_size

train_dataset, val_dataset = random_split(mnist_train_valset, [train_size, val_size])

batch_size = 10

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(mnist_testset, batch_size=batch_size, shuffle=True)

device = "cpu"

In [4]:
from Model import MLP, train, test, print_size_of_model

In [5]:
net = MLP(input_size=28*28, output_size=10).to(device)

# Train the model

In [6]:
MODEL_FILENAME = './Saved_models/Mnist.pt'

if Path(MODEL_FILENAME).exists():
    net.load_state_dict(torch.load(MODEL_FILENAME))
    print('Loaded model from disk')
else:
    train(net, train_loader, val_loader, 1000, 5, device=device)
    torch.save(net.state_dict(), MODEL_FILENAME)

Loaded model from disk


# Print weights and size of the model before quantization

In [7]:
print(f'Accuracy of the model before quantization: {test(net, test_loader)}')

  0%|          | 0/1000 [00:00<?, ?it/s]

Accuracy of the model before quantization: 97.37


# Factorization

In [8]:
from Factorization import factorize

In [9]:
# Выполните эту ячейку, если будете что-то менять в Factorize.py, так ipynb импортирует новые изменения
# Если этого не сделать, то придется перезагружать ядро, что не очень удобно

import importlib
import Factorization

importlib.reload(Factorization)
from Factorization import factorize, SVD_quant

## NN approximation

In [10]:
reduction_rate = 3

In [11]:
import re

net_state_copy = net.state_dict().copy()
linear_weights_keys = [layer for layer in net_state_copy if 'linear' in layer and '.weight' in layer]

for layer in linear_weights_keys[1:-1]:
    print(layer)
    W = net.state_dict()[layer].detach()
    # net_state_copy[layer] = factorize(W, torch.linalg.matrix_rank(W) // reduction_rate, method='min-max')
    net_state_copy[layer] = SVD_quant(W, torch.linalg.matrix_rank(W) // reduction_rate, method='min-max')

linear2.weight


In [12]:
new_net = MLP(input_size=28*28, output_size=10).to(device)
new_net.load_state_dict(net_state_copy)

test(new_net, test_loader)

  0%|          | 0/1000 [00:00<?, ?it/s]

83.91

## Calculation of memory usage

In [13]:
import re

net_state_copy = net.state_dict().copy()
linear_weights_keys = [layer for layer in net_state_copy if 'linear' in layer and '.weight' in layer]

mem_usage_init = []
mem_usage_compressed = []

for layer in linear_weights_keys[1:-1]:
    W = net.state_dict()[layer].detach()
    A, B, sc = factorize(W, torch.linalg.matrix_rank(W) // reduction_rate, return_int=True)

    mem_usage_init.append(W.element_size() * W.numel())
    mem_usage_compressed.append(A.element_size() * A.numel() + B.element_size() * B.numel())

print(f'Initial (byte): {sum(mem_usage_init)}')
print(f'Compressed (byte): {sum(mem_usage_compressed)}')

rank: 33
Initial (byte): 40000
Compressed (byte): 6600
