# CKKS-basic #3
# Multiplication

# 今回の目標

**1. CKKS 暗号文の掛け算を理解する**

# 前処理

In [1]:
import numpy as np
from numpy.polynomial import Polynomial
import random

# Parameters

In [2]:
scale = 64
polynomial_moduls_degree = 4
ciphertext_level = 1
ciphertext_modulus_list = [4129, 73]
ciphertext_modulus = 4129 * 73
stddev = 2.0
level1_ciphertext_modulus_inv = 905
level1_prime = 301409
level1_prime_inverse = 37677
level1_relinkey_modulus = ciphertext_modulus * level1_prime

# Subroutine

In [3]:
def coeff_encode(
    z: list,
    scale: int,
) -> Polynomial:
    """Converts a complex vector into a integer coefficient polynomial."""
    coef = []

    for z_i in z:
        # Step1
        scaled_zi = scale * z_i
        
        # Step2
        coef.append(np.round(np.real(scaled_zi)).astype(int))
        coef.append(np.round(np.imag(scaled_zi)).astype(int))

    return Polynomial(coef)


def coeff_decode(
    p: Polynomial,
    scale: int,
) -> np.array:
    """Converts a integer coefficient polynomial into a complex vector."""
    """Inverse operation of coeff_encode"""
    rescaled_p = []

    for i in range(len(p.coef) // 2):
        # Step1 and Step2
        rescaled_p.append(p.coef[2 * i] / scale + 1j * (p.coef[2 * i + 1]) / scale)

    return rescaled_p


def poly_mul(
    poly1: Polynomial, 
    poly2: Polynomial
) -> Polynomial:
    """polynomial multiplication with negative cyclic"""

    mult_result = np.convolve(poly1.coef, poly2.coef)
    n = len(poly1)
    result = np.zeros(n)
    for i in range(n):
        if i + n < len(mult_result):
            result[i] = mult_result[i] - mult_result[i + n]
        else:
            result[i] = mult_result[i]
    return Polynomial(result)


def rem_ciphertext_modulus(
    poly: Polynomial, 
    ciphertext_modulus: int
) -> Polynomial:
    """convert the range of the plaintext into the ciphertext"""
    """reduce each polynomial coefficient modulo the ciphertext modulus"""
    
    result = []
    n = len(poly)
    
    for coef in poly.coef:
        result.append(
            coef % ciphertext_modulus
        )
    
    return Polynomial(result)


def convert_plaintext_modulus(
    poly: Polynomial, 
    ciphertext_modulus: int
) -> Polynomial:
    """convert the range of the ciphertext [0, q - 1] into the plaintext [- q/2, q/2]"""
    
    result = []
    
    for coef in poly.coef:
        if coef > ciphertext_modulus / 2:
            result.append(
                coef - ciphertext_modulus
            )
        else:
            result.append(
                coef
            )
    
    return Polynomial(result)


def generate_secret_key(
    polynomial_moduls_degree: int
) -> Polynomial:
    """generate a secret key"""

    secret_key = []
    for _ in range(polynomial_moduls_degree):
        secret_key.append( 
            random.randint(0, 1)
        )
    
    return Polynomial(secret_key)

def generate_public_key(
    secret_key: Polynomial, 
    polynomial_modulus_degree: int, 
    ciphertext_modulus: int, 
    stddev: int, 
) -> list[Polynomial]:
    """generate a public key."""

    # Step 1
    """generate an rlwe sample."""
    rlwe_sample = []
    for _ in range(polynomial_modulus_degree):
        rlwe_sample.append( 
            random.randint(0, ciphertext_modulus)
        )
    rlwe_sample = Polynomial(np.array(rlwe_sample))

    # Step 2
    """generate a noise."""
    noise = []
    for _ in range(polynomial_modulus_degree):
        noise.append( 
            random.randint(0, stddev)
        )
    noise = Polynomial(np.array(noise))

    public_key = []
    
    # Step 3
    """compute b"""
    ctxt_b = poly_mul(rlwe_sample, secret_key) + noise
    
    public_key.append(
        rem_ciphertext_modulus(
            rlwe_sample, 
            ciphertext_modulus
        )
    )
    public_key.append(
        rem_ciphertext_modulus(
            ctxt_b, 
            ciphertext_modulus
        )
    )

    return public_key


def encrypt(
    public_key: list[Polynomial], 
    plaintext: Polynomial, 
    ciphertext_modulus
) -> Polynomial:
    """transform a plaintext into a ciphertext"""
    
    ciphertext = []

    ciphertext.append(
        rem_ciphertext_modulus(
            public_key[0], 
            ciphertext_modulus
        )
    )
    ciphertext.append(
        rem_ciphertext_modulus(
            public_key[1] + plaintext, 
            ciphertext_modulus
        )
    )

    return ciphertext


def decrypt(
    secret_key: Polynomial, 
    ciphertext: list[Polynomial], 
    ciphertext_modulus: int
) -> Polynomial:
    """transform a ciphertext into a plaintext"""

    decrypted_ciphertext = ciphertext[1] - poly_mul(ciphertext[0], secret_key)
    
    decrypted_ciphertext = rem_ciphertext_modulus(
        decrypted_ciphertext, 
        ciphertext_modulus
    )
    
    decrypted_ciphertext = convert_plaintext_modulus(
        decrypted_ciphertext, 
        ciphertext_modulus
    )
    
    return decrypted_ciphertext


def add(
    ciphertext1: list[Polynomial], 
    ciphertext2: list[Polynomial]
) -> list[Polynomial]:
    """add ciphertexts"""
    ciphertext = []

    for i in range(len(ciphertext1)):   
        ciphertext.append(
            rem_ciphertext_modulus(
                ciphertext1[i] + ciphertext2[i], 
                ciphertext_modulus
            )
        )

    return ciphertext

# 1. Ciphertext Multiplication

## 暗号文の掛け算の概要

再線型化鍵（Relinearization Key, RelinKey）が必要

1. 乗算結果を3成分の暗号文で表現する
1. 3成分の暗号文を RelinKey を用いて、2成分にする
1. Rescale 処理を行い、暗号文のレベルを1下げる

## 1.1 Ciphertext Multiplication wo/RelinKey Algorithm

In [4]:
def mul_without_relinkey(
    ciphertext1: list[Polynomial], 
    ciphertext2: list[Polynomial], 
    ciphertext_modulus: int, 
) -> list[Polynomial]:
    """multiply ciphertext without relin_key"""
    result = []
    
    # Step 1
    result.append(
        rem_ciphertext_modulus(
            poly_mul(
                ciphertext1[0], 
                ciphertext2[0]
            ), 
            ciphertext_modulus
        )
    )
    result.append(
        rem_ciphertext_modulus(
            poly_mul(
                ciphertext1[0], 
                ciphertext2[1]
            ) + poly_mul(
                ciphertext1[1], 
                ciphertext2[0]
            ), 
            ciphertext_modulus
        )
    )
    result.append(
        rem_ciphertext_modulus(
            poly_mul(
                ciphertext1[1], 
                ciphertext2[1]
            ), 
            ciphertext_modulus
        )
    )
    
    return result


def coeff_decode_without_rescale(
    p: Polynomial,
    scale: int,
) -> np.array:
    """Converts a integer coefficient polynomial into a complex vector."""
    """Inverse operation of coeff_encode"""
    rescaled_p = []

    for i in range(len(p.coef) // 2):
        # Step1 and Step2
        rescaled_p.append(p.coef[2 * i] / (scale * scale) + 1j * (p.coef[2 * i + 1]) / (scale * scale))

    return rescaled_p


def generate_secret_key_list(
    secret_key: Polynomial
) -> list[Polynomial]:
    """generate a secret key list"""

    secret_key_list = [secret_key]
    secret_key_list.append(
        poly_mul(
            secret_key, 
            secret_key
        )
    )
    
    return secret_key_list


def decrypt_without_relinkey(
    secret_key_list: list[Polynomial], 
    ciphertext: list[Polynomial], 
    ciphertext_modulus: int
) -> Polynomial:
    """transform a ciphertext into a plaintext"""

    decrypted_ciphertext = ciphertext[2] - poly_mul(ciphertext[1], secret_key_list[0]) + poly_mul(ciphertext[0], secret_key_list[1])
    decrypted_ciphertext = rem_ciphertext_modulus(
        decrypted_ciphertext, 
        ciphertext_modulus
    )
    
    decrypted_ciphertext = convert_plaintext_modulus(
        decrypted_ciphertext, 
        ciphertext_modulus
    )

    return decrypted_ciphertext

## 1.2 Ciphertext Multiplication wo/RelinKey Example

In [5]:
secret_key = generate_secret_key(polynomial_moduls_degree)
public_key = generate_public_key(secret_key, polynomial_moduls_degree, ciphertext_modulus, stddev)
secret_key_list = generate_secret_key_list(secret_key)

cleartext1 = np.array([3 + 4j, 2 - 1j])
cleartext2 = np.array([5 - 2j, 3 + 4j])
print("cleartext1:", cleartext1)
print("cleartext2:", cleartext2)

print()

plaintext1 = coeff_encode(cleartext1, scale)
plaintext2 = coeff_encode(cleartext2, scale)

ciphertext1 = encrypt(public_key, plaintext1, ciphertext_modulus)
ciphertext2 = encrypt(public_key, plaintext2, ciphertext_modulus)

ciphertext = mul_without_relinkey(ciphertext1, ciphertext2, ciphertext_modulus)
print("ciphertext[0]:", ciphertext[0])
print("ciphertext[1]:", ciphertext[1])

print()

decrypted_ciphertext = decrypt_without_relinkey(secret_key_list, ciphertext, ciphertext_modulus)

decoded_decrypted_ciphertext = coeff_decode_without_rescale(decrypted_ciphertext, scale)
print("cleartext1 * cleartext2: ", cleartext1 * cleartext2)
print("decoded_decrypted_ciphertext: ", decoded_decrypted_ciphertext)

cleartext1: [3.+4.j 2.-1.j]
cleartext2: [5.-2.j 3.+4.j]

ciphertext[0]: 275049.0 + 280886.0·x + 140678.0·x² + 62248.0·x³
ciphertext[1]: 144316.0 + 96528.0·x + 134764.0·x² + 212357.0·x³

cleartext1 * cleartext2:  [23.+14.j 10. +5.j]
decoded_decrypted_ciphertext:  [(-8.8115234375+8.90625j), (15.0615234375+15.345703125j)]


In [6]:
cleartext1_poly = Polynomial([3, 4, 2, -1])
cleartext2_poly = Polynomial([5, -2, 3, 4])

cleartext_convolution = poly_mul(
    cleartext1_poly, 
    cleartext2_poly
)

print("cleartext convolution: ", cleartext_convolution)

cleartext convolution:  -9.0 + 9.0·x + 15.0·x² + 15.0·x³


## 1.3 RelinKey Generation Algorithm

In [7]:
def generate_relinearization_key(
    secret_key: Polynomial, 
    polynomial_moduls_degree: int, 
    level1_prime: int, 
    level1_relinkey_modulus: int, 
    stddev: int, 
) -> list[Polynomial]:
    """generate a public key."""

    # Step 1
    """generate an rlwe sample."""
    rlwe_sample = []
    for _ in range(polynomial_moduls_degree):
        rlwe_sample.append( 
            random.randint(0, level1_relinkey_modulus)
        )
    rlwe_sample = Polynomial(np.array(rlwe_sample))

    # Step 2
    """generate a noise."""
    noise = []
    for _ in range(polynomial_moduls_degree):
        noise.append( 
            random.randint(0, stddev)
        )
    noise = Polynomial(np.array(noise))

    relin_key = []
    
    # Step 3
    """compute P s^2"""
    
    relin_plaintext = level1_prime * poly_mul(secret_key, secret_key)
    
    # Step 4
    """compute b"""
    # ctxt_b = poly_mul(rlwe_sample, secret_key) + noise + relin_plaintext
    ctxt_b = poly_mul(rlwe_sample, secret_key) + relin_plaintext
    
    relin_key.append(
        rem_ciphertext_modulus(
            rlwe_sample, 
            level1_relinkey_modulus
        )
    )
    relin_key.append(
        rem_ciphertext_modulus(
            ctxt_b, 
            level1_relinkey_modulus
        )
    )

    return relin_key

## 1.4 RelinKey Generation Example

In [8]:
secret_key = generate_secret_key(polynomial_moduls_degree)
relin_key = generate_relinearization_key(secret_key, polynomial_moduls_degree, level1_prime, level1_relinkey_modulus, stddev)

print("relin_key:", relin_key)

relin_key: [Polynomial([7.69619634e+10, 3.03343478e+10, 8.58256620e+10, 1.17794706e+10], domain=[-1,  1], window=[-1,  1], symbol='x'), Polynomial([6.55392819e+10, 8.40938578e+10, 6.51821914e+10, 1.64465147e+10], domain=[-1,  1], window=[-1,  1], symbol='x')]


## 1.5 Ciphertext Multiplication wo/rescale Algorithm

In [9]:
def modulus_down(
    poly: Polynomial, 
    small_modulus1: int, 
    small_modulus2: int, 
) -> Polynomial:
    """transform poly mod (small_modulus1 * small_modulus2) into poly mod small_modulus1"""
    result = poly - rem_ciphertext_modulus(
        poly, 
        small_modulus2
    )
    result = rem_ciphertext_modulus(
        result,
        small_modulus1
    )
    
    return result



def mul_without_rescale(
    ciphertext1: list[Polynomial], 
    ciphertext2: list[Polynomial], 
    relin_key: list[Polynomial], 
    ciphertext_modulus_list: list[int], 
    ciphertext_modulus: int, 
    level1_prime: int, 
    level1_relinkey_modulus: int, 
    level1_prime_inverse: int
) -> list[Polynomial]:
    """multiply ciphertext"""
    extended_ciphertext = []
    
    # Step 1
    extended_ciphertext.append(
        rem_ciphertext_modulus(
            poly_mul(
                ciphertext1[0], 
                ciphertext2[0]
            ), 
            ciphertext_modulus
        )
    )
    extended_ciphertext.append(
        rem_ciphertext_modulus(
            poly_mul(
                ciphertext1[0], 
                ciphertext2[1]
            ) + poly_mul(
                ciphertext1[1], 
                ciphertext2[0]
            )    , 
            ciphertext_modulus
        )
    )
    extended_ciphertext.append(
        rem_ciphertext_modulus(
            poly_mul(
                ciphertext1[1], 
                ciphertext2[1]
            ), 
            ciphertext_modulus
        )
    )
    
    # Step 2
    result = []
    for i in range(len(relin_key)):
        ith_result = rem_ciphertext_modulus(
            poly_mul(
                relin_key[i], 
                extended_ciphertext[0]
            ), 
            level1_relinkey_modulus
        )
#         ith_result = ith_result - rem_ciphertext_modulus(
#             ith_result, 
#             level1_prime
#         )
#         ith_result = rem_ciphertext_modulus(
#             ith_result,
#             ciphertext_modulus
#         )
        ith_result = modulus_down(
            ith_result, 
            ciphertext_modulus, 
            level1_prime
        )
        ith_result = rem_ciphertext_modulus(
            level1_prime_inverse * ith_result, 
            ciphertext_modulus
        )
        ith_result = rem_ciphertext_modulus(
            ith_result + extended_ciphertext[i + 1], 
            ciphertext_modulus
        )
        result.append(ith_result)
    
    return result

## 1.6 Ciphertext Multiplication wo/rescale Example

In [10]:
secret_key = generate_secret_key(polynomial_moduls_degree)
public_key = generate_public_key(secret_key, polynomial_moduls_degree, ciphertext_modulus, stddev)
relin_key = generate_relinearization_key(secret_key, polynomial_moduls_degree, level1_prime, level1_relinkey_modulus, stddev)

cleartext1 = np.array([3 + 4j, 2 - 1j])
cleartext2 = np.array([5 - 2j, 3 + 4j])
print("cleartext1:", cleartext1)
print("cleartext2:", cleartext2)

print()

plaintext1 = coeff_encode(cleartext1, scale)
plaintext2 = coeff_encode(cleartext2, scale)

ciphertext1 = encrypt(public_key, plaintext1, ciphertext_modulus)
ciphertext2 = encrypt(public_key, plaintext2, ciphertext_modulus)

ciphertext = mul_without_rescale(
    ciphertext1, 
    ciphertext2, 
    relin_key, 
    ciphertext_modulus_list, 
    ciphertext_modulus, 
    level1_prime, 
    level1_relinkey_modulus, 
    level1_prime_inverse
)
print("ciphertext[0]:", ciphertext[0])
print("ciphertext[1]:", ciphertext[1])

print()

decrypted_ciphertext = decrypt(secret_key, ciphertext, ciphertext_modulus)

decoded_decrypted_ciphertext = coeff_decode_without_rescale(decrypted_ciphertext, scale)
print("cleartext convolution: ", cleartext_convolution)
print("decoded_decrypted_ciphertext: ", decoded_decrypted_ciphertext)

cleartext1: [3.+4.j 2.-1.j]
cleartext2: [5.-2.j 3.+4.j]

ciphertext[0]: 187881.0 + 204042.0·x + 25546.0·x² + 53864.0·x³
ciphertext[1]: 71350.0 + 72866.0·x + 177106.0·x² + 44501.0·x³

cleartext convolution:  -9.0 + 9.0·x + 15.0·x² + 15.0·x³
decoded_decrypted_ciphertext:  [(-9.062744140625+8.84375j), (14.90576171875+15.25048828125j)]


## 1.7 Ciphertext Multiplication Algorithm

In [11]:
def modulus_down(
    poly: Polynomial, 
    small_modulus1: int, 
    small_modulus2: int, 
) -> Polynomial:
    """transform poly mod (small_modulus1 * small_modulus2) into poly mod small_modulus1"""
    result = poly - rem_ciphertext_modulus(
        poly, 
        small_modulus2
    )
    result = rem_ciphertext_modulus(
        result,
        small_modulus1
    )
    
    return result


def mul(
    ciphertext1: list[Polynomial], 
    ciphertext2: list[Polynomial], 
    relin_key: list[Polynomial], 
    ciphertext_modulus_list: list[int], 
    ciphertext_modulus: int, 
    level1_ciphertext_modulus_inv: int, 
    level1_prime: int, 
    level1_relinkey_modulus: int, 
    level1_prime_inverse: int
) -> (list[int], int, list[Polynomial]):
    """multiply ciphertext"""
    extended_ciphertext = []
    
    # Step 1
    extended_ciphertext.append(
        rem_ciphertext_modulus(
            poly_mul(
                ciphertext1[0], 
                ciphertext2[0]
            ), 
            ciphertext_modulus
        )
    )
    extended_ciphertext.append(
        rem_ciphertext_modulus(
            poly_mul(
                ciphertext1[0], 
                ciphertext2[1]
            ) + poly_mul(
                ciphertext1[1], 
                ciphertext2[0]
            )    , 
            ciphertext_modulus
        )
    )
    extended_ciphertext.append(
        rem_ciphertext_modulus(
            poly_mul(
                ciphertext1[1], 
                ciphertext2[1]
            ), 
            ciphertext_modulus
        )
    )
    
    # Step 2
    result = []
    for i in range(len(relin_key)):
        ith_result = rem_ciphertext_modulus(
            poly_mul(
                relin_key[i], 
                extended_ciphertext[0]
            ), 
            level1_relinkey_modulus
        )
        ith_result = modulus_down(
            ith_result, 
            ciphertext_modulus, 
            level1_prime
        )
        ith_result = rem_ciphertext_modulus(
            level1_prime_inverse * ith_result, 
            ciphertext_modulus
        )
        ith_result = rem_ciphertext_modulus(
            ith_result + extended_ciphertext[i + 1], 
            ciphertext_modulus
        )
        result.append(ith_result)
    
    # Step 3
    ciphertext_modulus //= ciphertext_modulus_list[-1]
    for i in range(len(result)):
        result[i] = modulus_down(
            result[i], 
            ciphertext_modulus,
            ciphertext_modulus_list[-1]
        )
        result[i] = rem_ciphertext_modulus(
            result[i] * level1_ciphertext_modulus_inv, 
            ciphertext_modulus
        )
    ciphertext_modulus_list.pop(-1)
    
    return (ciphertext_modulus_list, ciphertext_modulus, result)

## 1.8 Ciphertext Multiplication Example

In [12]:
# parametes
scale = 64
polynomial_moduls_degree = 4
ciphertext_level = 1
ciphertext_modulus_list = [4129, 73]
ciphertext_modulus = 4129 * 73
stddev = 2.0
level1_ciphertext_modulus_inv = 905
level1_prime = 301409
level1_prime_inverse = 37677
level1_relinkey_modulus = ciphertext_modulus * level1_prime

# logic
secret_key = generate_secret_key(polynomial_moduls_degree)
public_key = generate_public_key(secret_key, polynomial_moduls_degree, ciphertext_modulus, stddev)
relin_key = generate_relinearization_key(secret_key, polynomial_moduls_degree, level1_prime, level1_relinkey_modulus, stddev)

cleartext1 = np.array([3 + 4j, 2 - 1j])
cleartext2 = np.array([5 - 2j, 3 + 4j])
print("cleartext1:", cleartext1)
print("cleartext2:", cleartext2)

print()

plaintext1 = coeff_encode(cleartext1, scale)
plaintext2 = coeff_encode(cleartext2, scale)

ciphertext1 = encrypt(public_key, plaintext1, ciphertext_modulus)
ciphertext2 = encrypt(public_key, plaintext2, ciphertext_modulus)

(ciphertext_modulus_list, ciphertext_modulus, ciphertext) = mul(
        ciphertext1, 
    ciphertext2, 
    relin_key, 
    ciphertext_modulus_list, 
    ciphertext_modulus, 
    level1_ciphertext_modulus_inv, 
    level1_prime, 
    level1_relinkey_modulus, 
    level1_prime_inverse
)
print("ciphertext[0]:", ciphertext[0])
print("ciphertext[1]:", ciphertext[1])

print()

decrypted_ciphertext = decrypt(secret_key, ciphertext, ciphertext_modulus)

decoded_decrypted_ciphertext = coeff_decode(decrypted_ciphertext, scale)
print("cleartext convolution: ", cleartext_convolution)
print("decoded_decrypted_ciphertext: ", decoded_decrypted_ciphertext)

cleartext1: [3.+4.j 2.-1.j]
cleartext2: [5.-2.j 3.+4.j]

ciphertext[0]: 2413.0 + 4118.0·x + 2221.0·x² + 3178.0·x³
ciphertext[1]: 1902.0 + 486.0·x + 3068.0·x² + 4029.0·x³

cleartext convolution:  -9.0 + 9.0·x + 15.0·x² + 15.0·x³
decoded_decrypted_ciphertext:  [(-7.984375+7.765625j), (13.234375+13.296875j)]
