In [None]:
import torch
import torch.nn as nn
from torchinfo import summary
import pandas as pd

class NDT4HLS(torch.nn.Module):
    def __init__(self, config, trial_length, num_neurons, device, max_spikes):
        super().__init__()
        self.config = config
        self.trial_length = trial_length
        self.num_neurons = num_neurons
        self.device = device

        # TODO buffer
        if config.FULL_CONTEXT:
            self.src_mask = None
        else:
            self.src_mask = {} # multi-GPU masks
        if config.EMBED_DIM == 0:
            self.num_input = num_neurons
        else:
            self.num_input = config.EMBED_DIM * num_neurons

        self._init_transformer()
        self.src_mask = self._get_or_generate_context_mask(torch.zeros(self.trial_length, self.trial_length, device=device))
        self.padding_mask = torch.randint(0, 2, (self.trial_length, self.trial_length), device=device).bool()
        #self.init_weights()

    def _init_transformer(self):
        norm = nn.LayerNorm(self.num_input)
        #encoder_layer = nn.TransformerEncoderLayer(d_model=self.num_input, nhead=num_heads)
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=self.num_input, 
                                       nhead=self.config.NUM_HEADS,
                                       dim_feedforward=self.config.HIDDEN_SIZE,
                                       dropout=self.config.DROPOUT,
                                       activation=self.config.ACTIVATION,
                                       norm_first=self.config.PRE_NORM,
                                       device=self.device),
            self.config.NUM_LAYERS,
            norm=norm
        )
    
    def _get_or_generate_context_mask(self, src, do_convert=True, expose_ic=True):
        if self.config.FULL_CONTEXT:
            return None
        if str(src.device) in self.src_mask:
            return self.src_mask[str(src.device)]
        size = src.size(0) # T
        context_forward = self.config.CONTEXT_FORWARD
        if self.config.CONTEXT_FORWARD < 0:
            context_forward = size
        mask = (torch.triu(torch.ones(size, size, device=src.device), diagonal=-context_forward) == 1).transpose(0, 1)
        if self.config.CONTEXT_BACKWARD > 0:
            back_mask = (torch.triu(torch.ones(size, size, device=src.device), diagonal=-self.config.CONTEXT_BACKWARD) == 1)
            mask = mask & back_mask
        if expose_ic and self.config.CONTEXT_WRAP_INITIAL and self.config.CONTEXT_BACKWARD > 0:
            # Expose initial segment for IC
            initial_mask = torch.triu(torch.ones(self.config.CONTEXT_BACKWARD, self.config.CONTEXT_BACKWARD, device=src.device))
            mask[:self.config.CONTEXT_BACKWARD, :self.config.CONTEXT_BACKWARD] |= initial_mask
        mask = mask.float()
        def binary_mask_to_attn_mask(x):
            return x.float().masked_fill(x == 0, float('-inf')).masked_fill(x == 1, float(0.0))
        if do_convert:
            mask = binary_mask_to_attn_mask(mask)
        self.src_mask[str(src.device)] = mask
        return self.src_mask[str(src.device)]

    def forward(self, src):  
        #src_mask = self._get_or_generate_context_mask(src)
        output = self.transformer_encoder(src, mask=self.src_mask)
        return output
    
model = torch.load('./model/NDT4HLS.pth', map_location=torch.device('cpu'))
print(model)

In [None]:
import hls4ml
from pprint import pprint
config = hls4ml.utils.config_from_pytorch_model(model, 
                                                granularity='name',
                                                backend='Vitis',
                                                input_shapes=[[1, 180, 182]], 
                                                default_precision='ap_fixed<18,8,AP_RND_CONV>', 
                                                inputs_channel_last=True, 
                                                transpose_outputs=False)

In [None]:
from typing import Tuple, List
import math
from einops import rearrange
import torch
import numpy as np
torch.set_printoptions(precision=15)
class TorchQuantizer(torch.nn.Module):
    def __init__(self, 
                 bitwidth=18, 
                 int_bitwidth=8, 
                 signed=True,
                 rounding='CONVERGENT',
                 saturation='WRAP',
                 calibration=False,
                 quantize=True,
                 dtype=torch.float64):
        super(TorchQuantizer, self).__init__()
        self.bitwidth = bitwidth
        self.int_bitwidth = int_bitwidth
        self.signed = signed
        self.m = pow(2, self.bitwidth) if not calibration else 1 #in calibration mode, no need to calculate m
        self.m_i = pow(2, self.int_bitwidth) if not calibration else 1
        self.q = self.m / self.m_i
        self.q = float(self.q)
        self.lower_bound = -self.m/2 if self.signed else 0
        self.upper_bound = self.m/2-1 if self.signed else self.m-1
        self.rounding = rounding
        self.saturation = saturation
        self.calibration = calibration
        self.quantize = quantize
        self.max_int_bits = torch.tensor(-torch.inf)
        self.max_value = torch.tensor(-torch.inf)
        self.min_frac_bits = torch.tensor(torch.inf)
    def forward(self, x):
        if self.quantize == False:
            return x
        if self.calibration:
            x_flat = x.flatten()
            x_flat = x_flat[x_flat != 0]
            #check if x_flat is not empty
            if x_flat.nelement() > 0:
                max_int_bits = torch.max(torch.ceil(torch.log2(torch.abs(x_flat))).max())
                max_int_bits += 1 if self.signed else 0
                min_frac_bits = torch.min(torch.ceil(torch.log2(torch.abs(x_flat))).min())
                min_frac_bits += 1 if self.signed else 0
                self.max_int_bits = torch.max(max_int_bits, self.max_int_bits).int()
                self.min_frac_bits = torch.min(min_frac_bits, self.min_frac_bits).int()
            return x
        if self.rounding == 'CONVERGENT':
            if self.saturation == 'WRAP':
                qx = ((torch.round(x * self.q) - self.lower_bound) % (self.upper_bound - self.lower_bound + 1) + self.lower_bound) / self.q
            else:
                qx = torch.clamp(torch.round(x * self.q), self.lower_bound, self.upper_bound)/self.q
        else:
            if self.saturation == 'WRAP':
                qx = ((torch.trunc(x * self.q) - self.lower_bound) % (self.upper_bound - self.lower_bound + 1) + self.lower_bound) / self.q
            else:
                qx = torch.clamp(torch.trunc(x * self.q), self.lower_bound, self.upper_bound)/self.q
        # if qx == nan, raise expcetion
        if torch.isnan(qx).any():
            print("x:",x)
            raise Exception("Quantized value is NaN")
        return qx
    def forward_inplace(self, x):
        if self.quantize == False:
            return x
        if self.calibration:
            x_flat = x.flatten()
            x_flat = x_flat[x_flat != 0]
            if x_flat.nelement() > 0:
                max_int_bits = torch.max(torch.ceil(torch.log2(torch.abs(x_flat))).max())
                max_int_bits += 1 if self.signed else 0
                min_frac_bits = torch.min(torch.ceil(torch.log2(torch.abs(x_flat))).min())
                min_frac_bits += 1 if self.signed else 0
                self.max_int_bits = torch.max(max_int_bits, self.max_int_bits).int()
                self.min_frac_bits = torch.min(min_frac_bits, self.min_frac_bits).int()
                #self.max_int_bits += 1 if self.signed else 0
                #self.min_frac_bits += 1 if self.signed else 0
            return x
        if self.rounding == 'CONVERGENT':
            if self.saturation == 'WRAP':
                x.mul_(self.q).round_().sub_(self.lower_bound).remainder_(self.upper_bound - self.lower_bound + 1).add_(self.lower_bound).div_(self.q)
            else:
                x.mul_(self.q).round_().clamp_(self.lower_bound, self.upper_bound).div_(self.q)
        else:
            if self.saturation == 'WRAP':
                x.mul_(self.q).trunc_().sub_(self.lower_bound).remainder_(self.upper_bound - self.lower_bound + 1).add_(self.lower_bound).div_(self.q)
            else:
                x.mul_(self.q).trunc_().clamp_(self.lower_bound, self.upper_bound).div_(self.q)
        #x.mul_(self.q).round_()
        #x.clamp_(self.lower_bound, self.upper_bound)
        #x.round_()
        #x.div_(self.q)
        return x

class QLinear(torch.nn.Linear):
    def __init__(self, 
                 in_features:int, 
                 out_features:int, 
                 bias:bool=True, 
                 device=None,
                 dtype=torch.float64,
                 quant_config:dict=None,
                 calibration=False):
        super(QLinear, self).__init__(in_features, out_features, bias, device=device, dtype=dtype)
        self.in_features = in_features
        self.out_features = out_features
        self.quant_config = quant_config
        self.calibration = calibration
        self.weight_qtzr = TorchQuantizer(**quant_config['weight'], calibration=calibration)
        self.bias_qtzr = TorchQuantizer(**quant_config['bias'], calibration=calibration)
        self.input_qtzr = TorchQuantizer(**quant_config['input'], calibration=calibration)
        self.output_qtzr = TorchQuantizer(**quant_config['output'], calibration=calibration)
        self.dtpye = dtype
        #self.reset_parameters()
        
    #def reset_parameters(self):
    #    #reset to zero
    #    torch.nn.init.zeros_(self.weight, dtype=self.dtype)
    #    if self.bias is not None:
    #        torch.nn.init.zeros_(self.bias, dtype=self.dtype)
            
    def forward(self, x):
        qw = self.weight_qtzr(self.weight)
        qx = self.input_qtzr(x)
        qy = torch.matmul(qx, qw.t())
        if self.bias is not None:
            qy += self.bias_qtzr(self.bias)
        qy = self.output_qtzr(qy)
        return qy
    
class QFlashMultiheadAttention(torch.nn.MultiheadAttention):
    def __init__(self, 
                 embed_dim:int, 
                 num_heads:int, 
                 bias:bool=True, 
                 batch_first:bool=False, 
                 device=None, 
                 dtype=torch.float64,
                 quant_config:dict=None,
                 token_tile_size:int=1,
                 embed_tile_size:int=1,
                 head_tile_size:int=1,
                 max_neg_value:float=-8.0,
                 calibration=False):
        super(QFlashMultiheadAttention, self).__init__(embed_dim, 
                                                  num_heads,  
                                                  bias=bias, 
                                                  add_bias_kv=False, 
                                                  add_zero_attn=False,
                                                  kdim=None, 
                                                  vdim=None, 
                                                  batch_first=batch_first, 
                                                  device=device, 
                                                  dtype=dtype)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.batch_first = batch_first
        self.head_dim = embed_dim // num_heads
        self.scale = torch.tensor(1.0 / math.sqrt(self.head_dim))
        self.in_proj = QLinear(embed_dim, 
                               3*embed_dim, 
                               bias=bias, 
                               device=device, 
                               dtype=dtype,
                               quant_config=quant_config['in_proj'], calibration=calibration)
        self.scale_qtzr = TorchQuantizer(**quant_config['scale'], calibration=calibration)
        self.token_tile_size = token_tile_size
        self.embed_tile_size = embed_tile_size
        self.head_tile_size = head_tile_size
        self.max_neg_value = max_neg_value
        self.row_sum_qtzr = TorchQuantizer(**quant_config['row_sum'], calibration=calibration)
        self.exp_input_qtzr = TorchQuantizer(**quant_config['exp_input'], rounding='TRUNCATE', saturation='SAT', calibration=calibration)
        self.exp_output_qtzr = TorchQuantizer(**quant_config['exp_output'], saturation='SAT', calibration=calibration)
        self.inv_input_qtzr = TorchQuantizer(**quant_config['inv_input'], rounding='TRUNCATE', saturation='SAT', calibration=calibration)
        self.inv_output_qtzr = TorchQuantizer(**quant_config['inv_output'], saturation='SAT', calibration=calibration)
        self.attn_out_qtzr = TorchQuantizer(**quant_config['out_proj']['input'], calibration=calibration)
        self.out_proj = QLinear(embed_dim, 
                                embed_dim, 
                                bias=bias, 
                                device=device, 
                                dtype=dtype,
                                quant_config=quant_config['out_proj'], calibration=calibration)
        self.device = device
        self.dtype = dtype
        
    def forward(self, query, attn_mask=None):
        q, k, v = self.in_proj(query).chunk(3, dim=-1)
        
        tgt_len, bsz, embed_dim = query.shape
        head_dim = embed_dim // self.num_heads
        q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1)
        k = k.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1)
        v = v.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1)
        o = torch.zeros_like(q)
        all_row_sums = torch.zeros((bsz * self.num_heads, tgt_len, 1), dtype = self.dtype, device = self.device)
        all_row_maxes = torch.full((bsz * self.num_heads, tgt_len, 1), self.max_neg_value, dtype = self.dtype, device = self.device)

        num_tiles = math.ceil(tgt_len / self.token_tile_size)
        if attn_mask is not None and attn_mask.ndim == 2:
            #attn_mask = rearrange(attn_mask, 'b n -> 1 1 b n')
            mask = attn_mask.bool()
            #print("attn_mask shape:", attn_mask.shape)

        if attn_mask is None:
            col_masks = (None,) * num_tiles
            mask = (col_masks,) * num_tiles 
        else:
            mask = ((mask,) * num_tiles) if mask.shape[-2] == 1 else mask.split(self.token_tile_size, dim = -2)
            #print("attn_mask shape1:", attn_mask.shape)
            mask = tuple(((row_mask,) * num_tiles) if row_mask.shape[-1] == 1 else row_mask.split(self.token_tile_size, dim = -1) for row_mask in mask)

        B, Nt, E = q.shape
        scale = self.scale_qtzr(self.scale)
        row_splits = zip(
            q.split(self.token_tile_size, dim = -2),
            o.split(self.token_tile_size, dim = -2),
            mask,
            all_row_sums.split(self.token_tile_size, dim = -2),
            all_row_maxes.split(self.token_tile_size, dim = -2),
        )
        #attn_weight_debug = torch.zeros((self.num_heads, tgt_len, tgt_len), dtype = self.dtype, device = self.device)
        #exp_weight_debug = torch.zeros((self.num_heads, tgt_len, tgt_len), dtype = self.dtype, device = self.device)
        #row_max_debug = torch.zeros((self.num_heads, tgt_len, tgt_len), dtype = self.dtype, device = self.device)
        #row_sum_debug = torch.zeros((self.num_heads, tgt_len, tgt_len), dtype = self.dtype, device = self.device)
        #with open("K.txt", 'a') as f:
        #    np.savetxt(f, k.reshape(-1, tgt_len*head_dim).detach().numpy(), fmt='%.6f')
        #with open("Q.txt", 'a') as f:
        #    np.savetxt(f, q.reshape(-1, tgt_len*head_dim).detach().numpy(), fmt='%.6f')
        #with open("V.txt", 'a') as f:
        #    np.savetxt(f, v.reshape(-1, tgt_len*head_dim).detach().numpy(), fmt='%.6f')
        for i, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
            col_splits = zip(
                k.split(self.token_tile_size, dim = -2),
                v.split(self.token_tile_size, dim = -2),
                row_mask
            )
            for j, (kc, vc, col_mask) in enumerate(col_splits):
                attn_weights = torch.einsum('... i d, ... j d -> ... i j', qc, kc) * scale
                if col_mask is not None:
                #    #print("col_mask:", ~col_mask)
                    attn_weights.masked_fill_(col_mask, -1000000)
                block_row_maxes = attn_weights.amax(dim = -1, keepdims = True)
                new_row_maxes = torch.maximum(row_maxes, block_row_maxes)
                #row_max_debug[:, i*self.token_tile_size:(i+1)*self.token_tile_size, j*self.token_tile_size:(j+1)*self.token_tile_size] = new_row_maxes
                att_weights = attn_weights - new_row_maxes
                #attn_weight_debug[:, i*self.token_tile_size:(i+1)*self.token_tile_size, j*self.token_tile_size:(j+1)*self.token_tile_size] = attn_weights
                att_weights = self.exp_input_qtzr(att_weights)
                exp_weights = torch.exp(att_weights)
                exp_weights = self.exp_output_qtzr(exp_weights)
                if col_mask is not None:
                    exp_weights.masked_fill_(col_mask, 0.0)
                #exp_weight_debug[:, i*self.token_tile_size:(i+1)*self.token_tile_size, j*self.token_tile_size:(j+1)*self.token_tile_size] = exp_weights
                block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = 1e-10)
                exp_values = torch.einsum('... i j, ... j d -> ... i d', exp_weights, vc)
                exp_row_max_diff = self.exp_input_qtzr(row_maxes - new_row_maxes)
                exp_row_max_diff = self.exp_output_qtzr(torch.exp(exp_row_max_diff))
                new_row_sums = self.row_sum_qtzr(exp_row_max_diff * row_sums + block_row_sums)
                #row_sum_debug[:, i*self.token_tile_size:(i+1)*self.token_tile_size, j*self.token_tile_size:(j+1)*self.token_tile_size] = new_row_sums
                oc.mul_(exp_row_max_diff)
                oc.add_(exp_values)
                self.attn_out_qtzr.forward_inplace(oc)
                
                row_maxes.copy_(new_row_maxes)
                row_sums.copy_(new_row_sums)
            new_row_sums = self.inv_input_qtzr(new_row_sums)
            oc.mul_(self.inv_output_qtzr(torch.reciprocal(new_row_sums + 1e-10)))
            self.attn_out_qtzr.forward_inplace(oc)
        #with open("O.txt", 'a') as f:
        #    np.savetxt(f, o.reshape(-1, tgt_len*head_dim).detach().numpy(), fmt='%.6f')
        attn_output = o.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
        attn_output = self.out_proj(attn_output)
        #with open("attn_weights.txt", 'a') as f:
        #    np.savetxt(f, attn_weight_debug.reshape(-1, tgt_len*tgt_len).detach().numpy(), fmt='%.6f')
        #with open("exp_weights.txt", 'a') as f:
        #    np.savetxt(f, exp_weight_debug.reshape(-1, tgt_len*tgt_len).detach().numpy(), fmt='%.6f')
        #with open("row_max.txt", 'a') as f:
        #    np.savetxt(f, row_max_debug.reshape(-1, tgt_len*tgt_len).detach().numpy(), fmt='%.6f')
        #with open("row_sum.txt", 'a') as f:
        #    np.savetxt(f, row_sum_debug.reshape(-1, tgt_len*tgt_len).detach().numpy(), fmt='%.6f')
        return attn_output
        

    
class QLayerNorm(torch.nn.LayerNorm):
    def __init__(self, 
                 normalized_shape:Tuple[int, ...],
                 quant_config:dict=None,
                 calibration=False):
        super(QLayerNorm, self).__init__(normalized_shape)
        self.input_qtzr = TorchQuantizer(**quant_config['input'], calibration=calibration)
        self.scale_qtzr = TorchQuantizer(**quant_config['scale'], calibration=calibration)
        self.bias_qtzr = TorchQuantizer(**quant_config['bias'], calibration=calibration)
        self.output_qtzr = TorchQuantizer(**quant_config['output'], calibration=calibration)
        self.mean_qtzr = TorchQuantizer(**quant_config['input'], calibration=calibration)
        self.var_input_qtzr = TorchQuantizer(**quant_config['var_input'], rounding='TRUNCATE', saturation='SAT', calibration=calibration)
        self.var_output_qtzr = TorchQuantizer(**quant_config['var_output'], saturation='SAT', calibration=calibration)
        self.inv_embed_dim = torch.tensor(1.0 / self.normalized_shape[-1])
        self.dim_qtzr = TorchQuantizer(bitwidth=18, int_bitwidth=0, signed=False, calibration=calibration)
        self.inv_embed_dim = self.dim_qtzr(self.inv_embed_dim)
    def forward(self, x):
        x = self.input_qtzr(x)
        xmean = x.sum(dim=-1, keepdim=True)
        xmean.mul_(self.inv_embed_dim)
        xmean = self.mean_qtzr(xmean)
        #import numpy as np
        #with open("xmean.txt", 'a') as f:
        #    np.savetxt(f, xmean[:,0,:].detach().numpy(), fmt='%.6f')
        xsqr = x**2
        xsqrsum = xsqr.sum(dim=-1, keepdim=True)
        xsqrsum.mul_(self.inv_embed_dim)
        #with open("xvar.txt", 'a') as f:
        #    np.savetxt(f, xsqrsum[:,0,:].detach().numpy(), fmt='%.6f')
        xvar = xsqrsum - xmean**2
        xvar = self.var_input_qtzr(xvar)
        xvar = 1.0 / torch.sqrt(xvar+1e-8)
        xvar = self.var_output_qtzr(xvar)
        #with open("xvarout.txt", 'a') as f:
        #    np.savetxt(f, xvar[:,0,:].detach().numpy(), fmt='%.6f')
        xnorm = (x - xmean) * xvar
        weight = self.scale_qtzr(self.weight)
        xnorm.mul_(weight)
        bias = self.bias_qtzr(self.bias)
        xnorm.add_(bias)
        xnorm = self.output_qtzr(xnorm)
        return xnorm
    
class QFeedForward(nn.Module):
    def __init__(self, 
                 embed_dim: int, 
                 hidden_dim: int,
                 bias: bool = True, 
                 device: str = 'cpu', 
                 dtype: torch.dtype = torch.float64,
                 quant_config: dict = None,
                 calibration: bool = False):
        super(QFeedForward, self).__init__()
        self.in_proj = QLinear(embed_dim, 
                               hidden_dim, 
                               bias=bias, 
                               device=device, 
                               dtype=dtype,
                               quant_config=quant_config['in_proj'], calibration=calibration)
        self.activation = nn.ReLU()
        self.out_proj = QLinear(hidden_dim, 
                                embed_dim, 
                                bias=bias, 
                                device=device, 
                                dtype=dtype,
                                quant_config=quant_config['out_proj'], calibration=calibration)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.in_proj(x)
        x = self.activation(x)
        x = self.out_proj(x)
        return x
    
class QTransformerEncoderLayer(nn.TransformerEncoderLayer):
    def __init__(self, 
                 embed_dim: int, 
                 num_heads: int, 
                 hidden_dim: int, 
                 dropout: float = 0.0, 
                 activation: str = 'relu', 
                 norm_first: bool = True, 
                 device: str = 'cpu', 
                 dtype: torch.dtype = torch.float64,
                 quant_config: dict = None,
                 calibration: bool = False,
                 src_mask: torch.Tensor = None):
        super(QTransformerEncoderLayer, self).__init__(embed_dim, 
                                                       num_heads, 
                                                       hidden_dim, 
                                                       dropout, 
                                                       activation, 
                                                       norm_first)
        self.self_attn = QFlashMultiheadAttention(embed_dim,
                                                    num_heads,
                                                    device=device,
                                                    dtype=dtype,
                                                    quant_config=quant_config['self_attn'], calibration=calibration)
        self.feedforward = QFeedForward(embed_dim,
                                    hidden_dim,
                                    device=device,
                                    dtype=dtype,
                                    quant_config=quant_config['ffn'], calibration=calibration)
        self.norm1 = QLayerNorm(embed_dim,
                                quant_config=quant_config['norm1'], calibration=calibration)
        self.norm2 = QLayerNorm(embed_dim,
                                quant_config=quant_config['norm2'], calibration=calibration)
        self.dropout = nn.Dropout(dropout)
        self.norm_first = norm_first
        self.input_qtzr = TorchQuantizer(**quant_config['input'], calibration=calibration)
        self.src_mask = src_mask

    def forward(self, 
                src: torch.Tensor, 
                src_mask: torch.Tensor = None) -> torch.Tensor:
        src = self.input_qtzr(src)
        if self.norm_first:
            #print("src:", src)
            #with open("src_norm1_in.txt", 'a') as f:
            #    np.savetxt(f, src[:,0,:].detach().numpy(), fmt='%.6f')
            src_norm = self.norm1(src)
            #import numpy as np
            #with open("src_norm1.txt", 'a') as f:
            #    np.savetxt(f, src_norm[:,0,:].detach().numpy(), fmt='%.6f')
            src2 = self.self_attn(src_norm, attn_mask=src_mask)
            #with open("src_mha.txt", 'a') as f:
            #    np.savetxt(f, src2[:,0,:].detach().numpy(), fmt='%.6f')
            #print("src_attn:", src2)
            src = src + self.dropout(src2)
            #print("src:", src)
            src_norm = self.norm2(src)
            #print("src_norm2:", src_norm)
            #with open("src_norm2.txt", 'a') as f:
            #    np.savetxt(f, src_norm[:,0,:].detach().numpy(), fmt='%.6f')
            src2 = self.feedforward(src_norm)
            #print("src2:", src2)
            #with open("src_ffn.txt", 'a') as f:
            #    np.savetxt(f, src2[:,0,:].detach().numpy(), fmt='%.6f')
            src = src + self.dropout(src2)
        else:
            src2 = self.self_attn(src, attn_mask=src_mask)
            src = src + self.dropout(src2)
            src = self.norm1(src)
            src2 = self.feedforward(src)
            src = src + self.dropout(src2)
            src = self.norm2(src)
        return src

class QTransformerEncoder(nn.TransformerEncoder):
    def __init__(self,
                 encoder_layer: List[QTransformerEncoderLayer],
                 num_layers: int,
                 norm: QLayerNorm,
                 dtype: torch.dtype = torch.float64):
        super(QTransformerEncoder, self).__init__(encoder_layer[0], num_layers, norm)
        self.layer_list = encoder_layer
        self.norm = norm
        self.dtype = dtype

    def forward(self, 
                src: torch.Tensor, 
                mask: torch.Tensor = None) -> torch.Tensor:
        output = src
        for mod in self.layer_list:
            output = mod(output, src_mask=mask)
        output = self.norm(output)
        return output
    
    def transfer_weights(self, 
                         model: nn.Module):
        for i, layer in enumerate(self.layer_list):
            layer.norm1.weight.data = model.transformer_encoder.layers[i].norm1.weight.type(self.dtype)
            layer.norm1.bias.data = model.transformer_encoder.layers[i].norm1.bias.type(self.dtype)
            layer.norm2.weight.data = model.transformer_encoder.layers[i].norm2.weight.type(self.dtype)
            layer.norm2.bias.data = model.transformer_encoder.layers[i].norm2.bias.type(self.dtype)
            layer.self_attn.in_proj.weight.data = model.transformer_encoder.layers[i].self_attn.in_proj_weight.type(self.dtype)
            layer.self_attn.in_proj.bias.data = model.transformer_encoder.layers[i].self_attn.in_proj_bias.type(self.dtype)
            layer.self_attn.out_proj.weight.data = model.transformer_encoder.layers[i].self_attn.out_proj.weight.type(self.dtype)
            layer.self_attn.out_proj.bias.data = model.transformer_encoder.layers[i].self_attn.out_proj.bias.type(self.dtype)
            layer.feedforward.in_proj.weight.data = model.transformer_encoder.layers[i].linear1.weight.type(self.dtype)
            layer.feedforward.in_proj.bias.data = model.transformer_encoder.layers[i].linear1.bias.type(self.dtype)
            layer.feedforward.out_proj.weight.data = model.transformer_encoder.layers[i].linear2.weight.type(self.dtype)
            layer.feedforward.out_proj.bias.data = model.transformer_encoder.layers[i].linear2.bias.type(self.dtype)
        self.norm.weight.data = model.transformer_encoder.norm.weight.type(self.dtype)
        self.norm.bias.data = model.transformer_encoder.norm.bias.type(self.dtype)

def calibrate_transformer(qmodel: QTransformerEncoder, 
                          quant_config: dict, 
                          calibration_data: torch.Tensor,
                          calibration_mask: torch.Tensor
                          ) -> dict:
    with torch.no_grad():
        qmodel.eval()
        qy = qmodel(calibration_data, mask=calibration_mask)
        for i, layer in enumerate(qmodel.layer_list):
            #print("Calibrating layer:", id(layer.norm1.input_qtzr.max_int_bits))
            #print("Calibrating:", layer.norm1.input_qtzr.max_int_bits)
            quant_config[i]['norm1']['input']['int_bitwidth'] = layer.norm1.input_qtzr.max_int_bits.item()
            quant_config[i]['norm1']['scale']['int_bitwidth'] = layer.norm1.scale_qtzr.max_int_bits.item()
            quant_config[i]['norm1']['bias']['int_bitwidth'] = layer.norm1.bias_qtzr.max_int_bits.item()
            quant_config[i]['norm1']['output']['int_bitwidth'] = layer.norm1.output_qtzr.max_int_bits.item()
            quant_config[i]['norm1']['var_input']['int_bitwidth'] = layer.norm1.var_input_qtzr.max_int_bits.item()
            quant_config[i]['norm1']['var_output']['int_bitwidth'] = layer.norm1.var_output_qtzr.max_int_bits.item()
            quant_config[i]['norm2']['input']['int_bitwidth'] = layer.norm2.input_qtzr.max_int_bits.item()
            quant_config[i]['norm2']['scale']['int_bitwidth'] = layer.norm2.scale_qtzr.max_int_bits.item()
            quant_config[i]['norm2']['bias']['int_bitwidth'] = layer.norm2.bias_qtzr.max_int_bits.item()
            quant_config[i]['norm2']['output']['int_bitwidth'] = layer.norm2.output_qtzr.max_int_bits.item()
            quant_config[i]['norm2']['var_input']['int_bitwidth'] = layer.norm2.var_input_qtzr.max_int_bits.item()
            quant_config[i]['norm2']['var_output']['int_bitwidth'] = layer.norm2.var_output_qtzr.max_int_bits.item()
            quant_config[i]['self_attn']['in_proj']['input']['int_bitwidth'] = layer.self_attn.in_proj.input_qtzr.max_int_bits.item()
            quant_config[i]['self_attn']['in_proj']['weight']['int_bitwidth'] = layer.self_attn.in_proj.weight_qtzr.max_int_bits.item()
            quant_config[i]['self_attn']['in_proj']['bias']['int_bitwidth'] = layer.self_attn.in_proj.bias_qtzr.max_int_bits.item()
            quant_config[i]['self_attn']['in_proj']['output']['int_bitwidth'] = layer.self_attn.in_proj.output_qtzr.max_int_bits.item()
            quant_config[i]['self_attn']['out_proj']['input']['int_bitwidth'] = layer.self_attn.out_proj.input_qtzr.max_int_bits.item()
            quant_config[i]['self_attn']['out_proj']['weight']['int_bitwidth'] = layer.self_attn.out_proj.weight_qtzr.max_int_bits.item()
            quant_config[i]['self_attn']['out_proj']['bias']['int_bitwidth'] = layer.self_attn.out_proj.bias_qtzr.max_int_bits.item()
            quant_config[i]['self_attn']['out_proj']['output']['int_bitwidth'] = layer.self_attn.out_proj.output_qtzr.max_int_bits.item()
            quant_config[i]['self_attn']['row_sum']['int_bitwidth'] = layer.self_attn.row_sum_qtzr.max_int_bits.item()
            #quant_config[i]['self_attn']['exp_input']['int_bitwidth'] = layer.self_attn.exp_input_qtzr.max_int_bits.item()
            #quant_config[i]['self_attn']['exp_output']['int_bitwidth'] = layer.self_attn.exp_output_qtzr.max_int_bits.item()
            quant_config[i]['self_attn']['inv_input']['int_bitwidth'] = layer.self_attn.inv_input_qtzr.max_int_bits.item()
            quant_config[i]['self_attn']['inv_output']['int_bitwidth'] = layer.self_attn.inv_output_qtzr.max_int_bits.item()
            quant_config[i]['ffn']['in_proj']['input']['int_bitwidth'] = layer.feedforward.in_proj.input_qtzr.max_int_bits.item()
            quant_config[i]['ffn']['in_proj']['weight']['int_bitwidth'] = layer.feedforward.in_proj.weight_qtzr.max_int_bits.item()
            quant_config[i]['ffn']['in_proj']['bias']['int_bitwidth'] = layer.feedforward.in_proj.bias_qtzr.max_int_bits.item()
            quant_config[i]['ffn']['in_proj']['output']['int_bitwidth'] = layer.feedforward.in_proj.output_qtzr.max_int_bits.item()
            quant_config[i]['ffn']['out_proj']['input']['int_bitwidth'] = layer.feedforward.out_proj.input_qtzr.max_int_bits.item()
            quant_config[i]['ffn']['out_proj']['weight']['int_bitwidth'] = layer.feedforward.out_proj.weight_qtzr.max_int_bits.item()
            quant_config[i]['ffn']['out_proj']['bias']['int_bitwidth'] = layer.feedforward.out_proj.bias_qtzr.max_int_bits.item()
            quant_config[i]['ffn']['out_proj']['output']['int_bitwidth'] = layer.feedforward.out_proj.output_qtzr.max_int_bits.item()
        quant_config['norm']['input']['int_bitwidth'] = qmodel.norm.input_qtzr.max_int_bits.item()
        quant_config['norm']['scale']['int_bitwidth'] = qmodel.norm.scale_qtzr.max_int_bits.item()
        quant_config['norm']['bias']['int_bitwidth'] = qmodel.norm.bias_qtzr.max_int_bits.item()
        quant_config['norm']['output']['int_bitwidth'] = qmodel.norm.output_qtzr.max_int_bits.item()
        quant_config['norm']['var_input']['int_bitwidth'] = qmodel.norm.var_input_qtzr.max_int_bits.item()
        quant_config['norm']['var_output']['int_bitwidth'] = qmodel.norm.var_output_qtzr.max_int_bits.item()
    return quant_config

In [None]:
import json
import copy
#load quantization config
def load_transformer_quant_config(quant_config_path: str = "./quant_config.json",
                                  norm_quant_config_path: str = "./norm_quant_config.json",
                                  num_layers: int = 1) -> dict:
    with open(quant_config_path, 'r') as f:
        quant_config = json.load(f)
    with open("./norm_quant_config.json", 'r') as f:
        norm_quant_config = json.load(f)
    transformer_quant_config = {}
    for i in range(num_layers):
        transformer_quant_config[i] = copy.deepcopy(quant_config)
    transformer_quant_config['norm'] = copy.deepcopy(norm_quant_config)
    return transformer_quant_config


transformer_quant_config = load_transformer_quant_config(num_layers=4)
pprint(transformer_quant_config)

In [None]:
from typing import Dict

def gen_init_state(num_layers:int) -> Dict[str, int]:
    state = {}
    for i in range(num_layers):
        state.update({'layers_'+str(i)+'_norm1.Precision.var_table': 18,
                      'layers_'+str(i)+'_norm1.VarTableSize': 10,
                      'layers_'+str(i)+'_norm1.Precision.result': 18,
                      'layers_'+str(i)+'_self_attn.Precision.exp_table': 18,
                      'layers_'+str(i)+'_self_attn.ExpTableSize': 10,
                      'layers_'+str(i)+'_self_attn.Precision.inv_table': 18,
                      'layers_'+str(i)+'_self_attn.InvTableSize': 10,
                      'layers_'+str(i)+'_self_attn.Precision.in_proj_out': 18,
                      'layers_'+str(i)+'_self_attn.Precision.out_proj_in': 18,
                      'layers_'+str(i)+'_self_attn.Precision.in_proj_weight': 18,
                      'layers_'+str(i)+'_self_attn.Precision.out_proj_weight': 18,
                      'layers_'+str(i)+'_self_attn.Precision.result': 18,
                      'layers_'+str(i)+'_add1.Precision.result': 18,
                      'layers_'+str(i)+'_norm2.Precision.var_table': 18,
                      'layers_'+str(i)+'_norm2.VarTableSize': 10,
                      'layers_'+str(i)+'_norm2.Precision.result': 18,
                      'layers_'+str(i)+'_ffn.Precision.in_proj_weight': 18,
                      'layers_'+str(i)+'_ffn.Precision.out_proj_weight': 18,
                      'layers_'+str(i)+'_ffn.Precision.hidden': 18,
                      'layers_'+str(i)+'_ffn.Precision.result': 18,
                      'layers_'+str(i)+'_add2.Precision.result': 18})
    state.update({'norm.Precision.var_table': 18,
                  'norm.VarTableSize': 10,
                  'norm.Precision.result': 18})
    return state

import re
def sync_hls_config(hls_config:dict, 
                    state:dict) -> dict:
    for key in state.keys():
        subkey = key.split('.')
        if len(subkey) == 3:
            match = re.match(r'(ap_ufixed|ap_fixed|fixed|ufixed)<(\d+),(-?\d+)(?:,(\w+)(?:,(\w+)(?:,(\d+))?)?)?>', hls_config['LayerName'][subkey[0]][subkey[1]][subkey[2]])
            base_type, total_bits, integer_bits, rounding, saturation, sat_bits = match.groups()
            if 'table' in subkey[2]:
                hls_config['LayerName'][subkey[0]][subkey[1]][subkey[2]] = f'{base_type}<{state[key]},{integer_bits},{rounding},{saturation},{sat_bits}>'
            else:
                hls_config['LayerName'][subkey[0]][subkey[1]][subkey[2]] = f'{base_type}<{state[key]},{integer_bits},{rounding}>'
        elif len(subkey) == 2:
            hls_config['LayerName'][subkey[0]][subkey[1]] = int(2**state[key])
    return hls_config

def sync_quant_config(quant_config:dict, 
                      state:dict) -> dict:
    for key in state.keys():
        subkey = key.split('.')
        layername = subkey[0]
        if layername != 'norm':
            layeridx = int(layername.split('_')[1])
            if layername.endswith('self_attn'):
                layername = 'self_attn'
            else:
                layername = layername.split('_')[2]
        else:
            if varname == 'var_table_size':
                quant_config[layername]['var_input']['bitwidth'] = state[key]
            elif varname == 'var_table':
                quant_config[layername]['var_output']['bitwidth'] = state[key]
            elif varname == 'result':
                quant_config[layername]['output']['bitwidth'] = state[key]
            continue
        if subkey[1] == 'Precision':
            varname = subkey[2]
        else:
            varname = subkey[1]

        if 'norm' in layername:
            if varname == 'VarTableSize':
                quant_config[layeridx][layername]['var_input']['bitwidth'] = state[key]
            elif varname == 'var_table':
                quant_config[layeridx][layername]['var_output']['bitwidth'] = state[key]
            elif varname == 'result':
                quant_config[layeridx][layername]['output']['bitwidth'] = state[key]
                #if layername == 'norm1':
                #    quant_config[layeridx]['self_attn']['in_proj']['input']['bitwidth'] = state[key]
                #elif layername == 'norm2':
                #    quant_config[layeridx]['ffn']['in_proj']['input']['bitwidth'] = state[key]
        elif 'self_attn' in layername:
            if varname == 'ExpTableSize':
                quant_config[layeridx][layername]['exp_input']['bitwidth'] = state[key]
            elif varname == 'exp_table':
                quant_config[layeridx][layername]['exp_output']['bitwidth'] = state[key]
            elif varname == 'InvTableSize':
                quant_config[layeridx][layername]['inv_input']['bitwidth'] = state[key]
            elif varname == 'inv_table':
                quant_config[layeridx][layername]['inv_output']['bitwidth'] = state[key]
            elif varname == 'in_proj_out':
                quant_config[layeridx][layername]['in_proj']['output']['bitwidth'] = state[key]
            elif varname == 'out_proj_in':
                quant_config[layeridx][layername]['out_proj']['input']['bitwidth'] = state[key]
            elif varname == 'in_proj_weight':
                quant_config[layeridx][layername]['in_proj']['weight']['bitwidth'] = state[key]
            elif varname == 'out_proj_weight':
                quant_config[layeridx][layername]['out_proj']['weight']['bitwidth'] = state[key]
            elif varname == 'result':
                quant_config[layeridx][layername]['out_proj']['output']['bitwidth'] = state[key]
                #quant_config[layeridx]['norm2']['input']['bitwidth'] = state[key]
        elif 'ffn' in layername:
            if varname == 'in_proj_weight':
                quant_config[layeridx][layername]['in_proj']['weight']['bitwidth'] = state[key]
            elif varname == 'out_proj_weight':
                quant_config[layeridx][layername]['out_proj']['weight']['bitwidth'] = state[key]
            elif varname == 'hidden':
                quant_config[layeridx][layername]['in_proj']['output']['bitwidth'] = state[key]
                quant_config[layeridx][layername]['out_proj']['input']['bitwidth'] = state[key]
            elif varname == 'result':
                quant_config[layeridx][layername]['out_proj']['output']['bitwidth'] = state[key]
                #if quant_config.get(layeridx+1) is not None:
                #    quant_config[layeridx+1]['norm1']['input']['bitwidth'] = state[key]
                #if quant_config.get('norm') is not None:
                #    quant_config['norm']['input']['bitwidth'] = state[key]
        elif 'add1' in layername:
            if varname == 'result':
                quant_config[layeridx]['norm2']['input']['bitwidth'] = state[key]
        elif 'add2' in layername:
            if varname == 'result':
                if quant_config.get(layeridx+1) is not None:
                    quant_config[layeridx+1]['norm2']['input']['bitwidth'] = state[key]
                elif quant_config.get('norm') is not None:
                    quant_config['norm']['input']['bitwidth'] = state[key]
    return quant_config

In [None]:
import numpy as np
mask_labels = np.load('./model/mask_labels.npy')
src = np.load('./model/src.npy')
mask_labels = torch.tensor(mask_labels)
src = torch.tensor(src, dtype=torch.float64)

In [None]:
qmodel = QTransformerEncoder([QTransformerEncoderLayer(182, 2, 128, quant_config=transformer_quant_config[0], calibration=True),
                              QTransformerEncoderLayer(182, 2, 128, quant_config=transformer_quant_config[1], calibration=True),
                              QTransformerEncoderLayer(182, 2, 128, quant_config=transformer_quant_config[2], calibration=True),
                              QTransformerEncoderLayer(182, 2, 128, quant_config=transformer_quant_config[3], calibration=True)], 
                             4, 
                             QLayerNorm(182, quant_config=transformer_quant_config['norm'], calibration=True),
                             dtype=torch.float64)
qmodel.transfer_weights(model)
#calibration
with torch.no_grad():
    qmodel.eval()
    transformer_quant_config = calibrate_transformer(qmodel, transformer_quant_config, src[0:1].permute(1,0,2), model.src_mask)

qmodel = QTransformerEncoder([QTransformerEncoderLayer(182, 2, 128, quant_config=transformer_quant_config[0], calibration=False),
                              QTransformerEncoderLayer(182, 2, 128, quant_config=transformer_quant_config[1], calibration=False),
                              QTransformerEncoderLayer(182, 2, 128, quant_config=transformer_quant_config[2], calibration=False),
                              QTransformerEncoderLayer(182, 2, 128, quant_config=transformer_quant_config[3], calibration=False)], 
                             4, 
                             QLayerNorm(182, quant_config=transformer_quant_config['norm'], calibration=False),
                             dtype=torch.float64)
qmodel.transfer_weights(model)
with torch.no_grad():
    qmodel.eval()
    qy = qmodel(src[0:1].permute(1,0,2), mask=mask_labels)
    print(qy)

In [None]:
import random
state = gen_init_state(4)

state['layers_0_norm2.Precision.result'] = 18
config = sync_hls_config(config, state)
transformer_quant_config = sync_quant_config(transformer_quant_config, state)
pprint(transformer_quant_config)
pprint(config)
!rm -rf ./hls/ndt
hls_model = hls4ml.converters.convert_from_pytorch_model(
                                                            model,
                                                            [[1, 180, 182]],
                                                            output_dir='./hls/ndt',
                                                            project_name='myproject',
                                                            backend='Vitis',
                                                            #part='xcu250-figd2104-2L-e',
                                                            part='xcu55c-fsvh2892-2L-e',
                                                            #board='alveo-u55c',
                                                            hls_config=config,
                                                            io_type='io_tile_stream',
                                                        )
hls_model.compile()
qmodel = QTransformerEncoder([QTransformerEncoderLayer(182, 2, 128, quant_config=transformer_quant_config[0], calibration=False),
                              QTransformerEncoderLayer(182, 2, 128, quant_config=transformer_quant_config[1], calibration=False),
                              QTransformerEncoderLayer(182, 2, 128, quant_config=transformer_quant_config[2], calibration=False),
                              QTransformerEncoderLayer(182, 2, 128, quant_config=transformer_quant_config[3], calibration=False)], 
                             4, 
                             QLayerNorm(182, quant_config=transformer_quant_config['norm'], calibration=False),
                             dtype=torch.float64)
qmodel.transfer_weights(model)
import os 
if os.path.exists('./src_norm2.txt'):
    os.remove('./src_norm2.txt')
if os.path.exists('./src_ffn.txt'):
    os.remove('./src_ffn.txt')
if os.path.exists('./src_norm1_in.txt'):
    os.remove('./src_norm1_in.txt')
with torch.no_grad():
    qmodel.eval()
    qy = qmodel(src[0:1].permute(1,0,2), mask=model.src_mask)
    print(qy)
    hls_y = hls_model.predict(src[0:1].numpy())
    print(hls_y)
    #check if the output is closer enough
    assert np.allclose(qy.flatten().detach().numpy(), hls_y, atol=1e-5)

In [None]:
assert False
import random
state = gen_init_state(4)
for key in state.keys():
    print('---------------------------------')
    print("key", key)
    print("bits", state[key])
    state[key] += random.randint(-3, 3)
    print("bits", state[key])
#state['layers_0_ffn.Precision.result'] = 11
    config = sync_hls_config(config, state)
    transformer_quant_config = sync_quant_config(transformer_quant_config, state)
    hls_model = hls4ml.converters.convert_from_pytorch_model(
                                                                model,
                                                                [[1, 180, 182]],
                                                                output_dir='./hls/ndt',
                                                                project_name='myproject',
                                                                backend='Vitis',
                                                                #part='xcu250-figd2104-2L-e',
                                                                part='xcu55c-fsvh2892-2L-e',
                                                                #board='alveo-u55c',
                                                                hls_config=config,
                                                                io_type='io_tile_stream',
                                                            )
    hls_model.compile()

    qmodel = QTransformerEncoder([QTransformerEncoderLayer(182, 2, 128, quant_config=transformer_quant_config[0], calibration=False),
                                  QTransformerEncoderLayer(182, 2, 128, quant_config=transformer_quant_config[1], calibration=False),
                                  QTransformerEncoderLayer(182, 2, 128, quant_config=transformer_quant_config[2], calibration=False),
                                  QTransformerEncoderLayer(182, 2, 128, quant_config=transformer_quant_config[3], calibration=False)], 
                                 4, 
                                 QLayerNorm(182, quant_config=transformer_quant_config['norm'], calibration=False),
                                 dtype=torch.float64)
    qmodel.transfer_weights(model)

    with torch.no_grad():
        qmodel.eval()
        qy = qmodel(src[0:1].permute(1,0,2), mask=model.src_mask)
        print(qy)
        hls_y = hls_model.predict(src[0:1].numpy())
        print(hls_y)
        #check if the output is closer enough
        assert np.allclose(qy.flatten().detach().numpy(), hls_y, atol=1e-5)