<a href="https://www.kaggle.com/code/skshmjn/simple-quantisation-qat?scriptVersionId=213516621" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

## Imports


In [1]:
import torch 
import torch.nn as nn 
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import torch.nn.utils.parametrize as parametrize

import os
from tqdm import tqdm 

In [2]:
device = 'cpu'
g = torch.Generator(device=device).manual_seed(545345320492802)

## DATASET


In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=1024, shuffle=True, generator=g)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=1024, shuffle=False, generator=g)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 15874026.81it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 469605.56it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4370297.81it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1830197.79it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






## MODEL

In [4]:
class SimpleNet(nn.Module):
    def __init__(self, hidden_size_1=50, hidden_size_2=50):
        super(SimpleNet,self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x
simple_net = SimpleNet().to(device)

## Training

In [5]:
epochs = 10

In [6]:
def train_model(model, train_data_loader, epochs):
    model.train()
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(epochs):
     
        epoch_loss = 0
    
        data_iterator = tqdm(train_data_loader, desc=f'Epoch {epoch+1}')
        
        for data in data_iterator:     
            X, y = data
            X = X.to(device)
            y = y.to(device)
            # Forward pass
            y_pred = model(X)
            loss = criterion(y_pred, y)
    
            epoch_loss += loss
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss/len(train_data_loader):.4f}")

In [7]:
train_model(simple_net, train_loader, 10)

Epoch 1: 100%|██████████| 59/59 [00:12<00:00,  4.54it/s]


Epoch 1/10 - Loss: 0.9927


Epoch 2: 100%|██████████| 59/59 [00:13<00:00,  4.53it/s]


Epoch 2/10 - Loss: 0.3260


Epoch 3: 100%|██████████| 59/59 [00:12<00:00,  4.57it/s]


Epoch 3/10 - Loss: 0.2606


Epoch 4: 100%|██████████| 59/59 [00:12<00:00,  4.58it/s]


Epoch 4/10 - Loss: 0.2195


Epoch 5: 100%|██████████| 59/59 [00:13<00:00,  4.45it/s]


Epoch 5/10 - Loss: 0.1908


Epoch 6: 100%|██████████| 59/59 [00:13<00:00,  4.48it/s]


Epoch 6/10 - Loss: 0.1684


Epoch 7: 100%|██████████| 59/59 [00:13<00:00,  4.48it/s]


Epoch 7/10 - Loss: 0.1518


Epoch 8: 100%|██████████| 59/59 [00:12<00:00,  4.55it/s]


Epoch 8/10 - Loss: 0.1386


Epoch 9: 100%|██████████| 59/59 [00:13<00:00,  4.50it/s]


Epoch 9/10 - Loss: 0.1301


Epoch 10: 100%|██████████| 59/59 [00:13<00:00,  4.42it/s]

Epoch 10/10 - Loss: 0.1190





In [8]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

# print_size_of_model(very_simple_net)

## Static Quantisation

In [9]:
simple_net.eval()
simple_net.qconfig = torch.quantization.default_qconfig
model_fp32_prepared = torch.quantization.prepare(simple_net)
model_int8_static = torch.quantization.convert(model_fp32_prepared)
model_int8_static = model_int8_static



## Dynamic Quantisation

In [10]:
model_int8_dynamic = torch.quantization.quantize_dynamic(
    simple_net.to(device), {nn.Linear}, dtype=torch.qint8
)

In [11]:

# Testing Loop
def test_model(model, test_loader):
    model = model.to(device)  
    correct = 0
    total = 0
    test_loss = 0.0
    criterion = nn.CrossEntropyLoss().to(device)
    model.eval()
    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)  

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            # Statistics
            # _, predicted = outputs.max(1)
            # total += labels.size(0)
            # correct += predicted.eq(labels).sum().item()
            for idx, i in enumerate(outputs):
                if torch.argmax(i) == labels[idx]:
                    correct +=1
                else:
                    wrong_counts[labels[idx]] +=1
                total +=1
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

    accuracy = 100. * correct / total
    print(f"Test Loss: {test_loss/len(test_loader):.4f}, Accuracy: {accuracy:.2f}%")


In [12]:
%%time
test_model(simple_net,test_loader) 
print_size_of_model(simple_net)

wrong counts for the digit 0: 19
wrong counts for the digit 1: 15
wrong counts for the digit 2: 28
wrong counts for the digit 3: 52
wrong counts for the digit 4: 37
wrong counts for the digit 5: 40
wrong counts for the digit 6: 26
wrong counts for the digit 7: 44
wrong counts for the digit 8: 40
wrong counts for the digit 9: 66
Test Loss: 0.1264, Accuracy: 96.33%
Size (KB): 171.878
CPU times: user 2.43 s, sys: 2.95 ms, total: 2.43 s
Wall time: 2.26 s


In [13]:
%%time
test_model(model_int8_static,test_loader)
print_size_of_model(model_int8_static)

wrong counts for the digit 0: 19
wrong counts for the digit 1: 21
wrong counts for the digit 2: 74
wrong counts for the digit 3: 33
wrong counts for the digit 4: 133
wrong counts for the digit 5: 164
wrong counts for the digit 6: 83
wrong counts for the digit 7: 131
wrong counts for the digit 8: 149
wrong counts for the digit 9: 176
Test Loss: 0.4161, Accuracy: 90.17%
Size (KB): 47.778
CPU times: user 2.5 s, sys: 13.9 ms, total: 2.52 s
Wall time: 2.34 s


In [14]:
%%time
test_model(model_int8_dynamic,test_loader)
print_size_of_model(model_int8_dynamic)

wrong counts for the digit 0: 19
wrong counts for the digit 1: 15
wrong counts for the digit 2: 27
wrong counts for the digit 3: 54
wrong counts for the digit 4: 37
wrong counts for the digit 5: 39
wrong counts for the digit 6: 26
wrong counts for the digit 7: 45
wrong counts for the digit 8: 40
wrong counts for the digit 9: 67
Test Loss: 0.1269, Accuracy: 96.31%
Size (KB): 47.202
CPU times: user 2.43 s, sys: 2.01 ms, total: 2.44 s
Wall time: 2.25 s


## Quantisation aware training 


In [15]:
qat_model = SimpleNet()

In [16]:
qat_model.qconfig = torch.quantization.default_qconfig
qat_model.train()
qat_model_quantized = torch.quantization.prepare_qat(qat_model) # Insert observers
qat_model_quantized

SimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=50, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=50, out_features=50, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=50, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [17]:
train_model(qat_model_quantized, train_loader, 10)

Epoch 1: 100%|██████████| 59/59 [00:13<00:00,  4.41it/s]


Epoch 1/10 - Loss: 0.9560


Epoch 2: 100%|██████████| 59/59 [00:13<00:00,  4.49it/s]


Epoch 2/10 - Loss: 0.3008


Epoch 3: 100%|██████████| 59/59 [00:13<00:00,  4.48it/s]


Epoch 3/10 - Loss: 0.2439


Epoch 4: 100%|██████████| 59/59 [00:13<00:00,  4.50it/s]


Epoch 4/10 - Loss: 0.2093


Epoch 5: 100%|██████████| 59/59 [00:12<00:00,  4.60it/s]


Epoch 5/10 - Loss: 0.1844


Epoch 6: 100%|██████████| 59/59 [00:13<00:00,  4.54it/s]


Epoch 6/10 - Loss: 0.1643


Epoch 7: 100%|██████████| 59/59 [00:13<00:00,  4.53it/s]


Epoch 7/10 - Loss: 0.1489


Epoch 8: 100%|██████████| 59/59 [00:12<00:00,  4.60it/s]


Epoch 8/10 - Loss: 0.1342


Epoch 9: 100%|██████████| 59/59 [00:13<00:00,  4.53it/s]


Epoch 9/10 - Loss: 0.1214


Epoch 10: 100%|██████████| 59/59 [00:12<00:00,  4.56it/s]

Epoch 10/10 - Loss: 0.1115





In [18]:
%%time
test_model(qat_model_quantized,test_loader )
print_size_of_model(qat_model_quantized)

wrong counts for the digit 0: 17
wrong counts for the digit 1: 19
wrong counts for the digit 2: 49
wrong counts for the digit 3: 45
wrong counts for the digit 4: 40
wrong counts for the digit 5: 39
wrong counts for the digit 6: 31
wrong counts for the digit 7: 35
wrong counts for the digit 8: 48
wrong counts for the digit 9: 47
Test Loss: 0.1220, Accuracy: 96.30%
Size (KB): 178.018
CPU times: user 2.39 s, sys: 981 µs, total: 2.39 s
Wall time: 2.2 s


In [19]:
qat_model_quantized.eval()
qat_model_quantized = torch.quantization.convert(qat_model_quantized)

In [20]:
%%time
test_model(qat_model_quantized,test_loader )
print_size_of_model(qat_model_quantized)

wrong counts for the digit 0: 12
wrong counts for the digit 1: 20
wrong counts for the digit 2: 43
wrong counts for the digit 3: 43
wrong counts for the digit 4: 40
wrong counts for the digit 5: 40
wrong counts for the digit 6: 32
wrong counts for the digit 7: 36
wrong counts for the digit 8: 48
wrong counts for the digit 9: 55
Test Loss: 0.1226, Accuracy: 96.31%
Size (KB): 47.778
CPU times: user 2.42 s, sys: 3.93 ms, total: 2.43 s
Wall time: 2.24 s


## LORA 


In [21]:
simple_model = SimpleNet(1000,2000)

In [22]:
train_model(simple_model, train_loader, 2)

Epoch 1: 100%|██████████| 59/59 [00:21<00:00,  2.68it/s]


Epoch 1/2 - Loss: 0.3615


Epoch 2: 100%|██████████| 59/59 [00:20<00:00,  2.82it/s]

Epoch 2/2 - Loss: 0.1034





In [23]:
test_model(simple_model,test_loader)

wrong counts for the digit 0: 10
wrong counts for the digit 1: 11
wrong counts for the digit 2: 30
wrong counts for the digit 3: 28
wrong counts for the digit 4: 18
wrong counts for the digit 5: 22
wrong counts for the digit 6: 19
wrong counts for the digit 7: 28
wrong counts for the digit 8: 33
wrong counts for the digit 9: 49
Test Loss: 0.0795, Accuracy: 97.52%


In [24]:
total_parameters_original = 0
for index, layer in enumerate([simple_model.linear1, simple_model.linear2, simple_model.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')


Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2,807,010


In [25]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device=device):
        super().__init__()
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)
        
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        
        return original_weights

In [26]:
def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )
parametrize.register_parametrization(
    simple_model.linear1, "weight", linear_layer_parameterization(simple_model.linear1, device)
)
parametrize.register_parametrization(
    simple_model.linear2, "weight", linear_layer_parameterization(simple_model.linear2, device)
)
parametrize.register_parametrization(
    simple_model.linear3, "weight", linear_layer_parameterization(simple_model.linear3, device)
)

def enable_disable_lora(enabled=True):
    for layer in [simple_model.linear1, simple_model.linear2, simple_model.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

In [27]:
for name, param in simple_model.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original


In [28]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 7
mnist_trainset.data = mnist_trainset.data[exclude_indices][:30]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices][:30]

train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)




In [29]:
train_model(simple_model, train_loader, epochs=1)

Epoch 1: 100%|██████████| 3/3 [00:00<00:00, 69.77it/s]

Epoch 1/1 - Loss: 0.0697





In [30]:
enable_disable_lora(enabled=True)
test_model(simple_model, test_loader)

wrong counts for the digit 0: 10
wrong counts for the digit 1: 13
wrong counts for the digit 2: 32
wrong counts for the digit 3: 31
wrong counts for the digit 4: 17
wrong counts for the digit 5: 23
wrong counts for the digit 6: 20
wrong counts for the digit 7: 17
wrong counts for the digit 8: 43
wrong counts for the digit 9: 71
Test Loss: 0.0898, Accuracy: 97.23%


In [31]:
enable_disable_lora(enabled=False)
test_model(simple_model, test_loader)

wrong counts for the digit 0: 10
wrong counts for the digit 1: 11
wrong counts for the digit 2: 30
wrong counts for the digit 3: 28
wrong counts for the digit 4: 18
wrong counts for the digit 5: 22
wrong counts for the digit 6: 19
wrong counts for the digit 7: 28
wrong counts for the digit 8: 33
wrong counts for the digit 9: 49
Test Loss: 0.0795, Accuracy: 97.52%
