# MLToolbox Demonstration
Expected RAM usage: 40 GB.
Expected runtime: 20 minutes. (with debug_mode=True, as will be explained below)

NVIDIA A100-SXM4-40GB, 10 CPUs: 80 minutes (with debug_mode=False)

Note that it is possible to run the notebook on different hardware platforms. When multiple GPUs are available the notebook will automatically utilize them, to speed up the computation.

## Introduction

This demo notebook explores how to prepare a machine learning model for use with Fully Homomorphic Encryption (FHE) through MLToolbox. MLToolbox provides specialized tools designed to convert models into polynomial representations, all while minimizing performance degradation. In the following sections, we'll delve into the process of adapting models to polynomial form and discuss how MLToolbox simplifies this task.

Note: 
 -  Currently supported models: AlexNet, LeNet5, Squeezenet, SqueezenetCHET, ResNet18, ResNet50, ResNet101, ResNet152 and CLIP-RN50. To learn how to add a new model, please refer to: `help(nn_module)`.
 -  Supported datasets: CIFAR10, CIFAR100, COVID_CT, COVID_XRAY, Places205, Places365 and ImageNet. To add a new dataset, please see `help(DatasetWrapper)`.


## Table of Contents

* [Step 1. Training the base model](#train)            
* [Step 2. Transforming the original model into an intermediate, range-aware form](#intermediate)   
* [Step 3. Transforming the range-aware form into a polynomial form](#polynomial)   
* [Step 4. Encrypting the trained model and predicting over encrypted data](#encrypt)        


<a id="train"></a>
## Step 1. Training the base model

In this step, we will train the base model using MLToolbox. This base model will serve as the starting point for the gradual conversion process we will undertake later.

Note: If you already have a pre-trained model, you can skip this step. Simply ensure that the model is saved as demonstrated in [the steps below](#save).

### 1.1. Start with some imports:

In [None]:
import os
#Printing only error debug printouts
os.environ["LOG_LEVEL"]="ERROR"
import json
import numpy as np
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
import utils # used for benchmarking
#trange used to show progress of training loops
from tqdm.notebook import trange
#magic function that renders the figure in a notebook
%matplotlib inline 
import matplotlib.pyplot as plt
import matplotlib.cm as cm

#FHE related import
import pyhelayers
#mltoolbox imports
from pyhelayers.mltoolbox.arguments import Arguments
import pyhelayers.mltoolbox.utils.util as util
from pyhelayers.mltoolbox.poly_activation_converter import starting_point
from pyhelayers.mltoolbox.data_loader.ds_factory import DSFactory
from pyhelayers.mltoolbox.model.DNN_factory import DNNFactory

### 1.2. Convert dataset into ffcv format

We use the ffcv library to speed up training (https://docs.ffcv.io/index.html). Beyond just acceleration, the ffcv library introduces unique data transformations not found in Pytorch. One notable transformation is the cutout transformation (https://arxiv.org/abs/1708.04552), which has been shown to enhance performance.
It is possible to not use the ffcv library with mltoolbox, in such a case the `args.ffcv argument`, that will be presented below, should be set to `False`.

In [None]:
from ffcv.writer import DatasetWriter
from ffcv.fields import RGBImageField, IntField
from torchvision import transforms,datasets

path = 'outputs/mltoolbox/train.beton'
writer = DatasetWriter(path, {'image': RGBImageField(),'label': IntField()})
ds = datasets.CIFAR10('outputs/mltoolbox/cifar_data', train=True, download=True)
writer.from_indexed_dataset(ds)

### 1.3. Initialize the trainer object.
In the following cell, we initialize the `Arguments` class. This class defines the default values for various parameters related to the training process, allowing for their customization. 
`model`, `dataset_name`, `num_epochs`, `classes` and `data_dir` are required arguments and do not have default values.
The `data_dir` argument specifies the dataset location.

If you are using a CPU for training, debug_mode is automatically set to True. This enables the model to be trained on a small subset of the dataset and for only one epoch, allowing you to test your setup quickly.
When CUDA is available, debug_mode is automatically set to False to perform the actual, full-scale training process.

MLToolbox uses a fixed seed by default, to ensure reproducibility. Randomicity is achieved by setting `args.seed` to different values. Still fluctuations in the results are possible, when running the code on a different architecture.

A pre-trained model checkpoint can be used, by setting args.from_checkpoint to the checkpoint location. It is also possible to run the training from scratch by setting: args.from_checkpoint = ''

The `starting_point` method receives the user arguments and returns two objects: `trainer` and `poly_activation_converter`, which will assist in converting the original model into a polynomial form using the following steps."

In [None]:
debug_mode = not torch.cuda.is_available()
if debug_mode:
    num_epochs = 1
    batch_size = 10
else:
    num_epochs = 30
    batch_size = 500
    
args = Arguments(model="resnet18", dataset_name="CIFAR10_224", num_epochs=num_epochs, classes=10, data_dir = 'cifar_data')

#After initializing an `Argument` object it is possible to customize its settings
args.lr=0.05
args.opt = "sgd"
args.batch_size = batch_size
args.ffcv = True
args.ffcv_train_data_path = path
args.pooling_type = 'max'

baseline_chp_location = os.path.join(utils.get_data_sets_dir(), 'mltoolbox', 'resnet18_baseline_checkpoint.pth.tar')
args.from_checkpoint = baseline_chp_location

#Used to do a quick test run. When performing a real run, either remove this argument, or set it to `False` 
args.debug_mode = debug_mode

#Use the following line to utilize the arguments:
trainer, poly_activation_converter, _ = starting_point(args)

In [None]:
model = trainer.get_model()
print(model)

The supported datasets can be observed using the following call:

In [None]:
DSFactory.print_supported_datasets()

The supported models can be observed using the following:

In [None]:
DNNFactory.print_supported_models()

### 1.4. Training the base model

Run the training loop using the trainer:

In [None]:
# This list will accumulate the validation accuracy for each epoch, so we can later plot it
val_acc = []

# Run the training loop if the model was not loaded from checkpoint
if not args.from_checkpoint:
        epochs_range = trange(1,args.num_epochs + 1)
        scheduler = ReduceLROnPlateau(trainer.get_optimizer(), factor=0.5, patience=3, min_lr=0.000001, verbose=True)
        for epoch in epochs_range:
                trainer.train_step(args, epoch, epochs_range)
                val_metrics, val_cf = trainer.validation(args, epoch) # perform validation. Returns metrics (val_metrics) and confusion matrix (val_cf)
                val_acc.append(val_metrics.get_avg('accuracy'))
                
                scheduler.step(val_metrics.get_avg('loss'))
                
        # Saving the model
        # The save location is defined by the args.save_dir argument. The default value is set to outputs/mltoolbox.
        util.save_model(trainer, poly_activation_converter, args, val_metrics, epoch, val_cf)


Let's print the test metrics:

In [None]:
test_metrics, test_cf = trainer.test(args, 1)
print({'loss': test_metrics.get_avg('loss'), 'accuracy': test_metrics.get_avg('accuracy')})

We now have a trained model that can be converted to a polynomial form later on. The accuracy of this model is 0.962. Once the model is ready, we save it to a file for future use.

Note: as mentioned earlier, you don't have to use a model trained with MLToolbox. You can use a pre-trained model and skip Step 1 of this notebook – just make sure your model is saved in the supported form `{'model': model}` as shown below, and that it is supported by mltoolbox, or extended:

```
state = {'model': model} 
torch.save(state, file_name)
```

In [None]:
def plot(val_acc):
    plt.plot([*range(1, args.num_epochs + 1)], val_acc)
    plt.xlabel('Epoch')
    plt.title('Validation Accuracy')
    plt.xticks(np.arange(1,args.num_epochs + 1,step=2)) 
    plt.grid(True)
    plt.show()

#if we run the training, the validation accuracy graph can be plotted
if val_acc:
    plot(val_acc)

<a id="intermediate"></a>
## Step 2. Transforming the original model into an intermediate, range-aware form

In this step, we aim to minimize the input range to the non-polynomial activation layers. Particularly, this process involves the ReLU activation layers keeping track of the input range. Additionally, a regularization term is added to the network, penalizing values that fall outside of this range.

There are two primary reasons for performing this step:

- To enable the approximation of activations by polynomials, it is essential to provide the range of inputs to the activation function beforehand.
- Our goal is to optimize the input range for more accurate activation approximation using low-degree polynomials. While we use the interval [-10, 10] as an example, the actual range depends on the specific model and data. It is preferable to have a smaller range to achieve a more precise approximation. However, a significantly larger range could negatively impact the accuracy of the polynomial model.

### 2.1. Define the arguments to represent what we want to do:

In [None]:
if debug_mode:
    num_epochs = 1
    batch_size = 10
else:
    num_epochs = 13
    batch_size = 200
    
args = Arguments(model="resnet18", dataset_name="CIFAR10_224", num_epochs=num_epochs, classes=10, data_dir = 'cifar_data')

args.opt = "sgd"
args.pooling_type = "avg"
args.activation_type= "relu_range_aware"
args.lr=0.005
args.batch_size = batch_size
args.range_awareness_loss_weight=0.002
args.range_aware_train = True
args.save_dir = "outputs/mltoolbox/range_aware"
args.ffcv = True
args.ffcv_train_data_path = path
args.debug_mode = debug_mode

baseline_chp_location = os.path.join(utils.get_data_sets_dir(), 'mltoolbox', 'resnet18_baseline_checkpoint.pth.tar')
args.from_checkpoint = baseline_chp_location

* `args.opt="sgd"`: The optimizer to be used (the other option is `'adam'`, which is the default)

* `args.pooling_type = "avg"`: All the pooling operations will be replaced by average-pooling (the default is `"avg"`).

* `args.activation_type="relu_range_aware"`: The Relu activations will be account for it's input range (other options are `'trainable_poly'`, `'non_trainable_poly'`, `'approx_relu'`, `'relu'`, `'weighted_relu'`, `'relu_range_aware'`,`'square'`).

* `args.lr=0.005`: Learning rate

* `args.batch_size=batch_size`: Batch size

* `args.range_awareness_loss_weight=0.002`: This is the weight that defines how much attention is given to diminishing the ranges during training, relatively to the CrossEntropyLoss. This value needs to be tuned for the used model and data such that the training does not suffer from too hursh accuracy degradation.

* `args.range_aware_train=True`: A flag that makes the training be range aware. Can be turned off (The default value is False).

* `args.save_dir = "outputs/mltoolbox/range_aware"`: The directory where the outputs of this step will be saved.

* `args.ffcv = True`: Use ffcv library during training, to optimize speed. In addition the ffcv adds a cutout transformation, which improves the accuracy (The default value is False).

* `args.ffcv_train_data_path = path`: The name and location of the ffcv converted data.

* `args.debug_mode = debug_mode`: When the debug mode is set to True, the training uses a small subset of the data. The default value is False.

* `args.from_checkpoint = baseline_chp_location`: The checkpoint to load the model from.

### 2.2. Run starting_point again, with the new arguments

In [None]:
trainer, poly_activation_converter, epoch = starting_point(args)

### 2.3 Replacing max-pooling

FHE does not natively support max pooling operations. To address this limitation, we replace max pooling with average pooling. There are two approaches to accomplish this: 1) by training the base model from scratch with average pooling; 2) by converting the pooling operation to average pooling after the model has been trained, and then continuing training for a few additional epochs. To configure MLToolbox to option #2, set the `pooling_type` argument to `'avg'` and run the `make_fhe_friendly` method, as shown below:

In [None]:
model = trainer.get_model()
model.module.make_fhe_friendly(add_bn=False, pooling_type=args.pooling_type) 

### 2.4. Transform and train the range-aware model

We train the model for several more epochs. At the beginning of the training loop, the `replace_activations` function handles anything that needs to be replaced in the current epoch. Then, the train step is called. This is the same train step we ran before; the only difference is that the loss function will now have an extra term that will regulize the input, striving to bring them towards the required range.

In [None]:
val_acc = []
ranges_train = []
ranges_val = []
scheduler = ReduceLROnPlateau(trainer.get_optimizer(), factor=0.5, patience=2, min_lr=args.min_lr, verbose=True)
epochs_range = trange(1,args.num_epochs + 1)

for epoch in epochs_range:
    poly_activation_converter.replace_activations(trainer, epoch, scheduler)
    trainer.train_step(args, epoch, epochs_range)
    ranges_train.append(trainer.get_all_ranges(args))
    
    val_metrics, val_cf = trainer.validation(args, epoch) # perform validation. Returns metrics (val_metrics) and confusion matrix (val_cf)
    val_acc.append(val_metrics.get_avg('accuracy'))
    ranges_val.append(trainer.get_all_ranges(args))
    
util.save_model(trainer, poly_activation_converter, args, val_metrics, epoch, val_cf)

The resulting range-aware model that we trained exhibits an overall accuracy of 0.96. We set the `range_aware_weight` to 0.001. When running your model, the `range_aware_weight` should be carefully tuned. While some degradation in this step is expected, it's important to ensure the accuracy doesn't decrease too significantly when working with the ranges.

In [None]:
test_metrics, test_cf = trainer.test(args, epoch)
print({'loss': test_metrics.get_avg('loss'), 'accuracy': test_metrics.get_avg('accuracy')})

plot(val_acc)

In [None]:
print(trainer.get_model())

let's observe the ranges of the model's activations

In [None]:
def plot_ranges(ranges):
    num_epochs = len(ranges)
    num_items = len(ranges[num_epochs-1])

    # Select a specific color map
    color_map = cm.get_cmap('tab10')
    start_epoch = next((i for i, epoch in enumerate(ranges) if epoch), None)
    
    if start_epoch is None:
        print("No ranges data found.")
        return
    
    num_epochs_to_plot = num_epochs - start_epoch
    for item_index in range(num_items):
        min_values = [epoch[item_index][0] for epoch in ranges[start_epoch:]]
        max_values = [epoch[item_index][1] for epoch in ranges[start_epoch:]]
        # Generate color index based on item index
        color_index = item_index % color_map.N
        
        plt.plot(range(num_epochs_to_plot), min_values, label=f'Min- {item_index+1}', color=color_map(color_index))
        plt.plot(range(num_epochs_to_plot), max_values, label=f'Max- {item_index+1}', color=color_map(color_index))

    #plt.legend(loc='upper right')  # Display legend
    plt.grid(True)
    plt.xlabel('Epoch')  # Set x-axis label
    plt.ylabel('Value')  # Set y-axis label
    plt.title('Minimum and Maximum Values over Epochs')  # Set plot title
    
    # Set x-axis limits to start from start_epoch
    plt.xlim(start_epoch, num_epochs_to_plot)
    
    plt.show()  # Display the plot
    
plot_ranges(ranges_train)

In [None]:
plot_ranges(ranges_val)

Most of the ranges are below [-10,10], while one activation still has larger ranges, around [-10,10].

<a id="polynomial"></a>
## Step 3. Transforming the range-aware form into a polynomial form

The Relu activations are replaced according to the arguments we've defined.
We start with a partially transformed model; the pooling type is already average, and some batch normalization may have been added.
Relu is replaced by a non-trainable polynomial, that approximate the RELU in the range that it holds. The remaining epochs are used to improve the model with no additional changes.

At the beginning of the training loop, the replace_activations function handles anything that needs to be replaced in the current epoch. Then, the train step is called. This is the same train step we ran before; the only difference is that the loss function has an extra term that regulizes the inputs in the required range.

After the training loop has completed, we save the resulting model.

In [None]:
args.pooling_type = "avg"
args.activation_type= "non_trainable_poly"
args.batch_size = 200
args.num_epochs = 25
args.lr=0.005
args.from_checkpoint = "outputs/mltoolbox/range_aware/resnet18_last_checkpoint.pth.tar"
args.range_awareness_loss_weight=0.1
args.range_aware_train = True
args.save_dir = "outputs/mltoolbox/polynomial"

if debug_mode:
    args.num_epochs = 1
    args.batch_size = 10
    
trainer, poly_activation_converter, epoch = starting_point(args)

In [None]:
val_acc = []
ranges_train = []
ranges_val = []
scheduler = ReduceLROnPlateau(trainer.get_optimizer(), factor=0.5, patience=2, min_lr=args.min_lr, verbose=True)
epochs_range = trange(1,args.num_epochs + 1)

for epoch in epochs_range:
    poly_activation_converter.replace_activations(trainer, epoch, scheduler)
    trainer.train_step(args, epoch, epochs_range)
    ranges_train.append(trainer.get_all_ranges(args))
    
    val_metrics, val_cf = trainer.validation(args, epoch) # perform validation. Returns metrics (val_metrics) and confusion matrix (val_cf)
    val_acc.append(val_metrics.get_avg('accuracy'))
    ranges_val.append(trainer.get_all_ranges(args))
    
util.save_model(trainer, poly_activation_converter, args, val_metrics, epoch, val_cf)

In [None]:
plot(val_acc)

The resulting accuracy we got is 0.962

In [None]:
plot_ranges(ranges_train)

In [None]:
plot_ranges(ranges_val)

### 3.1. Convert the model into onnx format.
The fhe_best checkpoint, that was saved by util.save_model is read and the model is converted into onnx (some format changes are applyed to the model during the convertion.)

In [None]:
path, model = util.save_onnx(args, poly_activation_converter, trainer)

In [None]:
print(model)

### Results

The below table summarizes the results achieved in this notebook:

| Technique    | Accuracy   | 
|:-------------|:-----------|
| **ReLU**     | **0.962**       |
| range-aware polynomial  | 0.962     |


<a id="encrypt"></a>
## Step 4. Encrypting the trained model and predicting over encrypted data
We now encrypt the polynomial model trained above and use the encrypted model to make predictions on encrypted data. We'll also make predictions using the plain (unencrypted) model, and compare the two sets of results. This is demonstrated in a different notebook. See `misc/19_MLToolbox_FHE.ipynb` for more details.  
We save the input samples and the expected output labels for the prediction process of notebook `misc/19_MLToolbox_FHE.ipynb`.

In [None]:
args.batch_size = 10
trainer, poly_activation_converter, epoch = starting_point(args)

plain_samples, labels = next(iter(trainer.val_generator))
torch.save(plain_samples, 'outputs/mltoolbox/plain_samples.pt')

## References:


Moran Baruch, Nir Drucker, Gilad Ezov, Eyal Kushnir, Jenny Lerner, Omri Soceanu, Itamar Zimerman. "Sensitive Tuning of Large Scale CNNs for E2E Secure Prediction using Homomorphic Encryption" (2023)
https://arxiv.org/abs/2304.14836