<a href="https://www.kaggle.com/code/aisuko/zero-degradation-matrix-multiplication?scriptVersionId=163020030" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Overview

**Note: all the images are from the blog in the Credits section.**

The practice with Transformers see here [Lighter models on GPU for inference](https://www.kaggle.com/code/aisuko/lighter-models-on-gpu-for-inference/notebook)

The main purpose of the LLM.int8() method is to make large models more accessible without performance degradation.

In the LLM.int8(see the second link in Credits section) paper. It explains:
* Why traditional quantization fails for large models
* The performance deterioration is caused by outlier features
* LLM.int8() algorithm

In essence, LLM.int8() seeks to complete the matrix multiplication computation in three steps:
1. From the input hidden states, extract the outliers(i.e. values that are larger than a certain threshold) by column.
2. Perform the matrix multiplication of the outliers in FP16 and the non-outliers in int8.
3. Dequantize the non-outlier results and add both outlier and non-outlier results together to receive the full result in FP16.

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/700/338/979/374/386/original/2133b586980a7691.mp4" width="60%" heigh="60%" alt="mixed-int8"></div>

# The importance of outlier features

A value that is outside the range of some numbers' global distribution is generally referred to as an outlier. Outlier detection has been widely used and covered in the current literature, and having prior knowledge of the distribution of your features helps with the task of outlier detection. More specifically, we have observed that classic quantization at scale fails for transformer-based models > 6B parameters. 

As mentioned earlier, 8-bit precision is extremely constrained, therefore quantizing a vector with several big values can produce widly erroneous results. Additionally, because of a built-in characteristic of the transformer-based architecture that links all the elements together, these errors tend to compound as they get propagated across multiple layers. Therefore, mixed-precision decomposition has been developed to facilitate efficient quantization with such extreme outliers.


# Inside the MatMul

Once the hidden states are computed we extract the outliers using a custom threshold and we decompose the matrix into two parts as explained above, We found that extracting all outliers with magnitude 6 or greater in this way recoveres full inference performance. The outlier part is done in fp16 so it is a classic matrix multiplication, whereas the 8-bit precision using vector-wise multiplication is done by quantizing the weights and hidden states into 8-bit precision using vector-wise quantization --that is, row-wise quantization for the hidden state and column-wise quantization for the weight matrix. After this step, the results are dequantized and returned in half-precision in order to add them to the first matrix multiplication.

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/700/455/558/279/509/original/aafbcc617e807c3a.png" width="60%" heigh="60%" alt="matmul"></div>

# How to use it

In [1]:
%%capture
!pip install bitsandbytes==0.41.3

In [2]:
import torch
import torch.nn as nn

from bitsandbytes.nn import Linear8bitLt



Here we define the model. We can also convert a checkpoint ot model of any precision to 8-bit(FP16, BF16 or FP32) but the input of the model has to be Fp16 for Int8 module to work. So we treat our model here as a fp16 model.

In [3]:
fp16_model=nn.Sequential(
    nn.Linear(64,64),
    nn.Linear(64,64)
)

torch.save(fp16_model.state_dict(), "model.pt")

`has_fp16_weights` is used to train in mixed Int8/Fp16 precision. Here we are interested in memory efficient inference for which we need to use `has_fp16_weights=False`.

In [4]:
# define an int8 model
int8_model=nn.Sequential(
    Linear8bitLt(64,64,has_fp16_weights=False),
    Linear8bitLt(64,64,has_fp16_weights=False)
)

In [5]:
int8_model.load_state_dict(torch.load("model.pt"))
int8_model[0].weight

Parameter containing:
Parameter(Int8Params([[ 0.0810,  0.0623,  0.0093,  ...,  0.0330,  0.0067, -0.1187],
            [ 0.0571,  0.0363,  0.0844,  ...,  0.1002, -0.0044,  0.0729],
            [ 0.0451, -0.0461,  0.0457,  ...,  0.1032, -0.0817,  0.0935],
            ...,
            [ 0.0010,  0.1032,  0.1181,  ..., -0.0441, -0.1196, -0.0173],
            [-0.1142,  0.0183,  0.0183,  ..., -0.0348, -0.1220,  0.0394],
            [ 0.0412, -0.0633,  0.0934,  ..., -0.0913,  0.0324,  0.1136]]))

In [6]:
int8_model=int8_model.to(0) # Quantization happens here
int8_model[0].weight

Parameter containing:
Parameter(Int8Params([[  83,   64,   10,  ...,   34,    7, -122],
            [  60,   38,   88,  ...,  105,   -5,   76],
            [  47,  -49,   48,  ...,  109,  -86,   98],
            ...,
            [   1,  106,  121,  ...,  -45, -123,  -18],
            [-119,   19,   19,  ...,  -36, -127,   41],
            [  42,  -65,   96,  ...,  -94,   33,  116]], device='cuda:0',
           dtype=torch.int8))

The weights values are "truncated" as we have seen when explaning quantization in the [Quantization Technologies](https://www.kaggle.com/code/aisuko/quantization-technologies). Also, the values seem to be distributed between [-127,127]. You might also wonder how to retrieve the FP16 weights in order to perform the outlier MatMul in fp16.

In [7]:
(int8_model[0].weight.CB*int8_model[0].weight.SCB)/127

tensor([[ 0.0808,  0.0613,  0.0095,  ...,  0.0332,  0.0067, -0.1190],
        [ 0.0584,  0.0364,  0.0836,  ...,  0.1024, -0.0048,  0.0741],
        [ 0.0457, -0.0470,  0.0456,  ...,  0.1063, -0.0826,  0.0956],
        ...,
        [ 0.0010,  0.1016,  0.1149,  ..., -0.0439, -0.1182, -0.0176],
        [-0.1158,  0.0182,  0.0180,  ..., -0.0351, -0.1220,  0.0400],
        [ 0.0409, -0.0623,  0.0912,  ..., -0.0917,  0.0317,  0.1132]],
       device='cuda:0')

In [8]:
# We can safely infer using model by making sure the input is in FP16
input_=torch.randn((1,64), dtype=torch.float16)
hidden_states=int8_model(input_.to(torch.device('cuda', 0)))
hidden_states

tensor([[-0.2104, -1.1426, -0.0222,  0.4993, -0.0264, -0.4763,  0.4167, -0.0701,
          0.4368,  0.2043, -0.0095, -0.1930, -0.2064, -0.9932, -0.0107,  0.5508,
         -0.1831,  0.1372,  0.4070, -0.2703,  0.1462, -0.1387, -0.0767,  0.0767,
          0.4946, -0.8608, -0.0306,  1.2139,  0.3298,  0.3669,  0.0539,  0.1787,
          0.1628, -0.0489,  0.1002,  0.2668,  0.3823, -0.4624, -0.5103, -0.5444,
          0.0184,  0.4468, -0.0021, -0.1556,  0.1512,  0.1228,  0.2539, -0.3762,
         -0.4758,  0.5840,  0.2112,  0.3572, -0.0346,  0.1412, -0.3032,  0.1333,
          0.0469, -0.0532, -0.0124, -0.5708,  0.1160, -0.3674,  0.0087, -0.1565]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MatMul8bitLtBackward>)

# Credits

* https://huggingface.co/blog/hf-bitsandbytes-integration?source=post_page-----287da2d5d7f1--------------------------------
* https://arxiv.org/abs/2208.07339