In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from helper import plot_quantization_error

In [23]:
#formula
# r = s(q - z)

In [24]:
def w8_a16_forward(input, weight, scales, bias):
    casted_weight = weight.to(input.dtype)
    output = F.linear(input, casted_weight) * scales
    if bias is not None:
        output = output + bias
    return output

In [25]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
        super().__init__()

        self.register_buffer("int8_weights", torch.randint(-128, 128, (out_features, in_features), dtype=torch.int8))

        self.register_buffer("scales", torch.randn((out_features), dtype=dtype))

        if bias:
            self.register_buffer("bias", torch.randn((1, out_features), dtype=dtype))
        else:
            self.bias = None

    def quantize(self, weights):
        #rmax / qmax
        w_fp32 = weights.clone().to(torch.float32)
        scales = torch.max(torch.abs(weights), dim=-1).values / 127
        scales = scales.to(weights.dtype)

        int8_weights = torch.round(weights / scales.unsqueeze(1)).to(torch.int8)

        self.scales = scales
        self.int8_weights = int8_weights
    
    def forward(self, input):
        self.quantize()
        return w8_a16_forward(input, self.int8_weights, self.scales, self.bias)

In [26]:
class DummyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(10, 2)
        self.linear_1 = nn.Linear(2, 4)
        self.linear_2 = nn.Linear(4, 6, bias=False)
        self.lm_head = nn.Linear(6, 2)
    def forward(self):
        pass

In [27]:
model = DummyModel()

In [28]:
module_name_to_exclude = ["lm_head"]

In [31]:
def replace_linear_with_target(module, target_class, module_name_to_exclude):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear) and not any([x == name for x in module_name_to_exclude]):
            old_bias = child.bias
            new_module = target_class(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)
            setattr(module, name, new_module)
            if old_bias is not None:
                getattr(module, name).bias = old_bias
        else:
            replace_linear_with_target(child, target_class, module_name_to_exclude)

In [32]:
replace_linear_with_target(model, W8A16LinearLayer, module_name_to_exclude)

In [33]:
model

DummyModel(
  (embedding): Embedding(10, 2)
  (linear_1): W8A16LinearLayer()
  (linear_2): W8A16LinearLayer()
  (lm_head): Linear(in_features=6, out_features=2, bias=True)
)

In [37]:
def replace_linear_with_target_quantize(module, target_class, module_name_to_exclude):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear) and not any([x == name for x in module_name_to_exclude]):
            old_bias = child.bias
            old_weights = child.weight
            new_module = target_class(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)
            setattr(module, name, new_module)
            getattr(module, name).quantize(old_weights)
            if old_bias is not None:
                getattr(module, name).bias = old_bias
        else:
            replace_linear_with_target_quantize(child, target_class, module_name_to_exclude)

In [38]:
model_2 = DummyModel()

In [39]:
replace_linear_with_target_quantize(model_2, W8A16LinearLayer, module_name_to_exclude)

In [40]:
model_2

DummyModel(
  (embedding): Embedding(10, 2)
  (linear_1): W8A16LinearLayer()
  (linear_2): W8A16LinearLayer()
  (lm_head): Linear(in_features=6, out_features=2, bias=True)
)

In [42]:
model_2.linear_1.int8_weights

tensor([[ 127,   61],
        [-127,   48],
        [ 127,  -77],
        [  30,  127]], dtype=torch.int8)