In [1]:
import numpy as np
import time
from multiprocessing import Pool, cpu_count
from functools import reduce
import compress_pickle

In [2]:
### Configuration
test_cases = [16384, 65536, 262144]
# test_cases = [4096, 16384, 65536, 262144, 1048576]
# test_cases = [10000, 100000, 1000000]
# test_cases = [1024, 2048, 4096, 8192]
# test_cases = [100000]
# there is a limit regarding BFV+batch
# ref: https://github.com/OpenMined/TenSEAL/issues/19#issuecomment-617067414
# algos = ["FLASHE+batch"]
algos = ["FLASHE", "Paillier", "Paillier+batch", "BFV", "BFV+batch", "CKKS", "CKKS+batch"]
# algos = ["BFV+batch"]
# algos = ["BFV", "BFV+batch"]
# algos = ["FLASHE", "FLASHE+batch"]

# when we have 10 operands for aggregation
# for BFV without bathcing, the upper bound of element_bits is 61 -- determined by the length of C long
# for BFV with batching, the upperbound of element_bits is 27 -- may be determined by the plaintext modulus 1964769281 
element_bits = 16
num_clients = 10
MAGIC_N_JOBS = 50
additional_bits = int(np.ceil(np.log2(num_clients + 1)))
int_bits = element_bits + additional_bits

In [3]:
### Quantization
import sys
sys.path.append("../")
from federatedml.secureprotol.jzf_aciq import ACIQ
from federatedml.secureprotol.jzf_quantize import \
    _static_quantize_padding, _static_unquantize_padding

def get_alpha_r_max(trainable_list, element_bits):
    local_min_list = []
    local_max_list = []
    size_list = []
    for idx, trainable in enumerate(trainable_list):
        local_min = []
        local_max = []
        for layer in trainable:
            local_min.append(np.amin(layer))
            local_max.append(np.amax(layer))

            if idx == 0:
                size_list.append(layer.size)

        local_min_list.append(np.array(local_min))
        local_max_list.append(np.array(local_max))

    local_min_list = np.array(local_min_list)
    local_max_list = np.array(local_max_list)

    min_list = np.amin(local_min_list, 0)
    max_list = np.amax(local_max_list, 0)

    n = len(trainable_list)
    aciq = ACIQ(element_bits)

    alpha_list = []
    r_max_list = []
    layer_cnt = 0
    for min, max in zip(min_list, max_list):
        alpha = aciq.get_alpha_gaus(min, max, size_list[layer_cnt])

        r_max = alpha * num_clients

        alpha_list.append(alpha)
        r_max_list.append(r_max)
        layer_cnt += 1

    return alpha_list, r_max_list

def quantize(trainable_list, element_bits):
    n = len(trainable_list)
    alpha_list, r_max_list = get_alpha_r_max(trainable_list,
                                            element_bits)
    quantized = []
    for _, trainable in enumerate(trainable_list):
        quantized_layers = []
        for idx, layer in enumerate(trainable):
            shape = layer.shape
            layer_flatten = layer.flatten()
            ret = _static_quantize_padding(layer_flatten,
                                           alpha_list[idx],
                                           element_bits,
                                           n)
            ret = np.array(ret).reshape(shape)
            quantized_layers.append(ret)
        quantized.append(np.array(quantized_layers))
    return np.array(quantized), np.array(alpha_list)

def unquantize(trainable, alpha_list, element_bits, num_clients):
    layers = []
    for idx, layer in enumerate(trainable):
        shape = layer.shape
        layer_flatten = layer.flatten()
        ret = _static_unquantize_padding(layer_flatten,
                                         alpha_list[idx],
                                         element_bits,
                                         num_clients)
        ret = np.array(ret).reshape(shape)
        layers.append(ret)

    layers = np.array(layers)
    return layers

In [4]:
from federatedml.secureprotol.jzf_quantize import \
    _static_batching_padding, _static_unbatching_padding

In [5]:
def chunks_idx(l, n):
    d, r = divmod(len(l), n)
    for i in range(n):
        si = (d+1)*(i if i < r else r) + d*(0 if i < r else i - r)
        yield si, si+(d+1 if i < r else d)

In [6]:
def _compress(flatten_array, num_bits):
    res = 0
    l = len(flatten_array)
    for element in flatten_array:
        res <<= num_bits
        res += element

    return res, l

In [7]:
def compress_multi(flatten_array, num_bits):
    l = len(flatten_array)
    
    pool_inputs = []
    sizes = []
    pool = Pool(MAGIC_N_JOBS)
    
    for begin, end in chunks_idx(range(l), MAGIC_N_JOBS):
        sizes.append(end - begin)
        
        pool_inputs.append([flatten_array[begin:end], num_bits])

    pool_outputs = pool.starmap(_compress, pool_inputs)
    pool.close()
    pool.join()
    
    res = 0

    for idx, output in enumerate(pool_outputs):
        res +=  output[0] << (int(np.sum(sizes[idx + 1:])) * num_bits)
    
    num_bytes = (num_bits * l - 1) // 8 + 1
    res = res.to_bytes(num_bytes, 'big')
    return res, l

In [8]:
### Encryption
from federatedml.secureprotol.jzf_flashe import FlasheCipher
flashe = FlasheCipher(20)
flashe.set_num_clients(num_clients)
flashe.generate_prp_seed()
flashe.set_iter_index(0)
flashe.idx = 0

def flashe_encrypt(value):
    return flashe.encrypt(value)

def flashe_decrypt(value):
    flashe.set_idx_list(raw_idx_list=[0] * num_clients, mode="decrypt")
    return flashe.decrypt(value)

from federatedml.secureprotol.jzf_paillier import PaillierCipher
paillier = PaillierCipher()
paillier.generate_key(n_length=2048)

def paillier_encrypt(value):
    return paillier.encrypt(value)

def paillier_decrypt(value):
    return paillier.decrypt(value)

from federatedml.secureprotol.jzf_bfv import BFVCipher
bfv = BFVCipher(p=128, m=2048)
# # bfv = BFVCipher(p=1964769281, m=8192, flagBatching=True)
bfv.generate_key()

def bfv_encrypt(value):
    return bfv.encrypt(value)

def bfv_decrypt(value):
    return bfv.decrypt(value)

from federatedml.secureprotol.jzf_bfv import BFVCipher
# p needs to be prime and p-1 must be multiple of 2*m
# bfv_2 = BFVCipher(p=65537, m=2048, flagBatching=True)
bfv_2 = BFVCipher(p=1964769281, m=8192, flagBatching=True)
bfv_2.generate_key()

def bfv_batch_encrypt(value):
    return bfv_2.encrypt(value)

def bfv_batch_decrypt(value):
    return bfv_2.decrypt(value)

print(int(1964769281).bit_length())

from federatedml.secureprotol.jzf_ckks import CKKSCipher
ckks = CKKSCipher(8192, None, 2 ** 40)

def ckks_batch_encrypt(value):
    return ckks.encrypt(value)

def ckks_batch_decrypt(value):
    return ckks.decrypt(value)

def ckks_encrypt(value):
    return ckks.encrypt_no_batch(value)

def ckks_decrypt(value):
    return ckks.decrypt_no_batch(value)

ModuleNotFoundError: No module named 'federatedml.secureprotol.affine'

In [None]:
print((paillier.get_n() ** 2).bit_length())

In [None]:
def test_an_algo(a, a_q, alphas, algo, no_add=False):
    print(f'\t{algo}')
    
    if algo == "FLASHE" or algo == "FLASHE+batch":
        int_bits = flashe.int_bits
    elif algo == "Paillier" or algo == "Paillier+batch":
        int_bits = (paillier.get_n() ** 2).bit_length()
    else:
        int_bits = None
    
    begin = time.time()
    ### perform encryption
    if algo == "FLASHE":
        a_e = flashe_encrypt(a_q)
    elif algo == "FLASHE+batch":
        shape = a_q.shape
        a_q = _static_batching_padding(a_q, int_bits, element_bits, additional_bits)
        a_e = flashe_encrypt(a_q)
    elif algo == "Paillier":
        a_e = paillier_encrypt(a_q)
    elif algo == "Paillier+batch":
        shape = a_q.shape
        a_q = _static_batching_padding(a_q, paillier.key_length, element_bits, additional_bits)
        a_e = paillier_encrypt(a_q)
    elif algo == "BFV":
        shape = a_q.shape
        a_e = bfv_encrypt(a_q)
    elif algo == "BFV+batch":
        a_q = a_q.astype(np.int64)
        shape = a_q.shape
        a_q = a_q.flatten()
        a_e = bfv_batch_encrypt(a_q)
    elif algo == "CKKS":
        a_e = ckks_encrypt(a)
    elif algo == "CKKS+batch":
        a_e = ckks_batch_encrypt(a)
        
    t_e = time.time()
    print(f'\t\tEncryption: {t_e - begin} sec')
    
    if algo == "FLASHE" or algo == "FLASHE+batch":
        a_e_c = compress_multi(a_e.flatten().astype(object), int_bits)
    elif algo == "Paillier" or algo == "Paillier+batch":
        a_e_c = compress_multi(a_e.flatten().astype(object), int_bits)
    elif algo == "BFV" or algo == "BFV+batch" or algo == "CKKS" or algo == "CKKS+batch":
        a_e_c = a_e

#     a_e_b = compress_pickle.dumps(a_e_c, 'bz2')
    a_e_b = pickle.dumps(a_e_c)
    l_e_c = len(a_e_b)
    
    print(f'\t\tPlaintext {l_c} bytes')
    print(f'\t\tCiphertext {l_e_c} bytes')
    t_c = time.time()
    
    ### perform addition
    operands = [a_e] * num_clients
    if algo == "FLASHE" or algo == "FLASHE+batch":
        mod = 1 << 128
        a_a = reduce(lambda x, y: (x + y) % mod, operands)
    elif algo == "Paillier" or algo == "Paillier+batch":
        mod = paillier.get_n() ** 2
        a_a = reduce(lambda x, y: (x * y) % mod, operands)  # * instead of + !
    elif algo == "BFV":
        if no_add:
            a_a = a_e
        else:
            a_a = bfv.sum(operands)
    elif algo == "BFV+batch":
        a_a = bfv_2.sum(operands)
    elif algo == "CKKS":
        a_a = ckks.sum_no_batch(operands)
    elif algo == "CKKS+batch":
        a_a = ckks.sum(operands)
    
    t_a = time.time()
    print(f'\t\tAddition: {t_a - t_c} sec')
    
    ### perform decryption
    if algo == "FLASHE":
        a_d = flashe_decrypt(a_a)
    elif algo == "FLASHE+batch":
        a_d = flashe_decrypt(a_a)
        a_d = _static_unbatching_padding(a_d, int_bits, element_bits, additional_bits)
        a_d = a_d[:(int(np.prod(shape)))]
        a_d = a_d.reshape(shape)
    elif algo == "Paillier":
        a_d = paillier_decrypt(a_a)
    elif algo == "Paillier+batch":
        a_d = paillier_decrypt(a_a)
        a_d = _static_unbatching_padding(a_d, paillier.key_length, element_bits, additional_bits)
        a_d = a_d[:(int(np.prod(shape)))]
        a_d = a_d.reshape(shape)
    elif algo == "BFV":
        a_d = bfv_decrypt(a_a)
    elif algo == "BFV+batch":
        a_d = bfv_batch_decrypt(a_a)
        a_d = np.array(a_d, dtype=np.int64)
        a_d = a_d[:(int(np.prod(shape)))]
        a_d = a_d.reshape(shape)
    elif algo == "CKKS":
        a_d = ckks_decrypt(a_a)
    elif algo == "CKKS+batch":
        a_d = ckks_batch_decrypt(a_a)
    
    t_d = time.time()
    print(f'\t\tDecryption: {t_d - t_a} sec')
    
    ### perform unquantization
    if algo == "CKKS" or "CKKS+batch":
        a_u = [a_d]
    else:
        if algo == "BFV" or algo == "BFV+batch":
            a_d = np.array(a_d).reshape(shape)
        print(f'\t\ta_d = {a_d[0][:5]}')
        print(f'\t\ta_d = {a_d[0][-5:]}')
        a_u = unquantize(a_d, alphas, element_bits, num_clients)
        t_u = time.time()
#     print(f'\t\tUnquantization: {t_u - t_d} sec')

    
    print(f'\t\ta * 10 = {a_u[0][:5]}')
    print(f'\t\ta * 10 = {a_u[0][-5:]}')
    
    return t_e - begin, t_a - t_c, t_d - t_a, l_e_c

In [None]:
### for file saving
from collections import defaultdict
import os
import pickle

def rec_d():
    return defaultdict(rec_d)

result_dict = rec_d()

save_path = os.path.join(os.getcwd(), 'big-table.bin')

In [None]:
a_dict = {}
a_q_dict = {}
l_c_dict = {}
expected_dict = {}
for case in test_cases:
    print(f'[CASE] {case}')
    begin = time.time()

    ### generate data
    a = np.random.random(case)
    print(f'\ta = {a[:5]}')

    t_g = time.time()
    print(f'\tGenerate data: {t_g - begin} sec')

    a_q, alphas = quantize([[a]], element_bits)
    a_q = a_q[0]
    a_c = compress_multi(a_q[0], int_bits)
    #     a_b = compress_pickle.dumps(a_c, 'bz2')
    a_b = pickle.dumps(a_c)
    l_c = len(a_b)
    
    l_c_dict[case] = l_c

    a_b_float = pickle.dumps(a)
    l_c_float = len(a_b_float)

    expected = (np.array(a) * num_clients).tolist()
    
    a_q_dict[case] = a_q
    a_dict[case] = a
    expected_dict[case] = expected


for algo in algos:
    for case in test_cases:
        print(algo, case)
        expected = expected_dict[case]
        print(f'\ta * 10 = {expected[:5]}')
        print(f'\ta * 10 = {expected[-5:]}')
        
        no_add = False
        if case > 65536:
            if algo == "BFV":
                continue
        if case > 16384:
            if algo == "CKKS":
                continue
            if algo == "BFV":
                no_add = True

        t_e, t_a, t_d, l_e_c = test_an_algo(a_dict[case], a_q_dict[case], alphas, algo, no_add)

        result_dict[case][algo]['t_e'] = t_e
        result_dict[case][algo]['t_a'] = t_a
        result_dict[case][algo]['t_d'] = t_d
        result_dict[case][algo]['l_c'] = l_c_dict[case]
        result_dict[case][algo]['l_e_c'] = l_e_c

In [None]:
displayed_algos = ["Paillier", "Paillier+batch", "BFV", "BFV+batch", "CKKS", "CKKS+batch", "FLASHE"]

In [None]:
displayed_test_cases = [16384, 65536, 262144]
# displayed_test_cases = [16384, 65536, 262144, 1048576]

In [None]:
import pandas as pd

cols = {
    'l_c' : 'Plaintext',
    'l_e_c' : 'Ciphertext',
    't_e' : 'Encryption',
    't_d' : 'Decryption',
    't_a' : 'Addition'
}

for case in displayed_test_cases:
    print(f'[CASE] {case}')
    t = {}
    for c in cols.keys():
        l = []
        for algo in displayed_algos:
            if case > 16384:
                if algo == "CKKS":
                    continue
            if case > 65536:
                if algo == "BFV":
                    continue
            
            if 'l' in c:
#                 print(case, algo, c)
                kb = result_dict[case][algo][c] / 1024
                if kb >= 1024:
                    mb = kb / 1024
                    if mb > 1024:
                        gb = mb / 1024
                        l.append('{:.2f} GB'.format(gb))
                    else:
                        l.append('{:.2f} MB'.format(mb))
                else:
                    l.append('{:.2f} KB'.format(kb))
            else:
                l.append('{:.2f} s'.format(result_dict[case][algo][c]))
        t[cols[c]] = l

    if case > 65536:
        d = pd.DataFrame(data=t, index=displayed_algos[:2] + displayed_algos[3:4] + displayed_algos[5:])
    elif case > 16384:
        d = pd.DataFrame(data=t, index=displayed_algos[:4] + displayed_algos[5:])
    else:
        d = pd.DataFrame(data=t, index=displayed_algos)
    print(d)
    
#     print('C/P:')
#     if case > 65536:
#         for algo in displayed_algos[:2] + displayed_algos[3:]:
#             print(f'{algo} {result_dict[case][algo]["l_e_c"] / result_dict[case][algo]["l_c"]}')
#     else:
#         for algo in displayed_algos:
#             print(f'{algo} {result_dict[case][algo]["l_e_c"] / result_dict[case][algo]["l_c"]}')

In [None]:
pickle.dump(result_dict, open(save_path, 'wb'))

In [None]:
print(128 / 20)
print((128 / 5) / 20)
print(4096 / 20)
print((4096 / 85) / 20)