In [32]:
import torch
from torchvision import datasets, transforms
from naive_mnist import NaiveModel
from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
from nni.compression.pytorch.quantization.settings import set_quant_scheme_dtype

In [33]:
torch.set_default_dtype(torch.float32)
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# INT8 weight, INT8 activations
num_bits = 9
# Make sure this matches quantization config from MNIST_CNN_Training
configure_list = [{
    'quant_types': ['weight', 'input'],
    'quant_bits': {'weight': num_bits, 'input': num_bits},
    'quant_start_step': 2,
    'op_names': ['conv1', 'conv2']
}, {
    'quant_types': ['output'],
    'quant_bits': {'output': num_bits},
    'quant_start_step': 2,
    'op_names': ['relu1', 'relu2', 'relu3']
}, {
    'quant_types': ['output', 'weight', 'input'],
    'quant_bits': {'output': num_bits, 'weight': num_bits, 'input': num_bits},
    'quant_start_step': 2,
    'op_names': ['fc1', 'fc2'],
}]

set_quant_scheme_dtype('weight', 'per_tensor_symmetric', 'int')
set_quant_scheme_dtype('output', 'per_tensor_symmetric', 'int')
set_quant_scheme_dtype('input', 'per_tensor_symmetric', 'int')

# Load MNIST dataset with train/test split sets.
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=trans),
    batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=trans),
    batch_size=100, shuffle=True)

idim = next(iter(train_loader))[0][0].size()[1]
ifmap = next(iter(train_loader))[0][0].size()[0]
fc2_nodes = len(torch.unique(train_loader.dataset.targets))

# Create a NaiveModel object and apply QAT_Quantizer setup
model_path = "models/mnist_model_9bit.pth"
qmodel = NaiveModel().to(device)
dummy_input = torch.randn(1, ifmap, idim, idim).to(device)
optimizer = torch.optim.SGD(qmodel.parameters(), lr=0.01, momentum=0.5)
# To enable batch normalization folding in the training process, you should
# pass dummy_input to the QAT_Quantizer.
quantizer = QAT_Quantizer(qmodel, configure_list, optimizer, dummy_input=dummy_input)
quantizer.compress()

# Load trained model (from MNIST_CNN_Training step).
state = torch.load(model_path, map_location=device)
qmodel.load_state_dict(state, strict=True)
qmodel.eval();

In [34]:
# 'Fake' Quantization function [Jacob et. al]
def quantize(real_value, scale, zero_point, qmin, qmax):
    transformed_val = zero_point + real_value / scale
    clamped_val = torch.clamp(transformed_val, qmin, qmax)
    quantized_val = torch.round(clamped_val)
    return quantized_val

# Scaling function (Jacob et. al)
def scale_quant(real_value, num_bits):
    qmin = -(2 ** (num_bits - 1) - 1)
    qmax = 2 ** (num_bits - 1) - 1
    abs_max = torch.abs(real_value).max()
    scale = abs_max / (float(qmax - qmin) / 2)
    zero_point = 0
    quant = quantize(real_value, scale, zero_point, qmin, qmax)
    return scale, quant

In [35]:
conv = qmodel.conv1.module
num_bits = 9

sw_conv, filters_conv = scale_quant(conv.weight.cpu(), num_bits)
kernel = filters_conv[0][0]

s_in, x = scale_quant(test_loader.dataset.data, num_bits)

image = x[0]

In [40]:
with open("templates/TB_template.v", "r") as f_in:
    file_data = f_in.readlines()

for index, line in enumerate(file_data):
    if "// Start of custom" in line:
        start_point = index + 1
count = 0

for y in range(image.size()[1]-kernel.size()[1]):
    for x in range(image.size()[1]-kernel.size()[1]):
        mac_result = 0
        for yb in range(kernel.size()[1]):
            for xb in range(kernel.size()[1]):
                mac_result = mac_result + (image[x+xb][y+yb] * kernel[xb][yb])

                if image[x+xb][y+yb].item() < 0:
                    image_pre = "-"
                else:
                    image_pre = ""
                if kernel[xb][yb].item() < 0:
                    kernel_pre = "-"
                else:
                    kernel_pre = ""

                file_data.insert(start_point+count, f"        #4;\n        MULT({image_pre}15'd{abs(image[x+xb][y+yb].item()):.0f}, {kernel_pre}15'd{abs(kernel[xb][yb].item()):.0f});\n")
                count = count + 1
        file_data.insert(start_point+count, f"        #4;\n        CHECK_ACCUM(15'd{mac_result.item():.0f});\n")
        count = count + 1

with open("../../../Hardware/Verilog/mac_TB.v", 'w') as f_out:
    contents = "".join(file_data)
    f_out.write(contents)