In [1]:
from ckks.ckks_parameters import CKKSParameters
from ckks.ckks_key_generator import CKKSKeyGenerator
from util.polynomial import Polynomial
from ckks.ckks_decryptor import CKKSDecryptor
from ckks.ckks_encryptor import CKKSEncryptor
from ckks.ckks_evaluator import CKKSEvaluator
from util.plaintext import Plaintext
import numpy as np

# Test 1: CKKS Roundrip

In [2]:
# Test setup
from ckks.ckks_encoder import CKKSEncoder


poly_degree = 1024 # Must be a power of 2
ciph_modulus = 1 << 40
big_modulus = 1 << 1200 # Used for bootstrapping
scaling_factor = 1 << 30
params = CKKSParameters(poly_degree=poly_degree,
                        ciph_modulus=ciph_modulus,
                        big_modulus=big_modulus,
                        scaling_factor=scaling_factor)
# Message length must be half of poly_degree
message = [4.0 + 0j] * (poly_degree // 2)
print("Message length:", len(message))
print("First 5 elements:", message[:5])


encoder = CKKSEncoder(params) # Contains encode and decode functions
poly = encoder.encode(message, params.scaling_factor)


key_generator = CKKSKeyGenerator(params)
public_key = key_generator.public_key
secret_key = key_generator.secret_key


encryptor = CKKSEncryptor(params, public_key, secret_key)
encrypted_poly = encryptor.encrypt(poly)

decryptor = CKKSDecryptor(params, secret_key)
decrypted_poly = decryptor.decrypt(encrypted_poly)

decoded_message = encoder.decode(decrypted_poly)
print("Decoded message length:", len(decoded_message))
print("First 5 elements of decoded message:", decoded_message[:5])


assert len(decoded_message) == len(message), "Decoded message has incorrect length"
print("Decoded message matches original within tolerance:", np.allclose(decoded_message, message, atol=1e-1, rtol=1e-1))


Message length: 512
First 5 elements: [(4+0j), (4+0j), (4+0j), (4+0j), (4+0j)]
Polynomial input 1 mod 576460752303439873:  465615904642905654x^1023 + 378109183894462655x^1022 + 188630342986866802x^1021 + 528902832412567865x^1020 + 172952820719485409x^1019 + 10754387630094075x^1018 + 499864752076264351x^1017 + 518591562641816106x^1016 + 176657881568950763x^1015 + 260477528354086788x^1014 + 277908825274068639x^1013 + 372910114468142086x^1012 + 471415748829626441x^1011 + 282302523286423767x^1010 + 484451770808802479x^1009 + 312885459737938868x^1008 + 485450541436525142x^1007 + 86210424759762056x^1006 + 138136348395079888x^1005 + 43625828474445928x^1004 + 231737145844734705x^1003 + 402526781500404069x^1002 + 32138205121227308x^1001 + 154167235312493128x^1000 + 120382304306489049x^999 + 290054855328289300x^998 + 341712784656520137x^997 + 426309148761537395x^996 + 18782368767733864x^995 + 130212650434312798x^994 + 30676628200849678x^993 + 119453220045079778x^992 + 118193248430626075x^991 + 5

TypeError: can only concatenate list (not "Polynomial") to list

# Test 2: Addition Is Homomorphic

In [None]:
# Test setup

poly_degree = 8192 # Must be a power of 2
ciph_modulus = 1 << 40
big_modulus = 1 << 1200 # Used for bootstrapping
scaling_factor = 1 << 30
params = CKKSParameters(poly_degree=poly_degree,
                        ciph_modulus=ciph_modulus,
                        big_modulus=big_modulus,
                        scaling_factor=scaling_factor)
# Message length must be half of poly_degree
message1 = [4.0 + 0j] * (poly_degree // 2)
message2 = [3.0 + 0j] * (poly_degree // 2)
print("Message1 length:", len(message1))
print("Message2 length:", len(message2))
print("First 5 elements of message1:", message1[:5])
print("First 5 elements of message2:", message2[:5])



key_generator = CKKSKeyGenerator(params)
public_key = key_generator.public_key
secret_key = key_generator.secret_key

encoder = CKKSEncoder(params) # Contains encode and decode functions
plain_poly1 = encoder.encode(message1, params.scaling_factor)
plain_poly2 = encoder.encode(message2, params.scaling_factor)


encryptor = CKKSEncryptor(params, public_key, secret_key)
encrypted_poly1 = encryptor.encrypt(plain_poly1)
encrypted_poly2 = encryptor.encrypt(plain_poly2)

evaluator = CKKSEvaluator(params)
encrypted_message_sum = evaluator.add(encrypted_poly1, encrypted_poly2)

decryptor = CKKSDecryptor(params, secret_key)
decrypted_message_sum = decryptor.decrypt(encrypted_message_sum)

decoded_message_sum = encoder.decode(decrypted_message_sum)

true_message_sum = [m1 + m2 for m1, m2 in zip(message1, message2)]
print("True message sum length:", len(true_message_sum))
print("First 5 elements of true sum:", true_message_sum[:5])
print("First 5 elements of decoded sum:", decoded_message_sum[:5])



assert len(decoded_message_sum) == len(true_message_sum), "Decoded message has incorrect length"
print("Addition is homomorphic:", np.allclose(decoded_message_sum, true_message_sum,  rtol=1e-2, atol=1e-1))  

Message1 length: 4096
Message2 length: 4096
First 5 elements of message1: [(4+0j), (4+0j), (4+0j), (4+0j), (4+0j)]
First 5 elements of message2: [(3+0j), (3+0j), (3+0j), (3+0j), (3+0j)]
True message sum length: 4096
First 5 elements of true sum: [(7+0j), (7+0j), (7+0j), (7+0j), (7+0j)]
First 5 elements of decoded sum: [(6.999998500662044-6.755260389074877e-07j), (7.000005640644219-2.998539399498932e-06j), (6.999999784578556-4.941830954177475e-06j), (7.000001408346991-3.117330884167564e-06j), (6.9999968471263045+5.778340193867105e-07j)]
Addition is homomorphic: True


# Test 3: Multiplication Is Homomorphic

In [None]:
# Test setup


poly_degree = 8192 # Must be a power of 2
# NOTE: ciph_modulus of 1 << 40 is too small and produces innacurate results
ciph_modulus = 1 << 800
big_modulus = 1 << 1200
scaling_factor = 1 << 30
params = CKKSParameters(poly_degree=poly_degree,
                        ciph_modulus=ciph_modulus,
                        big_modulus=big_modulus,
                        scaling_factor=scaling_factor)
# Message length must be half of poly_degree
message1 = [4.0 + 0j] * (poly_degree // 2)
message2 = [3.0 + 0j] * (poly_degree // 2)
print("Message1 length:", len(message1))
print("Message2 length:", len(message2))
print("First 5 elements of message1:", message1[:5])
print("First 5 elements of message2:", message2[:5])



poly_product = plain_poly1.poly.multiply_naive(plain_poly2.poly) # NOTE: for some reason modding using ciph_modulus doesnt work
plain_poly_product = Plaintext(poly_product, params.scaling_factor)

key_generator = CKKSKeyGenerator(params)
public_key = key_generator.public_key
secret_key = key_generator.secret_key
relin_key = key_generator.relin_key

encoder = CKKSEncoder(params) # Contains encode and decode functions
plain_poly1 = encoder.encode(message1, params.scaling_factor)
plain_poly2 = encoder.encode(message2, params.scaling_factor)



encryptor = CKKSEncryptor(params, public_key, secret_key)
encrypted_poly1 = encryptor.encrypt(plain_poly1)
encrypted_poly2 = encryptor.encrypt(plain_poly2)

evaluator = CKKSEvaluator(params)
encrypted_poly_product = evaluator.multiply(encrypted_poly1, encrypted_poly2, relin_key)

decryptor = CKKSDecryptor(params, secret_key)
decrypted_poly_product = decryptor.decrypt(encrypted_poly_product)

decoded_message_product = encoder.decode(decrypted_poly_product)
print("Decoded message product length:", len(decoded_message_product))
print("First 5 elements of decoded product:", decoded_message_product[:5])

true_message_product = [m1 * m2 for m1, m2 in zip(message1, message2)]
print("True message product length:", len(true_message_product))
print("First 5 elements of true product:", true_message_product[:5])

assert len(decoded_message_product) == len(true_message_product), "Decoded message has incorrect length"
print("Multiplication is homomorphic:", np.allclose(decoded_message_product, true_message_product,  rtol=1e-2, atol=1e-1))  

Message1 length: 4096
Message2 length: 4096
First 5 elements of message1: [(4+0j), (4+0j), (4+0j), (4+0j), (4+0j)]
First 5 elements of message2: [(3+0j), (3+0j), (3+0j), (3+0j), (3+0j)]
Decoded message product length: 4096
First 5 elements of decoded product: [(12.000017330717952+1.8000333968012113e-05j), (12.000018105488307-2.2941719496983527e-05j), (11.999992375930299-5.300478966964159e-06j), (12.00003150215366-3.5393945333136374e-05j), (11.999979730315504+2.6520856035226214e-06j)]
True message product length: 4096
First 5 elements of true product: [(12+0j), (12+0j), (12+0j), (12+0j), (12+0j)]
Multiplication is homomorphic: True
