# MLToolbox demonstration
Expected RAM usage: 40 GB.
Expected runtime: 90 minutes.

NVIDIA A100-SXM4-40GB, 10cpu: 14 minutes

Note that it is possible to run the notebook on different hardware. When multiple GPUs are available the notebook will automatically utilize them, to speed up the computation.

## Introduction

This demo notebook focuses on preparing a model for efficient use with Fully Homomorphic Encryption (FHE) using MLToolbox. MLToolbox offers specialized tools that make models FHE-friendly while minimizing performance degradation. Let's delve into the process of making models FHE-friendly and learn how MLToolbox simplifies this task.

Note: 
 -  Currently supported models: AlexNet, LeNet5, Squeezenet, SqueezenetCHET, ResNet18, ResNet50, ResNet101, ResNet152 and CLIP-RN50. To learn how to add a new model, please refer to: `help(nn_module)`.
 -  Supported datasets: CIFAR10, CIFAR100, COVID_CT, COVID_XRAY, Places205, Places365 and ImageNet. To add a new dataset, please see `help(DatasetWrapper)`.


## Table of Contents

* [Step 1. Training the original model](#train)            
* [Step 2. Transforming the original model into an intermediate, range-aware form](#intermidate)   
* [Step 3. Transforming the range-aware form into an FHE-Friendly form](#polynomial)   
* [Step 4. Encrypting the trained model and predicting over encrypted data](#encrypt)        


<a id="train"></a>
## Step 1. Training the base model

In this step, we will train the base model using MLToolbox. This base model will serve as the starting point for the gradual conversion process we will undertake later.

Note: If you already have a pre-trained model, you can skip this step. Simply ensure that the model is saved as demonstrated in [the steps below](#save).

### 1.1. Start with some imports:

In [None]:
import os
#Printing only error debug printouts
os.environ["LOG_LEVEL"]="ERROR"
import json
import numpy as np
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
import utils # used for benchmarking
#trange used to show progress of training loops
from tqdm.notebook import trange
#magic function that renders the figure in a notebook
%matplotlib inline 
import matplotlib.pyplot as plt
import matplotlib.cm as cm

#FHE related import
import pyhelayers
#mltoolbox imports
from pyhelayers.mltoolbox.arguments import Arguments
import pyhelayers.mltoolbox.utils.util as util
from pyhelayers.mltoolbox.poly_activation_converter import starting_point
from pyhelayers.mltoolbox.data_loader.ds_factory import DSFactory
from pyhelayers.mltoolbox.model.DNN_factory import DNNFactory

### 1.2. Initialize the trainer object.
In the following cell, we initialize the `Arguments` class. This class defines the default values for various parameters related to the training process, allowing for their customization. 
`model`,`dataset_name`, `num_epochs`, `classes` and `data_dir` are required arguments, that do not have defaults.
The `data_dir` argument specifies the dataset location.

For tutorial efficiency reasons, we used a short training duration of 25 epochs, and argue that using more epochs (and maybe other hyper-parameters) would achieve better accuracy.

MLToolbox uses a fixed seed by default, to ensure reproducibility. Randomicity is achieved by setting `args.seed` to different values. Still fluctuations in the results are possible, when running the code on a different architecture.

The `starting_point` method receives the user arguments and returns two objects: `trainer` and `poly_activation_converter`, which will assist in converting the original model into an FHE-friendly form using the following steps."

In [None]:
args = Arguments(model="lenet5", dataset_name="CIFAR10", num_epochs=25, classes=10, data_dir = 'cifar_data')

#After initializing an `Argument` object it is possible to customize its settings
args.lr=0.001
args.batch_size = 200

#Use the following line to utilize the arguments:
trainer, poly_activation_converter, _ = starting_point(args)

In [None]:
model = trainer.get_model()
print(model)

The supported datasets can be observed using the following call:

In [None]:
DSFactory.print_supported_datasets()

The supported models can be observed using the following:

In [None]:
DNNFactory.print_supported_models()

### 1.3. Training the original model

Run the training loop using the trainer:

In [None]:
epochs_range = trange(1,args.num_epochs + 1)

#This list will accumulate the validation accuracy for each epoch, so we can later plot it
val_acc = []
scheduler = ReduceLROnPlateau(trainer.get_optimizer(), factor=0.5, patience=2, min_lr=0.000001, verbose=True)
for epoch in epochs_range:
        trainer.train_step(args, epoch, epochs_range)
        val_metrics, val_cf = trainer.validation(args, epoch) # perform validation. Returns metrics (val_metrics) and confusion matrix (val_cf)
        val_acc.append(val_metrics.get_avg('accuracy'))
        
        scheduler.step(val_metrics.get_avg('loss'))

Let's print the test metrics:

In [None]:
test_metrics, test_cf = trainer.test(args, epoch)
print({'loss': test_metrics.get_avg('loss'), 'accuracy': test_metrics.get_avg('accuracy')})

We now have a trained model that can be converted to the FHE-friendly form later on. The accuracy of this model is 0.65, which can be improved upon training for some more epochs. Once the model is ready, we save it to a file for future use: 

<a id="save"></a> 
Note: as mentioned earlier, you don't have to use a model trained with MLToolbox. You can use a pre-trained model and skip Step 1 of this notebook – just make sure your model is saved in the supported form `{'model': model}` as shown below, and that it is supported by mltoolbox, or extended:

`state = {'model': model}
torch.save(state, file_name)`

Below we save the model we just trained as a checkpoint. The save location is defined by the args.save_dir argument. The default value is set to `outputs/mltoolbox`. 

In [None]:
util.save_model(trainer, poly_activation_converter, args, val_metrics, epoch, val_cf)

In [None]:
def plot(val_acc):
    plt.plot([*range(1, args.num_epochs + 1)], val_acc)
    plt.xlabel('Epoch')
    plt.title('Validation Accuracy')
    plt.xticks(np.arange(1,args.num_epochs + 1,step=2)) 
    plt.grid(True)
    plt.show()

plot(val_acc)

### 1.4 Replacing max-pooling

FHE does not natively support max pooling operations. To address this limitation, we replace max pooling with average pooling. There are two approaches to accomplish this: 1) by training the base model from scratch with average pooling; 2) by converting the pooling operation to average pooling after the model has been trained, and then continuing training for a few additional epochs. To configure MLToolbox to option #2, set the `pooling_type` argument before running the `starting_point` and training steps:

In [None]:
args.pooling_type = "avg"

<a id="intermidate"></a>
## Step 2. Transforming the original model into an intermediate, range-aware form

In this step, we aim to minimize the input range to the non-polynomial activation layers. Particularly, this process involves the ReLU activation layers keeping track of the input range. Additionally, a regularization term is added to the network, penalizing values that fall outside of this range.

There are two primary reasons for performing this step:

- To enable the approximation of activations by polynomials, it is essential to provide the range of inputs to the activation function beforehand.
- Our goal is to optimize the input range for more accurate activation approximation using low-degree polynomials. While we use the interval [-10, 10] as an example, the actual range depends on the specific model and data. It is preferable to have a smaller range to achieve a more precise approximation. However, a significantly larger range could negatively impact the accuracy of the polynomial model.

### 2.1. Define the arguments to represent what we want to do:

In [None]:
args.pooling_type = "avg"
args.activation_type= "relu_range_aware"
args.num_epochs = 12
args.lr=0.001
args.from_checkpoint = "outputs/mltoolbox/lenet5_last_checkpoint.pth.tar"
args.range_awareness_loss_weight=0.01
args.range_aware_train = True
args.save_dir = "outputs/mltoolbox/range_aware"

* `args.pooling_type = "avg"`: All the pooling operations will be replaced by average-pooling (the default is `"max"`, in which case the pooling layers are not replaced).

* `args.activation_type="relu_range_aware"`: The Relu activations will be account for it's input range (other options are `'trainable_poly'`, `'non_trainable_poly'`, `'approx_relu'`, `'relu'`, `'weighted_relu'`, `'relu_range_aware'`,`'square'`).

* `args.num_epochs=12`: The number of epochs to train

* `args.lr=0.001`: Learning rate

* `args.from_checkpoint = "outputs/mltoolbox/lenet5_last_checkpoint.pth.tar"`: The checkpoint to load the model from.

* `args.range_awareness_loss_weight=0.01`: This is the weight that defines how much attention is given to diminishing the ranges during training, relatively to the CrossEntropyLoss. This value needs to be tuned for the used model and data such that the training does not suffer from too hursh accuracy degradation.

* `args.range_aware_train=True`: A flag that makes the training be range aware. Can be turned off.

* `args.save_dir = "outputs/mltoolbox/range_aware"`: The directory where the output of this step will be saved

### 2.2. Run starting_point again, with the new arguments

In [None]:
trainer, poly_activation_converter, epoch = starting_point(args)

### 2.3. Transform and train the range-aware model

We train the model for several more epochs. At the beginning of the training loop, the `replace_activations` function handles anything that needs to be replaced in the current epoch. Then, the train step is called. This is the same train step we ran before; the only difference is that the loss function will now have an extra term that will regulize the input, striving to bring them towards the required range.

In [None]:
val_acc = []
ranges_train = []
ranges_val = []
scheduler = ReduceLROnPlateau(trainer.get_optimizer(), factor=0.5, patience=2, min_lr=args.min_lr, verbose=True)
epochs_range = trange(1,args.num_epochs + 1)

for epoch in epochs_range:
    poly_activation_converter.replace_activations(trainer, epoch, scheduler)
    trainer.train_step(args, epoch, epochs_range)
    ranges_train.append(trainer.get_all_ranges(args))
    
    val_metrics, val_cf = trainer.validation(args, epoch) # perform validation. Returns metrics (val_metrics) and confusion matrix (val_cf)
    val_acc.append(val_metrics.get_avg('accuracy'))
    ranges_val.append(trainer.get_all_ranges(args))
    
util.save_model(trainer, poly_activation_converter, args, val_metrics, epoch, val_cf)

test_metrics, test_cf = trainer.test(args, epoch)
print({'loss': test_metrics.get_avg('loss'), 'accuracy': test_metrics.get_avg('accuracy')})

The resulting range-aware model that we trained exhibits an overall accuracy of 0.67. This accuracy slightly surpasses that of the original model because we did not fully optimize the original model during training, allowing further improvements in the current step. We set the `range_aware_weight` relatively low in this attempt, since the ranges are not too high, to begin with. When running your model, the `range_aware_weight` should be carefully tuned. 

In [None]:
plot(val_acc)

In [None]:
print(trainer.get_model())

let's observe the ranges of the model's activations

In [None]:
def plot_ranges(ranges):
    num_epochs = len(ranges)
    num_items = len(ranges[num_epochs-1])

    # Select a specific color map
    color_map = cm.get_cmap('tab10')
    start_epoch = next((i for i, epoch in enumerate(ranges) if epoch), None)
    
    if start_epoch is None:
        print("No ranges data found.")
        return
    
    num_epochs_to_plot = num_epochs - start_epoch
    for item_index in range(num_items):
        min_values = [epoch[item_index][0] for epoch in ranges[start_epoch:]]
        max_values = [epoch[item_index][1] for epoch in ranges[start_epoch:]]
        # Generate color index based on item index
        color_index = item_index % color_map.N
        
        plt.plot(range(num_epochs_to_plot), min_values, label=f'Min- {item_index+1}', color=color_map(color_index))
        plt.plot(range(num_epochs_to_plot), max_values, label=f'Max- {item_index+1}', color=color_map(color_index))

    #plt.legend(loc='upper right')  # Display legend
    plt.grid(True)
    plt.xlabel('Epoch')  # Set x-axis label
    plt.ylabel('Value')  # Set y-axis label
    plt.title('Minimum and Maximum Values over Epochs')  # Set plot title
    
    # Set x-axis limits to start from start_epoch
    plt.xlim(start_epoch, num_epochs_to_plot)
    
    plt.show()  # Display the plot
    
plot_ranges(ranges_train)

In [None]:
plot_ranges(ranges_val)

The ranges are around [-10,10], no need to make them smaller then this. If we wanted to, we could enlarge the args.range_awareness_loss_weight to turn more attention to diminishing the ranges, or try training this range step for more epochs to have slow gradual change.

<a id="polynomial"></a>
## Step 3. Transforming the range-aware form into an FHE-Friendly form

The Relu activations are replaced according to the arguments we've defined.
We start with a partially transformed model; the pooling type is already average, and some batch normalization may have been added.
Relu is replaced by a non-trainable polynomial, that approximate the RELU in the range that it holds. The remaining epochs are used to improve the model with no additional changes.

At the beginning of the training loop, the replace_activations function handles anything that needs to be replaced in the current epoch. Then, the train step is called. This is the same train step we ran before; the only difference is that the loss function has an extra term that regulizes the inputs in the required range.

After the training loop has completed, we save the resulting model.

In [None]:
args.pooling_type = "avg"
args.activation_type= "non_trainable_poly"
args.num_epochs = 20 #10 can be enough
args.lr=0.001
args.from_checkpoint = "outputs/mltoolbox/range_aware/lenet5_best_checkpoint.pth.tar"
args.range_awareness_loss_weight=0.1
args.range_aware_train = True
args.save_dir = "outputs/mltoolbox/polynomial"

trainer, poly_activation_converter, epoch = starting_point(args)

In [None]:
val_acc = []
ranges_train = []
ranges_val = []
scheduler = ReduceLROnPlateau(trainer.get_optimizer(), factor=0.5, patience=2, min_lr=args.min_lr, verbose=True)
epochs_range = trange(1,args.num_epochs + 1)

for epoch in epochs_range:
    poly_activation_converter.replace_activations(trainer, epoch, scheduler)
    trainer.train_step(args, epoch, epochs_range)
    ranges_train.append(trainer.get_all_ranges(args))
    
    val_metrics, val_cf = trainer.validation(args, epoch) # perform validation. Returns metrics (val_metrics) and confusion matrix (val_cf)
    val_acc.append(val_metrics.get_avg('accuracy'))
    ranges_val.append(trainer.get_all_ranges(args))
    
util.save_model(trainer, poly_activation_converter, args, val_metrics, epoch, val_cf)

In [None]:
plot(val_acc)

The resulting accuracy we got is 0.66, which is above the accuracy we started with (of the base model)

In [None]:
plot_ranges(ranges_train)

In [None]:
plot_ranges(ranges_val)

### 2.4. Convert the model into onnx format.
The fhe_best checkpoint, that was saved by util.save_model is read and the model is converted into onnx (some format changes are applyed to the model during the convertion.)

In [None]:
path, model = util.save_onnx(args, poly_activation_converter, trainer)

In [None]:
print(model)

### Results

The below table summarizes the results achieved in this notebook:

| Technique    | Accuracy   | 
|:-------------|:-----------|
| **ReLU**     | **0.66**       |
| range-aware polynomial  | 0.66     |

*With more training we were able to achieve accuracy of above 0.9 for both ReLU and polynomial model.

<a id="encrypt"></a>
## Step 4. Encrypting the model and predicting over encrypted data
We now encrypt the FHE-friendly model trained above and use the encrypted model to make predictions on encrypted data. We'll also make predictions using the plain (unencrypted) model, and compare the two sets of results.

### 4.1 Extract a batch from the validation set
First, we extract a batch from the validation set. This is the data we'll run and compare the encrypted and plain models on. We set the batch size to smaller value, so that the encrypted prediction runs faster.

In [None]:
args.batch_size = 10
trainer, poly_activation_converter, epoch = starting_point(args)

plain_samples, labels = next(iter(trainer.val_generator))
batch_size = len(plain_samples)

### 4.2 Perform prediction using the plain model trained above
We load the best computed checkpoint of the plain model and use it to run predictions over the batch extracted above.
The resulting labels are computed as the argmax of the predicted probabilities.

In [None]:
checkpoint = torch.load(os.path.join('outputs/mltoolbox/polynomial/lenet5_best_checkpoint.pth.tar'))
model = checkpoint['model']
plain_model_predictions = model(plain_samples).detach().numpy()
plain_predicted_labels = np.argmax(plain_model_predictions, 1)

### 4.3 Load NN architecture and weights using the FHE library

We use the `init_from_onnx_file` function of the `NeuralNetPlain` class to load the NN model.

In [None]:
model_path = 'outputs/mltoolbox/polynomial/lenet5.onnx'
nnp = pyhelayers.NeuralNetPlain()
hyper_params = pyhelayers.PlainModelHyperParams()
nnp.init_from_files(hyper_params,[model_path])

### 4.4 Compile

Using HE can require configuring complex and non-intuitive parameters. Luckily, helayers has an `Optimizer` tool offers an automatic optimization process that analyzes the model, and tunes various HE parameters to work best for the given scenario.

The optimizer runs when we run 'compile' on the plain model. It receives also run requirements, some simple and intuitive input from the user (e.g., the desired security level). As output, it produces a `profile` object which contains all the details related to the HE configuration and packing. These details are automatically selected to ensure optimal performance given the user's requests.

In [None]:
he_run_req = pyhelayers.HeRunRequirements()
# Use the HEaaN context as the underlying FHE
he_run_req.set_he_context_options([pyhelayers.HeaanContext()])
# The encryption is at least as strong as 128-bit encryption.
he_run_req.set_security_level(128)
# Our numbers are theoretically stored with a precision of about 2^-40.
he_run_req.set_fractional_part_precision(40)
# The batch size for NN.
he_run_req.optimize_for_batch_size(batch_size)
# The model weights are kept in the plain
he_run_req.set_model_encrypted(False)



# Compile - run the optimizer
profile = pyhelayers.HeModel.compile(nnp, he_run_req)

profile_as_json = profile.to_string()
# Profile supports I/O operations and can be stored on file.
print(json.dumps(json.loads(profile_as_json), indent=4))

### 4.5 Initialize the context, and encrypt the NN
Now we initialize the context object and encrypt the neural network using our profile object.

In [None]:
context=pyhelayers.HeModel.create_context(profile)
nn = pyhelayers.NeuralNet(context)
nn.encode(nnp, profile)

### 4.6 Encrypt the input samples
Here, we encrypt the samples we're going to be running an inference on. The data is encrypted by the iop object (input output processor), which contains model meta data only, and can process inputs and outputs of the model.

In [None]:
iop=nn.create_io_processor()
encrypted_samples = pyhelayers.EncryptedData(context)
iop.encode_encrypt_inputs_for_predict(encrypted_samples, [plain_samples])

### 4.7 Run prediction over encrypted data, using the encrypted model

In [None]:
encrypted_predictions = pyhelayers.EncryptedData(context)
with utils.elapsed_timer('predict', batch_size) as timer:
    nn.predict(encrypted_predictions, encrypted_samples)

### 4.8 Decrypt the prediction result
The final labels are computed as the argmax of the 10 predicted probabilities.

In [None]:
fhe_model_predictions = iop.decrypt_decode_output(encrypted_predictions)
fhe_predicted_labels = np.argmax(fhe_model_predictions, 1)

### 4.9 Compare the predictions of the encrypted model with the predictions of the plain model
The FHE model's predictions are shown to match those produced by the plain model.

In [None]:
print('labels predicted by the FHE model: ', fhe_predicted_labels)
print('labels predicted by the plain model: ', plain_predicted_labels)
np.testing.assert_array_equal(fhe_predicted_labels, plain_predicted_labels)

## References:


Moran Baruch, Nir Drucker, Gilad Ezov, Eyal Kushnir, Jenny Lerner, Omri Soceanu, Itamar Zimerman. "Sensitive Tuning of Large Scale CNNs for E2E Secure Prediction using Homomorphic Encryption" (2023)
https://arxiv.org/abs/2304.14836