In [8]:
from brevitas.inject.enum import *
from brevitas.core.bit_width import BitWidthImplType
from brevitas.core.quant import QuantType
from brevitas.core.restrict_val import FloatToIntImplType
from brevitas.core.restrict_val import RestrictValueType
from brevitas.core.scaling import ScalingImplType
from brevitas.core.zero_point import ZeroZeroPoint
from brevitas.quant.solver import WeightQuantSolver, ActQuantSolver
from brevitas.quant.base import *
from brevitas.core.function_wrapper.ops_ste import CeilSte

In [9]:
class Int8WeightPerTensorFloatScratch(WeightQuantSolver):
    quant_type = QuantType.INT # integer quantization
    bit_width_impl_type = BitWidthImplType.CONST # constant bit width
    float_to_int_impl_type = FloatToIntImplType.ROUND # round to nearest
    scaling_impl_type = ScalingImplType.STATS # scale based on statistics
    scaling_stats_op = StatsOp.MAX # scale statistics is the absmax value
    restrict_scaling_type = RestrictValueType.FP # scale factor is a floating point value
    scaling_per_output_channel = False # scale is per tensor
    bit_width = 8 # bit width is 8
    signed = True # quantization range is signed
    narrow_range = True # quantization range is [-127,127] rather than [-128, 127]
    zero_point_impl = ZeroZeroPoint # zero point is 0.
    
class Int8ActPerTensorFloatScratch(ActQuantSolver):
    quant_type = QuantType.INT # integer quantization
    bit_width_impl_type = BitWidthImplType.CONST # constant bit width
    float_to_int_impl_type = FloatToIntImplType.ROUND # round to nearest
    scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS # scale is a parameter initialized from statistics
    scaling_stats_op = StatsOp.PERCENTILE # scale statistics is a percentile of the abs value
    high_percentile_q = 99.999 # percentile is 99.999
    collect_stats_steps = 300  # statistics are collected for 300 forward steps before switching to a learned parameter
    restrict_scaling_type = RestrictValueType.FP # scale is a floating-point value
    scaling_per_output_channel = False  # scale is per tensor
    bit_width = 8  # bit width is 8
    signed = True # quantization range is signed
    narrow_range = False # quantization range is [-128, 127] rather than [-127, 127]
    zero_point_impl = ZeroZeroPoint # zero point is 0.

In [12]:
import torch
from brevitas.nn import QuantLinear, QuantIdentity
quant_linear = QuantLinear(2, 4, weight_quant=Int8WeightPerTensorFloatScratch ,bias=True)
quant_identity = QuantIdentity(act_quant = Int8ActPerTensorFloatScratch, return_quant_tensor=True)
print(quant_identity)



QuantIdentity(
  (input_quant): ActQuantProxyFromInjector(
    (_zero_hw_sentinel): StatelessBuffer()
  )
  (act_quant): ActQuantProxyFromInjector(
    (_zero_hw_sentinel): StatelessBuffer()
    (fused_activation_quant_proxy): FusedActivationQuantProxy(
      (activation_impl): Identity()
      (tensor_quant): RescalingIntQuant(
        (int_quant): IntQuant(
          (float_to_int_impl): RoundSte()
          (tensor_clamp_impl): TensorClamp()
          (delay_wrapper): DelayWrapper(
            (delay_impl): _NoDelay()
          )
        )
        (scaling_impl): ParameterFromRuntimeStatsScaling(
          (stats_input_view_shape_impl): OverTensorView()
          (stats): _Stats(
            (stats_impl): AbsPercentile()
          )
          (restrict_scaling): _RestrictValue(
            (restrict_value_impl): FloatRestrictValue()
          )
          (clamp_scaling): _ClampValue(
            (clamp_min_ste): Identity()
          )
          (restrict_inplace_preprocess): Identit

In [4]:
import torch

# torch random array
arr = torch.rand(1,5)
weight_quant = quant_linear.weight_quant
tensor_quant = weight_quant.tensor_quant # instance RescalingIntQuant
scaling_impl = quant_linear.weight_quant.tensor_quant.scaling_impl # threshold, StatsFromParameterScaling
int_scaling_impl = quant_linear.weight_quant.tensor_quant.int_scaling_impl # int threshold

first_tracked_param = scaling_impl.parameter_list_stats.first_tracked_param

# random torch tensor

print("Original tensor: ", arr)
print("Weight Quant: ", weight_quant(arr))
print("Tensor Quant: ", tensor_quant(arr))
threshold = scaling_impl(arr)
print("Scaling: ", threshold)
int_threshold = int_scaling_impl(arr)
print("Int Scaling: ", int_scaling_impl(arr))
print("Scale Factor: ", threshold/int_threshold)

Original tensor:  tensor([[0.1168, 0.4111, 0.9212, 0.8079, 0.9039]])
Weight Quant:  QuantTensor(value=tensor([[0.1155, 0.4095, 0.6667, 0.6667, 0.6667]], grad_fn=<MulBackward0>), scale=tensor(0.0052, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))
Tensor Quant:  (tensor([[0.1155, 0.4095, 0.6667, 0.6667, 0.6667]], grad_fn=<MulBackward0>), tensor(0.0052, grad_fn=<DivBackward0>), tensor(0.), tensor(8.))
Scaling:  tensor(0.6667, grad_fn=<ViewBackward0>)
Int Scaling:  tensor([[-0.4578, -0.3352, -0.0532, -0.1247, -0.0644]])
Scale Factor:  tensor([[ -1.4562,  -1.9892, -12.5399,  -5.3476, -10.3496]],
       grad_fn=<DivBackward0>)


In [6]:
quant_linear
print("first_tracked_param: ", first_tracked_param())

first_tracked_param:  tensor([ 0.4354,  0.3994,  0.1296, -0.1740,  0.5606,  0.6667, -0.6104, -0.1527],
       grad_fn=<ViewBackward0>)


In [7]:
scaling_impl.tracked_parameter_list

[Parameter containing:
 tensor([[ 0.4354,  0.3994],
         [ 0.1296, -0.1740],
         [ 0.5606,  0.6667],
         [-0.6104, -0.1527]], requires_grad=True)]

# QuantIdentity

In [13]:
input_arr = torch.rand(4,4)
print(input_arr)
print(quant_identity(input_arr))

tensor([[0.0574, 0.4205, 0.8934, 0.9537],
        [0.2962, 0.5842, 0.3586, 0.7235],
        [0.0826, 0.7101, 0.4055, 0.6815],
        [0.7516, 0.0194, 0.9429, 0.2446]])
QuantTensor(value=tensor([[0.0596, 0.4172, 0.8940, 0.9462],
        [0.2980, 0.5811, 0.3576, 0.7227],
        [0.0820, 0.7078, 0.4023, 0.6780],
        [0.7525, 0.0224, 0.9462, 0.2459]], grad_fn=<MulBackward0>), scale=tensor(0.0075, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))


In [19]:
from brevitas.core.scaling.standalone import ConstScaling 

from brevitas.core.restrict_val import PowerOfTwoRestrictValue
int_scaling_impl = ConstScaling(3.0)
pot_scaling_impl = ConstScaling(0.5, restrict_scaling_impl=PowerOfTwoRestrictValue())
print(pot_scaling_impl(torch.empty(1)))
print(int_scaling_impl(torch.empty(1)))

tensor(0.5000)
tensor(3.)


In [20]:
from brevitas.core.restrict_val import PowerOfTwoRestrictValue
import torch
arr = torch.ones(5)
restrict_scaling_impl = PowerOfTwoRestrictValue()
print(restrict_scaling_impl.restrict_init_tensor(arr))
print(restrict_scaling_impl.restrict_init_tensor(arr*2))

print(restrict_scaling_impl.restrict_init_tensor(arr*3))

print(restrict_scaling_impl.restrict_init_tensor(arr*4))


tensor([0., 0., 0., 0., 0.])
tensor([1., 1., 1., 1., 1.])
tensor([1.5850, 1.5850, 1.5850, 1.5850, 1.5850])
tensor([2., 2., 2., 2., 2.])


# QuantIdentity Inference

DependencyError: 'Int8ActPerTensorFloatInference' can not resolve attribute 'max_val' while building 'scaling_init_impl'

In [23]:
class Int8ActPerTensorFloatInference(ActQuantSolver):
    quant_type = QuantType.INT # integer quantization
    bit_width_impl_type = BitWidthImplType.CONST # constant bit width
    float_to_int_impl_type = FloatToIntImplType.ROUND # round to nearest
    scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS # scale is a parameter initialized from statistics
    scaling_stats_op = StatsOp.PERCENTILE # scale statistics is a percentile of the abs value
    high_percentile_q = 99.999 # percentile is 99.999
    collect_stats_steps = 300  # statistics are collected for 300 forward steps before switching to a learned parameter
    restrict_scaling_type = RestrictValueType.FP # scale is a floating-point value
    scaling_per_output_channel = False  # scale is per tensor
    bit_width = 8  # bit width is 8
    signed = True # quantization range is signed
    narrow_range = False # quantization range is [-128, 127] rather than [-127, 127]
    zero_point_impl = ZeroZeroPoint # zero point is 0.

quant_identity_inference = QuantIdentity(act_quant=Int8ActPerTensorFloatInference, return_quant_tensor=True)

In [None]:
arr = torch.rand()