# bitlinear

> A quantized linear layer implementation here

## Idea

This paper https://arxiv.org/pdf/2402.17764.pdf suggest that we can use a linear layer which weight can effectively only get three values $(-scale, 0, scale)$

But for pretraining purpose they still use full (or half) precision weights, effectively do quantization only during the forward pass:

$
ThreeValuesLinear(W, x) = ThreeValuesWeight(W) x
$

$
ThreeValuesWeight(W) = 
    sign({
        W - {\sum(W) \over {inputFeatures outputFeatures}}
    })
    {
        {|W|} \over {inputFeatures outputFeatures}
    }
$

So $sign(...)$ part effectively make it $(-1, 0, 1)$ values, and multiplying to a $scale$ converts to $(-scale, 0, scale)$

But since it's not clear how to introduce gradient for $sign$, as well as how to perform updated over quantized weights - they use non-quantized weights and perform quantization as a forward pass.

But there is a method - https://arxiv.org/pdf/2307.05695.pdf paper introduce a ReLoRA approach - making a high-rank updates via a sequence of low-rank ones.

So I am gonna try the following approach:

- Initialize $W$ weights
- Save them to a temporary file $File_W$
- Quantize $W$ weights to $W_quant$
- Unload $W$ weights
- For a few iterations
  - Train $i$-th LoRA adapter via ReLoRA procedure
  - Load $W$ weights from $File_W$
  - Merge with $\Delta W$ we got from the adapter: $W = W + \Delta W$
  - Save them to a temporary file $File_W$
  - Quantize $W$ weights to $W_quant$
  - Unload $W$ weights

## Implementation

In [1]:
#| default_exp bitlinear

In [2]:
#| export
from typing import List, Union, Tuple, Iterable
import torch
from bitlinear.adapters import LinearAdapter, LoRAAdapter, MergeableLayer

## Determining parameter count per byte

Okay, keeping in mind we need each parameter to have 3 values - {-1, 0, 1} * scale - let's see how much parameters can we pack inside one uint32

In [3]:
#| export
STORAGE_BIT_COUNT = 8
STORAGE_DTYPE = torch.ByteTensor

In [4]:
#| export
def _get_parameter_count_per_n_bits(n: int) -> int:
    i = 0
    while True:
        j = i + 1
        if 3 ** j > 2 ** n:
            break
        i += 1
    return i


STORAGE_VALUES_PER_ITEM = _get_parameter_count_per_n_bits(STORAGE_BIT_COUNT)
STORAGE_VALUES_PER_ITEM

5

So 5 parameters group per one uint8.

## Parameter group index to parameter mapping

Now let's generate a tensor of uint8 index -> [5 * float16] values

In [5]:
#| export
def _generate_parameter_mappings(parameter_count: int, pad_to_size: int) -> List[List[int]]:
    def _iter(rest_count):
        if rest_count == 0:
            return [[]]
        else:
            result = []
            for p in [-1, 0, 1]:
                for rest in _iter(rest_count-1):
                    result.append([p] + rest)
            return result
    
    response = _iter(parameter_count)
    assert len(response) < pad_to_size
    response += [ [1] * parameter_count ] * (pad_to_size - len(response))
    return response


def _generate_parameter_mappings_tensor(parameter_count: int, pad_to_size: int) -> torch.Tensor:
    return torch.FloatTensor(_generate_parameter_mappings(parameter_count, pad_to_size))

In [6]:
#| export
MAPPING_UINT8_TO_5_PARAMS = _generate_parameter_mappings_tensor(
    STORAGE_VALUES_PER_ITEM,
    2 ** STORAGE_BIT_COUNT
)
assert MAPPING_UINT8_TO_5_PARAMS.shape == (256, 5)

In [7]:
MAPPING_UINT8_TO_5_PARAMS

tensor([[-1., -1., -1., -1., -1.],
        [-1., -1., -1., -1.,  0.],
        [-1., -1., -1., -1.,  1.],
        ...,
        [ 1.,  1.,  1.,  1.,  1.],
        [ 1.,  1.,  1.,  1.,  1.],
        [ 1.,  1.,  1.,  1.,  1.]])

## Dequantization function

Dequantization is a fairly simple process - for each byte here - use it as an index for the mapping

In [8]:
#| export
@torch.no_grad
def dequantize_weights(weight_mapping: torch.Tensor, packed_weights: torch.Tensor, scale: Union[torch.Tensor, float]) \
    -> torch.Tensor:
    weights_per_item = weight_mapping.shape[-1]
    weights_packed_shape = list(packed_weights.shape[:-1]) + \
        [weights_per_item * packed_weights.shape[-1]]
    dequantized_weights_k = weight_mapping[packed_weights.long(), :].view(weights_packed_shape)
    return dequantized_weights_k * scale

## Quantization

Reverse quantization is a bit more complicated.

Here we calculate the difference between each 5-parameter group vector ang pick the most similar vector index

In [9]:
#| export
@torch.no_grad
def quantize_weights(weight_mapping: torch.FloatTensor, weights: torch.FloatTensor, mean: Union[torch.Tensor, float]) \
    -> torch.Tensor:
    weights_centered = weights - mean
    weights_sign = weights_centered.sign()
    weights_reshaped = weights_sign.view(
        list(weights_sign.shape[:-1]) + [weights_sign.shape[-1] // STORAGE_VALUES_PER_ITEM, 1, STORAGE_VALUES_PER_ITEM]
    )
    weight_mapping_reshaped = weight_mapping.view(
        [1] * len(weights_sign.shape) + [2 ** STORAGE_BIT_COUNT, STORAGE_VALUES_PER_ITEM]
    )
    weights_mapping_diff = (weights_reshaped - weight_mapping_reshaped).abs()
    weights_group_scores = weights_mapping_diff.sum(dim=-1)
    weigths_group_chosen = weights_group_scores.argmin(dim=-1)
    weigths_group_chosen.clamp_(0, 3 ** STORAGE_VALUES_PER_ITEM - 1)
    return weigths_group_chosen.byte()

## Test dequantization-quantization sequence

To test this two functions above I will generate some random "packed weights" than dequantize them and quantize back.

In [10]:
def test_dequantize_quantize():
    torch.manual_seed(42)
    for _ in range(100):
        index = (torch.rand(10, 200 // STORAGE_VALUES_PER_ITEM) * (3 ** STORAGE_VALUES_PER_ITEM - 1)).round()
        index_restored = quantize_weights(
            MAPPING_UINT8_TO_5_PARAMS,
            dequantize_weights(MAPPING_UINT8_TO_5_PARAMS, index, 50.0),
            0
        )
        assert (index == index_restored).all()


test_dequantize_quantize()

In [11]:
#| export
class BitLinear(MergeableLayer):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 bias: bool = True,
                 device=None,
                 dtype=None,
                 original_weights_filename: Union[str, None] = None,
                 adapter: Union[None, LinearAdapter]=None):
        super(BitLinear, self).__init__(adapter=adapter)
        assert in_features % STORAGE_VALUES_PER_ITEM == 0

        if device:
            self.mapping = MAPPING_UINT8_TO_5_PARAMS.to(device=device)
        else:
            self.mapping = MAPPING_UINT8_TO_5_PARAMS * 1
        self.mapping_cpu = MAPPING_UINT8_TO_5_PARAMS * 1
            

        self.original_weights_filename = original_weights_filename
        
        initial_linear = torch.nn.Linear(in_features, out_features, bias=bias, device="cpu", dtype=dtype)
        self.mean, self.scale, self.quant_weight = self._wrap_parameters(self._quantize_weight(
            initial_linear.weight,
            device=device
        ))
        
        if bias:
            bias_tensor = initial_linear.bias.data
            if device is not None:
                bias_tensor = bias_tensor.to(device)
            self.bias = torch.nn.Parameter(bias_tensor)
        else:
            self.register_parameter("bias", None)
        if original_weights_filename:
            torch.save(
                initial_linear.weight,
                self.original_weights_filename,
            )        
        
        self.adapter = adapter

    def _wrap_parameters(self, tensors: Iterable[torch.Tensor]) -> List[torch.Tensor]:
        return [
            torch.nn.Parameter(tensor, requires_grad=False)
            for tensor in tensors
        ]

    @torch.no_grad
    def _quantize_weight(self, weight: torch.Tensor, device: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        weight = weight.cpu()
        mean = weight.mean().cpu()
        scale = weight.abs().mean().cpu()
        qweight = quantize_weights(self.mapping_cpu, weight, mean).to(device)
        return mean.to(device), scale.to(device), qweight.to(device)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        W = dequantize_weights(self.mapping, self.quant_weight, self.scale)
        response = torch.nn.functional.linear(input, W, self.bias)
        if self.adapter:
            adapter = self.adapter(input)
            response = response + adapter
        return response
    
    def _update_parameter(self, parameters: Iterable[torch.nn.Parameter], tensors: Iterable[torch.Tensor]) -> None:
        for parameter, tensor in zip(parameters, tensors):
            parameter.data = tensor
    
    @torch.no_grad
    def _update_weight(self, update: torch.Tensor) -> None:
        assert self.original_weights_filename is not None
        weight = torch.load(
            self.original_weights_filename,
            map_location="cpu"
        )
        weight_updated = weight + update.to("cpu")
        self._update_parameter(
            [self.mean, self.scale, self.quant_weight],
            self._quantize_weight(weight_updated, self.quant_weight.device)
        )
        torch.save(
            weight_updated,
            self.original_weights_filename,
        )

    def merge_adapter(self) -> None:
        assert self.adapter is not None
        self._update_weight(
            self.adapter.calculate_weight_update()
        )
        self.adapter.reset()

In [12]:
def test_lora_merging():
    torch.manual_seed(42)
    similarities_diff_lora_merged_and_lora_raw = []
    for batch_size in range(1, 1000):
        # Initialize "linear" layer
        linear = BitLinear(
            in_features=10,
            out_features=20,
            bias=True,
            device=None,
            dtype=None,
            original_weights_filename="test-linear-weights.bin",
            adapter=None,
        )
        # Compute layer output for random input
        input = torch.rand(batch_size, 10)
        output_linear = linear(input)
        # Add non-trained (means lora_B is all zeroes) adapter
        lora = LoRAAdapter(in_features=10, out_features=20, lora_rank=3)
        linear.adapter = lora
        output_linear_nottrained_lora = linear(input)

        # Check that a non-trained adapter do not change a thing
        similarity_raw_and_resetted = torch.nn.functional.cosine_similarity(output_linear, output_linear_nottrained_lora)
        assert similarity_raw_and_resetted.min() >= 1.0 - 1e-5

        # Imitate training making some lora_B nonzero
        lora.lora_b.data = torch.rand(*lora.lora_b.shape) * 10.0
        # Compute layer+lora output
        output_linear_lora_trained = linear(input)

        # Merge lora adapter into weights and re-quantize them
        delta_W = lora.calculate_weight_update()
        linear.merge_adapter()
        # Compute merged layer+lora output
        output_linear_lora_merged = linear(input)

        # Check that merged layer+lora is more similar to layer+lora than to raw layer output
        similarity_lora_and_raw = torch.nn.functional.cosine_similarity(output_linear, output_linear_lora_trained)
        similarity_lora_and_merged = torch.nn.functional.cosine_similarity(output_linear_lora_trained, output_linear_lora_merged)

        similarities_diff_lora_merged_and_lora_raw.append(
            (similarity_lora_and_merged - similarity_lora_and_raw).mean().item()
        )
    assert torch.FloatTensor(similarities_diff_lora_merged_and_lora_raw).mean() > 0.3


test_lora_merging()

In [13]:
#| hide
import nbdev; nbdev.nbdev_export()