## Packing

In [1]:
import torch

## Packing

In [5]:
def pack_weights(uint8tensor, bits):
    if uint8tensor.shape[0] * bits % 8 != 0:
        raise ValueError(f"The input shape needs to be a mutiple \
        of {8 / bits} - got {uint8tensor.shape[0]}")

    num_values = uint8tensor.shape[0] * bits // 8

    num_steps = 8 // bits

    unpacked_idx = 0

    packed_tensor = torch.zeros((num_values), dtype=torch.uint8)

    
    for i in range(num_values):
        for j in range(num_steps):
            packed_tensor[i] |= uint8tensor[unpacked_idx] << (bits * j)
            unpacked_idx += 1
    return packed_tensor

In [3]:
unpacked_tensor = torch.tensor([1, 0, 3, 2], 
                               dtype=torch.uint8)
unpacked_tensor

tensor([1, 0, 3, 2], dtype=torch.uint8)

In [6]:
pack_weights(unpacked_tensor, 2)

tensor([177], dtype=torch.uint8)

In [7]:
unpacked_tensor = torch.tensor([1, 0, 3, 2, 3, 3, 3, 3], 
                               dtype=torch.uint8)

In [8]:
pack_weights(unpacked_tensor, 2)

tensor([177, 255], dtype=torch.uint8)

## Unpacking

In [10]:
def unpack_weights(uint8tensor, bits):
    num_values = uint8tensor.shape[0] * 8 // bits

    num_steps = 8 // bits

    unpacked_tensor = torch.zeros((num_values), dtype=torch.uint8)

    unpacked_idx = 0

    mask = 2 ** bits - 1

    for i in range(uint8tensor.shape[0]):
        for j in range(num_steps):
            unpacked_tensor[unpacked_idx] |= uint8tensor[i] >> (bits * j)
            unpacked_idx += 1

    unpacked_tensor &= mask
    return unpacked_tensor

In [11]:
unpacked_tensor = torch.tensor([177, 255], 
                               dtype=torch.uint8)

In [14]:
unpack_weights(unpacked_tensor, 2)

tensor([1, 0, 3, 2, 3, 3, 3, 3], dtype=torch.uint8)