In [1]:
import torchvision
import torch
import torch.nn as nn
import sys
sys.path.append("../misc");
from utils import DataShaper
import os
import math
import matplotlib.pyplot as plt
import numpy as np
from brevitas.nn import QuantConv2d, QuantIdentity, QuantReLU
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint, Int8WeightPerTensorFixedPoint,Uint8ActPerTensorFixedPoint

torch.manual_seed(0)
design="bottleneck_cifar_split_vector"

# aie_teardown()
sys.path.append("../../../utils"); import xrtutils
xclbin_path = os.path.abspath("../bottleneck_block/"+design+"/build/final.xclbin")
insts_path  = os.path.abspath("../bottleneck_block/"+design+"/build/insts.txt")

log_folder="log/log_"+design

enable_aie = True
aie_is_setup = False
enable_trace = True
trace_file='traces/'+design+'.txt'

app = None
in_buf = None
arg1_buf = None
out_buf = None
dtype_in  = np.dtype("int8")
dtype_out = np.dtype("uint8")


shape_in_act   = (32,32,32,8)
shape_in_wts1  = (8,32,1,1,8,8) #out,in,ky,kx,in8,out8
shape_in_wts2  = (8,8,3,3,8,8)  #out,in,ky,kx,in8,out8
shape_in_wts3  = (32,8,1,1,8,8)   #out,in,ky,kx,in8,out8
shape_total_wts= (69632,1)
shape_out      = (32,32,32,8)

trace_size = 16384

def setup_aie(xclbin_path, insts_path, 
              in_0_shape, in_0_dtype,
              in_1_shape, in_1_dtype, 
              out_buf_shape, out_buf_dtype,
              enable_trace=False,
              kernel_name="MLIR_AIE"):
    app = xrtutils.AIE_Application(xclbin_path, insts_path, kernel_name)
    app.register_buffer(2, shape=in_0_shape, dtype=in_0_dtype)
    app.register_buffer(3, shape=in_1_shape, dtype=in_1_dtype)
    if enable_trace:
      out_buf_len_bytes = np.prod(out_buf_shape) * np.dtype(out_buf_dtype).itemsize
      out_buf_shape = (out_buf_len_bytes + trace_size, )
      out_buf_dtype = np.uint8
    app.register_buffer(4, shape=out_buf_shape, dtype=out_buf_dtype)
    return app

def extract_trace(out_buf, out_buf_shape, out_buf_dtype):
    trace_size_words = trace_size//4
    out_buf_flat = out_buf.reshape((-1,)).view(np.uint32)
    output_prefix = out_buf_flat[:-trace_size_words].view(out_buf_dtype).reshape(out_buf_shape)
    trace_suffix = out_buf_flat[-trace_size_words:]
    return output_prefix, trace_suffix

def write_out_trace(trace, file_name):
    out_str = "\n".join(f"{i:0{8}x}" 
                        for i in trace
                        if i != 0)
    with open(file_name, 'w') as f:
      f.write(out_str)

app = setup_aie(xclbin_path, insts_path,
                            shape_in_act, dtype_in,      
                            shape_total_wts,dtype_in,
                            shape_out, dtype_out,enable_trace)

torch.manual_seed(0)
torch.use_deterministic_algorithms(True)

if not os.path.exists(log_folder):
    os.makedirs(log_folder)
   
    
input=torch.randn(1, 256,32,32)
num_classes=10
ds = DataShaper()

# try:
for i in range (0,1): 
    class QuantBottleneck_projected(nn.Module):
        expansion = 4
        def __init__(self, in_planes=256, planes=64):
            super(QuantBottleneck_projected, self).__init__()
            self.quant_id_1 = QuantIdentity(act_quant=Int8ActPerTensorFixedPoint,bit_width=8, return_quant_tensor=True) 
            self.conv1 = QuantConv2d(in_planes, planes, kernel_size=1,bit_width=8,weight_bit_width=8, bias=False, weight_quant=Int8WeightPerTensorFixedPoint, return_quant_tensor=True)
            self.conv2 = QuantConv2d(planes, planes, kernel_size=3,bit_width=8,weight_bit_width=8, bias=False,padding=1,padding_mode ='zeros', weight_quant=Int8WeightPerTensorFixedPoint, return_quant_tensor=True)
            self.conv3 = QuantConv2d(planes, self.expansion *planes, kernel_size=1,bit_width=8,weight_bit_width=8, bias=False,weight_quant=Int8WeightPerTensorFixedPoint, return_quant_tensor=True)
            self.relu1 = QuantReLU(act_quant=Uint8ActPerTensorFixedPoint ,bit_width=8, return_quant_tensor=True)
            self.relu2 = QuantReLU(act_quant=Uint8ActPerTensorFixedPoint ,bit_width=8, return_quant_tensor=True)
            self.relu3 = QuantReLU(act_quant=Uint8ActPerTensorFixedPoint ,bit_width=8, return_quant_tensor=True)

        def forward(self, x):
            out_q = self.quant_id_1(x)
            out = self.conv1(out_q)
            out = self.relu1(out)
            out = self.conv2(out)
            out = self.relu2(out)
            out = self.conv3(out)
            out = self.quant_id_1(out)
            out = out+out_q
            out = self.relu3(out)
            return out
   
    quant_bottleneck_model=QuantBottleneck_projected()
    

    
    quant_conv1 = QuantConv2d(256, 64, kernel_size=1,bit_width=8,weight_bit_width=8, bias=False, weight_quant=Int8WeightPerTensorFixedPoint, return_quant_tensor=True)
    quant_conv2 = QuantConv2d(64, 64, kernel_size=3,bit_width=8,weight_bit_width=8, bias=False,padding=1,padding_mode ='zeros', weight_quant=Int8WeightPerTensorFixedPoint, return_quant_tensor=True)
    quant_conv3 = QuantConv2d(64, 256, kernel_size=1,bit_width=8,weight_bit_width=8, bias=False,weight_quant=Int8WeightPerTensorFixedPoint, return_quant_tensor=True)
    
    simple_conv1= nn.Conv2d(256, 64, kernel_size=1, bias=False)
    simple_conv2= nn.Conv2d(64, 64, kernel_size=3,padding=1,padding_mode ='zeros', bias=False)
    simple_conv3= nn.Conv2d(64, 256, kernel_size=1, bias=False)

        # bit_width= for QuantIdentity, since it has no weight tensor
    quant_id_1 = QuantIdentity(act_quant=Int8ActPerTensorFixedPoint,bit_width=8, return_quant_tensor=True) 
    quant_id_2 = QuantIdentity(act_quant=Uint8ActPerTensorFixedPoint, bit_width=8, return_quant_tensor=True)
    quant_id_3 = QuantIdentity(act_quant=Uint8ActPerTensorFixedPoint, bit_width=8, return_quant_tensor=True)
    quant_id_4 = QuantIdentity(act_quant=Uint8ActPerTensorFixedPoint, bit_width=8, return_quant_tensor=True)

    quant_relu1 = QuantReLU(act_quant=Uint8ActPerTensorFixedPoint ,bit_width=8, return_quant_tensor=True)
    quant_relu2 = QuantReLU(act_quant=Uint8ActPerTensorFixedPoint ,bit_width=8, return_quant_tensor=True)
    quant_relu3 = QuantReLU(act_quant=Uint8ActPerTensorFixedPoint ,bit_width=8, return_quant_tensor=True)

    quant_bottleneck_model.eval()
    quant_id_1.eval()
    quant_id_2.eval()
    quant_id_3.eval()
    quant_id_4.eval()
    quant_conv1.eval()
    quant_conv2.eval()
    quant_conv3.eval()
    quant_relu1.eval()
    quant_relu2.eval()
    quant_relu3.eval()
    

     # q_inp == int_inp * inp_scale

    
    inp_scale1 = quant_id_1.quant_act_scale()
    inp_scale2 = quant_relu1.quant_act_scale()
    inp_scale3= quant_relu2.quant_act_scale()
    inp_scale4= quant_relu3.quant_act_scale()

    weight_scale1 = quant_conv1.quant_weight_scale()
    weight_scale2 = quant_conv2.quant_weight_scale()
    weight_scale3 = quant_conv3.quant_weight_scale()

    combined_scale1=-torch.log2(inp_scale1*weight_scale1/inp_scale2)
    combined_scale2=-torch.log2(inp_scale2*weight_scale2/inp_scale3)
    combined_scale3=-torch.log2(inp_scale3*weight_scale3/inp_scale1)
    combined_scale4=-torch.log2(inp_scale1/inp_scale4)
    print("combined_scale after first conv1x1:",combined_scale1.item())
    print("combined_scale after second conv3x3:",combined_scale2.item())
    print("combined_scale after third conv1x1:",combined_scale3.item())
    print("combined_scale after adding skip connection:",(combined_scale4).item())
    
    int_weight1 = quant_conv1.quant_weight().int(float_datatype=True)
    int_weight2 = quant_conv2.quant_weight().int(float_datatype=True)
    int_weight3 = quant_conv3.quant_weight().int(float_datatype=True)

    # update class weights
    quant_bottleneck_model.conv1.load_state_dict(quant_conv1.state_dict())
    quant_bottleneck_model.conv2.load_state_dict(quant_conv2.state_dict())
    quant_bottleneck_model.conv3.load_state_dict(quant_conv3.state_dict())
    

    q_bottleneck_out=quant_bottleneck_model(input)
    gold_out=q_bottleneck_out.int(float_datatype=True).data.numpy().astype(dtype_out)
    # print("Golden::Brevitas::",gold_out)
    from brevitas.export import export_onnx_qcdq
    q_inp = quant_id_1(input)
    int_inp = q_inp.int(float_datatype=True)
 
    before_input=int_inp.squeeze().data.numpy().astype(dtype_in)
    # print(before_input)
    before_input.tofile(log_folder+"/before_ifm_mem_fmt_1x1.txt", sep=",", format="%d")
    ifm_mem_fmt = ds.reorder_mat(before_input,'YCXC8' , 'CYX' )
    # print("Input after:::",ifm_mem_fmt.reshape((32,256, 32)))
    ifm_mem_fmt.tofile(log_folder+"/after_ifm_mem_fmt_1x1.txt", sep=",", format="%d")
    wts1 = ds.reorder_mat(int_weight1.data.numpy().astype(dtype_in),'OIYXI8O8' , 'OIYX' )
    wts2 = ds.reorder_mat(int_weight2.data.numpy().astype(dtype_in),'OIYXI8O8' , 'OIYX' )
    wts3 = ds.reorder_mat(int_weight3.data.numpy().astype(dtype_in),'OIYXI8O8' , 'OIYX' )
    
    total_wts=np.concatenate((wts1,wts2,wts3),axis=None)
    total_wts.tofile(log_folder+"/weights_mem_fmt_final.txt", sep=",", format="%d")
    # print("total_wts", total_wts.shape)
    for i in range (0,2):
        app.buffers[2].write(ifm_mem_fmt)# input's standard format CYX | scalar YCX
        app.buffers[3].write(total_wts) # wts's standard format OIYX | scalar OIYX
        app.run()
        output3= app.buffers[4].read()
        if enable_trace:
            output3, trace = extract_trace(output3, shape_out, dtype_out)
            write_out_trace(trace, trace_file)

    temp_out    = output3.reshape(32,32,32,8)
    temp2_out   = ds.reorder_mat( temp_out, 'CDYX','YCXD' )
    ofm_mem_fmt = temp2_out.reshape(256,32,32)   
    ofm_mem_fmt.tofile(log_folder+"/after_ofm_mem_fmt_final.txt", sep=",", format="%d")
    
    ofm_mem_fmt=torch.from_numpy(ofm_mem_fmt).unsqueeze(0)
    assert(np.allclose(ofm_mem_fmt, gold_out, rtol=0, atol=2.))
    print ("TEST PASS: AIE output matches golden quantized output")


combined_scale after first conv1x1: 10.0
combined_scale after second conv3x3: 11.0
combined_scale after third conv1x1: 11.0
combined_scale after adding skip connection: -1.0


  return super().rename(names)


TEST PASS: AIE output matches golden quantized output


In [2]:
if (enable_trace):
    print(trace)
else:
    print ("tracing not enabled")

[2149842944 3690978303 3690978303 ...    8397825 3774931976   15075328]
