# Custom Build an 8-Bit Quantizer
Here we will create a quantizer which can quantize any model in 8-bit precision using per channel quantization scheme. The quantizer is modality agnostic meaning we can apply it on vision, audio, text, and even multimodal models.<br>
- **Step 1 :-** creating a **`W8A16LinearLayer`** class to store 8-bit weights and scales
- **Step 2 :-** replacing all **`torch.nn.Linear`** layers with **`W8A16LinearLayer`**
- **Step 3 :-** building a quantizer and quantize a model end-to-end
- **Step 4 :-** test the naive absmax quantization on many scenario and study its impact

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
random_int8 = torch.randint(-128, 127, (32, 16)).to(torch.int8)
random_hs = torch.randn((1, 16), dtype=torch.bfloat16)
scales = torch.randn((1, 32), dtype=torch.bfloat16)
bias = torch.randn((1, 32), dtype=torch.bfloat16)

**NOTE:** weight matrix has the shape **(output dimension, input dimension)**. When we perform the matrix mulitplication between the int8 matrix and the hidden states, we will have a vector of batch size output dimension. So it is important that the scales have the same shape as the output shape of the weight matrix and same for the bias.

In [3]:
F.linear(random_hs, random_int8.to(random_hs.dtype))

tensor([[-148.0000,  -56.2500,  192.0000,   74.5000,  103.0000, -189.0000,
          324.0000,  -29.8750,  250.0000,   33.5000,  -50.2500,  -67.5000,
          -69.5000,  316.0000,  -37.2500,   16.0000,  100.0000, -262.0000,
          632.0000,  -53.2500,  428.0000,  388.0000,  -31.2500,   -6.2188,
          -98.5000, -382.0000, -232.0000, -296.0000, -332.0000, -164.0000,
           20.0000,  164.0000]], dtype=torch.bfloat16)

First we have cast the weights into the same data type as the hidden states. Then on top of this we will perform matrix multiplication via **`F.linear()`** function from PyTorch.

In [4]:
F.linear(random_hs, random_int8.to(random_hs.dtype)) * scales

tensor([[ 2.3875e+01, -5.3750e+01,  5.5600e+02, -5.7250e+01, -1.1500e+02,
          4.7000e+02, -2.9000e+02,  3.1750e+01, -9.5000e+01,  2.0250e+01,
         -3.5625e+00,  6.0500e+01, -7.5195e-02, -5.8800e+02,  7.2500e+01,
         -3.1375e+01, -1.9500e+02,  3.2000e+01, -9.4000e+02,  1.1406e+00,
          3.9000e+02,  4.1000e+02, -7.8750e+00,  5.1562e+00,  3.3000e+01,
          4.9250e+01,  7.0500e+01, -2.7600e+02,  1.0900e+02,  1.9600e+02,
          3.3500e+01, -6.0500e+01]], dtype=torch.bfloat16)

In [5]:
(F.linear(random_hs, random_int8.to(random_hs.dtype)) * scales) + bias

tensor([[ 2.3000e+01, -5.4500e+01,  5.5600e+02, -5.7250e+01, -1.1600e+02,
          4.7000e+02, -2.9000e+02,  3.1000e+01, -9.6000e+01,  2.0250e+01,
         -2.5938e+00,  5.9500e+01,  2.4512e-01, -5.8800e+02,  7.3500e+01,
         -3.1125e+01, -1.9400e+02,  3.2000e+01, -9.4000e+02,  1.1953e+00,
          3.9000e+02,  4.1000e+02, -6.7812e+00,  6.0625e+00,  3.3250e+01,
          5.0500e+01,  6.9000e+01, -2.7800e+02,  1.0850e+02,  1.9600e+02,
          3.4500e+01, -6.1250e+01]], dtype=torch.bfloat16)

Then we will multiply this with the input scales and optionally add a bias term

In [6]:
def w8_a16_forward(weight, input, scales, bias=None):
    
    casted_weights = weight.to(input.dtype)
    output = F.linear(input, casted_weights) * scales
    
    if bias is not None:
        output = output + bias
      
    return output

In [7]:
print("With bias:\n\n", 
      w8_a16_forward(random_int8, random_hs, scales, bias))

print("\nWithout bias:\n\n", 
      w8_a16_forward(random_int8, random_hs, scales))

With bias:

 tensor([[ 2.3000e+01, -5.4500e+01,  5.5600e+02, -5.7250e+01, -1.1600e+02,
          4.7000e+02, -2.9000e+02,  3.1000e+01, -9.6000e+01,  2.0250e+01,
         -2.5938e+00,  5.9500e+01,  2.4512e-01, -5.8800e+02,  7.3500e+01,
         -3.1125e+01, -1.9400e+02,  3.2000e+01, -9.4000e+02,  1.1953e+00,
          3.9000e+02,  4.1000e+02, -6.7812e+00,  6.0625e+00,  3.3250e+01,
          5.0500e+01,  6.9000e+01, -2.7800e+02,  1.0850e+02,  1.9600e+02,
          3.4500e+01, -6.1250e+01]], dtype=torch.bfloat16)

Without bias:

 tensor([[ 2.3875e+01, -5.3750e+01,  5.5600e+02, -5.7250e+01, -1.1500e+02,
          4.7000e+02, -2.9000e+02,  3.1750e+01, -9.5000e+01,  2.0250e+01,
         -3.5625e+00,  6.0500e+01, -7.5195e-02, -5.8800e+02,  7.2500e+01,
         -3.1375e+01, -1.9500e+02,  3.2000e+01, -9.4000e+02,  1.1406e+00,
          3.9000e+02,  4.1000e+02, -7.8750e+00,  5.1562e+00,  3.3000e+01,
          4.9250e+01,  7.0500e+01, -2.7600e+02,  1.0900e+02,  1.9600e+02,
          3.3500e+01, -

In [8]:
### running this will result in an error
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
        super().__init__()
        
        self.int8_weights = nn.Parameter(torch.Tensor([0, 1]).to(dtype=torch.int8))

try:
    
    W8A16LinearLayer(1, 1)
    
except Exception as error:
    print("\033[91m", type(error).__name__, ": ", error, "\033[0m")

[91m RuntimeError :  Only Tensors of floating point and complex dtype can require gradients [0m


When we create an **`nn.parameter`** layer, PyTorch expects that parameter where it is able to compute gradients on it. We can't explicitly compute gradients on **int8 tensors** yet. So we should get an error.

In [9]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
        super().__init__()
        
        self.register_buffer(
            "int8_weights",
            torch.randint(-128, 127, (out_features, in_features), dtype=torch.int8))
        
        self.register_buffer("scales", torch.randn((out_features), dtype=dtype))
        
        if bias:
            self.register_buffer("bias", torch.randn((1, out_features), dtype=dtype))
        else:
            self.bias = None

This is the right approach to store int8 weights is instead of saving attributes as being an endless parameter, is to call a method **`register_buffer()`**. This way instead of storing a parameter, we just store a buffer means we don't need to compute gradients on the tensor, and we can initialize it with whatever dtype we want.

In [10]:
dummy_instance = W8A16LinearLayer(16, 32)

In [11]:
print(dummy_instance.int8_weights.shape)
print(dummy_instance.scales.shape)

torch.Size([32, 16])
torch.Size([32])


Creating a forward pass for the class

In [12]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
        super().__init__()
        
        
        self.register_buffer("int8_weights",torch.randint(-128, 127, (out_features, in_features), dtype=torch.int8))
        
        self.register_buffer("scales", torch.randn((out_features), dtype=dtype))
        
        if bias:
            self.register_buffer("bias", torch.randn((1, out_features), dtype=dtype))
        else:
            self.bias = None

    def forward(self, input):
        return w8_a16_forward(self.int8_weights, input, self.scales, self.bias)

In [13]:
module = W8A16LinearLayer(16, 32)
dummy_hidden_states = torch.randn(1, 6, 16)

In [14]:
module(dummy_hidden_states).shape

torch.Size([1, 6, 32])

In [15]:
module(dummy_hidden_states).dtype

torch.float32

We have a linear layer which is working fine and a forward pass.

In [16]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
        super().__init__()
        
        self.register_buffer("int8_weights", torch.randint(-128, 127, (out_features, in_features), dtype=torch.int8))
        
        self.register_buffer("scales", torch.randn((out_features), dtype=dtype))
        
        if bias:
            self.register_buffer("bias", torch.randn((1, out_features), dtype=dtype))
        else:
            self.bias = None

    def quantize(self, weights):
        w_fp32 = weights.clone().to(torch.float32)

        scales = w_fp32.abs().max(dim=-1).values / 127
        scales = scales.to(weights.dtype)

        int8_weights = torch.round(weights/scales.unsqueeze(1)).to(torch.int8)

        self.int8_weights = int8_weights
        self.scales = scales
    
    def forward(self, input):
        return w8_a16_forward(self.int8_weights, input, self.scales, self.bias)      

Here we will add the quantization method. So first upcast the weights into FP32 then find the scale value and make sure that scale has same dtype as input weights. Then using the formula find the int8 weights.

In [17]:
module = W8A16LinearLayer(4, 8)

In [18]:
print("Weights before:\n" , module.int8_weights)

Weights before:
 tensor([[  24,  104,   -4,   70],
        [   1,  116,  -41,   90],
        [  85, -125,   43,   49],
        [ -60,   25, -125,   85],
        [  22, -103,  125, -106],
        [-103,  -39,   84, -124],
        [  45,  -84,   55,   -1],
        [ -80,  -81,  -53,  -11]], dtype=torch.int8)


In [19]:
random_matrix = torch.randn((4, 8), dtype=torch.bfloat16)

In [20]:
module.quantize(random_matrix)

In [21]:
print("Weights After:\n" , module.int8_weights)

Weights After:
 tensor([[ -16,  -23,  127,    1,   26,   11,  -28,    9],
        [  95,  -14,   24,  -23,    4, -105, -127,  -43],
        [  41,  -49,  -88,   46, -128,   72,  -34,  104],
        [ -34,  -53,  -53,   66, -128,   19,    1,  -82]], dtype=torch.int8)


In [22]:
module.scales

tensor([0.0187, 0.0146, 0.0107, 0.0114], dtype=torch.bfloat16)

In [23]:
module.scales.shape

torch.Size([4])

In [24]:
module.int8_weights.shape

torch.Size([4, 8])

In [25]:
### dequantized weights
module.int8_weights * module.scales.unsqueeze(1)

tensor([[-0.2988, -0.4297,  2.3750,  0.0187,  0.4863,  0.2051, -0.5234,  0.1680],
        [ 1.3906, -0.2051,  0.3516, -0.3359,  0.0586, -1.5391, -1.8594, -0.6289],
        [ 0.4414, -0.5273, -0.9453,  0.4941, -1.3750,  0.7734, -0.3652,  1.1172],
        [-0.3867, -0.6016, -0.6016,  0.7500, -1.4531,  0.2158,  0.0114, -0.9297]],
       dtype=torch.bfloat16)

In [26]:
### original weights
random_matrix

tensor([[-0.3086, -0.4355,  2.3750,  0.0231,  0.4941,  0.2031, -0.5273,  0.1641],
        [ 1.3906, -0.2031,  0.3516, -0.3340,  0.0608, -1.5391, -1.8594, -0.6250],
        [ 0.4414, -0.5234, -0.9531,  0.4961, -1.3672,  0.7773, -0.3594,  1.1250],
        [-0.3887, -0.6055, -0.5977,  0.7461, -1.4453,  0.2139,  0.0116, -0.9258]],
       dtype=torch.bfloat16)

In [27]:
(random_matrix - module.int8_weights * module.scales.unsqueeze(1)).abs().mean()

tensor(0.0036, dtype=torch.bfloat16)