# Brevitas

Brevitas alows us to train quantized NN. This tool is very useful and future tools used in the course are based on this.

This notebook serves as an introducion to quantizing NNs through the crration (and training) on a fully quantized MNIST classifier.

I opteed for a more robust architecture this time to avoid low precision.

In [10]:
from torch.nn import Module, Flatten
import torch.nn.functional as F

import brevitas.nn as qnn
from brevitas.quant import Int8Bias

### The model itself.

- As it is a fully quantized model, we introduduce a quntidentity to quantize the input (4 bit activation)
- All the data passing through this network will be quantized until the output ass all operation are int

In [2]:
class QuantWeightActBiasLeNet(Module):
    def __init__(self):
        super(QuantWeightActBiasLeNet, self).__init__()
        self.quant_inp = qnn.QuantIdentity(bit_width=4, return_quant_tensor=True)
        self.fc1   = qnn.QuantLinear(28*28, 64, bias=True, weight_bit_width=4, bias_quant=Int8Bias)
        self.relu3 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True)
        self.fc2   = qnn.QuantLinear(64, 64, bias=True, weight_bit_width=4, bias_quant=Int8Bias)
        self.relu4 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True)
        self.fc3   = qnn.QuantLinear(64, 10, bias=True, weight_bit_width=4, bias_quant=Int8Bias)


    def forward(self, x):
        out = self.quant_inp(x)
        out = out.reshape(out.shape[0], -1)
        out = self.relu3(self.fc1(out))
        out = self.relu4(self.fc2(out))
        out = self.fc3(out)
        return out

quant_weight_act_bias_lenet = QuantWeightActBiasLeNet()


### Some inspections

Lets play around with the layers, see what they have that's so special !

In [13]:
model

QuantWeightActBiasLeNet(
  (quant_inp): QuantIdentity(
    (input_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (act_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
      (fused_activation_quant_proxy): FusedActivationQuantProxy(
        (activation_impl): Identity()
        (tensor_quant): RescalingIntQuant(
          (int_quant): IntQuant(
            (float_to_int_impl): RoundSte()
            (tensor_clamp_impl): TensorClamp()
            (delay_wrapper): DelayWrapper(
              (delay_impl): _NoDelay()
            )
          )
          (scaling_impl): ParameterFromRuntimeStatsScaling(
            (stats_input_view_shape_impl): OverTensorView()
            (stats): _Stats(
              (stats_impl): AbsPercentile()
            )
            (restrict_scaling): _RestrictValue(
              (restrict_value_impl): FloatRestrictValue()
            )
            (clamp_scaling): _ClampValue(
            

In [30]:
print(model.fc1.weight)
print(model.fc1.quant_weight())
print(model.fc1.quant_weight().int())
print(model.fc1.quant_weight().int().dtype)

Parameter containing:
tensor([[ 0.0172, -0.0245,  0.0327,  ...,  0.0017,  0.0097,  0.0283],
        [-0.0331, -0.0084, -0.0166,  ...,  0.0080,  0.0136, -0.0135],
        [-0.0155,  0.0343, -0.0267,  ...,  0.0319,  0.0018,  0.0105],
        ...,
        [-0.0203, -0.0080, -0.0112,  ..., -0.0201,  0.0035,  0.0320],
        [-0.0194, -0.0177,  0.0003,  ..., -0.0292, -0.0258, -0.0317],
        [-0.0167, -0.0203, -0.0191,  ..., -0.0003,  0.0342,  0.0018]],
       requires_grad=True)
QuantTensor(value=tensor([[ 0.0153, -0.0255,  0.0306,  ...,  0.0000,  0.0102,  0.0306],
        [-0.0306, -0.0102, -0.0153,  ...,  0.0102,  0.0153, -0.0153],
        [-0.0153,  0.0357, -0.0255,  ...,  0.0306,  0.0000,  0.0102],
        ...,
        [-0.0204, -0.0102, -0.0102,  ..., -0.0204,  0.0051,  0.0306],
        [-0.0204, -0.0153,  0.0000,  ..., -0.0306, -0.0255, -0.0306],
        [-0.0153, -0.0204, -0.0204,  ..., -0.0000,  0.0357,  0.0000]],
       grad_fn=<MulBackward0>), scale=tensor(0.0051, grad_fn=<Div

### Training and testing

sameprinciples as studied previously

In [11]:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Data preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)



In [12]:
from torch import nn
import torch.optim as optim

# Model, loss function, and optimizer
model = QuantWeightActBiasLeNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [6]:
# Training loop
for epoch in range(5):  # Train for 5 epochs
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')



Epoch 1, Loss: 0.052803657948970795
Epoch 2, Loss: 0.21269558370113373
Epoch 3, Loss: 0.06838513910770416
Epoch 4, Loss: 0.009453174658119678
Epoch 5, Loss: 0.007582823280245066


In [7]:
# Testing loop
import torch
model.eval()
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

accuracy = 100. * correct / len(test_loader.dataset)
print(f'Test Accuracy: {accuracy:.2f}%')


Test Accuracy: 96.97%


# Learn more

You can learn more about quantizing your model here : [Quant getting started](https://xilinx.github.io/brevitas/getting_started.html)

This documentation will introduction you to weight-only quantization all the way to full quantization in a simple lighthearted way .

We will also have a lot of tie during the lab where we'll take time to slow down and look at what's happenning. Stay tuned !