In [2]:
import torch
import torch.nn as nn
import numpy as np
from brevitas.nn import QuantIdentity, QuantLinear, QuantReLU
from brevitas.quant import Int8WeightPerTensorFloat, Int8Bias
from brevitas.quant_tensor import QuantTensor
import brevitas

# Print Brevitas version
print(f"Brevitas version: {brevitas.__version__}")

# Constants
N_INPUT = 2
N_HIDDEN_1 = 64
N_HIDDEN_2 = 80
SCALE = 1/32000.

def quantize_input_tensor(in_float, scale, bit_width=16, zero_point=0.0, training=False, device=torch.device('cpu')):
    int_value = in_float / scale
    quant_value = (int_value - zero_point) * scale
    quant_tensor_input = QuantTensor(
        quant_value,
        scale=torch.tensor(scale),
        zero_point=torch.tensor(zero_point),
        bit_width=torch.tensor(float(bit_width)),
        signed=True,
        training=training).to(device)
    return quant_tensor_input

class Digital_twin(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes, scale, device):
        super(Digital_twin, self).__init__()
        self.fc1 = QuantLinear(obs_dim, hidden_sizes[0], bias=True, 
                               weight_quant=Int8WeightPerTensorFloat, bias_quant=Int8Bias, return_quant_tensor=True)
        self.fc1.cache_inference_quant_bias=True
        self.relu1 = QuantReLU(act_quant=None, return_quant_tensor=True)
        self.fc2 = QuantLinear(hidden_sizes[0], hidden_sizes[1], bias=True, 
                               weight_quant=Int8WeightPerTensorFloat, bias_quant=Int8Bias, return_quant_tensor=True)
        self.fc2.cache_inference_quant_bias=True
        self.relu2 = QuantReLU(act_quant=None, return_quant_tensor=True)
        self.fc3 = QuantLinear(hidden_sizes[1], 2 * act_dim, bias=True, 
                               weight_quant=Int8WeightPerTensorFloat, bias_quant=Int8Bias, return_quant_tensor=True)
        self.fc3.cache_inference_quant_bias=True
        self.device = device
        self.scale = scale
        
    def forward(self, obs):
        obs = quantize_input_tensor(obs, self.scale, training=self.training, device=self.device)
        net_out = self.fc1(obs)
        net_out = self.relu1(net_out)
        net_out = self.fc2(net_out)
        net_out = self.relu2(net_out)
        net_out = self.fc3(net_out)
        return net_out

def create_input_tensor(dim, scale, device):
    random_numbers = np.random.uniform(low=10., high=16000., size=(1, dim)).astype(np.int16)    
    tensor = torch.tensor(random_numbers, dtype=torch.float32, device=device)
    return tensor * scale

# Software-only forward pass
def software_forward(X, W1, b1, W2, b2, W3, b3):
    def relu(x):
        return np.maximum(x, 0)
    
    A1 = relu(np.dot(W1, X.T) + b1)
    A2 = relu(np.dot(W2, A1) + b2)
    Y = np.dot(W3, A2) + b3
    return Y.T

# Main execution
if __name__ == "__main__":
    device = torch.device('cpu')
    hidden_sizes = [N_HIDDEN_1, N_HIDDEN_2]
    model = Digital_twin(N_INPUT, 1, hidden_sizes, SCALE, device).to(device)
    
    # Set model to evaluation mode
    model.eval()
    
    # Create input tensor
    in_float = create_input_tensor(N_INPUT, SCALE, device)
    
    print("Input tensor:")
    print(in_float)
    
    # Forward pass
    with torch.no_grad():
        out_brevitas = model(in_float)
    
    print("\nModel output (before int conversion):")
    print(out_brevitas.value)
    
    # Int32 conversion (will overflow)
    out_brevitas_int32 = out_brevitas.int()
    print("\nModel output (after int32 conversion):")
    print(out_brevitas_int32)
    
    # Int64 conversion
    out_brevitas_int64 = (out_brevitas.value / out_brevitas.scale[0,0]).to(torch.int64)
    print("\nModel output (after int64 conversion):")
    print(out_brevitas_int64)
    
    # Software-only forward pass
    W1 = model.fc1.quant_weight().int().cpu().numpy().astype(np.int64)
    b1 = model.fc1.quant_bias().int().cpu().numpy().astype(np.int64).reshape(-1, 1)
    W2 = model.fc2.quant_weight().int().cpu().numpy().astype(np.int64)
    b2 = model.fc2.quant_bias().int().cpu().numpy().astype(np.int64).reshape(-1, 1)
    W3 = model.fc3.quant_weight().int().cpu().numpy().astype(np.int64)
    b3 = model.fc3.quant_bias().int().cpu().numpy().astype(np.int64).reshape(-1, 1)
    
    X_sw = (in_float / SCALE).cpu().numpy().astype(np.int64)
    Y_sw = software_forward(X_sw, W1, b1, W2, b2, W3, b3)
    
    print("\nSoftware-only output:")
    print(Y_sw)
    
    # Compare results with int64
    print("\nDifference between int64 Brevitas and software-only:")
    diff_int64 = out_brevitas_int64.cpu().numpy() - Y_sw
    print(diff_int64)
    print(f"Max absolute difference (int64): {np.abs(diff_int64).max()}")
    
    # Compare results with int32
    print("\nDifference between int32 Brevitas and software-only:")
    diff_int32 = out_brevitas_int32.cpu().numpy() - Y_sw
    print(diff_int32)
    print(f"Max absolute difference (int32): {np.abs(diff_int32).max()}")

Brevitas version: 0.10.3
Input tensor:
tensor([[0.1538, 0.2910]])

Model output (before int conversion):
tensor([[-0.0076, -0.0328]])

Model output (after int32 conversion):
tensor([[-2147483648, -2147483648]], dtype=torch.int32)

Model output (after int64 conversion):
tensor([[ -50867355648, -221071736832]])

Software-only output:
[[ -50867339837 -221071725940]]

Difference between int64 Brevitas and software-only:
[[-15811 -10892]]
Max absolute difference (int64): 15811

Difference between int32 Brevitas and software-only:
[[ 48719856189 218924242292]]
Max absolute difference (int32): 218924242292
