# Weights Packing
Assume we want to quantize our model in 4-bit precision, and we want to store the weights in a
torch tensor. So ideally we want to create a tensor with some values and pass `dtype=torch.int4`.
But the problem is there is no native support for 4-bit weights in PyTorch.
```python
import torch
tensor = torch.tensor([0,1], dtype=torch.int4) # is not supported!
```
The only possible solution is instead of saving the tensor in 4-bit we have to save it in 8-bit
as currently it's the dtype with smallest precision that is available in PyTorch. So in practice
we need to save the tensor in 8-bit precision. But this will overhead for large models therefore if we go for native approach, means if we store the 4-bit weights in an 8-bit tensor, 
there will be no point quantizing the model into 4-bit because all the parameters will be stored in 8-bit precision. So for that we need to pack the 4-bit weights into 8-bit tensor.

## How does packing works?
Consider the tensor given bbelow that stores 4 values that can be represented in 2-bit precision, but stored in 8-bit.
```python
import torch
unpacked_tensor = torch.tensor([1, 0, 3, 2], dtype=torch.int8)
```
Currently this tensor is stored as [00000001  00000000  00000011  00000010]. This is not really optimal as we need to allocate 4 times 8-bit terms of memory in order to store weights that can be encoded only in 2-bit.<br>
To solve this we can "pack" all these data points into a single 8-bit tensor as [10110001]-->(10,11,00,01). This value in unit8 will end up being 177.
```python
packed_tensor = torch.tensor([177], dtype=torch.int8)
```
#### Advantages:
- It reflects the "true" memory footprint of the quantized weights
#### Disadvantages:
- The unpacked tensors need to be a shape with a multiple of 8//nbits
- It needs to unpack before performing an operation




# Packing 2-bit Weights

In [1]:
import torch

```Python
# Example Tensor: [1, 0, 3, 2]
    # 1 0 3 2 - 01 00 11 10

    # Starting point of packed int8 Tensor
    # [0000 0000]
    
    ##### First Iteration Start:
    # packed int8 Tensor State: [0000 0000]
    # 1 = 0000 0001
    # 0000 0001
    # No left shifts in the First Iteration
    # After bit-wise OR operation between 0000 0000 and 0000 0001:
    # packed int8 Tensor State: 0000 0001
    ##### First Iteration End

    ##### Second Iteration Start:
    # packed int8 Tensor State: [0000 0001]
    # 0 = 0000 0000
    # 0000 0000
    # 2 left shifts:
    # [0000 0000] (1 shift)-> 0000 0000 (2 shift)-> 0000 0000
    # After bit-wise OR operation between 0000 0001 and 0000 0000:
    # packed int8 Tensor State: 0000 0001
    ##### Second Iteration End

    ##### Third Iteration Start:
    # packed int8 Tensor State: [0000 0001]
    # 3 = 0000 0011
    # 0000 0011
    # 4 left shifts:
    # [0000 0011] (1 shift)-> 0000 0110 (2 shift)-> 0000 1100
    # 0000 1100 (3 shift)-> 0001 1000 (4 shift)-> 0011 0000
    # After bit-wise OR operation between 0000 0001 and 0011 0000:
    # packed int8 Tensor State: 0011 0001
    ##### Third Iteration End

    ##### Fourth Iteration Start:
    # packed int8 Tensor State: [0011 0001]
    # 2 = 0000 0010
    # 0000 0010
    # 6 left shifts:
    # [0000 0010] (1 shift)-> 0000 0100 (2 shift)-> 0000 1000
    # 0000 1000 (3 shift)-> 0001 0000 (4 shift)-> 0010 0000
    # 0010 0000 (5 shift)-> 0100 0000 (6 shift)-> 1000 0000
    # After bit-wise OR operation between 0011 0001 and 1000 0000:
    # packed int8 Tensor State: 1011 0001
    ##### Fourth Iteration End
    
    # Final packed int8 Tensor State: [1011 0001]
```

In [2]:
def pack_weights(uint8tensor, bits):
    if uint8tensor.shape[0] * bits % 8 != 0:
        raise ValueError(f"The input shape needs to be a mutiple \
        of {8 / bits} - got {uint8tensor.shape[0]}")

    num_values = uint8tensor.shape[0] * bits // 8

    num_steps = 8 // bits # ----> 4

    unpacked_idx = 0

    packed_tensor = torch.zeros((num_values), dtype=torch.uint8)

    # 1 0 3 2 - 01 00 11 10 --> for this each two bits we will retrieve the corresponding value 
    
    # [0000 0000] -> 0000 0001 ==== packed_tensor -> unpacked_tensor (hift these values on left by bits*j)

    # 0000 0001 --> result after bitwise OR operation between [0000 0000] and [0000 0001] 

    # 0000 0000 -> 0000 0000

    # 0000 0000 --> result after bitwise OR operation between [0000 0001] and [0000 0000]

    # 0000 0011 - 0011 0000 - 0011 0001 --> shifting by 4 bits

    # 1011 0001 --> shfiting by 6 bits
    
    for i in range(num_values):
        for j in range(num_steps):
            packed_tensor[i] |= uint8tensor[unpacked_idx] << (bits * j)
            unpacked_idx += 1
    return packed_tensor

In [3]:
unpacked_tensor = torch.tensor([1, 0, 3, 2], dtype=torch.uint8)

In [4]:
pack_weights(unpacked_tensor, 2)

tensor([177], dtype=torch.uint8)

In [5]:
unpacked_tensor = torch.tensor([1, 0, 3, 2, 3, 3, 3, 3], dtype=torch.uint8)

In [6]:
pack_weights(unpacked_tensor, 2)

tensor([177, 255], dtype=torch.uint8)

## Unpacking 2-Bit Weights

```Python
# Example Tensor: [10110001]
    # Which was Originally: 1 0 3 2 - 01 00 11 10

    # Starting point of unpacked Tensor
    # [00000000 00000000 00000000 00000000]
    
    ##### First Iteration Start:
    # packed int8 Tensor: [10110001]
    # You want to extract 01 from [101100 01]
    # No right shifts in the First Iteration
    # After bit-wise OR operation between 00000000 and 10110001:
    # [10110001 00000000 00000000 00000000]
    # unpacked Tensor state: [10110001 00000000 00000000 00000000]
    ##### First Iteration End

    ##### Second Iteration Start:
    # packed int8 Tensor: [10110001]
    # You want to extract 00 from [1011 00 01]
    # 2 right shifts:
    # [10110001] (1 shift)-> 01011000 (2 shift)-> 00101100
    # After bit-wise OR operation between 00000000 and 00101100:
    # [10110001 00101100 00000000 00000000]
    # unpacked Tensor state: [10110001 00101100 00000000 00000000]
    ##### Second Iteration End

    ##### Third Iteration Start:
    # packed int8 Tensor: [10110001]
    # You want to extract 11 from [10 11 0001]
    # 4 right shifts:
    # [10110001] (1 shift)-> 01011000 (2 shift)-> 00101100
    # 00101100 (3 shift)-> 00010110 (4 shift)-> 00001011
    # After bit-wise OR operation between 00000000 and 00001011:
    # [10110001 00101100 00001011 00000000]
    # unpacked Tensor state: [10110001 00101100 00001011 00000000]
    ##### Third Iteration End

    ##### Fourth Iteration Start:
    # packed int8 Tensor: [10110001]
    # You want to extract 10 from [10 110001]
    # 6 right shifts:
    # [10110001] (1 shift)-> 01011000 (2 shift)-> 00101100
    # 00101100 (3 shift)-> 00010110 (4 shift)-> 00001011
    # 00001011 (5 shift)-> 00000101 (6 shift)-> 00000010
    # After bit-wise OR operation between 00000000 and 00000010:
    # [10110001 00101100 00001011 00000010]
    # unpacked Tensor state: [10110001 00101100 00001011 00000010]
    ##### Fourth Iteration End
    
    # Last step: Perform masking (bit-wise AND operation)
    # Mask: 00000011
    # Bit-wise AND operation between 
    # unpacked Tensor and 00000011
    # [10110001 00101100 00001011 00000010] <- unpacked tensor
    # [00000011 00000011 00000011 00000011] <- Mask
    # [00000001 00000000 00000011 00000010] <- Result

    # Final
    # unpacked Tensor state: [00000001 00000000 00000011 00000010]

```

In [7]:
def unpack_weights(uint8tensor, bits):
    num_values = uint8tensor.shape[0] * 8 // bits

    num_steps = 8 // bits

    unpacked_tensor = torch.zeros((num_values), dtype=torch.uint8)

    unpacked_idx = 0

    # 1 0 3 2 - 01 00 11 10

    # [00000000 00000000 00000000 00000000]
    # [10110001 00101100 00001011 00000010]
    # [00000001 00000000 00000011 00000010]

    # 10110001
    # 00000011
    
    # 00000001

    # 1: [10110001]
    # 2: [00101100]
    # 3: [00001011]

    mask = 2 ** bits - 1

    for i in range(uint8tensor.shape[0]):
        for j in range(num_steps):
            unpacked_tensor[unpacked_idx] |= uint8tensor[i] >> (bits * j)
            unpacked_idx += 1

    unpacked_tensor &= mask
    return unpacked_tensor

In [8]:
unpacked_tensor = torch.tensor([177, 255], dtype=torch.uint8)

In [9]:
# Answer should be: torch.tensor([1, 0, 3, 2, 3, 3, 3, 3]
unpack_weights(unpacked_tensor, 2)

tensor([1, 0, 3, 2, 3, 3, 3, 3], dtype=torch.uint8)