In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import brevitas
import brevitas.nn as qnn
from brevitas.core.quant import QuantType
from brevitas.quant import Int8ActPerTensorFloat, Int8WeightPerTensorFloat
from typing import Dict, List, Set, Optional, Callable
from collections import OrderedDict
import concrete
from concrete.numpy.compilation import Configuration
from concrete.ml.torch.compile import compile_brevitas_qat_model
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(concrete.ml.__version__)

0.6.1


# FP32 Model

In [3]:
class DenseClassifier(nn.Module):
    def __init__(self, hparams):
        super(DenseClassifier, self).__init__()
        self.hparams = hparams
        self.dense1 = nn.Linear(hparams['n_feats'], hparams['hidden_dim'])
        self.dp1 = nn.Dropout(0.1)
        self.act1 = nn.ReLU()
        
        self.dense2 = nn.Linear(hparams['hidden_dim'], 1)
           
    def forward(self, src):
        x = self.dense1(src)
        x = self.dp1(x)
        x = self.act1(x)
        
        x = self.dense2(x)
        return x

In [4]:
config = {
        'n_feats': 12,
        'hidden_dim': 32,
}

In [5]:
model_fp32 = DenseClassifier(config)

# Quant Model

In [6]:
class QDenseClassifier(nn.Module):
    def __init__(self,
                 hparams: dict,
                 bits: int,
                 act_quant: brevitas.quant = Int8ActPerTensorFloat,
                 weight_quant: brevitas.quant = Int8WeightPerTensorFloat):
        super(QDenseClassifier, self).__init__()
        self.hparams = hparams
        self.dense1 = qnn.QuantLinear(hparams['n_feats'], hparams['hidden_dim'], weight_bit_width=bits, weight_quant=weight_quant, bias=True, return_quant_tensor=True)
        self.dp1 = qnn.QuantDropout(0.1)
        self.act1 = qnn.QuantReLU(act_quant=act_quant)
        
        self.dense2 = qnn.QuantLinear(hparams['hidden_dim'], 1, weight_bit_width=bits, weight_quant=weight_quant, bias=True, return_quant_tensor=True)
            
    def forward(self, src):
        x = self.dense1(src)
        x = self.dp1(x)
        x = self.act1(x)
        
        x = self.dense2(x)
        return x

In [7]:
model_quant = QDenseClassifier(config, 7)

In [8]:
# From concrete-ml VGG notebook
def mapping_keys(pretrained_weights: Dict, model: nn.Module, device: str) -> nn.Module:

    """
    Initialize the quantized model with pre-trained fp32 weights.
    Args:
        pretrained_weights (Dict): The state_dict of the pre-trained fp32 model.
        model (nn.Module): The Brevitas model.
        device (str): Device type.
    Returns:
        Callable: The quantized model with the pre-trained state_dict.
    """

    # Brevitas requirement to ignore missing keys
    brevitas.config.IGNORE_MISSING_KEYS = True

    old_keys = list(pretrained_weights.keys())
    new_keys = list(model.state_dict().keys())
    new_state_dict = OrderedDict()

    for old_key, new_key in zip(old_keys, new_keys):
        new_state_dict[new_key] = pretrained_weights[old_key]

    model.load_state_dict(new_state_dict)
    model = model.to(device)

    return model

In [9]:
torch.equal(model_quant.dense1.bias, model_fp32.dense1.bias)

False

In [10]:
model_quant = mapping_keys(model_fp32.state_dict(), model_quant, 'cpu')

In [11]:
torch.equal(model_quant.dense1.bias, model_fp32.dense1.bias)

True

In [12]:
# From concrete-ml VGG notebook
def fhe_compatibility(model: Callable, bit: int, data: DataLoader) -> Callable:
    """Test if the model is FHE-compatible.
    Args:
        model (Callable): The Brevitas model.
        bit (int): Bit of quantization.
        data (DataLoader): The data loader.
    Returns:
        Callable: Quantized model.
    """
    configuration = Configuration(
        dump_artifacts_on_unexpected_failures=False,
        # This is for our tests only, never use that in prod.
        enable_unsafe_features=True,
        # This is for our tests only, never use that in prod.
        use_insecure_key_cache=True,
        insecure_key_cache_location="ConcreteNumpyKeyCache",
        jit=False,
        p_error=None,
        global_p_error=None,
    )

    qmodel = compile_brevitas_qat_model(
        model.to("cpu"),
        # Training
        torch_inputset=data,
        n_bits={"model_inputs": bit, "op_inputs": bit, "op_weights": bit, "model_outputs": bit},
        configuration=configuration,
        show_mlir=False,
        # Set use_virtual_lib to False to use the real FHE execution.
        use_virtual_lib=True,
        # Concrete-ML uses table lookup (TLU) to represent any non-linear operation.
        # This TLU is implemented through the Programmable Bootstrapping (PBS).
        # A single PBS operation has P_ERROR chances of being incorrect.
        # Default value = 6.3342483999973e-05.
        p_error=6.3342483999973e-05,
        output_onnx_file="test.onnx",
    )

    clear_output()

    return qmodel

In [13]:
batch_size = 64
data = torch.randn((batch_size, config['n_feats']))
data.shape

torch.Size([64, 12])

In [14]:
model_quant2 = fhe_compatibility(model_quant, 7, data)
print(
    f"The maximum bit-width in the circuit = "
    f"{model_quant2.forward_fhe.graph.maximum_integer_bit_width()}"
)

onnx.brevitas.Quant




UnboundLocalError: local variable 'node_integer_inputs' referenced before assignment