In [None]:
import numpy as np
import tensorflow as tf

# %matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.colors as colors 

import ipysh
import Hunch_utils  as Htls
import Hunch_lsplot as Hplt
import Hunch_tSNEplot as Hsne

%aimport Dataset_QSH

In [50]:
import torch
import torch.utils.data
from torch import nn, optim
import torch.nn.functional as F
import brevitas.nn as qnn

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")


In [37]:
from brevitas.core.bit_width import BitWidthImplType
from brevitas.core.quant import QuantType
from brevitas.core.restrict_val import RestrictValueType
from brevitas.core.scaling import ScalingImplType
from brevitas.core.stats import StatsOp
from brevitas.nn import QuantConv2d, QuantHardTanh, QuantLinear
# Quant common
BIT_WIDTH_IMPL_TYPE = BitWidthImplType.CONST
SCALING_VALUE_TYPE = RestrictValueType.LOG_FP
SCALING_IMPL_TYPE = ScalingImplType.PARAMETER
NARROW_RANGE_ENABLED = True

# Weight quant common
STATS_OP = StatsOp.MEAN_LEARN_SIGMA_STD
BIAS_ENABLED = False
WEIGHT_SCALING_IMPL_TYPE = ScalingImplType.STATS
SIGMA = 0.001

# QuantHardTanh configuration
HARD_TANH_MIN = -1.0
HARD_TANH_MAX = 1.0
ACT_PER_OUT_CH_SCALING = False

def get_stats_op(quant_type):
    if quant_type == QuantType.BINARY:
        return StatsOp.AVE
    else:
        return StatsOp.MAX


def get_quant_type(bit_width):
    if bit_width is None:
        return QuantType.FP
    elif bit_width == 1:
        return QuantType.BINARY
    else:
        return QuantType.INT


def get_act_quant(act_bit_width, act_quant_type):
    if act_quant_type == QuantType.INT:
        act_scaling_impl_type = ScalingImplType.PARAMETER
    else:
        act_scaling_impl_type = ScalingImplType.CONST
    return QuantHardTanh(quant_type=act_quant_type,
                         bit_width=act_bit_width,
                         bit_width_impl_type=BIT_WIDTH_IMPL_TYPE,
                         min_val=HARD_TANH_MIN,
                         max_val=HARD_TANH_MAX,
                         scaling_impl_type=act_scaling_impl_type,
                         restrict_scaling_type=SCALING_VALUE_TYPE,
                         scaling_per_channel=ACT_PER_OUT_CH_SCALING,
                         narrow_range=NARROW_RANGE_ENABLED)


def get_quant_linear(in_features, out_features, per_out_ch_scaling, bit_width, quant_type, stats_op):
    return QuantLinear(bias=BIAS_ENABLED,
                       in_features=in_features,
                       out_features=out_features,
                       weight_quant_type=quant_type,
                       weight_bit_width=bit_width,
                       weight_bit_width_impl_type=BIT_WIDTH_IMPL_TYPE,
                       weight_scaling_per_output_channel=per_out_ch_scaling,
                       weight_scaling_stats_op=stats_op,
                       weight_scaling_stats_sigma=SIGMA)




In [46]:
from brevitas.core.quant import QuantType
from brevitas.core.quant import RescalingIntQuant

INTERMEDIATE_FC_PER_OUT_CH_SCALING = True
LAST_FC_PER_OUT_CH_SCALING = False
IN_DROPOUT = 0.2
HIDDEN_DROPOUT = 0.2


weight_quant_type = get_quant_type(1)
act_quant_type = get_quant_type(1)
in_quant_type = get_quant_type(1)
stats_op = get_stats_op(weight_quant_type)
# INT-8
class TestModel(torch.nn.Module):
    def __init__(self):

        super(TestModel, self).__init__()
        self.fc1 = get_quant_linear(in_features=30,
                                    out_features=30,
                                    per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING,
                                    bit_width=1,
                                    quant_type=weight_quant_type,
                                    stats_op=stats_op)
        
        
        
    def forward(self, x):
        #         out = self.relu3(self.fc1(x))
        #         out = self.relu4(self.fc2(out))
        #         out = self.fc3(out)
        return self.fc1(x)

In [21]:
qsh = Dataset_QSH.Dataset_QSH()
file = ipysh.abs_builddir+'/te_db_r15_clean_shuffle.npy'
qsh.load(file)

qsh.dim = 15
qsh.set_null(np.nan)
qsh.set_normal_positive()


In [22]:
params = {'batch_size': 1,
          'shuffle': True,
          'num_workers': 6}
ds = qsh.get_torch_dataset(**params)

In [47]:
m = TestModel()
x = torch.randn(1, 30, requires_grad=True)
X = m(x)
X

tensor([[ 0.5300, -0.1324,  0.0214, -0.3962, -0.5261,  0.4405, -0.2946,  0.1931,
         -0.1001, -0.1119,  0.2916, -1.3516, -0.1463,  0.5977, -0.3714, -0.6793,
          0.3242,  0.0601, -0.0283,  0.1863, -0.2546,  0.4349,  0.5386, -0.0071,
          0.2733, -0.3356,  0.2590,  0.4032, -0.0746,  0.0745]],
       grad_fn=<MmBackward>)

In [52]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [48]:
import brevitas.onnx as bo
bo.export_finn_onnx(m,(1,30),'/tmp/test_bo.onnx')

