In [1]:
from ckks.ckks_parameters import CKKSParameters
from ckks.ckks_key_generator import CKKSKeyGenerator
from util.polynomial import Polynomial
from ckks.ckks_decryptor import CKKSDecryptor
from ckks.ckks_encryptor import CKKSEncryptor
from ckks.ckks_evaluator import CKKSEvaluator
from util.plaintext import Plaintext
import numpy as np

# Test 1: CKKS Roundrip

In [2]:
# Test setup
from ckks.ckks_encoder import CKKSEncoder


poly_degree = 4
ciph_modulus = 1 << 40
big_modulus = 1 << 1200 # Used for bootstrapping
scaling_factor = 1 << 30
params = CKKSParameters(poly_degree=poly_degree,
                        ciph_modulus=ciph_modulus,
                        big_modulus=big_modulus,
                        scaling_factor=scaling_factor)
message = [4.0 + 0j, 3 + 0j]
print("Message:", message)


encoder = CKKSEncoder(params) # Contains encode and decode functions
poly = encoder.encode(message, params.scaling_factor)
# print("Encoded polynomial:", str(poly))


key_generator = CKKSKeyGenerator(params)
public_key = key_generator.public_key
secret_key = key_generator.secret_key


encryptor = CKKSEncryptor(params, public_key, secret_key)
encrypted_poly = encryptor.encrypt(poly)
# print("Encrypted polynomial:", str(encrypted_poly) )

decryptor = CKKSDecryptor(params, secret_key)
decrypted_poly = decryptor.decrypt(encrypted_poly)
# print("Decrypted polynomial:", str(decrypted_poly) )

decoded_message = encoder.decode(poly)
print("Decoded message:", decoded_message)


assert len(decoded_message) == len(message), "Decoded message has incorrect length"
print("Decoded message matches original within tolerance:", np.allclose(decoded_message, message, atol=1e-1, rtol=1e-1))


Message: [(4+0j), (3+0j)]
Decoded message: [(3.999999998686854+6.585445244677857e-10j), (3.000000001313146-6.585445244677857e-10j)]
Decoded message matches original within tolerance: True


# Test 2: Addition Is Homomorphic

In [None]:
# Test setup

poly_degree = 4
ciph_modulus = 1 << 40
big_modulus = 1 << 1200 # Used for bootstrapping
scaling_factor = 1 << 30
params = CKKSParameters(poly_degree=poly_degree,
                        ciph_modulus=ciph_modulus,
                        big_modulus=big_modulus,
                        scaling_factor=scaling_factor)
message1 = [4.0 + 0j, 3 + 0j]
print("Message1:", message)
message2 = [4.0 + 0j, 3 + 0j]
print("Message2:", message)



key_generator = CKKSKeyGenerator(params)
public_key = key_generator.public_key
secret_key = key_generator.secret_key

encoder = CKKSEncoder(params) # Contains encode and decode functions
plain_poly1 = encoder.encode(message1, params.scaling_factor)
plain_poly2 = encoder.encode(message2, params.scaling_factor)


encryptor = CKKSEncryptor(params, public_key, secret_key)
encrypted_poly1 = encryptor.encrypt(plain_poly1)
encrypted_poly2 = encryptor.encrypt(plain_poly2)

evaluator = CKKSEvaluator(params)
encrypted_message_sum = evaluator.add(encrypted_poly1, encrypted_poly2)

decryptor = CKKSDecryptor(params, secret_key)
decrypted_message_sum = decryptor.decrypt(encrypted_message_sum)

decoded_message_sum = encoder.decode(decrypted_message_sum)

true_message_sum = [m1 + m2 for m1, m2 in zip(message1, message2)]
print("True message sum:", true_message_sum)


print("Decoded message Sum:", str(decoded_message_sum) )

assert len(decoded_message_sum) == len(true_message_sum), "Decoded message has incorrect length"
print("Addition is homomorphic:", np.allclose(decoded_message_sum, true_message_sum,  rtol=1e-2, atol=1e-1))  

Message1: [(4+0j), (3+0j)]
Message2: [(4+0j), (3+0j)]
True message sum: [(8+0j), (6+0j)]
Decoded message Sum: [(7.999999999894898+1.5898670713276886e-09j), (6.000000003830392+2.7277807790326847e-10j)]
Addition is homomorphic: True


# Test 3: Multiplication Is Homomorphic

In [None]:
# Test setup


poly_degree = 4
# NOTE: ciph_modulus of 1 << 40 is too small and produces innacurate results
ciph_modulus = 1 << 800
big_modulus = 1 << 1200
scaling_factor = 1 << 30
params = CKKSParameters(poly_degree=poly_degree,
                        ciph_modulus=ciph_modulus,
                        big_modulus=big_modulus,
                        scaling_factor=scaling_factor)
message1 = [4.0 + 0j, 3 + 0j]
print("Message1:", message)
message2 = [4.0 + 0j, 3 + 0j]
print("Message2:", message)


poly_product = plain_poly1.poly.multiply_naive(plain_poly2.poly) # NOTE: for some reason modding using ciph_modulus doesnt work
plain_poly_product = Plaintext(poly_product, params.scaling_factor)

key_generator = CKKSKeyGenerator(params)
public_key = key_generator.public_key
secret_key = key_generator.secret_key
relin_key = key_generator.relin_key

encoder = CKKSEncoder(params) # Contains encode and decode functions
plain_poly1 = encoder.encode(message1, params.scaling_factor)
plain_poly2 = encoder.encode(message2, params.scaling_factor)



encryptor = CKKSEncryptor(params, public_key, secret_key)
encrypted_poly1 = encryptor.encrypt(plain_poly1)
encrypted_poly2 = encryptor.encrypt(plain_poly2)

evaluator = CKKSEvaluator(params)
encrypted_poly_product = evaluator.multiply(encrypted_poly1, encrypted_poly2, relin_key)

decryptor = CKKSDecryptor(params, secret_key)
decrypted_poly_product = decryptor.decrypt(encrypted_poly_product)

decoded_message_product = encoder.decode(decrypted_poly_product)
print("Decoded message product:", str(decoded_message_product) )

true_message_product = [m1 * m2 for m1, m2 in zip(message1, message2)]
print("True message product:", true_message_product)

assert len(decoded_message_product) == len(true_message_product), "Decoded message has incorrect length"
print("Multiplication is homomorphic:", np.allclose(decoded_message_product, true_message_product,  rtol=1e-2, atol=1e-1))  

Message1: [(4+0j), (3+0j)]
Message2: [(4+0j), (3+0j)]
Decoded message product: [(15.99999999212901+1.3170890154554082e-08j), (9.00000000590324-9.878167618951328e-09j)]
True message product: [(16+0j), (9+0j)]
Multiplication is homomorphic: True
