# L5-B: Packing 2-bit Weights

In this lesson, you will learn how to store low precision weights through a technique called "packing".

## Packing

In [None]:
import torch

**Note:** Younes will explain the below code, and walk through each iteration step. You can go through the comprehensive explaination in the markdown below after first watching Younes's explaination.

```Python
# Example Tensor: [1, 0, 3, 2]
    # 1 0 3 2 - 01 00 11 10

    # Starting point of packed int8 Tensor
    # [0000 0000]
    
    ##### First Iteration Start:
    # packed int8 Tensor State: [0000 0000]
    # 1 = 0000 0001
    # 0000 0001
    # No left shifts in the First Iteration
    # After bit-wise OR operation between 0000 0000 and 0000 0001:
    # packed int8 Tensor State: 0000 0001
    ##### First Iteration End

    ##### Second Iteration Start:
    # packed int8 Tensor State: [0000 0001]
    # 0 = 0000 0000
    # 0000 0000
    # 2 left shifts:
    # [0000 0000] (1 shift)-> 0000 0000 (2 shift)-> 0000 0000
    # After bit-wise OR operation between 0000 0001 and 0000 0000:
    # packed int8 Tensor State: 0000 0001
    ##### Second Iteration End

    ##### Third Iteration Start:
    # packed int8 Tensor State: [0000 0001]
    # 3 = 0000 0011
    # 0000 0011
    # 4 left shifts:
    # [0000 0011] (1 shift)-> 0000 0110 (2 shift)-> 0000 1100
    # 0000 1100 (3 shift)-> 0001 1000 (4 shift)-> 0011 0000
    # After bit-wise OR operation between 0000 0001 and 0011 0000:
    # packed int8 Tensor State: 0011 0001
    ##### Third Iteration End

    ##### Fourth Iteration Start:
    # packed int8 Tensor State: [0011 0001]
    # 2 = 0000 0010
    # 0000 0010
    # 6 left shifts:
    # [0000 0010] (1 shift)-> 0000 0100 (2 shift)-> 0000 1000
    # 0000 1000 (3 shift)-> 0001 0000 (4 shift)-> 0010 0000
    # 0010 0000 (5 shift)-> 0100 0000 (6 shift)-> 1000 0000
    # After bit-wise OR operation between 0011 0001 and 1000 0000:
    # packed int8 Tensor State: 1011 0001
    ##### Fourth Iteration End
    
    # Final packed int8 Tensor State: [1011 0001]
```

In [None]:
def pack_weights(uint8tensor, bits):
    if uint8tensor.shape[0] * bits % 8 != 0:
        raise ValueError(f"The input shape needs to be a mutiple \
        of {8 / bits} - got {uint8tensor.shape[0]}")

    num_values = uint8tensor.shape[0] * bits // 8

    num_steps = 8 // bits

    unpacked_idx = 0

    packed_tensor = torch.zeros((num_values), dtype=torch.uint8)

    # 1 0 3 2 - 01 00 11 10

    # [0000 0000] -> 0000 0001

    # 0000 0001

    # 0000 0000 - 0000 0000

    # 0000 0011 - 0011 0000 - 0011 0001

    # 1011 0001
    
    for i in range(num_values):
        for j in range(num_steps):
            packed_tensor[i] |= uint8tensor[unpacked_idx] << (bits * j)
            unpacked_idx += 1
    return packed_tensor

In [None]:
unpacked_tensor = torch.tensor([1, 0, 3, 2], 
                               dtype=torch.uint8)

In [None]:
pack_weights(unpacked_tensor, 2)

In [None]:
unpacked_tensor = torch.tensor([1, 0, 3, 2, 3, 3, 3, 3], 
                               dtype=torch.uint8)

In [None]:
pack_weights(unpacked_tensor, 2)