In [1]:
import torch

original_weight = torch.randn((4,4))
print(original_weight)

tensor([[ 0.2369,  0.2868, -0.1640, -0.3524],
        [ 0.9579, -0.0910, -0.0846,  1.2647],
        [ 0.4284,  0.0883,  1.3122, -0.8039],
        [ 1.3006, -1.2052,  0.2557, -0.8690]])


In [5]:
# Now, we define two functions first for quantization and second for de-quantization.

def asymmetric_quantization(original_weight):
    # define the data type that you want to quantize. In our example, it's INT8.
    quantized_data_type = torch.int8

    # Get Wmax and Wmin value.
    Wmax = original_weight.max().item()
    Wmin = original_weight.min().item()

    # Get Qmax and Qmin value.
    Qmax = torch.iinfo(quantized_data_type).max
    Qmin = torch.iinfo(quantized_data_type).min

    # Calculate the Scale value using the scale formual.
    S = (Wmax - Wmin) / (Qmax - Qmin)

    # Calculate the zero point value using zero point formula.
    Z = Qmin - (Wmin/S)

    # Check if the Z value is out of the range.
    if Z < Qmin:
        Z = Qmin
    elif Z > Qmax:
        Z = Qmax
    else:
        # Zero point datatypes should be INT8 same as the Quantized value.
        Z = int(round(Z))

    # We have original_weight, scale and zero_point, now we can calculate the Quantized weight
    quantized_weight = (original_weight/S) + Z

    # We'll also round it and also use the torch clamp function to ensure the quantized weight doesn't goes out of range and should remain within Qmin and Qmax.
    quantized_weight = torch.clamp(torch.round(quantized_weight), Qmin, Qmax)

    # Finally cast the datatype to INT8
    quantized_data_type = quantized_weight.to(quantized_data_type)

    # return the final quantized weight.
    return quantized_weight, S, Z

def asasymmetric_dequantization(quantized_weight, scale, zero_point):
    # Also make sure to convert quantized_weight to float as substraction between two INT8 values (quantized_weight and zero_point) will give unwanted result. 

    dequantized_weight = scale * (quantized_weight.to(torch.float32) - zero_point)

    return dequantized_weight

quantized_weight, scale, zero_point = asymmetric_quantization(original_weight)
print(f"quantized weight: {quantized_weight}")
print("\n")
print(f"scale: {scale}")
print("\n")
print(f"zero point: {zero_point}")


dequantized_weight = asasymmetric_dequantization(quantized_weight, scale, zero_point)
print(dequantized_weight)

quantized_error = (dequantized_weight - original_weight).square().mean();
print(quantized_error);

quantized weight: tensor([[  18.,   23.,  -23.,  -42.],
        [  91.,  -15.,  -15.,  122.],
        [  37.,    3.,  127.,  -87.],
        [ 126., -128.,   20.,  -94.]])


scale: 0.009872203247219909


zero point: -6
tensor([[ 0.2369,  0.2863, -0.1678, -0.3554],
        [ 0.9576, -0.0888, -0.0888,  1.2636],
        [ 0.4245,  0.0888,  1.3130, -0.7996],
        [ 1.3031, -1.2044,  0.2567, -0.8688]])
tensor(5.6245e-06)
