# FP4 Quantization

High-Level Overview of FP4 Quantization for LLMs
FP4 (4-bit floating-point) quantization is a technique used to significantly reduce the memory footprint and computational cost of Large Language Models (LLMs). The core idea is to represent the high-precision floating-point numbers (like FP16 or FP32) that make up a model's weights and activations with a much smaller 4-bit floating-point format. This allows you to store a massive model in a fraction of the memory and perform computations faster.

Here's a high-level breakdown of how it works:

- The Challenge: LLMs are huge. A model like LLaMA-7B has 7 billion parameters, each typically stored as a 16-bit floating-point number. This requires roughly 14 GB of VRAM. This is a lot, and it limits who can run these models. The goal of quantization is to reduce this number.

- The Idea: Instead of using 16 bits to represent each number, we'll use only 4 bits. A 4-bit floating-point number has a much smaller range of values it can represent. This is a trade-off: we save memory and compute, but we lose some precision. The challenge is to do this in a way that the model's performance doesn't degrade too much.

The Core Process: Quantization and De-Quantization:

- Quantization: When a model is loaded, its high-precision weights are converted to the 4-bit format. This involves a scaling factor and a data type conversion. The key is to find the right scaling factor that minimizes the loss of information.

- De-Quantization (on-the-fly): During a forward pass (inference), the 4-bit weights are loaded from memory. However, to perform the actual matrix multiplication (the core operation in a Transformer's linear layer), the GPU's hardware often requires higher precision (e.g., FP16). So, the 4-bit weights are de-quantized back to a higher precision on the fly. The matrix multiplication is then performed in this higher precision, and the result is stored.

- Handling Outliers: A major issue with quantizing LLMs is the presence of "outliers." These are a few values in the weight or activation tensors that are much larger than the rest. A naive quantization scheme would be dominated by these outliers, making the rest of the values lose all their precision. Solutions like bitsandbytes' FP4 and NF4 handle this by using a small, high-precision representation for these outliers while quantizing the majority of the values to 4-bit. This is a "mixed-precision" approach within the 4-bit quantization.

The key components of a bitsandbytes-like implementation are:

## 4-bit Floating-Point Data Type
bnb defined the NF4 - NormalFloat4. For our purpose let's use the standard FP4.
In the FP4 format the data are represented as:

|sign|exp|exp|mantissa| (E2M1)

(-1)^s *(1+m/2)^(2-1) * 2^(exp-bias)

sign: +1, -1
mantissa: 0, 1
exponent: 00, 01, 10, 11 (with bias=1)

bias is pretty important because allow us to have negative exponents and managing subnormal numbers, that in deep learning are pretty important.

So the total representable range is:

[-1x(1+0.5)x(2^2)... +1x(1+0.5)x(2^2)]= [-6 ... 6]



In [1]:
class FP4_E2M1:
  '''
  class that represent the E2M1 format
  '''
  def __init__(self):
    self.values = []
    for sign in [0,1]:
      for exp in range(2**2):
        for mantissa in range(2):
          if exp==0 and mantissa == 0:
            value = 0
          else:
            exp_val = exp-1
            mantissa_val = 1+mantissa*0.5
            value = (1 if sign==0 else -1) * mantissa_val * (2**(exp_val))

          if value not in self.values:
            self.values.append(value)
    self.values = sorted(self.values)

In [2]:
# In case of E2M1
fp4_range = FP4_E2M1()
fp4_range.values

[-6.0,
 -4.0,
 -3.0,
 -2.0,
 -1.5,
 -1.0,
 -0.75,
 0,
 0.75,
 1.0,
 1.5,
 2.0,
 3.0,
 4.0,
 6.0]

In [3]:
import torch
import torch.nn as nn

In [4]:
fp4_range = torch.tensor(FP4_E2M1().values)
fp4_range

tensor([-6.0000, -4.0000, -3.0000, -2.0000, -1.5000, -1.0000, -0.7500,  0.0000,
         0.7500,  1.0000,  1.5000,  2.0000,  3.0000,  4.0000,  6.0000])

## Educational FP4 Quantization

Very first version.
This implementation quickly and simply describes the algorithm from an educational point of view.

- The input tensor is taken and flattened to one dimension (flatten operation)

- The scale value is computed as absmax()

- The entire tensor is scaled

- For each value in the tensor (using broadcasting), the closest value is calculated within the bucket of the allowed 4-bit value range

- Finally, the quantized data is returned

In [5]:
torch.tensor([0.1,0.2,0.3,0.4]).unsqueeze(1) - torch.tensor([0.1,0.2,0.3,0.4])

tensor([[ 0.0000, -0.1000, -0.2000, -0.3000],
        [ 0.1000,  0.0000, -0.1000, -0.2000],
        [ 0.2000,  0.1000,  0.0000, -0.1000],
        [ 0.3000,  0.2000,  0.1000,  0.0000]])

In [6]:
class FP4_Quantizer():
  def __init__(self):
    self.fp4_values = torch.tensor(FP4_E2M1().values)
  def quantize(self, input_tensor):
    block = input_tensor.view(-1) # Flatten
    scale = block.abs().max() # Get the max value of the block for the scale
    if scale == 0:
      return torch.zeros_like(block), scale
    scaled_block =block/scale # Scale the tensor


    indices = torch.argmin(torch.abs(scaled_block.unsqueeze(1)-self.fp4_values),dim=1) # Find the nearest value from the range
    quantized_data = self.fp4_values[indices]
    ## I'm returning the quantized data. It's not the standard way of doing it. We'll see it in the next implementation
    return quantized_data, scale

  def dequantize(self,quantized_tensor,scale,original_shape):
    t = quantized_tensor*scale
    t = t.reshape(original_shape)
    return t

In [7]:
quantizer = FP4_Quantizer()

In [8]:
input_tensor = torch.randn((1,512))
input_tensor.shape, input_tensor

(torch.Size([1, 512]),
 tensor([[ 4.1789e-02, -1.0357e+00,  1.4939e-01,  7.1623e-01, -6.4241e-01,
          -3.8041e-01, -1.2908e+00,  1.7141e-01, -5.3247e-01,  1.2557e+00,
          -1.9053e+00, -7.6853e-01, -1.0663e+00,  9.3685e-01,  5.0596e-01,
           9.1629e-01, -1.1104e+00, -8.7649e-01, -4.0421e-01, -3.8728e-01,
          -1.8048e+00,  5.5232e-01,  2.2793e-01,  6.8236e-01, -5.9425e-02,
           6.2779e-01, -8.2173e-01, -1.8350e-01, -2.2424e-02, -1.3844e+00,
           6.0245e-01, -1.7137e-01,  3.8091e-01,  4.3450e-02,  1.4742e+00,
          -5.2175e-01,  2.0057e-01,  7.9602e-01,  2.2168e-01,  4.7247e-02,
           6.2529e-01,  2.0800e+00,  6.3924e-01, -2.3966e-01, -4.9708e-01,
           7.9678e-01,  5.3253e-01,  1.6228e-02,  9.0824e-01,  1.5564e+00,
          -5.8099e-01, -1.2539e+00,  4.9693e-01,  3.0307e-01,  3.2636e-01,
           8.2259e-01,  1.4596e+00, -2.0068e+00,  1.1978e+00, -7.4647e-01,
           1.0010e+00,  2.8108e-01,  1.1000e+00,  8.4446e-01,  4.7295e-01,
  

In [9]:
quantized_tensor, scale = quantizer.quantize(input_tensor=input_tensor)
quantized_tensor, scale

(tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.7500,  0.0000,
          0.0000,  0.0000, -0.7500,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000, -0.7500,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.7500,  0.0000,  0.0000,
          0.0000,  0.0000,  0.7500,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.7500,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.7500,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.7500, -0.7500,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000, -0.7500,  0.0000,  0.7500,  0.0000,  0.0000,
         -0.7500, -0.7500,  0.0000,  0.0000,  0.0000, -0.7500, -0.7500,  0.0000,
          0.0000,  0.0000,  0.0000,  0.7500,  0.0000,  0.0000,  0.7500, -0.7500,
         -0.7500,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  

In [10]:
dequantized_tensor = quantizer.dequantize(quantized_tensor,scale, input_tensor.shape)

In [11]:
dequantized_tensor-input_tensor

tensor([[-4.1789e-02,  1.0357e+00, -1.4939e-01, -7.1623e-01,  6.4241e-01,
          3.8041e-01, -1.2277e+00, -1.7141e-01,  5.3247e-01, -1.2557e+00,
         -6.1314e-01,  7.6853e-01,  1.0663e+00, -9.3685e-01, -5.0596e-01,
         -9.1629e-01,  1.1104e+00,  8.7649e-01,  4.0421e-01,  3.8728e-01,
         -7.1369e-01, -5.5232e-01, -2.2793e-01, -6.8236e-01,  5.9425e-02,
         -6.2779e-01,  8.2173e-01,  1.8350e-01,  2.2424e-02, -1.1341e+00,
         -6.0245e-01,  1.7137e-01, -3.8091e-01, -4.3450e-02,  1.0442e+00,
          5.2175e-01, -2.0057e-01, -7.9602e-01, -2.2168e-01, -4.7247e-02,
         -6.2529e-01,  4.3851e-01, -6.3924e-01,  2.3966e-01,  4.9708e-01,
         -7.9678e-01, -5.3253e-01, -1.6228e-02, -9.0824e-01,  9.6202e-01,
          5.8099e-01,  1.2539e+00, -4.9693e-01, -3.0307e-01, -3.2636e-01,
         -8.2259e-01,  1.0589e+00, -5.1163e-01, -1.1978e+00,  7.4647e-01,
         -1.0010e+00, -2.8108e-01, -1.1000e+00, -8.4446e-01, -4.7295e-01,
          7.0670e-01,  9.9185e-01, -6.

In [12]:
(dequantized_tensor-input_tensor).abs().mean()

tensor(0.6053)

As you can see we have quantized and dequantize the original tensor, and of course we have a loss on the convertion. The average error is not that big in this example

In some cases the differences are pretty huge, and this would be worst in case of bigger outliers

#### NB This is really a basic educational implementation that doesn't really optimize the space for the quantization. It just to show the algorithm

#Blockwise Quantization

This approach is simple and it's great, but real library use a more fine grain approach, calculating multiple scale factor base on blocks.

Let's change our code to do that

In [13]:
input_tensor = torch.randn((1,512))
input_tensor.shape, input_tensor

(torch.Size([1, 512]),
 tensor([[ 1.9134e+00,  6.2684e-01, -5.3194e-01,  8.6158e-04,  2.4018e-01,
           6.5365e-01, -1.2006e+00, -9.7287e-01, -8.6038e-01, -8.2020e-01,
          -1.4605e+00, -5.5664e-01, -1.1364e+00, -2.1831e+00,  4.7433e-01,
           1.4194e+00, -3.1464e-01, -9.2699e-01, -4.6061e-01,  1.5749e+00,
          -1.0689e+00,  7.8841e-02, -4.3668e-01,  1.6188e+00, -1.9432e-01,
           9.3042e-02,  3.1148e-01, -8.9370e-01, -1.5305e-01, -1.4858e-01,
          -1.0182e+00,  9.1295e-01, -1.0882e-01, -7.5600e-01,  7.1207e-01,
          -3.4220e-01,  1.6956e+00,  1.2045e+00,  4.9959e-01, -3.3056e-01,
          -1.0106e+00, -5.2655e-03,  1.2117e+00, -6.7760e-01,  1.0650e+00,
           1.9966e+00, -3.0152e-01,  6.9609e-01, -5.0688e-01, -1.6772e+00,
          -1.8749e-01,  1.2643e+00,  6.3555e-01,  1.2632e+00,  9.2076e-02,
           4.2368e-01, -2.0519e+00, -5.6756e-01, -2.1367e-02,  1.2719e-01,
          -1.1901e+00,  6.1733e-02, -4.7935e-01,  1.6718e+00, -1.3675e+00,
  

In [14]:
fp4_range

tensor([-6.0000, -4.0000, -3.0000, -2.0000, -1.5000, -1.0000, -0.7500,  0.0000,
         0.7500,  1.0000,  1.5000,  2.0000,  3.0000,  4.0000,  6.0000])

In [15]:
class FP4_Quantizer_Blockwise():
  def __init__(self,block_size=8):
    self.fp4_values = torch.tensor(FP4_E2M1().values)
    self.block_size = block_size

  def quantize(self,input_tensor):
    data_flat = input_tensor.view(-1) # Flatten
    num_blocks = (data_flat.numel()+ self.block_size -1) // self.block_size
    quantized_data = torch.zeros(num_blocks * (self.block_size//2), dtype=torch.uint8) # Every 8 bit we'll pack together 2 tensor of 4 bit
    scales = torch.zeros(num_blocks)

    for i in range(num_blocks):
      start = i*self.block_size
      end = min((i+1)*self.block_size,data_flat.numel())
      block = data_flat[start:end]
      scale = block.abs().max() # Get the max value of the block for the scale
      if scale == 0:
        scale = 1.0
      scales[i] = scale # Saving the scale factor for the block

      scaled_block = block/scale # Scale the tensor
      indices = torch.argmin(torch.abs(scaled_block.unsqueeze(1)-self.fp4_values),dim=1) # Find the nearest value
      # Combine two 4 bit indices in one uint8 value
      # This operation refactor the indices organizing it in group of two [[1,2],[3,4]...]
      # Then pack the values of the first column with the second column moving this one 4bit to the left (left bit shift operator)
      # For example if the index is 5 (0101) shifting it left will result in 0101 0000 (80)
      if indices.numel() % 2 != 0:
        # Pad with a dummy value to make the number of elements even
        indices = torch.cat((indices, torch.tensor([0], dtype=indices.dtype)))

      packed_indices = indices.view(-1,2)
      packed_values = packed_indices[:, 0] | (packed_indices[:, 1] << 4)
      quantized_data[i * (self.block_size // 2) : i * (self.block_size // 2) + packed_values.numel()] = packed_values
    return quantized_data, scales

  def dequantize(self,quantized_tensor,scales, original_shape):
    num_elements = torch.prod(torch.tensor(original_shape))
    dequantized_flat = torch.zeros(num_elements, dtype=torch.float32)

    num_blocks = scales.numel()
    current_index = 0
    for i in range(num_blocks):
      start = i * self.block_size
      end = min((i+1)*self.block_size, num_elements)
      current_block_size = end-start
      # How many 8-bit values to unpack for the current block
      packed_block_size = (current_block_size + 1) // 2
      packed_values = quantized_tensor[current_index:current_index+packed_block_size]
      # Unpack the values -> I need to do a bitwise operation the most signifanct bit will be the second index
      # The least significant bits will be the first index

      index_1 = packed_values & 0x0F
      index_2 = (packed_values >> 4) & 0x0F
      indices_unpacked =torch.stack([index_1,index_2], dim=1).view(-1)
      indices_unpacked = indices_unpacked[:current_block_size]
      fp4_block_value = self.fp4_values[indices_unpacked.long()]
      dequantized_flat[start:end] = fp4_block_value * scales[i]
      current_index += packed_block_size
    return dequantized_flat.view(original_shape)

In [16]:
quantizer = FP4_Quantizer_Blockwise(block_size=64)

In [17]:
quantized_data, scales = quantizer.quantize(input_tensor)

In [18]:
quantized_data, len(scales)

(tensor([121, 119, 119, 102, 102, 118,  86, 135, 103, 135, 118, 135, 119, 103,
         119, 134, 119, 119, 136, 119, 118, 120, 152, 119, 103, 135, 135, 119,
         117, 119, 118, 135, 134, 119, 103, 102, 135, 119, 119, 119, 118, 119,
         119, 135, 135, 119, 135, 119, 136, 117, 119, 119, 103, 120, 103, 119,
         119, 119, 135, 120, 119, 118, 119, 119, 119, 135, 135, 103, 103, 119,
         102, 119, 119, 103, 119,  87, 119, 104, 119, 119, 119, 119, 134, 119,
         118, 121, 120, 135, 119, 119, 119, 120, 119, 103, 118, 103, 133, 103,
         133, 120, 119, 119,  88, 104, 120, 103, 103, 135, 118, 119, 104, 102,
         120, 119, 120, 135, 119, 119, 119, 103, 119, 119, 119, 135, 102, 119,
         118, 119, 117, 120, 118, 119, 119, 119, 134, 118, 150, 119, 102, 118,
         119, 103,  87, 119, 118, 119, 136, 104, 120, 118, 118, 119, 103, 118,
         119, 118, 135, 103, 118, 102, 104, 119, 119, 120, 117, 135, 103,  87,
         120, 135, 103, 103, 105, 103, 135, 119, 119

We have N scale factor (with N =Tensor Input dimension / block_size)

In [19]:
dequantized_tensor = quantizer.dequantize(quantized_data,scales, input_tensor.shape)

In [20]:
dequantized_tensor

tensor([[ 2.1831,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -1.6373, -1.6373,
         -1.6373, -1.6373, -1.6373,  0.0000, -1.6373, -2.1831,  0.0000,  1.6373,
          0.0000, -1.6373,  0.0000,  1.6373, -1.6373,  0.0000,  0.0000,  1.6373,
          0.0000,  0.0000,  0.0000, -1.6373,  0.0000,  0.0000, -1.6373,  1.6373,
          0.0000,  0.0000,  0.0000,  0.0000,  1.6373,  1.6373,  0.0000,  0.0000,
         -1.6373,  0.0000,  1.6373,  0.0000,  1.6373,  2.1831,  0.0000,  0.0000,
          0.0000, -1.6373,  0.0000,  1.6373,  0.0000,  1.6373,  0.0000,  0.0000,
         -2.1831,  0.0000,  0.0000,  0.0000, -1.6373,  0.0000,  0.0000,  1.6373,
         -1.7574,  1.7574,  0.0000,  0.0000,  0.0000, -1.7574, -1.7574, -1.7574,
          0.0000,  1.7574,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         -1.7574,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.7574,
          0.0000,  1.7574,  0.0000,  0.0000,  0.0000,  1.7574,  0.0000,  0.0000,
          1.7574,  1.7574, -

In [21]:
(dequantized_tensor- input_tensor).abs().mean()

tensor(0.4377)

# Matmul operations

Let's see a problem using the FP4 data type

In [22]:
BLOCK_SIZE = 64

In [23]:
class matmul():
  def __init__(self,quantizer):
    self.quantizer = quantizer
  def __call__(self, input_tensor, weights, scales=None, weights_quantized=False, shape=None):
    if weights_quantized:
      if shape is None or scales is None:
        raise Exception("'shape' and 'scales' are required")
      weights = self.quantizer.dequantize(weights, scales, shape)
      weights = weights.to(torch.bfloat16)
    output = torch.matmul(input_tensor, weights.T)
    return output

In [24]:
quantizer = FP4_Quantizer_Blockwise(block_size=BLOCK_SIZE)

In [25]:
in_features, out_features = 1024, 512
weights = torch.randn(out_features, in_features).to(torch.bfloat16)
input_tensor = torch.randn(1, in_features).to(torch.bfloat16)

In [26]:
matmul_operation = matmul(quantizer = quantizer)

In [27]:
base_matmul_result = matmul_operation(input_tensor, weights, weights_quantized=False)

In [28]:
quantized_weight, scales = quantizer.quantize(weights)
dequantized_matmul_result = matmul_operation(input_tensor, quantized_weight, weights_quantized=True,scales=scales, shape=weights.shape)

In [29]:
dequantized_matmul_result - base_matmul_result

tensor([[-17.5000, -33.2500,  16.7500,  -6.5312,  16.2500,  20.6250, -25.0000,
          -6.5000,  -5.3750, -16.6250,   8.4375,  -2.8750, -18.6250,  -8.5000,
         -19.0000,  33.0000,  18.5000,   4.2500,  18.8750,  -7.3750, -13.3750,
          -4.2188,  -1.5938,   6.0000,  47.5000,  10.6250,  -7.5000,   0.8750,
          -4.1250,  16.3750, -32.2500,  14.0000,   8.8750,  21.2500,  21.7500,
          11.0000,  20.7500,   1.0000,   0.5000,  10.2500,  -3.2500, -18.5000,
          -6.2500,   7.8125, -31.0000,   2.7500,  15.7500, -13.5625, -31.0000,
           0.2500,  27.5000,   8.2500,  20.3750, -24.8750, -19.8750, -13.3750,
           0.2500, -30.0000, -40.0000, -11.4375,  22.2500,  -8.3750,  -7.0000,
         -13.5000,  -3.6562,  -8.0000,  14.5000, -24.0000, -23.2500,  -8.5625,
         -22.5000,  15.7500,  35.0000,   1.8125,  25.5000, -20.6250, -20.5000,
           6.2500,  -4.8750,   0.5000, -14.3750, -14.8750,  -3.3125,  -1.5625,
         -30.3750,  12.0000,   8.7500, -11.2500, -20

In [30]:
(dequantized_matmul_result - base_matmul_result).abs().mean()

tensor(14.2500, dtype=torch.bfloat16)

The errors are huge, and this will lead to gigantic error by our models. That's the reason enterprice libraries like BitsandBytes doesn't use this data type but actually they define a special data type called NF4.

NF4 works pretty well with LLM due to their nature

# NF4 - Normal Float 4

The weights in large neural networks, including LLMs, tend to follow a zero-centered normal distribution. This means most weights are clustered around zero, with fewer weights at the extremes. NF4 takes advantage of this by creating a quantization scheme where the "bins" or discrete values are not equally spaced. Instead, there are more bins around zero to capture the fine-grained details of the majority of the weights, and fewer, wider bins for the less common outlier weights.


This non-uniform approach is more information-theoretically optimal for normally distributed data, as it minimizes the quantization error and preserves the crucial information in the weights that are essential for the model's performance.

In [31]:
class NF4_Quantizer_Blockwise():
  def __init__(self,block_size=8):
    self.nf4_values = torch.tensor([
        -1.0000, -0.6962, -0.5251, -0.3949, -0.2844, -0.1848, -0.0911, 0.0000,
         0.0796,  0.1609,  0.2461,  0.3379,  0.4407,  0.5626,  0.7229,  1.0000
    ], dtype=torch.float32) # Precomputed
    self.block_size = block_size

  def quantize(self,input_tensor):
    data_flat = input_tensor.view(-1) # Flatten
    num_blocks = (data_flat.numel()+ self.block_size -1) // self.block_size
    quantized_data = torch.zeros(num_blocks * (self.block_size//2), dtype=torch.uint8) # Every 8 bit we'll pack together 2 tensor of 4 bit
    scales = torch.zeros(num_blocks)

    for i in range(num_blocks):
      start = i*self.block_size
      end = min((i+1)*self.block_size,data_flat.numel())
      block = data_flat[start:end]
      scale = block.abs().max() # Get the max value of the block for the scale
      if scale == 0:
        scale = 1.0
      scales[i] = scale # Saving the scale factor for the block

      scaled_block = block/scale # Scale the tensor
      indices = torch.argmin(torch.abs(scaled_block.unsqueeze(1)-self.nf4_values),dim=1) # Find the nearest value
      # Combine two 4 bit indices in one uint8 value
      # This operation refactor the indices organizing it in group of two [[1,2],[3,4]...]
      # Then pack the values of the first column with the second column moving this one 4bit to the left (left bit shift operator)
      # For example if the index is 5 (0101) shifting it left will result in 0101 0000 (80)
      if indices.numel() % 2 != 0:
        # Pad with a dummy value to make the number of elements even
        indices = torch.cat((indices, torch.tensor([0], dtype=indices.dtype)))

      packed_indices = indices.view(-1,2)
      packed_values = packed_indices[:, 0] | (packed_indices[:, 1] << 4)
      quantized_data[i * (self.block_size // 2) : i * (self.block_size // 2) + packed_values.numel()] = packed_values
    return quantized_data, scales

  def dequantize(self,quantized_tensor,scales, original_shape):
    num_elements = torch.prod(torch.tensor(original_shape))
    dequantized_flat = torch.zeros(num_elements, dtype=torch.float32)

    num_blocks = scales.numel()
    current_index = 0
    for i in range(num_blocks):
      start = i * self.block_size
      end = min((i+1)*self.block_size, num_elements)
      current_block_size = end-start
      # How many 8-bit values to unpack for the current block
      packed_block_size = (current_block_size + 1) // 2
      packed_values = quantized_tensor[current_index:current_index+packed_block_size]
      # Unpack the values -> I need to do a bitwise operation the most signifanct bit will be the second index
      # The least significant bits will be the first index

      index_1 = packed_values & 0x0F
      index_2 = (packed_values >> 4) & 0x0F
      indices_unpacked =torch.stack([index_1,index_2], dim=1).view(-1)
      indices_unpacked = indices_unpacked[:current_block_size]
      nf4_block_value = self.nf4_values[indices_unpacked.long()]
      dequantized_flat[start:end] = nf4_block_value * scales[i]
      current_index += packed_block_size
    return dequantized_flat.view(original_shape)

In [32]:
quantizer = NF4_Quantizer_Blockwise(block_size=BLOCK_SIZE)

In [33]:
matmul_operation = matmul(quantizer=quantizer)

In [34]:
quantized_weight, scales = quantizer.quantize(weights)

In [35]:
base_matmul_result = matmul_operation(input_tensor, weights, weights_quantized=False)

In [36]:
quantized_weight, scales = quantizer.quantize(weights)
dequantized_matmul_result = matmul_operation(input_tensor, quantized_weight, weights_quantized=True,scales=scales, shape=weights.shape)

In [37]:
dequantized_matmul_result - base_matmul_result

tensor([[ 0.7500, -1.5938,  1.7500,  0.5000, -4.0000,  0.2500,  1.2500,  1.7500,
         -1.4375, -1.0000, -3.3281, -1.5000,  2.1250,  4.0000,  2.5000,  0.7500,
          1.4375, -2.3750,  0.3125, -2.7500, -5.3750, -4.4062,  3.3125, -1.5000,
          1.7812, -4.1250, -4.1250,  6.6250,  0.2500,  1.0000,  1.7500, -2.5000,
          1.2500, -1.4375, -0.5625, -0.5000,  3.7500, -1.1250,  0.0000,  2.5000,
          2.0000,  3.2500,  4.7500, -0.6250,  1.0000, -2.7500,  4.0000,  5.2500,
         -1.0625,  0.0000, -0.8750,  0.0000, -1.2500,  0.3750, -2.2500,  1.0000,
          2.0000, -0.6250, -0.8750, -4.4688,  2.5000, -0.2500,  3.3750, -0.5000,
          1.0938, -5.0000,  1.8750,  0.8750, -5.0000,  3.7500,  4.5000,  4.7500,
          2.5000,  3.1875, -5.8750, -5.3750,  0.7500,  0.0000, -2.9375, -3.6250,
         -4.3438, -1.2500,  0.6250,  0.8125,  0.5000, -3.0000, -3.0000, -4.2500,
          4.4375,  2.6250, -4.0000,  2.5000,  0.2500,  4.2500,  3.5000, -1.2500,
         -3.7500, -0.3750, -

In [38]:
(dequantized_matmul_result - base_matmul_result).abs().mean()

tensor(2.4375, dtype=torch.bfloat16)

The error now are more manageable.

This algorithm is the one used from bitsandbytes when training QLora, or when you quantize the model to manage the memory requirements.

Of course the main advantages is inside a custom kernel and CUDA Optimization that performe the dequantize and the matmul operations directly in one step.