In [1]:
import os 
import torch
import torch.nn as nn
import gemm_int8

In [2]:
def quantize_row_int8_symmetric(mat: torch.Tensor):
    """
    Symmetric int8 quantization per row.
    mat: (N, M) float tensor
    Returns:
      q_mat: (N, M) int8
      scales: (N,) float32
    """
    qmin, qmax = -128, 127

    # Avoid division by zero
    max_vals = mat.abs().amax(dim=1, keepdim=True)  # (N, 1)
    max_vals = max_vals.clamp(min=1e-8)

    scales = (max_vals / qmax).squeeze(1)          # (N,)
    q_mat = torch.clamp(torch.round(mat / scales.unsqueeze(1)), qmin, qmax).to(torch.int8)

    return q_mat, scales.to(torch.float32)


def quantize_col_int8_symmetric(mat: torch.Tensor):
    """
    Symmetric int8 quantization per column.
    mat: (N, M) float tensor
    Returns:
      q_mat: (N, M) int8
      scales: (M,) float32
    """
    qmin, qmax = -128, 127

    max_vals = mat.abs().amax(dim=0, keepdim=True)  # (1, M)
    max_vals = max_vals.clamp(min=1e-8)

    scales = (max_vals / qmax).squeeze(0)           # (M,)
    q_mat = torch.clamp(torch.round(mat / scales.unsqueeze(0)), qmin, qmax).to(torch.int8)

    return q_mat, scales.to(torch.float32)
    

def dequant_col_int8(mat_int8: torch.Tensor,
                      scales: torch.Tensor,
                      out_dtype=torch.float16):
    """
    Dequantize int8 matrix per column.
    mat_int8: (N, M) int8
    scales: (M,) float32
    """ 
    mat_float = mat_int8.to(torch.float32)
    mat_dequant = mat_float * scales.unsqueeze(0)  # (N, M)
    return mat_dequant.to(out_dtype)

def dequant_int8_gemm(out_int: torch.Tensor,
                      x_scale: torch.Tensor,
                      w_scale: torch.Tensor,
                      out_dtype=torch.float16):
    """
    Dequantize result of INT8 matmul:
      out_int: (B, out_features) int32 or float32
      x_scale: (B,) from input rows
      w_scale: (out_features,) from weight rows/cols
    """
    if out_int.dtype == torch.int32:
        out_float = out_int.to(torch.float32)
    else:
        out_float = out_int

    # scale = x_scale.unsqueeze(1) * w_scale.unsqueeze(0)  # (B, out_features)
    # out = out_float * scale
    
    out = out_float * x_scale[:, None] * w_scale[None, :]
    
    return out.to(out_dtype)


In [3]:
class CustomLinear(nn.Module):
    def __init__(self, in_features, out_features, default_dtype=torch.float16):
        super().__init__()
        # Keep FP16 weights initially as (in_features, out_features)
        self.w = nn.Parameter(torch.randn(in_features, out_features, dtype=default_dtype))

        # For quantized weights:
        # we'll store w_q as (out_features, in_features) to match gemm_int8's (N, K)
        self.register_buffer("w_q", torch.empty(0, dtype=torch.int8), persistent=False)
        self.register_buffer("w_scale", torch.empty(0, dtype=torch.float32), persistent=False)

        self.is_quantized = False

    @torch.no_grad()
    def perform_weight_quantization(self):
        # Re-orient weights: (in_features, out_features) -> (out_features, in_features)
        w_t = self.w.t().contiguous()  # (out_features, in_features)

        w_q, w_scale = quantize_row_int8_symmetric(w_t)  # row-wise over out_features

        self.w_q = w_q
        self.w_scale = w_scale
        self.is_quantized = True

        # Free FP16 weights
        del self.w
        self.w = None
        print("CustomLinear: quantized weights to int8 and deleted float weights.")

    def forward(self, x):
        if not self.is_quantized:
            return x @ self.w

        # Quantize activations per row: (batch, in_features)
        x_q, x_scale = quantize_row_int8_symmetric(x)
        out_int = gemm_int8.matmul(x_q, self.w_q, alpha=1.0)

        out = dequant_int8_gemm(out_int, x_scale, self.w_scale, out_dtype=torch.float16)
        return out

In [4]:
class CustomModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, default_dtype=torch.float16):
        super().__init__()
        self.fc1 = CustomLinear(input_size, hidden_size, default_dtype)
        self.fc2 = CustomLinear(hidden_size, hidden_size, default_dtype)
        self.fc3 = nn.Linear(hidden_size, output_size, dtype=default_dtype)
        self.relu = nn.ReLU()

    @torch.no_grad()
    def start_quantization(self):
        # Quantize all CustomLinear layers
        for m in self.modules():
            if isinstance(m, CustomLinear):
                m.perform_weight_quantization()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x


In [5]:
hidden_size = 1024 * 8
batch_size = 512
ToyModel = CustomModel(hidden_size, hidden_size, hidden_size, default_dtype=torch.float16).cuda()

input_data = torch.randn(batch_size, hidden_size, device='cuda', dtype=torch.float16)

# Measure time 
torch._dynamo.reset()
# Warm up
for _ in range(10):
    output = ToyModel(input_data)

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)    
n_iter = 100
start_event.record()
for _ in range(n_iter):
    output = ToyModel(input_data)
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
avg_time = elapsed_time_ms / n_iter
print(f"Average time for forward pass over {n_iter} iterations: {avg_time:.2f} ms")
    
print(output.shape)

Average time for forward pass over 100 iterations: 1.33 ms
torch.Size([512, 8192])


In [6]:
# torch profiler
with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ],
    record_shapes=True,
    with_stack=True,
    profile_memory=True,
    with_flops=True,
) as prof:
    output = ToyModel(input_data)
    
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  Total GFLOPs  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::matmul         1.15%      26.438us        60.76%       1.402ms     700.886us       0.000us         0.00%     777.668us     388.834us           0 

In [7]:
ToyModel.start_quantization()

ToyModel = torch.compile(ToyModel, dynamic=True)

CustomLinear: quantized weights to int8 and deleted float weights.
CustomLinear: quantized weights to int8 and deleted float weights.


In [8]:
# Measure time 
torch._dynamo.reset()
# Warm up
for _ in range(10):
    output = ToyModel(input_data)
    
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)    
n_iter = 100
start_event.record()
for _ in range(n_iter):
    output = ToyModel(input_data)
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
avg_time = elapsed_time_ms / n_iter
print(f"Average time for quantized forward pass over {n_iter} iterations: {avg_time:.2f} ms")
print(output.shape)

Average time for quantized forward pass over 100 iterations: 1.00 ms
torch.Size([512, 8192])


In [9]:
# torch profiler
with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ],
    record_shapes=True,
    with_stack=True,
    profile_memory=True,
    with_flops=True,
) as prof:
    output = ToyModel(input_data)
    
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  Total GFLOPs  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                  Torch-Compiled Region         4.85%     142.288us        87.37%       2.566ms       2.566ms       0.000us         0.00%       1.030ms       1.030ms           0 