# Appendix B
# Multiprocessing with encrypted data
Author:

Arthur Pignet - apignet@equancy.com

As we use the CKKS encryption scheme, computations on ciphertext are really slow. In order to increase the speed of the training phase, we want to benefit from multiple cores. This notebook aims to demonstrate the feasibility of such tasks, and suggest an implementation.

The main issue relies on providing the data to the processes. The data needs to be serializable, which is not really the case of the tenseal context and vectors. 

### Classic idea of multiprocessing

Let say we dispose of N cores. To optimize the use of our cores, we want to create N subprocesses.
For each sub-process, Linux creates (forks) a child process, with its own and closed memory space, as a copy of the parent process. So in theory, the children has a copy of all variables of the parent in its own memory. To save memory space, Linux use a system of copy-on-write. The principe is that the child process sees in its virtual memory all the variables that were in the memory space of parent process when the children was created. But Linux does copy the variable from the parent's memory to the child's memory only the first time the children try to write on the variable.
With tenseal, CKKSvector and context cannot be copied with this principle, because the "copy" use the pickle*
serialization, which is not allowed on these tenseal objects.

### TenSEAL serialization

We are interested in serialize two tenseal objects, context and CKKSvector. The context is the main tenseal object. When created, it holds the secret and public keys. So it is needed to encrypt a vector.We can add to this context the relin and Galois keys. These keys are mandatory for every computation, therefore every vector encrypted with a specific context holds as an attribute a pointer on the context. Moreover,if the context is private (ie holds the secret key), the method decrypt() on a vector will not need the secret key, as it can be accessed with vector.context.secret_key.
It really convenient for computations but raise on issue on serialization. We we try to pickle a vector, the pickle module will try to serialize every attribute of the vector, so it will serialize, for every vector, a copy of the context, which is really memory consuming. to tackle this issue, the tenseal vectors and context came with their own serialization methods, build on top of Google protocol buffers.
So while being serialize, CKKS vector drops its context. Therefore, to deserialize a CKKS vector, we need a copy of its native context.

All in all, the protocol to transfer CKKS vectors can be written as follow :
- serialize the context
- serialize the vectors
- send the binary (ie serialized ) vectors and context
- deserialize the context
- deserialize the vectors using the deserialized context

### Multiprocessing implementation

The full multiprocessing protocol, which will be tested below, can be described :

We create N processes.
We initializes these N processes by passing a function initialization(), which takes as parameters the context in binary, and 1/N of the data, in binary. The initialization() function creates global variables for the context and the dataset, calling the tenseal.context_from() and tenseal.ckks_vector_from() methods. As the initialization() function is executed locally by the process, the global variables are created kind of locally in the process memory. Then, the main process puts in input_queues the function to perform, with its parameters, which need to be pickle-serializable. FOr our case, the function will be batch-forward-backward-propagation, and the parameters the weights. Therefore the weights will be serialized. Each Process deserializes the weights, using its local context, computes the gradient on its local dataset, serializes it, and puts the serialized gradient to a queue to give it back to the parent process. The parent process will then (first deserialize the gradients) sums the partial gradients to get the final gradient, then computes the descent direction, so as to eventually actualize the weights and bias.
We make the choice to copy the context N times, so as to give to every worker its own set of keys, which are mandatory for computation.
The context would also have been stored in a shared memory, to save memory space. But as every worker needs to access the keys every time they compute something, it would have probably raised concurrent access issues.

For the present demonstration, we will only make the sum of the whole dataset.

We use the native module of python for multiprocessing, [multiprocessing](https://docs.python.org/fr/3/library/multiprocessing.html), which use the pickle module to serialize data.

In [1]:
import logging 
import multiprocessing
import time
import os
import sys

import numpy as np
import tenseal as ts

os.chdir("/home/apignet/homomorphic-encryption/ckks_titanic/")
from src.features import build_features

In [2]:
%load_ext memory_profiler

# Definition of parameters

N is the number of worker. The work will be split in N blocks. 


In [3]:
N = 8

### Paths

In [4]:
WORKING_DIR = "/home/apignet/homomorphic-encryption/ckks_titanic/"
DATA_PATH = WORKING_DIR+"/data/raw/"  
SUBMISSION_PATH = WORKING_DIR+"/report/submission/"# whole data set
#DATA_PATH = WORKING_DIR+"/data/quick_demo/"   # subset of the data set, with 15 train_samples and 5 test_samples

LOG_PATH = WORKING_DIR+"/reports/log/"
LOG_FILENAME = "test_multi"

In [5]:
fileHandler = logging.FileHandler("{0}/{1}.log".format(LOG_PATH, LOG_FILENAME))
streamHandler = logging.StreamHandler(sys.stdout)
logging.basicConfig(format="%(asctime)s  [%(levelname)-8.8s]  %(message)s", 
                    datefmt='%m/%d/%Y %I:%M:%S %p', 
                    level = logging.INFO, 
                    handlers=[fileHandler, streamHandler]
                   )

## Static functions

To encrypt the data

In [6]:
def crytp_array(X, local_context):
    """
    This function encrypt a list of vector
    
    :parameters 
    ------------
    
    :param X ; list of list, interpreted as list of vector to encrypt
    :param local_context ; TenSEAL context object used to encrypt
    
    :returns
    ------------
    
    list ; list of CKKS ciphertext  
    
    """
    res = []
    for i in range(len(X)):
        res.append(ts.ckks_vector(local_context, X[i]))
        if i == len(X) // 4:
            logging.info("25 % ...")
        elif i == len(X) // 2 :
            logging.info("50 % ...")
        elif i == 3* len(X)//4:
            logging.info("75% ...")
    return res

# Loading and processing the data

As said, the detailed of the data processing, feature engineering can be found in the notebook Appendix A

In [7]:
%%memit
logging.info(os.getcwd())
raw_train, raw_test = build_features.data_import(DATA_PATH)
train, submission_test = build_features.processing(raw_train, raw_test)
#del submission_test

08/18/2020 03:36:06 PM  [INFO    ]  /home/apignet/homomorphic-encryption/ckks_titanic
08/18/2020 03:36:06 PM  [INFO    ]  loading the data into memory (pandas df)
08/18/2020 03:36:06 PM  [INFO    ]  Done
08/18/2020 03:36:06 PM  [INFO    ]  making final data set from raw data
08/18/2020 03:36:06 PM  [INFO    ]  Done
08/18/2020 03:36:06 PM  [INFO    ]  /home/apignet/homomorphic-encryption/ckks_titanic
08/18/2020 03:36:06 PM  [INFO    ]  loading the data into memory (pandas df)
08/18/2020 03:36:06 PM  [INFO    ]  Done
08/18/2020 03:36:06 PM  [INFO    ]  making final data set from raw data
08/18/2020 03:36:06 PM  [INFO    ]  Done
peak memory: 146.06 MiB, increment: 13.70 MiB


In [8]:
%%memit
train_labels = train.Survived
train_features = train.drop("Survived", axis=1)

peak memory: 146.31 MiB, increment: 0.01 MiB


# Definition of safety parameters

See Notebooks 0 or 1 for more detail about the choice of those. 

Here we use the same parameters as in the real case, but it is only for evaluating the true memory size of the dataset and the context. In fact the operation performed (sum) requires less restrictive parameters

In [9]:
%%memit
logging.info('Definition of safety parameters...')
timer = time.time()
# context = ts.context(ts.SCHEME_TYPE.CKKS, 32768,
# coeff_mod_bit_sizes=[60, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 60])
#context = ts.context(ts.SCHEME_TYPE.CKKS, 8192, coeff_mod_bit_sizes=[40, 21, 21, 21, 21, 21, 21, 40])

context = ts.context(ts.SCHEME_TYPE.CKKS, 16384, coeff_mod_bit_sizes=[60, 40, 40, 40, 40, 40, 40,40, 60])
context.global_scale = pow(2, 40)
logging.info("Done. " + str(round(time.time() - timer, 2)) + " seconds")


logging.info('Generation of the secret key...')
timer = time.time()
secret_key = context.secret_key()
context.make_context_public() #drop the relin keys, the galois keys, and the secret keys. 
logging.info("Done. " + str(round(time.time() - timer, 2)) + " seconds")
logging.info('Generation of the Galois Key...')
timer = time.time()
context.generate_galois_keys(secret_key)
logging.info("Done. " + str(round(time.time() - timer, 2)) + " seconds")
logging.info('Generation of the Relin Key...')
timer = time.time()
context.generate_relin_keys(secret_key)
logging.info("Done. " + str(round(time.time() - timer, 2)) + " seconds")
if context.is_public():
    logging.info("The context is now public, the context do not hold the secret key anymore, and decrypt \
    methods need the secret key to be provide,")
logging.info("serialization of the context started...")
timer = time.time()
b_context = context.serialize()
logging.info("Serialization done in %s seconds." %(time.time()-timer))


08/18/2020 03:36:07 PM  [INFO    ]  Definition of safety parameters...
08/18/2020 03:36:07 PM  [INFO    ]  Done. 0.31 seconds
08/18/2020 03:36:07 PM  [INFO    ]  Generation of the secret key...
08/18/2020 03:36:07 PM  [INFO    ]  Done. 0.0 seconds
08/18/2020 03:36:07 PM  [INFO    ]  Generation of the Galois Key...
08/18/2020 03:36:10 PM  [INFO    ]  Done. 3.3 seconds
08/18/2020 03:36:10 PM  [INFO    ]  Generation of the Relin Key...
08/18/2020 03:36:10 PM  [INFO    ]  Done. 0.13 seconds
08/18/2020 03:36:10 PM  [INFO    ]  The context is now public, the context do not hold the secret key anymore, and decrypt     methods need the secret key to be provide,
08/18/2020 03:36:10 PM  [INFO    ]  serialization of the context started...
08/18/2020 03:37:00 PM  [INFO    ]  Serialization done in 49.10720896720886 seconds.
peak memory: 1572.49 MiB, increment: 1426.18 MiB


# Data encryption
 

In [10]:
%%memit
logging.info("Data encryption...")
timer = time.time()
encrypted_X = crytp_array(train_features.to_numpy().tolist(), context)
encrypted_Y = crytp_array(train_labels.to_numpy().reshape((-1, 1)).tolist(), context)
logging.info("Done. " + str(round(time.time() - timer, 2)) + " seconds")

08/18/2020 03:37:00 PM  [INFO    ]  Data encryption...
08/18/2020 03:37:05 PM  [INFO    ]  25 % ...
08/18/2020 03:37:10 PM  [INFO    ]  50 % ...
08/18/2020 03:37:15 PM  [INFO    ]  75% ...
08/18/2020 03:37:25 PM  [INFO    ]  25 % ...
08/18/2020 03:37:30 PM  [INFO    ]  50 % ...
08/18/2020 03:37:35 PM  [INFO    ]  75% ...
08/18/2020 03:37:40 PM  [INFO    ]  Done. 40.55 seconds
peak memory: 4679.26 MiB, increment: 3562.12 MiB


# Worker functions

Definition of the function that will be performed continuously by the subprocess.

Definition of the function that will initialize the subprocess, by creating in its own memory space the context and the subprocess sub-dataset.

Definition of the function addition, that will be executed by the worker on its own partial dataset. 

In [11]:
def worker(input_queue, output_queue):
    """
    This functions turns on the process until a string 'STOP' is found in the input_queue queue

    It takes every couple (function, arguments of the functions) from the input_queue queue, and put the result into the output queue
    """
    for func, args in iter(input_queue.get, 'STOP'):
        result = func(*args)
        output_queue.put(result)

        
def initialization(*args):
    """
    :param:tuple : (b_context, b_X)
            b_context: binary representation of the context_. context_.serialize()$
            b_X : list of binary representations of samples from CKKS vectors format
    This function is the first one to be passed in the input_queue queue of the process.
    It first deserialize the context_, passing it global,
    in the memory space allocated to the process
    Then the batch is also deserialize, using the context_,
    to generate a list of CKKS vector which stand for the encrypted samples on which the process will work
    """
    b_context = args[0]
    b_X = args[1]
    global context_
    context_ = ts.context_from(b_context)
    global local_X_
    local_X_ = [ts.ckks_vector_from(context_, i) for i in b_X]
    return 'Initialization done for process %s. Len of data : %i' % (
        multiprocessing.current_process().name, len(local_X_))


def addition(*arg):
    res=0
    for i in local_X_:
        res = i + res 
    return res.serialize()


## Sum without multiprocessing

We first try this out without multiprocessing, on one core

In [12]:
timer =time.time()
res = 0 
for x in encrypted_X:
    res = x + res
wct = time.time() -timer
logging.info('sum performed in %s without multiprocessing' %(time.time()-timer
    ))

08/18/2020 03:37:42 PM  [INFO    ]  sum performed in 1.3563103675842285 without multiprocessing


# Sum with multiprocessing. 

We first split the dataset into N batches. During the same time we serialize the data

In [13]:
timer = time.time()
b_X = [[] for _ in range(N)]
logging.info("Starting serialization of data")
ser_time = time.time()
for i in range(len(encrypted_X)):
    b_X[i % N].append(encrypted_X[i].serialize())
ser_time = time.time() -timer
logging.info("Data serialization done in %s seconds" % (ser_time))

08/18/2020 03:37:42 PM  [INFO    ]  Starting serialization of data
08/18/2020 03:40:54 PM  [INFO    ]  Data serialization done in 191.8620753288269 seconds


Then we use the module multiprocessing to create the N workers, the N in_queue and the queue_out
As soon as a worker start, the initialization begin.

In [14]:
logging.info("Initialization of %d workers" % N)
init_work_timer = time.time()
list_queue_in = []
queue_out = multiprocessing.Queue()
init_tasks = [(initialization, (b_context, x)) for x in b_X]
for init in init_tasks:
    list_queue_in.append(multiprocessing.Queue())
    list_queue_in[-1].put(init)
    multiprocessing.Process(target=worker, args=(list_queue_in[-1], queue_out)).start()
log_out = []
for _ in range(N):
    log_out.append(queue_out.get())
    logging.info(log_out[-1])
    
init_time =time.time() - init_work_timer
logging.info("Initialization done in %s seconds" % (init_time))

 

08/18/2020 03:40:54 PM  [INFO    ]  Initialization of 8 workers
08/18/2020 03:41:01 PM  [INFO    ]  Initialization done for process Process-8. Len of data : 112
08/18/2020 03:41:02 PM  [INFO    ]  Initialization done for process Process-9. Len of data : 112
08/18/2020 03:41:02 PM  [INFO    ]  Initialization done for process Process-11. Len of data : 111
08/18/2020 03:41:02 PM  [INFO    ]  Initialization done for process Process-10. Len of data : 112
08/18/2020 03:41:02 PM  [INFO    ]  Initialization done for process Process-13. Len of data : 111
08/18/2020 03:41:03 PM  [INFO    ]  Initialization done for process Process-12. Len of data : 111
08/18/2020 03:41:03 PM  [INFO    ]  Initialization done for process Process-14. Len of data : 111
08/18/2020 03:41:03 PM  [INFO    ]  Initialization done for process Process-15. Len of data : 111
08/18/2020 03:41:03 PM  [INFO    ]  Initialization done in 9.765893459320068 seconds


If everything went well, we now have our N workers which are listening for tasks. They all have their local context and data.

We can now put the task in the input_queues and get the result back in the output queue.

In [15]:
timer = time.time()
for q in list_queue_in:
    q.put((addition, ()))
result = 0
for _ in range(N):
    child_process_ans = queue_out.get()
    result = ts.ckks_vector_from(context, child_process_ans) + result
computation_time = time.time() -timer
for q in list_queue_in:
    q.put('STOP')

logging.info("Computation done in %s seconds" % (computation_time))


08/18/2020 03:41:04 PM  [INFO    ]  Computation done in 1.0061945915222168 seconds


## Comparison of results

When performing such quick operation, the multiprocessing is not worth it. The task of initialization, due to large amount of serialization performed, slows down the overall process.

In [16]:
logging.info("Without multiprocessing : %s seconds"% wct)
logging.info("With multiprocessing, computatiton: %s seconds"% computation_time)
logging.info("With multiprocessing, serialization: %s seconds"% ser_time)
logging.info("With multiprocessing, initialization: %s seconds"% init_time)
logging.info("With multiprocessing, Overall: %s seconds"% (ser_time+init_time+computation_time))



08/18/2020 03:41:04 PM  [INFO    ]  Without multiprocessing : 1.3562681674957275 seconds
08/18/2020 03:41:04 PM  [INFO    ]  With multiprocessing, computatiton: 1.0061945915222168 seconds
08/18/2020 03:41:04 PM  [INFO    ]  With multiprocessing, serialization: 191.8620753288269 seconds
08/18/2020 03:41:04 PM  [INFO    ]  With multiprocessing, initialization: 9.765893459320068 seconds
08/18/2020 03:41:04 PM  [INFO    ]  With multiprocessing, Overall: 202.6341633796692 seconds


But at the end the result is the same. In a real scenario, with a time consuming operation (gradient computations) the time gain is about the number of processes. 

In [17]:
np.abs(np.array(res.decrypt(secret_key))-np.array(result.decrypt(secret_key))).mean()

0.0