In [1]:
import torch

In [2]:
def pack_weights(uint8tensor, bits):
    if uint8tensor.shape[0] * bits % 8 != 0:
        raise ValueError(f"The input shape needs to be a mutiple \
        of {8 / bits} - got {uint8tensor.shape[0]}")

    num_values = uint8tensor.shape[0] * bits // 8

    num_steps = 8 // bits

    unpacked_idx = 0

    packed_tensor = torch.zeros((num_values), dtype=torch.uint8)

    # 1 0 3 2 - 01 00 11 10

    # [0000 0000] -> 0000 0001

    # 0000 0001

    # 0000 0000 - 0000 0000

    # 0000 0011 - 0011 0000 - 0011 0001

    # 1011 0001
    
    for i in range(num_values):
        for j in range(num_steps):
            packed_tensor[i] |= uint8tensor[unpacked_idx] << (bits * j)
            unpacked_idx += 1
    return packed_tensor

In [3]:
unpacked_tensor = torch.tensor([1, 0, 3, 2], 
                               dtype=torch.uint8)

In [4]:
pack_weights(unpacked_tensor, 2)

tensor([177], dtype=torch.uint8)

In [5]:
unpacked_tensor = torch.tensor([1, 0, 3, 2, 3, 3, 3, 3], 
                               dtype=torch.uint8)

In [6]:
pack_weights(unpacked_tensor, 2)

tensor([177, 255], dtype=torch.uint8)

In [7]:
def unpack_weights(uint8tensor, bits):
    num_values = uint8tensor.shape[0] * 8 // bits

    num_steps = 8 // bits

    unpacked_tensor = torch.zeros((num_values), dtype=torch.uint8)

    unpacked_idx = 0

    # 1 0 3 2 - 01 00 11 10

    # [00000000 00000000 00000000 00000000]
    # [10110001 00101100 00001011 00000010]
    # [00000001 00000000 00000011 00000010]

    # 10110001
    # 00000011
    
    # 00000001

    # 1: [10110001]
    # 2: [00101100]
    # 3: [00001011]

    mask = 2 ** bits - 1

    for i in range(uint8tensor.shape[0]):
        for j in range(num_steps):
            unpacked_tensor[unpacked_idx] |= uint8tensor[i] >> (bits * j)
            unpacked_idx += 1

    unpacked_tensor &= mask
    return unpacked_tensor

In [8]:
unpacked_tensor = torch.tensor([177, 255], 
                               dtype=torch.uint8)

In [9]:
unpack_weights(unpacked_tensor, 2)

tensor([1, 0, 3, 2, 3, 3, 3, 3], dtype=torch.uint8)

#### LLM quant papers:


1. LLM.int8


2.SmoothQuant


3.QLora


4.AWQ


5.QulP-2bit


6.HQQ-2bit


7.AQLM-2bit
