In [1]:
import numpy as np
import time
from multiprocessing import Pool, cpu_count
from functools import reduce
import pickle
import compress_pickle
import pandas as pd
from collections import defaultdict
import os
import sys

sys.path.append("../")

> The given result is produced by an AWS EC2 `c5.4xlarge` machine.

## 1. Global Variables

### 1.1 Configurable

In [2]:
# Aggregation task in a case
num_clients = 10  # how many clients' data to aggregate
element_bits = 16  # number of bits per element in a client's data

# Test cases
plaintext_lengths = [16384, 65536, 262144]
algorithms = ["Paillier", "Paillier+batch", "BFV", "BFV+batch", "CKKS", "CKKS+batch", "FLASHE"]

# Format of the results
cols = ['Plaintext', 'Ciphertext', 'Encryption', 'Decryption', 'Addition']

### 1.2 Predefined

In [3]:
# Global variables that one should not change its definition
# Affecting the efficiency
MAGIC_N_JOBS = 50

# To account for the worst-case overflow encountered during aggregation
additional_bits = int(np.ceil(np.log2(num_clients + 1)))
# After aggregation, number of bits per element in the server's sum
actual_element_bits = element_bits + additional_bits

## 2. Core Functions

### 2.1 Quantization Interface

In [4]:
from federatedml.secureprotol.jzf_aciq import ACIQ
from federatedml.secureprotol.jzf_quantize import \
    _static_quantize_padding, _static_unquantize_padding

def get_alpha_r_max(trainable_list, element_bits):
    local_min_list = []
    local_max_list = []
    size_list = []
    for idx, trainable in enumerate(trainable_list):
        local_min = []
        local_max = []
        for layer in trainable:
            local_min.append(np.amin(layer))
            local_max.append(np.amax(layer))

            if idx == 0:
                size_list.append(layer.size)

        local_min_list.append(np.array(local_min))
        local_max_list.append(np.array(local_max))

    local_min_list = np.array(local_min_list)
    local_max_list = np.array(local_max_list)

    min_list = np.amin(local_min_list, 0)
    max_list = np.amax(local_max_list, 0)

    n = len(trainable_list)
    aciq = ACIQ(element_bits)

    alpha_list = []
    r_max_list = []
    layer_cnt = 0
    for min, max in zip(min_list, max_list):
        alpha = aciq.get_alpha_gaus(min, max, size_list[layer_cnt])

        r_max = alpha * num_clients

        alpha_list.append(alpha)
        r_max_list.append(r_max)
        layer_cnt += 1

    return alpha_list, r_max_list

def quantize(trainable_list, element_bits):
    n = len(trainable_list)
    alpha_list, r_max_list = get_alpha_r_max(trainable_list,
                                            element_bits)
    quantized = []
    for _, trainable in enumerate(trainable_list):
        quantized_layers = []
        for idx, layer in enumerate(trainable):
            shape = layer.shape
            layer_flatten = layer.flatten()
            ret = _static_quantize_padding(layer_flatten,
                                           alpha_list[idx],
                                           element_bits,
                                           n)
            ret = np.array(ret).reshape(shape)
            quantized_layers.append(ret)
        quantized.append(np.array(quantized_layers))
    return np.array(quantized), np.array(alpha_list)

def unquantize(trainable, alpha_list, element_bits, num_clients):
    layers = []
    for idx, layer in enumerate(trainable):
        shape = layer.shape
        layer_flatten = layer.flatten()
        ret = _static_unquantize_padding(layer_flatten,
                                         alpha_list[idx],
                                         element_bits,
                                         num_clients)
        ret = np.array(ret).reshape(shape)
        layers.append(ret)

    layers = np.array(layers)
    return layers

### 2.2 Batching Interface

In [5]:
from federatedml.secureprotol.jzf_quantize import \
    _static_batching_padding, _static_unbatching_padding

### 2.3 Encryption Interface

#### 2.3.1 FLASHE

In [6]:
from federatedml.secureprotol.jzf_flashe import FlasheCipher
flashe = FlasheCipher(actual_element_bits)
flashe.set_num_clients(num_clients)
flashe.generate_prp_seed()
flashe.set_iter_index(0)
flashe.idx = 0

def flashe_encrypt(value):
    return flashe.encrypt(value)

def flashe_decrypt(value):
    flashe.set_idx_list(raw_idx_list=[0] * num_clients, mode="decrypt")
    return flashe.decrypt(value)

#### 2.3.2 Paillier

In [7]:
from federatedml.secureprotol.jzf_paillier import PaillierCipher
paillier = PaillierCipher()
paillier.generate_key(n_length=2048)

def paillier_encrypt(value):
    return paillier.encrypt(value)

def paillier_decrypt(value):
    return paillier.decrypt(value)

#### 2.3.3 BFV

In [8]:
from federatedml.secureprotol.jzf_bfv import BFVCipher
bfv = BFVCipher(p=128, m=2048)
# # bfv = BFVCipher(p=1964769281, m=8192, flagBatching=True)
bfv.generate_key()

def bfv_encrypt(value):
    return bfv.encrypt(value)

def bfv_decrypt(value):
    return bfv.decrypt(value)

from federatedml.secureprotol.jzf_bfv import BFVCipher
# p needs to be prime and p-1 must be multiple of 2*m
# bfv_2 = BFVCipher(p=65537, m=2048, flagBatching=True)
bfv_2 = BFVCipher(p=1964769281, m=8192, flagBatching=True)
bfv_2.generate_key()

def bfv_batch_encrypt(value):
    return bfv_2.encrypt(value)

def bfv_batch_decrypt(value):
    return bfv_2.decrypt(value)

#### 2.3.4 CKKS

In [9]:
from federatedml.secureprotol.jzf_ckks import CKKSCipher
ckks = CKKSCipher(8192, None, 2 ** 40)

def ckks_batch_encrypt(value):
    return ckks.encrypt(value)

def ckks_batch_decrypt(value):
    return ckks.decrypt(value)

def ckks_encrypt(value):
    return ckks.encrypt_no_batch(value)

def ckks_decrypt(value):
    return ckks.decrypt_no_batch(value)

### 2.4 Testing Logic

In [10]:
def test_an_algorithm(plaintext, quantized_plaintext, alphas, algorithm, plaintext_size, no_add=False):
    print(f'\tStart testing {algorithm}')
    
    ### Step 0/5: Configuration
    if algorithm == "FLASHE" or algorithm == "FLASHE+batch":
        num_bits_per_batch = flashe.int_bits
    elif algorithm == "Paillier" or algorithm == "Paillier+batch":
        num_bits_per_batch = (paillier.get_n() ** 2).bit_length()
    else:
        num_bits_per_batch = None
    
    begin_time = time.time()

    ### Step 1/5: Encryption
    if algorithm == "FLASHE":
        ciphertext = flashe_encrypt(quantized_plaintext)
    elif algorithm == "FLASHE+batch":
        shape = quantized_plaintext.shape
        ciphertext = _static_batching_padding(
            quantized_plaintext,
            num_bits_per_batch,
            element_bits,
            additional_bits
        )
        ciphertext = flashe_encrypt(quantized_plaintext)
    elif algorithm == "Paillier":
        ciphertext = paillier_encrypt(quantized_plaintext)
    elif algorithm == "Paillier+batch":
        shape = quantized_plaintext.shape
        quantized_plaintext = _static_batching_padding(
            quantized_plaintext,
            paillier.key_length,
            element_bits,
            additional_bits
        )
        ciphertext = paillier_encrypt(quantized_plaintext)
    elif algorithm == "BFV":
        shape = quantized_plaintext.shape
        ciphertext = bfv_encrypt(quantized_plaintext)
    elif algorithm == "BFV+batch":
        quantized_plaintext = quantized_plaintext.astype(np.int64)
        shape = quantized_plaintext.shape
        quantized_plaintext = quantized_plaintext.flatten()
        ciphertext = bfv_batch_encrypt(quantized_plaintext)
    elif algorithm == "CKKS":
        ciphertext = ckks_encrypt(plaintext)
    elif algorithm == "CKKS+batch":
        ciphertext = ckks_batch_encrypt(plaintext)
        
    encryption_time = time.time()
    encryption_duration = round(encryption_time - begin_time, 4)
    print(f'\t\tEncryption: {encryption_duration} sec')
    
    ### Step 2/5: Measure ciphertext size
    if algorithm == "FLASHE" or algorithm == "FLASHE+batch":
        compressed_ciphertext = compress_multi(ciphertext.flatten().astype(object), num_bits_per_batch)
    elif algorithm == "Paillier" or algorithm == "Paillier+batch":
        compressed_ciphertext = compress_multi(ciphertext.flatten().astype(object), num_bits_per_batch)
    elif algorithm == "BFV" or algorithm == "BFV+batch" or algorithm == "CKKS" or algorithm == "CKKS+batch":
        compressed_ciphertext = ciphertext
#     a_e_b = compress_pickle.dumps(a_e_c, 'bz2')
    compressed_ciphertext_bytes = pickle.dumps(compressed_ciphertext)
    ciphertext_size = len(compressed_ciphertext_bytes)
    
    print(f'\t\tPlaintext {plaintext_size} bytes')
    print(f'\t\tCiphertext {ciphertext_size} bytes')
    t_c = time.time()
    
    ### Step 3/5: Addition
    ciphertext_list = [ciphertext] * num_clients
    if algorithm == "FLASHE" or algorithm == "FLASHE+batch":
        mod = 1 << 128
        aggregated_ciphertext = reduce(lambda x, y: (x + y) % mod, ciphertext_list)
    elif algorithm == "Paillier" or algorithm == "Paillier+batch":
        mod = paillier.get_n() ** 2
        aggregated_ciphertext = reduce(lambda x, y: (x * y) % mod, ciphertext_list)  # * instead of + !
    elif algorithm == "BFV":
        if no_add:
            aggregated_ciphertext = ciphertext  # skip aggregation as indicated by no_add
        else:
            aggregated_ciphertext = bfv.sum(ciphertext_list)
    elif algorithm == "BFV+batch":
        aggregated_ciphertext = bfv_2.sum(ciphertext_list)
    elif algorithm == "CKKS":
        aggregated_ciphertext = ckks.sum_no_batch(ciphertext_list)
    elif algorithm == "CKKS+batch":
        aggregated_ciphertext = ckks.sum(ciphertext_list)
    
    addition_time = time.time()
    addition_duration = round(addition_time - encryption_time, 4)
    print(f'\t\tAddition: {addition_duration} sec')
    
    ### Step 4/5: Decryption
    if algorithm == "FLASHE":
        quantized_sum = flashe_decrypt(aggregated_ciphertext)
    elif algorithm == "FLASHE+batch":
        quantized_sum = flashe_decrypt(aggregated_ciphertext)
        quantized_sum = _static_unbatching_padding(
            quantized_sum, num_bits_per_batch, element_bits, additional_bits
        )
        quantized_sum = quantized_sum[:(int(np.prod(shape)))]
        quantized_sum = quantized_sum.reshape(shape)
    elif algorithm == "Paillier":
        quantized_sum = paillier_decrypt(aggregated_ciphertext)
    elif algorithm == "Paillier+batch":
        quantized_sum = paillier_decrypt(aggregated_ciphertext)
        quantized_sum = _static_unbatching_padding(
            quantized_sum, paillier.key_length, element_bits, additional_bits
        )
        quantized_sum = quantized_sum[:(int(np.prod(shape)))]
        quantized_sum = quantized_sum.reshape(shape)
    elif algorithm == "BFV":
        quantized_sum = bfv_decrypt(aggregated_ciphertext)
    elif algorithm == "BFV+batch":
        quantized_sum = bfv_batch_decrypt(aggregated_ciphertext)
        quantized_sum = np.array(quantized_sum, dtype=np.int64)
        quantized_sum = quantized_sum[:(int(np.prod(shape)))]
        quantized_sum = quantized_sum.reshape(shape)
    elif algorithm == "CKKS":
        quantized_sum = ckks_decrypt(aggregated_ciphertext)
    elif algorithm == "CKKS+batch":
        quantized_sum = ckks_batch_decrypt(aggregated_ciphertext)
    
    decryption_time = time.time()
    decryption_duration = round(decryption_time - addition_time, 4)
    print(f'\t\tDecryption: {decryption_duration} sec')
    
    ### Step 5/5: Dequantization
    if algorithm == "CKKS":
        final_sum = [quantized_sum.squeeze()]
    elif algorithm == "CKKS+batch":
        final_sum = [quantized_sum]
    else:
        if algorithm == "BFV" or algorithm == "BFV+batch":
            quantized_sum = np.array(quantized_sum).reshape(shape)
        print(f'\t\tQuantized sum (first 5): {quantized_sum[0][:5]}')
        print(f'\t\tQuantized sum (last 5): = {quantized_sum[0][-5:]}')
        final_sum = unquantize(quantized_sum, alphas, element_bits, num_clients)
    
    print(f'\t\tFinal sum (first 5): {[round(e, 4) for e in final_sum[0][:5]]}')
    print(f'\t\tFinal sum (last 5): = {[round(e, 4) for e in final_sum[0][-5:]]}')
    return encryption_duration, addition_duration, decryption_duration, ciphertext_size

## 3. Utilities

In [11]:
def rec_d():
    return defaultdict(rec_d)

def chunks_idx(l, n):
    d, r = divmod(len(l), n)
    for i in range(n):
        si = (d+1)*(i if i < r else r) + d*(0 if i < r else i - r)
        yield si, si+(d+1 if i < r else d)

def _compress(flatten_array, num_bits):
    res = 0
    l = len(flatten_array)
    for element in flatten_array:
        res <<= num_bits
        res += element

    return res, l

def compress_multi(flatten_array, num_bits):
    l = len(flatten_array)
    
    pool_inputs = []
    sizes = []
    pool = Pool(MAGIC_N_JOBS)
    
    for begin, end in chunks_idx(range(l), MAGIC_N_JOBS):
        sizes.append(end - begin)
        
        pool_inputs.append([flatten_array[begin:end], num_bits])

    pool_outputs = pool.starmap(_compress, pool_inputs)
    pool.close()
    pool.join()
    
    res = 0

    for idx, output in enumerate(pool_outputs):
        res +=  output[0] << (int(np.sum(sizes[idx + 1:])) * num_bits)
    
    num_bytes = (num_bits * l - 1) // 8 + 1
    res = res.to_bytes(num_bytes, 'big')
    return res, l

## 4. Evaluation

In [12]:
# Preprocessing: Data Generation and Quantization

plaintext_dict = {}
plaintext_size_dict = {}
quantized_plaintext_dict = {}
baseline_sum_dict = {}
for plaintext_length in plaintext_lengths:
    print(f'[CASE] {plaintext_length}')

    ### Generate a random array as plaintext
    begin_time = time.time()
    plaintext = np.random.random(plaintext_length)
    print(f'\tPlaintext (first 5): {[round(e, 4) for e in plaintext[:5]]}')  # for preview
    print(f'\tPlaintext (last 5): {[round(e, 4) for e in plaintext[-5:]]}')
    generate_time = time.time()
    print(f'\tGenerate data: {round(generate_time - begin_time, 4)} sec')

    ### Quantize the array
    ### (Need to reshape the data to fit the interface of `quantize`)
    quantized_plaintext, alphas = quantize([[plaintext]], element_bits)
    quantized_plaintext = quantized_plaintext[0]

    ### Measure the size of **the sum of** such quantized arrays as the plaintext size
    ### a) We do padding (using actual_elements_bits instead of elements_bits) 
    ### to be consistent with the ciphertext which can be the sum and overflow
    ### b) We do compression to measure the precise size
    big_number = compress_multi(quantized_plaintext[0], actual_element_bits) 
    #     a_b = compress_pickle.dumps(big_number, 'bz2')
    bytes_representation = pickle.dumps(big_number)
    plaintext_size = len(bytes_representation)

    ### Then calculate the baseline sum, using plaintext addition
    baseline_sum = (np.array(plaintext) * num_clients).tolist()

    ### Bookkeep what we have computed so far
    plaintext_dict[plaintext_length] = plaintext
    quantized_plaintext_dict[plaintext_length] = quantized_plaintext
    plaintext_size_dict[plaintext_length] = plaintext_size
    baseline_sum_dict[plaintext_length] = baseline_sum

[CASE] 16384
	Plaintext (first 5): [0.5032, 0.6625, 0.8493, 0.6363, 0.7524]
	Plaintext (last 5): [0.3907, 0.8782, 0.3004, 0.6378, 0.8549]
	Generate data: 0.0004 sec
[CASE] 65536
	Plaintext (first 5): [0.7883, 0.3875, 0.4672, 0.3907, 0.2929]
	Plaintext (last 5): [0.8693, 0.7826, 0.9076, 0.8279, 0.6174]
	Generate data: 0.0008 sec
[CASE] 262144
	Plaintext (first 5): [0.6993, 0.8836, 0.1541, 0.825, 0.1842]
	Plaintext (last 5): [0.4418, 0.0194, 0.418, 0.8075, 0.4643]
	Generate data: 0.0034 sec


In [13]:
# Actual experiments: Encryption, Aggregation, Decryption
result_dict = rec_d()
for algorithm in algorithms:
    for plaintext_length in plaintext_lengths:
        print(algorithm, plaintext_length)
        
        # Show the baseline results
        # and the actual results should be close to them
        baseline_sum = baseline_sum_dict[plaintext_length]
        print(f'\tBaseline sum (first 5): {[round(e, 2) for e in baseline_sum[:5]]}')
        print(f'\tBaseline sum (last 5): {[round(e, 2) for e in baseline_sum[-5:]]}')
        
        # Skip those experiments that are observed to render OOM errors
        no_add = False
        if plaintext_length > 65536:
            if algorithm == "BFV":
                continue
        if plaintext_length > 16384:
            if algorithm == "CKKS":
                continue
            if algorithm == "BFV":
                no_add = True

        # Unit test: Encryption -> Aggregation -> Decryption
        plaintext_size = plaintext_size_dict[plaintext_length]
        encrypt_time, aggregate_time, decrypt_time, ciphertext_size \
            = test_an_algorithm(
                plaintext=plaintext_dict[plaintext_length],
                quantized_plaintext=quantized_plaintext_dict[plaintext_length],
                alphas=alphas,
                algorithm=algorithm,
                plaintext_size=plaintext_size,
                no_add=no_add
            )

        # Bookkeeping the results
        result_dict[plaintext_length][algorithm]['Encryption'] = encrypt_time
        result_dict[plaintext_length][algorithm]['Addition'] = aggregate_time
        result_dict[plaintext_length][algorithm]['Decryption'] = decrypt_time
        result_dict[plaintext_length][algorithm]['Plaintext'] = plaintext_size
        result_dict[plaintext_length][algorithm]['Ciphertext'] = ciphertext_size

Paillier 16384
	Baseline sum (first 5): [5.03, 6.62, 8.49, 6.36, 7.52]
	Baseline sum (last 5): [3.91, 8.78, 3.0, 6.38, 8.55]
	Start testing Paillier
		Encryption: 23.8365 sec
		Plaintext 40976 bytes
		Ciphertext 8386576 bytes
		Addition: 6.7235 sec
		Decryption: 13.8273 sec
		Quantized sum (first 5): [226440 298120 327670 286340 327670]
		Quantized sum (last 5): = [175800 327670 135200 287020 327670]
		Final sum (first 5): [4.4377, 5.8425, 6.4216, 5.6116, 6.4216]
		Final sum (last 5): = [3.4453, 6.4216, 2.6496, 5.625, 6.4216]
Paillier 65536
	Baseline sum (first 5): [7.88, 3.87, 4.67, 3.91, 2.93]
	Baseline sum (last 5): [8.69, 7.83, 9.08, 8.28, 6.17]
	Start testing Paillier
		Encryption: 94.8211 sec
		Plaintext 163858 bytes
		Ciphertext 33546258 bytes
		Addition: 28.0184 sec
		Decryption: 55.1706 sec
		Quantized sum (first 5): [327670 186390 224740 187960 140890]
		Quantized sum (last 5): = [327670 327670 327670 327670 297010]
		Final sum (first 5): [6.4216, 3.6528, 4.4044, 3.6836, 2.76

The following operations are disabled in this setup: matmul, matmul_plain, enc_matmul_plain, conv2d_im2col.
If you need to use those operations, try increasing the poly_modulus parameter, to fit your input.
		Encryption: 0.3331 sec
		Plaintext 655378 bytes
		Ciphertext 27677454 bytes
		Addition: 0.9521 sec
		Decryption: 0.2311 sec
		Quantized sum: [array([6.99319265, 8.83625495, 1.54147346, ..., 4.18025343, 8.07500203,
       4.64332092])]
		Final sum (first 5): [6.9932, 8.8363, 1.5415, 8.2502, 1.8424]
		Final sum (last 5): = [4.4176, 0.194, 4.1803, 8.075, 4.6433]
FLASHE 16384
	Baseline sum (first 5): [5.03, 6.62, 8.49, 6.36, 7.52]
	Baseline sum (last 5): [3.91, 8.78, 3.0, 6.38, 8.55]
	Start testing FLASHE
		Encryption: 2.6336 sec
		Plaintext 40976 bytes
		Ciphertext 40976 bytes
		Addition: 7.1162 sec
		Decryption: 2.4022 sec
		Quantized sum (first 5): [226440 298120 327670 286340 327670]
		Quantized sum (last 5): = [175800 327670 135200 287020 327670]
		Final sum (first 5): [4.4377, 5

## 5. Summary of Results

In [14]:
save_path = os.path.join(os.getcwd(), 'big-table.bin')
pickle.dump(result_dict, open(save_path, 'wb'))

In [15]:
for plaintext_length in plaintext_lengths:
    print(f'[CASE] {plaintext_length}')
    
    table = {}
    for col in cols:
        l = []
        for algorithm in algorithms:
            # Skip those experiments that are observed to render OOM errors
            if plaintext_length > 16384:
                if algorithm == "CKKS":
                    continue
            if plaintext_length > 65536:
                if algorithm == "BFV":
                    continue
                    
            # Formatting the results
            if col in ['Plaintext', 'Ciphertext']:  # Size of objects
                kb = result_dict[plaintext_length][algorithm][col] / 1024
                if kb >= 1024:
                    mb = kb / 1024
                    if mb > 1024:
                        gb = mb / 1024
                        l.append('{:.2f} GB'.format(gb))
                    else:
                        l.append('{:.2f} MB'.format(mb))
                else:
                    l.append('{:.2f} KB'.format(kb))
            else:  # Time in seconds
                l.append('{:.2f} s'.format(result_dict[plaintext_length][algorithm][col]))
        table[col] = l

    if plaintext_length > 65536:
        data_frame = pd.DataFrame(data=table, index=algorithms[:2] + algorithms[3:4] + algorithms[5:])
    elif plaintext_length > 16384:
        data_frame = pd.DataFrame(data=table, index=algorithms[:4] + algorithms[5:])
    else:
        data_frame = pd.DataFrame(data=table, index=algorithms)
    print(data_frame)

[CASE] 16384
               Plaintext Ciphertext Encryption Decryption  Addition
Paillier        40.02 KB    8.00 MB    23.84 s    13.83 s    6.72 s
Paillier+batch  40.02 KB   96.49 KB     0.49 s     0.38 s    0.71 s
BFV             40.02 KB  513.09 MB    35.62 s    35.28 s    7.49 s
BFV+batch       40.02 KB    1.00 MB     1.15 s     1.14 s    0.01 s
CKKS            40.02 KB    6.60 GB    76.28 s    52.79 s  212.57 s
CKKS+batch      40.02 KB    1.65 MB     0.02 s     0.01 s    0.06 s
FLASHE          40.02 KB   40.02 KB     2.63 s     2.40 s    7.12 s
[CASE] 65536
                Plaintext Ciphertext Encryption Decryption Addition
Paillier        160.02 KB   31.99 MB    94.82 s    55.17 s  28.02 s
Paillier+batch  160.02 KB  385.92 KB     1.33 s     0.83 s   0.73 s
BFV             160.02 KB    2.00 GB   146.45 s   138.92 s   1.32 s
BFV+batch       160.02 KB    4.00 MB     1.33 s     1.25 s   0.05 s
CKKS+batch      160.02 KB    6.60 MB     0.08 s     0.06 s   0.22 s
FLASHE          160.02