# L4-A - Building your own Quantizer: Custom Build an 8-Bit Quantizer

In this lesson, you will learn how to compress any model in 8-bit precision.

## Step 1: class `W8A16LinearLayer`

- Build the target class, `W8A16LinearLayer()`, that will be responsible for quantizing your model.


Start by building the W8A16 linear layer class. We'll also break this down into different multiple subtasks.

1. Build a forward method, called 'W8A16_forward'. This will take as input:
   - eight-bit weights, 
   - 16 bit inputs,
   - scales, and 
   - optional bias.

2. Once you have built this method, the idea is to call that method inside the linear layers forward pass, with it's parameters: the 8 bit weights of the linear layer input, the scales that are stored inside the layer, and the optional bias as well.

So, what the W8A16 forward method will do under the hood is to first cast the eight bit weights into the same data type as the input.
So for example, in the case the input is in float16 or bfloat16, we will cast the weights into that precision while keeping
the weights into the same range as before, i.e. between -128 and 127.

### 1.1 - `w8_a16_forward` Function

-
```Python
W8A16LinearLayer
                    # 8-bit  # 16-bit         # optional
* w8_a16_forward -> weights, input,   scales, bias=None
                    
```
- Cast the 8-bit `weights` to the same data type as the `input`, "casted weights",
- keeping the "casted weights" in the same range as before, [-128, 127]
- Next, $$(({inputs} \cdot \text{``casted weights''}) * {scale}) + {bias}$$ 

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F


Below, some random inputs are defined:

- a random int8 matrix,
- random hidden states, 
- random scales, and 
- random bias.

In [4]:
random_int8 = torch.randint(-128, 127, (32, 16)).to(torch.int8)
random_hs = torch.randn((1, 16), dtype=torch.bfloat16)
scales = torch.randn((1, 32), dtype=torch.bfloat16)
bias = torch.randn((1, 32), dtype=torch.bfloat16)

**Note:** Since the values are random, what you see when running the code yourself, might be different than what is presented in this notebok.

Typically, the workflow would be something as follows.

We first cast the weights into the same data type as the hidden states. \
On top of that, you will perform the matrix multiplication by calling f.linear from PyTorch. \
In mathematical terms: random_hs * random_int8^T.

Then we'll multiply that with the input scales and optionally add a bias term at the end of the operation.

Notice that the weight matrix has the shape output dimension * input dimension.

When you perform the matrix multiplication between the weight matrix and the input hidden states, you will have a vector of batch size * output dimension.
So 1 * 32.

So it's important that the scales have this the same shape as the output shape of your weight matrix.
Same comment for the bias, so that you can broadcast the operations between the output and the scales, and the whole output and the bias.

In [5]:
F.linear(random_hs, random_int8.to(random_hs.dtype))

tensor([[ 290.0000,  -73.0000, -181.0000,  320.0000, -193.0000,  143.0000,
          -69.0000,  132.0000, -712.0000,  171.0000, -266.0000,  -46.0000,
          128.0000, -258.0000,  320.0000,  274.0000,  169.0000,  -75.0000,
          181.0000,  428.0000,  211.0000,  -18.0000,  576.0000,    8.9375,
         -228.0000,  354.0000,  700.0000,  532.0000,  604.0000,   25.7500,
          412.0000,  -11.1250]], dtype=torch.bfloat16)

In [6]:
F.linear(random_hs, random_int8.to(random_hs.dtype)) * scales

tensor([[ 182.0000,   35.0000,  253.0000, -244.0000,  302.0000,  -68.5000,
           52.5000,  206.0000, -298.0000,    3.9219,  -12.0000,   46.0000,
         -121.5000,  -44.2500,  422.0000, -218.0000, -176.0000,    4.6250,
         -338.0000, -185.0000,  204.0000,  -18.1250,  632.0000,   -5.9062,
         -324.0000,   49.7500, -584.0000, -324.0000,  792.0000,   -3.3125,
          880.0000,   -2.6406]], dtype=torch.bfloat16)

In [7]:
(F.linear(random_hs, random_int8.to(random_hs.dtype)) * scales) + bias

tensor([[ 183.0000,   34.5000,  254.0000, -244.0000,  302.0000,  -69.5000,
           52.7500,  205.0000, -296.0000,    3.7656,  -14.2500,   46.7500,
         -121.0000,  -45.0000,  424.0000, -217.0000, -174.0000,    5.9375,
         -338.0000, -184.0000,  204.0000,  -18.1250,  632.0000,   -5.6875,
         -324.0000,   49.5000, -580.0000, -324.0000,  792.0000,   -2.4062,
          880.0000,   -3.1094]], dtype=torch.bfloat16)

- Implement all this as a function, `w8_a16_forward`

In [8]:
def w8_a16_forward(weight, input, scales, bias=None):
    
    casted_weights = weight.to(input.dtype)
    output = F.linear(input, casted_weights) * scales
    
    if bias is not None:
        output = output + bias
      
    return output

In [9]:
print("With bias:\n\n", 
      w8_a16_forward(random_int8, random_hs, scales, bias))

print("\nWithout bias:\n\n", 
      w8_a16_forward(random_int8, random_hs, scales))

With bias:

 tensor([[ 183.0000,   34.5000,  254.0000, -244.0000,  302.0000,  -69.5000,
           52.7500,  205.0000, -296.0000,    3.7656,  -14.2500,   46.7500,
         -121.0000,  -45.0000,  424.0000, -217.0000, -174.0000,    5.9375,
         -338.0000, -184.0000,  204.0000,  -18.1250,  632.0000,   -5.6875,
         -324.0000,   49.5000, -580.0000, -324.0000,  792.0000,   -2.4062,
          880.0000,   -3.1094]], dtype=torch.bfloat16)

Without bias:

 tensor([[ 182.0000,   35.0000,  253.0000, -244.0000,  302.0000,  -68.5000,
           52.5000,  206.0000, -298.0000,    3.9219,  -12.0000,   46.0000,
         -121.5000,  -44.2500,  422.0000, -218.0000, -176.0000,    4.6250,
         -338.0000, -185.0000,  204.0000,  -18.1250,  632.0000,   -5.9062,
         -324.0000,   49.7500, -584.0000, -324.0000,  792.0000,   -3.3125,
          880.0000,   -2.6406]], dtype=torch.bfloat16)


### 1.2 - `init` Function of class `W8A16LinearLayer`

The next building block will leverage the method that we have just created. To continue
**building our linear layer class**, we'll start implementing the init method of that class.

Recall that for this linear layer we need to store the int8 weights, the scales and the bias.

Let's first start by implementing the skeleton of the init method. \
It has to kind of match the signature of the init method of a torch linear layer, so it has to contain input features and output features
in order to properly initialize the class.

ref. https://docs.pytorch.org/docs/stable/generated/torch.nn.Linear.html

The int8_weights tensor could be created as 'nn.Paramter(...)'. This issue with this is that with Pytorch, you can't explicitly compute gradients on int8 tensors, yet. Therefore, if you try to initiallise the class W8A16LinearLayer with this approach, it generates an error, saying "only tensors of floating point and complexity can require gradients".

- This is how the `init` is of [PyTorch Linear layer](https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear):
```Python
def __init__(self, in_features, out_features, bias=True,
             device=None, dtype=None)

```

In [10]:
### running this will result in an error
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, 
                 bias=True, dtype=torch.float32):
        super().__init__()
        
        self.int8_weights = nn.Parameter(torch.Tensor([0, 1]
                                     ).to(dtype=torch.int8))

try:
    
    W8A16LinearLayer(1, 1)
    
except Exception as error:
    print("\033[91m", type(error).__name__, ": ", error, "\033[0m")

[91m RuntimeError :  Only Tensors of floating point and complex dtype can require gradients [0m


The right approach to store int8 weights is, instead of saving attributes as being an nn.parameter, is to call a method, called 'register_buffer'.

That way, instead of storing a parameter, we just store a buffer, meaning we don't need to compute gradients on the tensor,
and you can initialize it with whatever dtype that you want.


Let's continue designing the linear layer.

We have our int8 weights.

We'll do the same thing for scales as well, by initializing them with the correct shape. \ 
Thus, we're also going to call 'register_buffer' on scales since - again - here we're just expecting to do simple inference. We're not interested in doing training. Therefore, just calling 'registered_buffer' is sufficient (there's no need to call nn.Parameters).

We're also going to store an optional bias.

In [11]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, 
                 bias=True, dtype=torch.float32):
        super().__init__()
        
        
        self.register_buffer(
            "int8_weights",
            torch.randint(
                -128, 127, (out_features, in_features), dtype=torch.int8
            )
        )
        
        self.register_buffer("scales", 
                             torch.randn((out_features), dtype=dtype))
        
        if bias:
            self.register_buffer("bias", 
                                 torch.randn((1, out_features), 
                                             dtype=dtype))
        
        else:
            self.bias = None

- Test your implementation.

In [12]:
dummy_instance = W8A16LinearLayer(16, 32)

In [13]:
print(dummy_instance.int8_weights.shape)
print(dummy_instance.scales.shape)

torch.Size([32, 16])
torch.Size([32])


### 1.3 - `forward` Function of class `W8A16LinearLayer`

- Use the `w8_a16_forward` defined earlier (Step 1.1) to define the `forward` function.

The next step is building the forward pass of that class.

Copy the code created for the W8A16LinearLayer class.

Call the method defined in the first sub task, on
- self.int8_weights, 
- self.scales, and 
- self.bias.

This method will do everything under the hood for us.

It will take care of casting the weights into the correct dtype and multiplying
everything with the scales and optionally add the whole results with bias.

In [14]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, 
                 bias=True, dtype=torch.float32):
        super().__init__()
        
        
        self.register_buffer(
            "int8_weights",
            torch.randint(
                -128, 127, (out_features, in_features), dtype=torch.int8
            )
        )
        
        self.register_buffer("scales", 
                             torch.randn((out_features), dtype=dtype))
        
        if bias:
            self.register_buffer("bias", 
                                 torch.randn((1, out_features), 
                                             dtype=dtype))
        
        else:
            self.bias = None

    def forward(self, input):
        return w8_a16_forward(self.int8_weights, 
                              input, self.scales, self.bias)

In [15]:
module = W8A16LinearLayer(16, 32)
dummy_hidden_states = torch.randn(1, 6, 16)

In [16]:
module(dummy_hidden_states).shape

torch.Size([1, 6, 32])

Check the data type.

The dtype is correct: Float32.

This is because we have initialized a random tensor and by default PyTorch initialize everything in torch.Float32.

In [18]:
module(dummy_hidden_states).dtype

torch.float32

### 1.4 - `quantize` Function of class `W8A16LinearLayer`

- `quantize` function will dynamically quantize half-precision weights into `torch.int8`

Now that we have **a forward pass** that is working, \
**a linear layer class** that has all the needed attributes, \ 
we need to **build the quantize method** in order to perform the linear quantization algorithm, so that the weights get correctly quantized.

Because, right now everything is random and you need to replace all the layers with this linear layer, you'll get gibberish output most likely.

Once we have defined that quantize method, the workflow will be the following; \
- You have your base model that is - let's say - in half precision. So either Fp16 or Vf16.
- We'll loop over all the linear layer classes, 
    - we replace them with our new linear layer class and then 
- call quantize by passing the old weight, in order to quantize the old weights into int8.

In order to quantize the weights and calculate the scales:
- Get the absolute values of the weights. 
- Get the maximum on the last dimension and divide it by 127 in order to get the scales.

Make sure that scales has the same datatype as the input weights.

So this is the **per channel linear quantization** as you're getting the maximum on each element of the last dimension.

This is how you get the int8 weights.
Remark that this is based on what we've learned in this lesson so ar.

We're just assigning self.int8_weights and scales with these tensors.

The forward pass will stay the same.

In [19]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, 
                 bias=True, dtype=torch.float32):
        super().__init__()
        
        
        self.register_buffer(
            "int8_weights",
            torch.randint(
                -128, 127, (out_features, in_features), dtype=torch.int8
            )
        )
        
        self.register_buffer("scales", 
                             torch.randn((out_features), dtype=dtype))
        
        if bias:
            self.register_buffer("bias", 
                                 torch.randn((1, out_features), 
                                             dtype=dtype))
        
        else:
            self.bias = None

    def quantize(self, weights):
        # It's recommended to upcast the weights in fp32 for stability.
        w_fp32 = weights.clone().to(torch.float32)

        scales = w_fp32.abs().max(dim=-1).values / 127
        scales = scales.to(weights.dtype)

        int8_weights = torch.round(weights
                        /scales.unsqueeze(1)).to(torch.int8)

        self.int8_weights = int8_weights
        self.scales = scales
    
    def forward(self, input):
        return w8_a16_forward(self.int8_weights, 
                              input, self.scales, self.bias)      


Let's let's try that out by first initializing the dummy module.

In [20]:
module = W8A16LinearLayer(4, 8)

In [21]:
print("Weights before:\n" , module.int8_weights)

Weights before:
 tensor([[  -6,    9,  -66,   86],
        [-124,  116, -121,   95],
        [  51,  -79,  101,   23],
        [ -45,   17, -102,   -4],
        [  77,  106,  -77, -122],
        [ -74,  -23, -122,   99],
        [ -84,  126,  115, -104],
        [ -97,  -76,   40,    1]], dtype=torch.int8)


In [31]:
random_matrix = torch.randn((4, 8), dtype=torch.bfloat16)

In [32]:
module.quantize(random_matrix)

In [33]:
print("Weights After:\n" , module.int8_weights)

Weights After:
 tensor([[  85,   45,   56,  -16, -127,  -21,  -68,   -8],
        [   2,  127,   -4,  -58,    5,  -58,   24,  -28],
        [ -14,  -31,   44,  -24,   30, -127,    0,   42],
        [  83,   18,   36,    0,  -53,  -23,    0, -126]], dtype=torch.int8)


In [34]:
module.scales

tensor([0.0167, 0.0144, 0.0161, 0.0124], dtype=torch.bfloat16)

In [26]:
module.scales.shape

torch.Size([4])

In [35]:
module.int8_weights.shape

torch.Size([4, 8])

If we directly multiply the int8_weights with the scale, it won't work. We have to reshape the scale first.

In [36]:
### dequantized weights
module.int8_weights * module.scales.unsqueeze(1)

tensor([[ 1.4219,  0.7539,  0.9375, -0.2676, -2.1250, -0.3516, -1.1406, -0.1338],
        [ 0.0288,  1.8281, -0.0576, -0.8359,  0.0723, -0.8359,  0.3457, -0.4023],
        [-0.2256, -0.5000,  0.7109, -0.3867,  0.4844, -2.0469,  0.0000,  0.6758],
        [ 1.0312,  0.2227,  0.4453,  0.0000, -0.6562, -0.2852,  0.0000, -1.5625]],
       dtype=torch.bfloat16)

In [37]:
### original weights
random_matrix

tensor([[ 1.4219e+00,  7.5391e-01,  9.4531e-01, -2.6172e-01, -2.1250e+00,
         -3.4961e-01, -1.1328e+00, -1.2793e-01],
        [ 3.4912e-02,  1.8281e+00, -6.1279e-02, -8.3203e-01,  6.8359e-02,
         -8.2812e-01,  3.4570e-01, -4.0430e-01],
        [-2.2852e-01, -5.0000e-01,  7.1094e-01, -3.9258e-01,  4.7852e-01,
         -2.0469e+00,  0.0000e+00,  6.8359e-01],
        [ 1.0312e+00,  2.2559e-01,  4.5117e-01,  3.7956e-04, -6.5625e-01,
         -2.8711e-01,  0.0000e+00, -1.5703e+00]], dtype=torch.bfloat16)

Calculate the quantization error.

In [38]:
(random_matrix - module.int8_weights 
 * module.scales.unsqueeze(1)).abs().mean()

tensor(0.0031, dtype=torch.bfloat16)