# Custom Build an 8-Bit Quantizer
Here we will create a quantizer which can quantize any model in 8-bit precision using per channel quantization scheme. The quantizer is modality agnostic meaning we can apply it on vision, audio, text, and even multimodal models.<br>
- **Step 1 :-** creating a **`W8A16LinearLayer`** class to store 8-bit weights and scales
- **Step 2 :-** replacing all **`torch.nn.Linear`** layers with **`W8A16LinearLayer`**
- **Step 3 :-** building a quantizer and quantize a model end-to-end
- **Step 4 :-** test the naive absmax quantization on many scenario and study its impact

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
random_int8 = torch.randint(-128, 127, (32, 16)).to(torch.int8)
random_hs = torch.randn((1, 16), dtype=torch.bfloat16)
scales = torch.randn((1, 32), dtype=torch.bfloat16)
bias = torch.randn((1, 32), dtype=torch.bfloat16)

**NOTE:** weight matrix has the shape **(output dimension, input dimension)**. When we perform the matrix mulitplication between the int8 matrix and the hidden states, we will have a vector of batch size output dimension. So it is important that the scales have the same shape as the output shape of the weight matrix and same for the bias.

In [3]:
F.linear(random_hs, random_int8.to(random_hs.dtype))

tensor([[-125.0000,  240.0000, -346.0000, -524.0000, -436.0000, -185.0000,
         -150.0000,  -17.6250,  160.0000, -330.0000, -330.0000,  328.0000,
         -225.0000,  418.0000,  580.0000, -136.0000, -122.0000,   31.1250,
           32.2500, -107.5000,  169.0000,   36.0000,  276.0000,  -33.5000,
         -380.0000,  143.0000,   97.0000, -162.0000, -199.0000,   74.0000,
          159.0000, -612.0000]], dtype=torch.bfloat16)

First we have cast the weights into the same data type as the hidden states. Then on top of this we will perform matrix multiplication via **`F.linear()`** function from PyTorch.

In [4]:
F.linear(random_hs, random_int8.to(random_hs.dtype)) * scales

tensor([[  -3.7188,  284.0000, -264.0000, -752.0000, -154.0000, -158.0000,
         -131.0000,  -14.3125,  100.5000, -430.0000,  109.0000, -416.0000,
          -32.7500, -103.5000,   24.7500,  -81.5000, -178.0000,   39.5000,
           -9.0000,  -28.5000,  103.0000,    4.5625,  -54.7500,    2.4844,
          124.0000,  -40.5000,   97.0000,  220.0000,   11.6875,  132.0000,
            2.5000,  520.0000]], dtype=torch.bfloat16)

In [5]:
(F.linear(random_hs, random_int8.to(random_hs.dtype)) * scales) + bias

tensor([[  -3.2031,  286.0000, -266.0000, -752.0000, -155.0000, -158.0000,
         -132.0000,  -13.8750,  100.0000, -430.0000,  107.5000, -416.0000,
          -33.0000, -104.0000,   26.2500,  -80.5000, -178.0000,   38.7500,
           -9.4375,  -28.3750,  104.5000,    5.0312,  -55.7500,    0.8750,
          124.0000,  -40.2500,   97.0000,  220.0000,   12.6250,  133.0000,
            3.2188,  520.0000]], dtype=torch.bfloat16)

Then we will multiply this with the input scales and optionally add a bias term

In [6]:
def w8_a16_forward(weight, input, scales, bias=None):
    
    casted_weights = weight.to(input.dtype)
    output = F.linear(input, casted_weights) * scales
    
    if bias is not None:
        output = output + bias
      
    return output

In [7]:
print("With bias:\n\n", 
      w8_a16_forward(random_int8, random_hs, scales, bias))

print("\nWithout bias:\n\n", 
      w8_a16_forward(random_int8, random_hs, scales))

With bias:

 tensor([[  -3.2031,  286.0000, -266.0000, -752.0000, -155.0000, -158.0000,
         -132.0000,  -13.8750,  100.0000, -430.0000,  107.5000, -416.0000,
          -33.0000, -104.0000,   26.2500,  -80.5000, -178.0000,   38.7500,
           -9.4375,  -28.3750,  104.5000,    5.0312,  -55.7500,    0.8750,
          124.0000,  -40.2500,   97.0000,  220.0000,   12.6250,  133.0000,
            3.2188,  520.0000]], dtype=torch.bfloat16)

Without bias:

 tensor([[  -3.7188,  284.0000, -264.0000, -752.0000, -154.0000, -158.0000,
         -131.0000,  -14.3125,  100.5000, -430.0000,  109.0000, -416.0000,
          -32.7500, -103.5000,   24.7500,  -81.5000, -178.0000,   39.5000,
           -9.0000,  -28.5000,  103.0000,    4.5625,  -54.7500,    2.4844,
          124.0000,  -40.5000,   97.0000,  220.0000,   11.6875,  132.0000,
            2.5000,  520.0000]], dtype=torch.bfloat16)


In [8]:
### running this will result in an error
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
        super().__init__()
        
        self.int8_weights = nn.Parameter(torch.Tensor([0, 1]).to(dtype=torch.int8))

try:
    
    W8A16LinearLayer(1, 1)
    
except Exception as error:
    print("\033[91m", type(error).__name__, ": ", error, "\033[0m")

[91m RuntimeError :  Only Tensors of floating point and complex dtype can require gradients [0m


When we create an **`nn.parameter`** layer, PyTorch expects that parameter where it is able to compute gradients on it. We can't explicitly compute gradients on **int8 tensors** yet. So we should get an error.

In [9]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
        super().__init__()
        
        self.register_buffer(
            "int8_weights",
            torch.randint(-128, 127, (out_features, in_features), dtype=torch.int8))
        
        self.register_buffer("scales", torch.randn((out_features), dtype=dtype))
        
        if bias:
            self.register_buffer("bias", torch.randn((1, out_features), dtype=dtype))
        else:
            self.bias = None

This is the right approach to store int8 weights is instead of saving attributes as being an endless parameter, is to call a method **`register_buffer()`**. This way instead of storing a parameter, we just store a buffer means we don't need to compute gradients on the tensor, and we can initialize it with whatever dtype we want.

In [10]:
dummy_instance = W8A16LinearLayer(16, 32)

In [11]:
print(dummy_instance.int8_weights.shape)
print(dummy_instance.scales.shape)

torch.Size([32, 16])
torch.Size([32])


Creating a forward pass for the class

In [12]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
        super().__init__()
        
        
        self.register_buffer("int8_weights",torch.randint(-128, 127, (out_features, in_features), dtype=torch.int8))
        
        self.register_buffer("scales", torch.randn((out_features), dtype=dtype))
        
        if bias:
            self.register_buffer("bias", torch.randn((1, out_features), dtype=dtype))
        else:
            self.bias = None

    def forward(self, input):
        return w8_a16_forward(self.int8_weights, input, self.scales, self.bias)

In [13]:
module = W8A16LinearLayer(16, 32)
dummy_hidden_states = torch.randn(1, 6, 16)

In [14]:
module(dummy_hidden_states).shape

torch.Size([1, 6, 32])

In [15]:
module(dummy_hidden_states).dtype

torch.float32

We have a linear layer which is working fine and a forward pass.

In [16]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
        super().__init__()
        
        self.register_buffer("int8_weights", torch.randint(-128, 127, (out_features, in_features), dtype=torch.int8))
        
        self.register_buffer("scales", torch.randn((out_features), dtype=dtype))
        
        if bias:
            self.register_buffer("bias", torch.randn((1, out_features), dtype=dtype))
        else:
            self.bias = None

    def quantize(self, weights):
        w_fp32 = weights.clone().to(torch.float32)

        scales = w_fp32.abs().max(dim=-1).values / 127
        scales = scales.to(weights.dtype)

        int8_weights = torch.round(weights/scales.unsqueeze(1)).to(torch.int8)

        self.int8_weights = int8_weights
        self.scales = scales
    
    def forward(self, input):
        return w8_a16_forward(self.int8_weights, input, self.scales, self.bias)      

Here we will add the quantization method. So first upcast the weights into FP32 then find the scale value and make sure that scale has same dtype as input weights. Then using the formula find the int8 weights.

In [17]:
module = W8A16LinearLayer(4, 8)

In [18]:
print("Weights before:\n" , module.int8_weights)

Weights before:
 tensor([[  70, -105,   74,   95],
        [  48,  -17,  -91,  -97],
        [  10,  106,  -92, -108],
        [ -44,   70,   -2,  -21],
        [  78,   10,  -43,   31],
        [ -33,  -82, -113,  -31],
        [ -52,   56,  -33,   78],
        [ -15,  -35,  -94,  120]], dtype=torch.int8)


In [19]:
random_matrix = torch.randn((4, 8), dtype=torch.bfloat16)

In [20]:
module.quantize(random_matrix)

In [21]:
print("Weights After:\n" , module.int8_weights)

Weights After:
 tensor([[  91,   63,  127,   85,   92,  -28,  -98,  -19],
        [ -18,   82,  -55,   -2,   30, -128,   40,  121],
        [   3, -110,   37,    0,   29,  -70,  127,  -73],
        [  76,  -60,   22,  127, -104,  100,   32,  -21]], dtype=torch.int8)


In [22]:
module.scales

tensor([0.0086, 0.0110, 0.0102, 0.0144], dtype=torch.bfloat16)

In [23]:
module.scales.shape

torch.Size([4])

In [24]:
module.int8_weights.shape

torch.Size([4, 8])

In [25]:
### dequantized weights
module.int8_weights * module.scales.unsqueeze(1)

tensor([[ 0.7812,  0.5430,  1.0938,  0.7305,  0.7930, -0.2412, -0.8438, -0.1631],
        [-0.1973,  0.9023, -0.6055, -0.0220,  0.3301, -1.4062,  0.4395,  1.3281],
        [ 0.0305, -1.1250,  0.3770,  0.0000,  0.2949, -0.7148,  1.2969, -0.7422],
        [ 1.0938, -0.8633,  0.3164,  1.8281, -1.5000,  1.4375,  0.4609, -0.3027]],
       dtype=torch.bfloat16)

In [26]:
### original weights
random_matrix

tensor([[ 7.8125e-01,  5.4297e-01,  1.0938e+00,  7.3047e-01,  7.8906e-01,
         -2.4316e-01, -8.3984e-01, -1.6602e-01],
        [-2.0020e-01,  9.0625e-01, -6.0156e-01, -2.4292e-02,  3.2617e-01,
          1.3984e+00,  4.4336e-01,  1.3281e+00],
        [ 2.6001e-02, -1.1250e+00,  3.7891e-01,  8.8501e-04,  2.9492e-01,
         -7.1484e-01,  1.2969e+00, -7.4609e-01],
        [ 1.0859e+00, -8.6719e-01,  3.1055e-01,  1.8281e+00, -1.5000e+00,
          1.4375e+00,  4.6484e-01, -3.0469e-01]], dtype=torch.bfloat16)

In [27]:
(random_matrix - module.int8_weights * module.scales.unsqueeze(1)).abs().mean()

tensor(0.0898, dtype=torch.bfloat16)

# Quantization Pipeline
Replace all of the `torch.nn.Linear` layers with the `W8A16LinearLayer` layer. Call `quantize` on the linear layers using the original weights.

In [28]:
def replace_linear_with_target(module, target_class, module_name_to_exclude):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear) and not \
          any([x == name for x in module_name_to_exclude]):
            old_bias = child.bias

            new_module = target_class(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)
            setattr(module, name, new_module)
            if old_bias is not None:
              getattr(module, name).bias = old_bias
        else:
            # Recursively call the function for nested modules
            replace_linear_with_target(child, target_class, module_name_to_exclude)

We can pass the model also module, target class of the new class that we are going to set in replacement to the linear layer and module name to exclude which is name of the module that we are going to exclude in this replacement logic. For better results it is better to keep the last module unquantized.<br>
We are going to simple loop over the modules named **children**, and if the sub module is an instance of an **nn.Linear** and we don't have any name that matches the names that are inside the module name to exclude, then we are going to move forward with the module replacement. So we will get the bias of the sub module in **`old_bias`** because we are going to use it to create the new target class.<br>
Then we can create the new module which is target class, the in_features and out_features should be the same as the linear layers and use the same dtype as sub modules weights.<br>
Then we will call set attributes function, we will replace the current attribute of module that has name as `'name'` with the `new_module`.<br>
And if the old module has a bias then we will explicitly set the bias of the new module to `old_bias`.<br>
Then we will recursively call this method on child module by passing the same arguments.

In [29]:
class DummyModel(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.emb = torch.nn.Embedding(1, 1)
    # Try with bias
    self.linear_1 = nn.Linear(1, 1)
    # Try without bias
    self.linear_2 = nn.Linear(1, 1, bias=False)
    # Lm prediction head
    self.lm_head = nn.Linear(1, 1, bias=False)

For testing purpose we have created a dummy model with two linear layers and one language model head, which is usually the last module in a transformer model.

In [30]:
model_1 = DummyModel()
model_2 = DummyModel()

The function changes the layers of model, so we have created two copies one for testing out the model name to exclude feature and another which will replace all linear layer instances with new one.

In [31]:
replace_linear_with_target(model_1, W8A16LinearLayer, ["lm_head"])
print(model_1)

DummyModel(
  (emb): Embedding(1, 1)
  (linear_1): W8A16LinearLayer()
  (linear_2): W8A16LinearLayer()
  (lm_head): Linear(in_features=1, out_features=1, bias=False)
)


In the arguments we have specified we don't want to replace the last layer.

In [32]:
replace_linear_with_target(model_2, W8A16LinearLayer, [])
print(model_2)

DummyModel(
  (emb): Embedding(1, 1)
  (linear_1): W8A16LinearLayer()
  (linear_2): W8A16LinearLayer()
  (lm_head): W8A16LinearLayer()
)


We have passed an empty list, due to which the function will replace all the layers.

### Linear Layer Replacement + Quantization
- Modify the `replace_linear_with_target` function to also perform quantization.
- Implement `replace_linear_with_target_and_quantize`.

In [33]:
def replace_linear_with_target_and_quantize(module, target_class, module_name_to_exclude):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear) and not \
        any([x == name for x in module_name_to_exclude]):
            old_bias = child.bias
            old_weight = child.weight

            new_module = target_class(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)
            setattr(module, name, new_module)

            getattr(module, name).quantize(old_weight)
            
            if old_bias is not None:
              getattr(module, name).bias = old_bias
        else:
            # Recursively call the function for nested modules
            replace_linear_with_target_and_quantize(child, target_class, module_name_to_exclude)

In the same function we will add a line after setting attributes, we will get the attributes and then quantize it. 

In [34]:
model_3 = DummyModel()

In [35]:
replace_linear_with_target_and_quantize(model_3, W8A16LinearLayer, ["lm_head"])
print(model_3)

DummyModel(
  (emb): Embedding(1, 1)
  (linear_1): W8A16LinearLayer()
  (linear_2): W8A16LinearLayer()
  (lm_head): Linear(in_features=1, out_features=1, bias=False)
)
