# Quantize

In [None]:
import torch

# int8 对称量化
def int8_quantize(tensor):
    """
    量化公式: quantized = round(tensor / scale)
    """
    scale = (tensor.abs().max() / 127).item()  # 使用.item()获取标量值
    quantized_tensor = (tensor / scale).round().clamp(-128, 127)
    return quantized_tensor, scale

# int8 反量化
def int8_dequantize(quantized_tensor, scale):
    """
    对int8量化张量进行反量化
    """
    dequantized_tensor = quantized_tensor * scale
    return dequantized_tensor

t=torch.randn(10)
quantized_tensor, scale = int8_quantize(t)
dequantized_tensor = int8_dequantize(quantized_tensor, scale)

print(t)
print(quantized_tensor)
print(dequantized_tensor)

tensor([-1.0613,  0.0791, -1.4436, -0.9547,  0.9770,  0.0605,  0.5642,  2.0765,
         1.8535, -0.3742])
tensor([-65.,   5., -88., -58.,  60.,   4.,  35., 127., 113., -23.])
tensor([-1.0628,  0.0818, -1.4388, -0.9483,  0.9810,  0.0654,  0.5723,  2.0765,
         1.8476, -0.3761])


In [5]:
import torch

# int8 非对称量化
def int8_asym_quantize(tensor):
    """
    量化公式: quantized = round((tensor - zero_point) / scale)
    """
    # 计算张量的最小值和最大值
    tensor_min = tensor.min().item()
    tensor_max = tensor.max().item()
    
    scale = (tensor_max - tensor_min) / 255.0
    
    # 计算zero_point，使得tensor_min对应到-128
    # zero_point = tensor_min - scale * (-128)
    zero_point = tensor_min + scale * 128
    
    # 进行量化: quantized = round((tensor - zero_point) / scale)
    quantized_tensor = torch.round((tensor - zero_point) / scale)
    quantized_tensor = quantized_tensor.clamp(-128, 127)
    
    return quantized_tensor, scale, zero_point

# int8 非对称反量化
def int8_asym_dequantize(quantized_tensor, scale, zero_point):
    """
    对int8非对称量化张量进行反量化
    
    反量化公式: dequantized = quantized * scale + zero_point
    """
    dequantized_tensor = quantized_tensor * scale + zero_point
    return dequantized_tensor

# 测试非对称量化

t = torch.tensor([1.0, 2.5, -0.5, 4.2, 0.8, -1.2, 3.1, 0.0, 2.0, -0.8])
quantized_tensor, scale, zero_point = int8_asym_quantize(t)
dequantized_tensor = int8_asym_dequantize(quantized_tensor, scale, zero_point)

print(t)
print(quantized_tensor)
print(dequantized_tensor)

tensor([ 1.0000,  2.5000, -0.5000,  4.2000,  0.8000, -1.2000,  3.1000,  0.0000,
         2.0000, -0.8000])
tensor([ -24.,   47.,  -95.,  127.,  -34., -128.,   75.,  -71.,   23., -109.])
tensor([ 1.0024,  2.5059, -0.5012,  4.2000,  0.7906, -1.2000,  3.0988,  0.0071,
         1.9976, -0.7976])
