# Weight Packing

In [1]:
import torch

In [2]:
def pack_weights(uint8tensor, bits):
    if uint8tensor.shape[0] * bits % 8 != 0:
        raise ValueError("The total number of bits must be a multiple of 8.")

    num_values = uint8tensor.shape[0] * bits // 8
    num_steps = 8 // bits
    unpacked_idx = 0
    packed_tensor = torch.zeros(num_values, dtype=torch.uint8)
    for i in range(num_values):
        for j in range(num_steps):
            packed_tensor[i] |= (uint8tensor[unpacked_idx] & 0xFF) << (j * bits)
            unpacked_idx += 1
    return packed_tensor


def unpack_weights(uint8tensor, bits):
    num_values = uint8tensor.shape[0] * 8 // bits
    num_steps = 8 // bits
    unpacked_idx = 0
    unpacked_tensor = torch.zeros(num_values, dtype=torch.uint8)
    mask = (1 << bits) - 1
    for i in range(uint8tensor.shape[0]):
        for j in range(num_steps):
            unpacked_tensor[unpacked_idx] = mask & (uint8tensor[i] >> (j * bits))
            unpacked_idx += 1
    return unpacked_tensor

In [5]:
unpacked_tensor = torch.tensor([1, 0, 3, 2, 3, 1, 2, 3], dtype=torch.uint8)
packed_tensor = pack_weights(unpacked_tensor, 2)

In [6]:
unpack_weights(packed_tensor, 2)

tensor([1, 0, 3, 2, 3, 1, 2, 3], dtype=torch.uint8)