# Model evaluation with AdaPT on MNIST dataset

In this notebook you can evaluate different approximate multipliers on various models based on MNIST dataset

Steps:
* Select number of threads to use
* Load dataset
* Load Adapt Layers
* Define Model
* Run model calibration for quantization
* Evaluate


**Note**:
* This notebook should be run on a X86 machine

* Please make sure you have run the installation steps first

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data.dataloader as dataloader
import torch.optim as optim

from torch.utils.data import TensorDataset
from torch.autograd import Variable
from torchvision import transforms
from torchvision.datasets import MNIST
import tqdm

## Select number of threads to use

For optimal performance set them as the number of your cpu threads (not cpu cores)

In [2]:
threads = 40
torch.set_num_threads(threads)

#maybe better performance
%env OMP_PLACES=cores
%env OMP_PROC_BIND=close
%env OMP_WAIT_POLICY=active

env: OMP_PLACES=cores
env: OMP_PROC_BIND=close
env: OMP_WAIT_POLICY=active


## Load Dataset


In [3]:
train = MNIST('./datasets/mnist_data/data', train=True, download=True, transform=transforms.Compose([
    transforms.ToTensor(), # ToTensor does min-max normalization. 
]), )

test = MNIST('./datasets/mnist_data/data', train=False, download=True, transform=transforms.Compose([
    transforms.ToTensor(), # ToTensor does min-max normalization. 
]), )

# Create DataLoader
dataloader_args = dict(shuffle=True, batch_size=64,num_workers=0 , pin_memory=False)
train_loader = dataloader.DataLoader(train, **dataloader_args)
test_loader = dataloader.DataLoader(test, **dataloader_args)

## Load Adapt Layers

In [4]:
#Load ADAPT layers
from adapt.approx_layers import axx_layers as approxNN

## Choose approximate multiplier 

Two approximate multipliers are already provided

**mul8s_acc** - (header file: mul8s_acc.h)   <--  default

**mul8s_1L2H** - (header file: mul8s_1L2H.h)



In order to use your custom multiplier you need to use the provided tool (LUT_generator) to easily create the C++ header for your multiplier. Then you just place it inside the adapt/cpu-kernels/axx_mults folder. The name of the axx_mult here must match the name of the header file. The same axx_mult is used in all layers. 

Tip: If you want explicitly to set for each layer a different axx_mult you must do it from the model definition using the respective AdaPT_Conv2d class of each layer.

In [5]:
axx_mult = 'mul8s_1L2H'

## Define Model

Jit compilation method loads 'on the fly' the C++ extentions of the approximate multipliers. Then the pytorch model is loaded

In [6]:
#set flag for use of AdaPT custom layers or vanilla PyTorch
use_adapt=True

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        if use_adapt:
             self.fc1 = approxNN.AdaPT_Linear(784, 548, axx_mult = axx_mult)
        else:
            self.fc1 = nn.Linear(784, 548)
        
        self.bc1 = nn.BatchNorm1d(548)

        if use_adapt:
            self.fc2 = approxNN.AdaPT_Linear(548, 252, axx_mult = axx_mult)
        else:    
            self.fc2 = nn.Linear(548, 252)
            
        self.bc2 = nn.BatchNorm1d(252)     
        
        if use_adapt:
            self.fc3 = approxNN.AdaPT_Linear(252, 10, axx_mult = axx_mult)
        else:
            self.fc3 = nn.Linear(252, 10)
                
    def forward(self, x):
        x = x.view((-1, 784))
        h = self.fc1(x)
        h = self.bc1(h)
        h = F.relu(h)
        h = F.dropout(h, p=0.5, training=self.training)
        
        h = self.fc2(h)
        h = self.bc2(h)
        h = F.relu(h)
        h = F.dropout(h, p=0.2, training=self.training)
        
        h = self.fc3(h)
        out = F.log_softmax(h,-1)
        return out

model =Model()
model.cpu()

#load pretrained weights
model.load_state_dict(torch.load('models/state_dicts/mnist.pt'))

#optimizer = optim.Adam(model.parameters(), lr=0.001)

Using /root/.cache/torch_extensions as PyTorch extensions root...
Emitting ninja build file /root/.cache/torch_extensions/PyInit_linear_mul8s_1L2H/build.ninja...
Building extension module PyInit_linear_mul8s_1L2H...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module PyInit_linear_mul8s_1L2H...
Using /root/.cache/torch_extensions as PyTorch extensions root...
No modifications detected for re-loaded extension module PyInit_linear_mul8s_1L2H, skipping build step...
Loading extension module PyInit_linear_mul8s_1L2H...
Using /root/.cache/torch_extensions as PyTorch extensions root...
No modifications detected for re-loaded extension module PyInit_linear_mul8s_1L2H, skipping build step...
Loading extension module PyInit_linear_mul8s_1L2H...


<All keys matched successfully>

## Run model calibration for quantization

Calibrates the quantization parameters 

Need to re-run it each time the model changes

In [52]:
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import calib

def collect_stats(model, data_loader, num_batches):
     """Feed data to the network and collect statistic"""

     # Enable calibrators
     for name, module in model.named_modules():
         if isinstance(module, quant_nn.TensorQuantizer):
             if module._calibrator is not None:
                 module.disable_quant()
                 module.enable_calib()
             else:
                 module.disable()
        
     evaluate_x = Variable(data_loader.dataset.data.type_as(torch.FloatTensor())).cpu()
     model(evaluate_x)
        
     # Disable calibrators
     for name, module in model.named_modules():
         if isinstance(module, quant_nn.TensorQuantizer):
             if module._calibrator is not None:
                 module.enable_quant()
                 module.disable_calib()
             else:
                 module.enable()

def compute_amax(model, **kwargs):
 # Load calib result
 for name, module in model.named_modules():
     if isinstance(module, quant_nn.TensorQuantizer):
         if module._calibrator is not None:
             if isinstance(module._calibrator, calib.MaxCalibrator):
                 module.load_calib_amax()
             else:
                 module.load_calib_amax(**kwargs)
         print(F"{name:40}: {module}")
 model.cpu()

# It is a bit slow since we collect histograms on CPU
with torch.no_grad():
    stats = collect_stats(model, test_loader, num_batches=2)
    amax = compute_amax(model, method="percentile", percentile=99.99)
    
    # optional - test different calibration methods
    #amax = compute_amax(model, method="mse")
    #amax = compute_amax(model, method="entropy")
    

W0822 16:33:04.832009 140603895383872 tensor_quantizer.py:173] Disable HistogramCalibrator
W0822 16:33:04.832573 140603895383872 tensor_quantizer.py:173] Disable HistogramCalibrator
W0822 16:33:04.832999 140603895383872 tensor_quantizer.py:173] Disable HistogramCalibrator
W0822 16:33:04.833375 140603895383872 tensor_quantizer.py:173] Disable HistogramCalibrator
W0822 16:33:04.833750 140603895383872 tensor_quantizer.py:173] Disable HistogramCalibrator
W0822 16:33:04.834131 140603895383872 tensor_quantizer.py:173] Disable HistogramCalibrator
W0822 16:33:04.835496 140603895383872 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0822 16:33:04.836484 140603895383872 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0822 16:33:04.837378 140603895383872 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0822 16:33:04.838256 140603895383872 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0822 16:33:04.839241 1406

fc1.quantizer                           : TensorQuantizer(8bit per-tensor amax=254.8755 calibrator=HistogramCalibrator scale=0.4982825219631195 quant)
fc1.quantizer_w                         : TensorQuantizer(8bit per-tensor amax=0.5369 calibrator=HistogramCalibrator scale=236.53900146484375 quant)
fc2.quantizer                           : TensorQuantizer(8bit per-tensor amax=5.3846 calibrator=HistogramCalibrator scale=22.013214111328125 quant)
fc2.quantizer_w                         : TensorQuantizer(8bit per-tensor amax=0.4709 calibrator=HistogramCalibrator scale=269.6973571777344 quant)
fc3.quantizer                           : TensorQuantizer(8bit per-tensor amax=5.8338 calibrator=HistogramCalibrator scale=20.485553741455078 quant)
fc3.quantizer_w                         : TensorQuantizer(8bit per-tensor amax=0.5777 calibrator=HistogramCalibrator scale=219.846435546875 quant)


## Evaluate

In [53]:
# force TensorQuantizers to load buffers on CPU instead of .cuda()
def force_quantizer_cpu():
    import torch
    from pytorch_quantization.nn import modules

    orig_fn = modules.tensor_quantizer.TensorQuantizer._load_from_state_dict

    def new_fn(self, state_dict, prefix, *args, **kwargs):
        key = prefix + '_amax'
        if key in state_dict:
            buf = state_dict[key].data.cpu()   # force CPU
            self.register_buffer("_amax", buf)
        else:
            # fallback to original
            orig_fn(self, state_dict, prefix, *args, **kwargs)

    modules.tensor_quantizer.TensorQuantizer._load_from_state_dict = new_fn

force_quantizer_cpu()

In [54]:
from adapt.approx_layers.systolic_utils import compare_exact_vs_approx

acc_baseline, acc_exact, acc_approx, delta = compare_exact_vs_approx(
    model, test_loader, axx_mult="mul8s_acc", device="cpu"
)

print(f'Baseline (normal CPU): {acc_baseline:.4f}')
print(f'Exact systolic:        {acc_exact:.4f}')
print(f'Approx systolic:       {acc_approx:.4f}')
print(f'Δ (Exact - Approx):    {delta:.4f}')


Using /root/.cache/torch_extensions as PyTorch extensions root...
No modifications detected for re-loaded extension module PyInit_linear_mul8s_1L2H, skipping build step...
Loading extension module PyInit_linear_mul8s_1L2H...
Using /root/.cache/torch_extensions as PyTorch extensions root...
No modifications detected for re-loaded extension module PyInit_linear_mul8s_1L2H, skipping build step...
Loading extension module PyInit_linear_mul8s_1L2H...
Using /root/.cache/torch_extensions as PyTorch extensions root...
No modifications detected for re-loaded extension module PyInit_linear_mul8s_1L2H, skipping build step...
Loading extension module PyInit_linear_mul8s_1L2H...
Using /root/.cache/torch_extensions as PyTorch extensions root...
No modifications detected for re-loaded extension module PyInit_linear_mul8s_1L2H, skipping build step...
Loading extension module PyInit_linear_mul8s_1L2H...
Using /root/.cache/torch_extensions as PyTorch extensions root...
No modifications detected for re-l

In [10]:
evaluate_x = Variable(test_loader.dataset.data.type_as(torch.FloatTensor())).cpu()
evaluate_y = Variable(test_loader.dataset.targets).cpu()


output = model(evaluate_x)
pred = output.data.max(1)[1]
d = pred.eq(evaluate_y.data).cpu()
accuracy = d.sum()/d.size()[0]

print('Accuracy:', accuracy)

Accuracy: tensor(0.9788)


In [30]:
import os, time, shutil, importlib
import torch
import torch.nn.functional as F
from torch.utils.cpp_extension import load, _get_build_directory

# =========================
# Config
# =========================
torch.manual_seed(0)
DEVICE = "cpu"

# Multiplier selection (matches header: adapt/axx_mults/<AXX_MULT>.h)
AXX_MULT = "mul8s_acc"    # e.g., "mul8s_acc"
USE_EXACT = False         # True -> exact multiply; False -> LUT

# =========================
# Resolve paths
# =========================
cwd = os.getcwd()
parent = os.path.abspath(os.path.join(cwd, ".."))          # one directory above the notebook
adapt_root = os.path.join(parent, "adapt")                 # /workspace/adapt/adapt
src_dir = os.path.join(adapt_root, "cpu-kernels")          # /workspace/adapt/adapt/cpu-kernels
include_root = adapt_root                                  # so #include "axx_mults/<AXX_MULT>.h" works

linear_src_candidates = [
    os.path.join(src_dir, "axx_linear_systolic.cpp"),
    os.path.join(src_dir, "axx_linear.cpp"),  # fallback if you only have the non-systolic file
]
conv2d_src = os.path.join(src_dir, "axx_conv2d_systolic.cpp")

print("Current path:", cwd)
print("Parent path:", parent)
print("ADaPT root:", adapt_root)
print("SRC_DIR:", src_dir)
print("Linear candidates:", linear_src_candidates)
print("Conv2d src:", conv2d_src)

# =========================
# Try load C++ extensions
# =========================
cpp_linear = None
cpp_conv2d = None
USE_CPP_LINEAR = True
USE_CPP_CONV2D = True


def _base_mod_name_for_src(src_path: str) -> str:
    base = os.path.basename(src_path)
    if base == "axx_linear_systolic.cpp":
        return "PyInit_linear_systolic_"
    elif base == "axx_linear.cpp":
        return "PyInit_linear_"
    return "PyInit_linear_"

def _build_mod_name(src_path: str, exact: bool, axx_mult: str, tag: str = "") -> str:
    base = _base_mod_name_for_src(src_path)
    if exact:
        base += "exact_"
    name = base + axx_mult
    if tag:
        name += "_" + tag
    return name

def _clean_build_dir(mod_name: str):
    try:
        bdir = _get_build_directory(mod_name, verbose=True)
        if os.path.isdir(bdir):
            print(f"[clean] Removing build dir: {bdir}")
            shutil.rmtree(bdir)
    except Exception as e:
        print(f"[clean] Could not locate/remove build dir for {mod_name}: {e}")

def try_load_linear_kernel():
    global cpp_linear, USE_CPP_LINEAR

    src = next((p for p in linear_src_candidates if os.path.isfile(p)), None)
    if src is None:
        print("[warn] No linear .cpp found; falling back to Python.")
        USE_CPP_LINEAR = False
        return

    # 1) First attempt: deterministic name
    mod_name = _build_mod_name(src, USE_EXACT, AXX_MULT)
    cflags = [
        f"-DAXX_MULT={AXX_MULT}",
        "-O3", "-fopenmp", "-march=native", "-std=c++17",
        f"-I{include_root}",
    ]
    if USE_EXACT:
        cflags.append("-DUSE_EXACT")

    def _try(name):
        print(f"[build] Trying to load: {name} from {src}")
        return load(
            name=name,
            sources=[src],
            extra_cflags=cflags,
            extra_ldflags=["-lgomp"],
            verbose=True,
        )

    try:
        cpp_linear = _try(mod_name)
        print(f"[info] Loaded C++ linear kernel: {mod_name}")
        return
    except Exception as e:
        print(f"[warn] First attempt failed for {mod_name}: {e}")

    # 2) Clean cache for that name and retry same name once
    _clean_build_dir(mod_name)
    try:
        cpp_linear = _try(mod_name)
        print(f"[info] Loaded C++ linear kernel after clean: {mod_name}")
        return
    except Exception as e:
        print(f"[warn] Retry after clean failed for {mod_name}: {e}")

    # 3) Force a fresh unique module name (timestamp) to bypass any caching issues
    unique_tag = f"t{int(time.time())}"
    mod_name2 = _build_mod_name(src, USE_EXACT, AXX_MULT, tag=unique_tag)
    try:
        cpp_linear = _try(mod_name2)
        print(f"[info] Loaded C++ linear kernel with unique name: {mod_name2}")
        return
    except Exception as e:
        print(f"[error] Could not load linear kernel at all. Falling back to Python.\n  Reason: {e}")
        USE_CPP_LINEAR = False
        cpp_linear = None
def try_load_conv2d_kernel():
    global cpp_conv2d, USE_CPP_CONV2D
    if not os.path.isfile(conv2d_src):
        print(f"[warn] Conv2d source not found: {conv2d_src}; falling back to Python.")
        USE_CPP_CONV2D = False
        return
    mod_name = ("PyInit_conv2d_systolic_exact_" if USE_EXACT else "PyInit_conv2d_systolic_") + AXX_MULT
    cflags = [f"-DAXX_MULT={AXX_MULT}", "-O3", "-fopenmp", "-march=native", f"-I{include_root}", "-std=c++17"]
    if USE_EXACT:
        cflags.append("-DUSE_EXACT")
    try:
        cpp_conv2d = load(
            name=mod_name,
            sources=[conv2d_src],
            extra_cflags=cflags,
            extra_ldflags=["-lgomp"],
            verbose=False,
        )
        print(f"[info] Loaded C++ conv2d kernel: {mod_name} from {conv2d_src}")
    except Exception as e:
        print(f"[warn] Could not load conv2d kernel. Falling back to Python.\n  Reason: {e}")
        USE_CPP_CONV2D = False
        cpp_conv2d = None

try_load_linear_kernel()
try_load_conv2d_kernel()

# =========================
# Quantization helpers
# =========================
def symmetric_int8_quant(x: torch.Tensor, amax: float):
    max_value = 127.0
    scale = max_value / max(1e-12, amax)
    q = torch.clamp((x * scale).round(), -128, 127).to(torch.int8)
    return q, scale

def dequant_from_int32(acc_i32: torch.Tensor, scale_in: float, scale_w: float):
    return acc_i32.to(torch.float32) / (scale_in * scale_w)

# =========================
# Python systolic simulators
# =========================
def systolic_linear_python(A_i8: torch.Tensor, B_i8: torch.Tensor, collect=False):
    M, K = A_i8.shape
    N = B_i8.shape[0]
    C = torch.zeros((M, N), dtype=torch.int32)
    snaps = []
    for t in range(K):
        a_col = A_i8[:, t].to(torch.int16)
        b_col = B_i8[:, t].to(torch.int16)
        C += (a_col.view(M, 1) * b_col.view(1, N)).to(torch.int32)
        if collect:
            snaps.append(C.clone())
    return C, snaps

def systolic_conv2d_python(X_i8: torch.Tensor, W_i8: torch.Tensor,
                           kernel_size=(3,3), stride=(1,1), padding=(0,0),
                           collect=False):
    N, C, H, W = X_i8.shape
    O, Cw, Kh, Kw = W_i8.shape
    assert C == Cw
    Sh, Sw = stride
    Ph, Pw = padding
    Ho = (H + 2*Ph - Kh)//Sh + 1
    Wo = (W + 2*Pw - Kw)//Sw + 1

    Xf = X_i8.to(torch.float32)  # unfold sometimes expects float
    cols = torch.nn.functional.unfold(
        Xf, kernel_size=(Kh, Kw), stride=(Sh, Sw), padding=(Ph, Pw)
    ).to(torch.int8)  # [N, C*Kh*Kw, Ho*Wo]

    K = C*Kh*Kw
    L = Ho*Wo
    W_flat = W_i8.view(O, K)

    Y = torch.zeros((N, O, L), dtype=torch.int32)
    snaps = []
    for n in range(N):
        for t in range(K):
            x_t = cols[n, t, :]
            w_t = W_flat[:, t]
            Y[n] += (w_t.view(O, 1).to(torch.int16) * x_t.view(1, L).to(torch.int16)).to(torch.int32)
        if collect:
            snaps.append(Y[n].clone())

    return Y.view(N, O, Ho, Wo), snaps

# =========================
# References using int8->int32->dequant
# =========================
def reference_linear_int8(A_f: torch.Tensor, B_f: torch.Tensor):
    amax_x = A_f.abs().max().item()
    amax_w = B_f.abs().max().item()
    A_i8, s_x = symmetric_int8_quant(A_f, amax_x)
    B_i8, s_w = symmetric_int8_quant(B_f, amax_w)
    C_i32, snaps = systolic_linear_python(A_i8, B_i8, collect=True)
    C_f = dequant_from_int32(C_i32, s_x, s_w)
    return C_f, (A_i8, B_i8, C_i32, snaps), (s_x, s_w)

def reference_conv2d_int8(X_f: torch.Tensor, W_f: torch.Tensor, stride=(1,1), padding=(0,0)):
    amax_x = X_f.abs().max().item()
    amax_w = W_f.abs().max().item()
    X_i8, s_x = symmetric_int8_quant(X_f, amax_x)
    W_i8, s_w = symmetric_int8_quant(W_f, amax_w)
    Y_i32, snaps = systolic_conv2d_python(X_i8, W_i8, kernel_size=W_f.shape[-2:], stride=stride, padding=padding, collect=True)
    Y_f = dequant_from_int32(Y_i32, s_x, s_w)
    return Y_f, (X_i8, W_i8, Y_i32, snaps), (s_x, s_w)

# =========================
# Tests (print inputs/outputs)
# =========================
def test_linear(M=4, K=6, N=5):
    print("\n=== Linear correctness test ===")
    A = torch.randn(M, K)
    B = torch.randn(N, K)  # [N,K]
    print("Input A (M x K):\n", A)
    print("Input B (N x K):\n", B)

    # FP32 reference
    ref_fp32 = A @ B.t()
    print("FP32 Reference Output (M x N):\n", ref_fp32)

    # Int8 systolic reference (python)
    ref_int8, (A_i8, B_i8, C_i32_py, snaps), (s_x, s_w) = reference_linear_int8(A, B)
    print("Quantized A (int8):\n", A_i8)
    print("Quantized B (int8):\n", B_i8)
    print("Int32 Accumulator (Python) C:\n", C_i32_py)
    print("Dequantized INT8 Output (Python):\n", ref_int8)

    # Optional: C++ kernel
    if USE_CPP_LINEAR and cpp_linear is not None:
        C_i32_cpp = cpp_linear.forward(A_i8, B_i8)
        C_cpp_f = dequant_from_int32(C_i32_cpp, s_x, s_w)
        print("C++ Kernel Int32 Output:\n", C_i32_cpp)
        print("C++ Kernel Dequantized Output:\n", C_cpp_f)
        diff_cpp = (C_cpp_f - ref_fp32).abs().max().item()
        diff_cpp_int8 = (C_cpp_f - ref_int8).abs().max().item()
        print(f"CPP vs FP32 | max abs diff: {diff_cpp:.6f}")
        print(f"CPP vs INT8-dequant | max abs diff: {diff_cpp_int8:.6f}")
    else:
        print("[info] C++ linear kernel not used; compared Python systolic to FP32.")

    diff_int8 = (ref_int8 - ref_fp32).abs().max().item()
    print(f"INT8-dequant vs FP32 | max abs diff: {diff_int8:.6f}")

def test_conv(N=1, C=2, H=7, W=7, O=3, Kh=3, Kw=3, stride=(1,1), padding=(1,1)):
    print("\n=== Conv2d correctness test ===")
    X = torch.randn(N, C, H, W)
    Wt = torch.randn(O, C, Kh, Kw)
    print("Input X (N x C x H x W):\n", X)
    print("Weights W (O x C x Kh x Kw):\n", Wt)

    # FP32 reference
    ref_fp32 = F.conv2d(X, Wt, bias=None, stride=stride, padding=padding)
    print("FP32 Reference Output (N x O x Ho x Wo):\n", ref_fp32)

    # Int8 systolic reference (python)
    ref_int8, (X_i8, W_i8, Y_i32_py, snaps), (s_x, s_w) = reference_conv2d_int8(X, Wt, stride=stride, padding=padding)
    print("Quantized Input X (int8):\n", X_i8)
    print("Quantized Weights W (int8):\n", W_i8)
    print("Int32 Accumulator (Python) Y:\n", Y_i32_py)
    print("Dequantized INT8 Output (Python):\n", ref_int8)

    # Optional: C++ kernel
    if USE_CPP_CONV2D and cpp_conv2d is not None:
        C_i32_cpp = cpp_conv2d.forward(X_i8, W_i8, [Kh, Kw], list(stride), list(padding))
        C_cpp_f = dequant_from_int32(C_i32_cpp, s_x, s_w)
        print("C++ Kernel Int32 Output:\n", C_i32_cpp)
        print("C++ Kernel Dequantized Output:\n", C_cpp_f)
        diff_cpp = (C_cpp_f - ref_fp32).abs().max().item()
        diff_cpp_int8 = (C_cpp_f - ref_int8).abs().max().item()
        print(f"CPP vs FP32 | max abs diff: {diff_cpp:.6f}")
        print(f"CPP vs INT8-dequant | max abs diff: {diff_cpp_int8:.6f}")
    else:
        print("[info] C++ conv2d kernel not used; compared Python systolic to FP32.")

    diff_int8 = (ref_int8 - ref_fp32).abs().max().item()
    print(f"INT8-dequant vs FP32 | max abs diff: {diff_int8:.6f}")

# =========================
# Run
# =========================
if __name__ == "__main__":
    test_linear(M=4, K=6, N=5)
    test_conv(N=1, C=2, H=7, W=7, O=3, Kh=3, Kw=3, stride=(1,1), padding=(1,1))
    print("\nDone.")


Current path: /workspace/adapt/examples
Parent path: /workspace/adapt
ADaPT root: /workspace/adapt/adapt
SRC_DIR: /workspace/adapt/adapt/cpu-kernels
Linear candidates: ['/workspace/adapt/adapt/cpu-kernels/axx_linear_systolic.cpp', '/workspace/adapt/adapt/cpu-kernels/axx_linear.cpp']
Conv2d src: /workspace/adapt/adapt/cpu-kernels/axx_conv2d_systolic.cpp
[build] Trying to load: PyInit_linear_systolic_mul8s_acc from /workspace/adapt/adapt/cpu-kernels/axx_linear_systolic.cpp
Using /root/.cache/torch_extensions as PyTorch extensions root...
No modifications detected for re-loaded extension module PyInit_linear_systolic_mul8s_acc_v2, skipping build step...
Loading extension module PyInit_linear_systolic_mul8s_acc_v2...
[warn] First attempt failed for PyInit_linear_systolic_mul8s_acc: No module named 'PyInit_linear_systolic_mul8s_acc_v2'
Using /root/.cache/torch_extensions as PyTorch extensions root...
[clean] Removing build dir: /root/.cache/torch_extensions/PyInit_linear_systolic_mul8s_acc


In [46]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------- Quant helpers for PyTorch quantized path ----------
# ---------- Quant helpers for PyTorch quantized path ----------
def quantize_activation_pt(x: torch.Tensor):
    """
    PyTorch-style per-tensor affine activation quantization (quint8).
    scale = (xmax - xmin)/255, zp = round(-xmin/scale) clamped to [0,255]
    """
    xmin = x.min().item()
    xmax = x.max().item()
    if xmax == xmin:
        xmax = xmin + 1e-6
    scale = (xmax - xmin) / 255.0
    zp = int(round(-xmin / scale))
    zp = max(0, min(255, zp))
    xq = torch.quantize_per_tensor(x, scale=scale, zero_point=zp, dtype=torch.quint8)
    return xq, scale, zp

def quantize_weight_pt(w: torch.Tensor):
    """
    PyTorch-style symmetric weight quantization (qint8): zero_point = 0
    scale = max(|w|)/127
    """
    amax = w.abs().max().item()
    if amax < 1e-12:
        amax = 1e-12
    scale = amax / 127.0
    wq = torch.quantize_per_tensor(w, scale=scale, zero_point=0, dtype=torch.qint8)
    return wq, scale


# ---------- Compare Linear ----------
def compare_linear_systolic_vs_pytorch_int8(M=4, K=6, N=5):
    print("\n=== Compare Linear: Systolic INT8 vs PyTorch INT8 ===")
    torch.manual_seed(0)
    A = torch.randn(M, K)          # activations
    B = torch.randn(N, K)          # weights [out, in]

    # FP32 reference
    ref_fp32 = A @ B.t()

    # -- PyTorch INT8 path --
    Aq, s_x_pt, zp_x_pt = quantize_activation_pt(A)   # quint8
    Wq, s_w_pt = quantize_weight_pt(B)                # qint8 (zp=0)

    # Quantized Linear. Note: kwarg is bias_, not bias.
    qlin = nn.quantized.Linear(in_features=K, out_features=N, bias_=False, dtype=torch.qint8)
    qlin.set_weight_bias(Wq, None)
    # The int32 accumulator scale is s_x_pt * s_w_pt; set module output quant params accordingly
    qlin.scale = s_x_pt * s_w_pt
    qlin.zero_point = 0

    out_q = qlin(Aq)          # quantized output tensor
    out_pt = out_q.dequantize()

    # -- Your systolic INT8 path (symmetric int8) --
    A_i8, s_x_sys = symmetric_int8_quant(A, A.abs().max().item())
    B_i8, s_w_sys = symmetric_int8_quant(B, B.abs().max().item())
    if 'cpp_linear' in globals() and cpp_linear is not None:
        C_i32_sys = cpp_linear.forward(A_i8, B_i8)
    else:
        C_i32_sys, _ = systolic_linear_python(A_i8, B_i8, collect=False)
    out_sys = dequant_from_int32(C_i32_sys, s_x_sys, s_w_sys)

    # -- Metrics --
    e_sys_vs_fp32 = (out_sys - ref_fp32).abs()
    e_pt_vs_fp32  = (out_pt  - ref_fp32).abs()
    e_sys_vs_pt   = (out_sys - out_pt ).abs()

    print(f"FP32 out shape: {ref_fp32.shape}")
    print(f"Sys vs FP32:  max={e_sys_vs_fp32.max().item():.6f}, mean={e_sys_vs_fp32.mean().item():.6f}")
    print(f"PT  vs FP32:  max={e_pt_vs_fp32.max().item():.6f}, mean={e_pt_vs_fp32.mean().item():.6f}")
    print(f"Sys vs PT :   max={e_sys_vs_pt.max().item():.6f},  mean={e_sys_vs_pt.mean().item():.6f}")


# ---------- Compare Conv2d ----------
def compare_conv_systolic_vs_pytorch_int8(N=1, C=2, H=7, W=7, O=3, Kh=3, Kw=3, stride=(1,1), padding=(1,1)):
    print("\n=== Compare Conv2d: Systolic INT8 vs PyTorch INT8 ===")
    torch.manual_seed(1)
    X = torch.randn(N, C, H, W)
    Wt = torch.randn(O, C, Kh, Kw)

    # FP32 reference
    ref_fp32 = F.conv2d(X, Wt, bias=None, stride=stride, padding=padding)

    # -- PyTorch INT8 path via quantized op --
    Xq, s_x_pt, zp_x_pt = quantize_activation_pt(X)  # quint8 input
    Wq, s_w_pt = quantize_weight_pt(Wt)              # qint8 weight (zp=0)
    out_scale = s_x_pt * s_w_pt
    out_zp = 0

    # torch.ops.quantized.conv2d(input, weight, bias, stride, padding, dilation, groups, out_scale, out_zero_point)
    out_q = torch.ops.quantized.conv2d(
        Xq, Wq, None, list(stride), list(padding), [1, 1], 1, out_scale, out_zp
    )
    out_pt = out_q.dequantize()

    # -- Your systolic INT8 path (symmetric int8) --
    X_i8, s_x_sys = symmetric_int8_quant(X, X.abs().max().item())
    W_i8, s_w_sys = symmetric_int8_quant(Wt, Wt.abs().max().item())
    if 'cpp_conv2d' in globals() and cpp_conv2d is not None:
        Y_i32_sys = cpp_conv2d.forward(X_i8, W_i8, [Kh, Kw], list(stride), list(padding))
    else:
        Y_i32_sys, _ = systolic_conv2d_python(X_i8, W_i8, kernel_size=(Kh, Kw), stride=stride, padding=padding, collect=False)
    out_sys = dequant_from_int32(Y_i32_sys, s_x_sys, s_w_sys)

    # -- Metrics --
    e_sys_vs_fp32 = (out_sys - ref_fp32).abs()
    e_pt_vs_fp32  = (out_pt  - ref_fp32).abs()
    e_sys_vs_pt   = (out_sys - out_pt ).abs()

    print(f"FP32 out shape: {ref_fp32.shape}")
    print(f"Sys vs FP32:  max={e_sys_vs_fp32.max().item():.6f}, mean={e_sys_vs_fp32.mean().item():.6f}")
    print(f"PT  vs FP32:  max={e_pt_vs_fp32.max().item():.6f}, mean={e_pt_vs_fp32.mean().item():.6f}")
    print(f"Sys vs PT :   max={e_sys_vs_pt.max().item():.6f},  mean={e_sys_vs_pt.mean().item():.6f}")


In [48]:
compare_linear_systolic_vs_pytorch_int8(M=4, K=6, N=5)
#compare_conv_systolic_vs_pytorch_int8(N=1, C=2, H=7, W=7, O=3, Kh=3, Kw=3, stride=(1,1), padding=(1,1))



=== Compare Linear: Systolic INT8 vs PyTorch INT8 ===
FP32 out shape: torch.Size([4, 5])
Sys vs FP32:  max=0.035596, mean=0.013366
PT  vs FP32:  max=7.577308, mean=2.424562
Sys vs PT :   max=7.593278,  mean=2.427414
