In [1]:
import torch
import torch.nn as nn


In [2]:
#W8A16 Linear Layer

def linear_quantize(tensor, dtype=torch.int8):
    r_min = tensor.min().item()
    r_max = tensor.max().item()
    
    q_min = torch.iinfo(dtype).min
    q_max = torch.iinfo(dtype).max
    
    scale = (r_max - r_min) / (q_max - q_min)
    zero_point = torch.round(torch.tensor(q_min - (r_min / scale))).to(torch.int8).item()

    zero_point = torch.clamp_(torch.tensor(zero_point), min=q_min, max=q_max)
    
    quantized_tensor = torch.round(tensor / scale + zero_point)
    quantized_tensor = torch.clamp_(quantized_tensor, min=q_min, max=q_max).to(torch.int)
    return scale, zero_point.item(), quantized_tensor

def channel_linear_quantize(tensor, dim=0, dtype=torch.int8):
    scales = torch.zeros(tensor.size(dim))
    zero_pts = torch.zeros(tensor.size(dim), dtype=torch.int)
    
        
    quantized_tensor = torch.zeros_like(tensor, dtype=torch.int)
    
    for i in range(tensor.size(dim)):
        scales[i], zero_pts[i], quant = linear_quantize(tensor.select(dim, i), dtype=dtype)

        if dim == 1:
            quantized_tensor[:, i] = quant
        else:
            quantized_tensor[i, :] = quant
            
    if (dim == 0):
        scales = scales.view(scales.shape[dim], -1)
        zero_pts = zero_pts.view(zero_pts.shape[dim], -1)
    elif dim == 1:
        
        scales = scales.view(-1, scales.shape[0])
        zero_pts = zero_pts.view(-1, zero_pts.shape[0])
        
    # quantized_tensor = torch.clamp_(torch.round(tensor / scales + zero_pts), 
    #                                   min=torch.iinfo(dtype).min, 
    #                                   max=torch.iinfo(dtype).max).to(torch.int)
    return scales, zero_pts, quantized_tensor

In [None]:

class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features=512, out_features=512):
        super(W8A16LinearLayer, self).__init__()
        self.linear = nn.Linear(in_features=in_features, out_features=out_features, dtype=torch.bfloat16)
        # self.scales = self.register_buffer("scales", torch.randn(in_features, dtype=torch.bfloat16))
        # self.zero_pts = self.register_buffer("zero_pts", torch.randn(in_features))

    
    def quantize(self, x):
        x = x.to(torch.float32)
        scales, zero_pts, quant = channel_linear_quantize(x, dim=0)
        
        return scales, zero_pts, quant
    
    def forward(self, x):
        

In [4]:
def linear_dequantize(scale, zero_point, quantized_tensor):
    return scale * (quantized_tensor.float() - zero_point)


In [5]:
quantized_ly = W8A16LinearLayer()
test_tensor = torch.randn(512, 512, dtype=torch.bfloat16)
scales, zero_pts, quant = quantized_ly.quantize(test_tensor)
dequantized_tensor = linear_dequantize(scales, zero_pts, quant)

In [6]:
print("Original Tensor:\n", test_tensor)
print("Scales:\n", scales.shape)
print("Zero Points:\n", zero_pts.shape)
print("Quantized Tensor:\n", quant)
print("dequantized_tensor:\n", dequantized_tensor)


Original Tensor:
 tensor([[-1.9824e-01,  1.4375e+00, -3.5547e-01,  ..., -4.2969e-01,
          2.4375e+00, -3.0469e-01],
        [-9.5703e-01,  1.2031e+00, -1.0703e+00,  ..., -3.8477e-01,
          1.2656e+00,  0.0000e+00],
        [-1.1875e+00, -1.3125e+00, -1.3750e+00,  ...,  1.2344e+00,
         -4.8047e-01, -8.3984e-01],
        ...,
        [ 1.5469e+00,  9.0625e-01, -1.8750e-01,  ...,  1.0234e+00,
          3.6719e-01, -2.6367e-01],
        [-3.8281e-01, -4.1992e-01,  4.4434e-02,  ..., -6.0156e-01,
         -4.5117e-01,  5.2734e-01],
        [ 3.7575e-04, -9.4531e-01,  1.8125e+00,  ..., -2.8125e-01,
         -6.8750e-01, -7.8516e-01]], dtype=torch.bfloat16)
Scales:
 torch.Size([512, 1])
Zero Points:
 torch.Size([512, 1])
Quantized Tensor:
 tensor([[ -6,  59, -12,  ..., -15,  98, -10],
        [-39,  56, -44,  ..., -14,  59,   3],
        [-60, -66, -69,  ...,  50, -28, -44],
        ...,
        [ 70,  44,  -1,  ...,  49,  22,  -4],
        [-26, -28,  -7,  ..., -36, -29,  14],
 

In [7]:
(dequantized_tensor - test_tensor).square().mean()  # should be small, e.g., < 1e-2

tensor(4.3525e-05)

In [8]:
from helper import plot_quantization_errors

plot_quantization_errors(test_tensor, quant, dequantized_tensor)

KeyboardInterrupt: 