# CKKS

## Imports

In [None]:
import numpy as np
from numpy.polynomial import Polynomial
import sympy as sympy

## Parameters and some functions

### Neal's NTL library

In [None]:
def is_power_of_two(n: int) -> bool:
    ''' INPUT: n>0 an integer
    OUTPUT: True iff n is a power of two '''
    return not n & (n-1)

def size(n: int) -> int:
    ''' INPUT: n>0 an integer
    OUTPUT: returns number of bits used to express n in binary
    >>> size(7)==3 and size(8)==4 and size(9)==4
    True
    '''
    return 1+size(n>>1) if n else 0

# integer arithmetic
def isqrt(k: int) -> int:
    ''' returns the integer floor of sqrt(k)
    IMPLEMENTATION: Newton's method
    NOTE: only used in is_prime()
    >>> isqrt(15)==3 and isqrt(16)==4
    True
    '''
    assert k>0, "input must be positive integer"
    x = k
    y = (x + 1) // 2
    while y < x:
        x = y
        y = (x + k//x) // 2
    return x

# functions related to prime numbers
def is_prime(x: int) -> bool:
    ''' INPUT: x>1 an integer
    OUTPUT: returns True iff x is prime
    IMPLEMENTATION: brute-force check of all d<=isqrt(x) for divisorship into x 
    >>> is_prime(17)==True and is_prime(341)==False
    True
    '''
    return all([x%divisor!=0 for divisor in range(2,isqrt(x)+1)])

def gen_good_prime(n: int, k: int, num_bits: int) -> int:
    ''' INPUT: n a power of 2
    OUTPUT: the kth "NTT-friendly" prime for n;
    i.e., the kth prime q with the property that n divides q-1
    '''
    assert k>0, "k should be a positive integer"
    good_prime = (1<<(num_bits-1))+1
    if k==1 and is_prime(good_prime):
        return good_prime
    while k>0 and size(good_prime)==num_bits:
        good_prime += n
        while not is_prime(good_prime):
            good_prime += n
        k -= 1
    return good_prime if (k==0 and size(good_prime)==num_bits) else 0

# modular arithmetic
def pow_mod(base: int, exp: int, q: int) -> int:
    ''' OUTPUT: the modular power (base**exp) % q
    IMPLEMENTATION: "square-and-multiply" algorithm
    '''
    base %= q
    result = base if 1&exp else 1
    while (exp != 0):
        exp >>= 1
        base = (base*base) % q
        if (1&exp):
            result = (base*result) % q
    return result

def inverse_mod(x: int, q: int) -> int:
    ''' INPUT: q a prime, 0<x<q
    OUTPUT: returns inverse of x modulo q
    IMPLEMENTATION: Fermat's Little Theorem
    >>> inverse_mod(1,2)==1
    True
    '''
    return pow_mod(x, q-2, q)

# functions related to primitive roots of unity modulo q
def is_primitive(x: int, n: int, q: int) -> bool:
    ''' INPUT: n a power of 2, q prime, satisfies n|(q-1)
    OUTPUT: returns True iff x is a prim nth root of unity modulo q
    >>> is_primitive(9, 8, 17)==True and is_primitive(13, 8, 17)==False
    True
    '''
    return pow_mod(x, n//2, q) == q-1

def gen_primitive_root(n: int, q: int) -> int:
    ''' INPUT: q prime, n a power of 2, with n dividing q-1
    OUTPUT: returns a primitive nth root of unity modulo q
    REFERENCE: https://crypto.stackexchange.com/questions/63614
    >>> gen_primitive_root(8, 17)==9 and gen_primitive_root(16, 97)==8
    True
    '''
    assert is_prime(q), "q must be prime"
    assert is_power_of_two(n), "n must be power of 2"
    assert (q-1) % n == 0, "n must divide q-1"
    def make_root(i: int) -> int:
        ''' NOTE: i**((q-1)/n) is always a root modulo q '''
        return pow_mod(i, (q-1)//n, q)

    for p in range(2,q):
        if pow_mod(make_root(p), n//2, q) != 1:
            return make_root(p)
    assert False, "primitive root not found. there exists a critical error in implementation"

# bit-reverse functions (used in NTT)
def bit_reverse(a: int, bit_length: int):
    ''' bit-reversal for (bit_length)-bit integer a 
    >>> bit_reverse(5, 4)
    10
    '''
    result = 0
    while bit_length > 0:
        result <<= 1
        result = (a & 1) | result
        a >>= 1
        bit_length -= 1
    return result

def bit_reverse_permute(a: np.array):
    '''in-place bit reverse permutation
    >>> bit_reverse_permute([0, 1, 2, 3, 4, 5, 6, 7])
    [0, 4, 2, 6, 1, 5, 3, 7]
    '''
    assert is_power_of_two(len(a)), "a must have length a power of 2"
    for i in range(len(a)):
        bit_reversed_index = bit_reverse(i, size(len(a))-1)
        if i < bit_reversed_index:
            a[i], a[bit_reversed_index] = a[bit_reversed_index], a[i]
    return a

# Number Theoretic Transform and Poly Mult (aka negacyclic convolution)
def ntt(a: np.array, omega_n: int, n: int, q: int) -> np.array:
    ''' NTT implemented via "decimation-in-time" (DIT) FFT
    INPUT: coefficient vector a = [a0, a1,...,a(n-1)]
    n a power of 2, q a prime with n dividing q-1
    omega_n a primitive nth root of unity mod q
    REFERENCE: ITERATIVE-FFT on page 917 in CLRS
    >>> ntt([1,2,3,7,5,4,1,2], 9, 8, 17)
    [8, 11, 14, 2, 12, 16, 7, 6]
    '''
    # BIT-REVERSE-COPY(a, A) (see CLRS, page 917)
    bit_reverse_permute(a)
    # apply Cooley--Tukey butterflies
    lgn = size(n)-1
    m = 1
    for i in range(lgn):
        m *= 2
        omega_m = pow_mod(omega_n, 1<<(lgn-i-1), q)
        k = 0
        while k<n:
            omega = 1
            for j in range(m//2):
                t = omega*a[k+j+m//2] % q
                u = a[k+j]
                a[k+j] = (u+t) % q
                a[k+j+m//2] = (u-t) % q
                omega = (omega * omega_m) % q
            k += m
    return a

def intt(a: np.array, omega_n: int, n: int, q: int) -> np.array:
    # scale 'a' by 1/N mod q
    # double-check 'a' is an np.array and not a list; otherwise scaling causes problems!
    a = inverse_mod(n, q) * np.array(a)
    return ntt(a, inverse_mod(omega_n, q), n, q)

def poly_mult(a: np.array, b: np.array, n: int, q: int) -> np.array:
    ''' INPUT: two coefficient vectors 'a' and 'b'
    n power of 2, q prime with (2n) dividing q-1
    OUTPUT: negacyclic convolution of 'a' and 'b'
    '''
    psi = gen_primitive_root(2*n, q) # prim (2n)th root of unity
    ipsi = inverse_mod(psi, q) # inverse of psi modulo q
    omega = (psi*psi)%q

    # generate PowMul_psi and PowMul_ipsi
    pow_mul_psi = [1]*n
    pow_mul_ipsi = [1]*n
    for i in range(n-1):
        pow_mul_psi[i+1] = (pow_mul_psi[i] * psi) % q
        pow_mul_ipsi[i+1] = (pow_mul_ipsi[i] * ipsi) % q
    pow_mul_psi = np.array(pow_mul_psi)
    pow_mul_ipsi = np.array(pow_mul_ipsi)

    # modular hadamard products with PowMul_psi
    A = (a*pow_mul_psi) % q
    B = (b*pow_mul_psi) % q
    # NTT's
    nttA = ntt(A, omega, n, q)
    nttA = ntt(B, omega, n, q)
    # modular hadamard product
    C = (A*B) % q
    # inverse NTT
    inttC = intt(C, omega, n, q)
    # modular hadamard products with PowMul_ipsi
    product = (pow_mul_ipsi*inttC) % q
    return product

### Parameters

In [None]:
scale = 64

N = 2**10
M = 2*N
# num_bits = number of bits in the prime modulus q
num_bits = 30
# set q equal to a good prime; i.e., q = (2N)*k + 1 for some integer k
q = gen_good_prime(2*N, 3, num_bits) # q is the 3rd good 30-bit prime

# We set xi, which will be used in our computations
xi = np.exp(2 * np.pi * 1j / M)

### Helpful functions

In [None]:
def round_coordinates(coordinates):
    """Gives the integral rest."""
    coordinates = coordinates - np.floor(coordinates)
    return coordinates

def coordinate_wise_random_rounding(coordinates):
    """Rounds coordinates randonmly."""
    r = round_coordinates(coordinates)
    f = np.array([np.random.choice([c, c-1], 1, p=[1-c, c]) for c in r]).reshape(-1)
    
    rounded_coordinates = coordinates - f
    rounded_coordinates = [int(coeff) for coeff in rounded_coordinates]
    return rounded_coordinates

## Encoder/Decoder

In [None]:
class CKKSEncoder:
    """Streamlined CKKS encoder to encode complex vectors into polynomials."""
    
    def __init__(self, M, scale):
        """Initialize"""
        self.xi = np.exp(2 * np.pi * 1j / M)
        self.M = M
        self.create_sigma_R_basis()
        self.scale = scale
        
    @staticmethod
    def vandermonde(xi: np.complex128, M: int) -> np.array:
        """Computes the Vandermonde matrix from a m-th root of unity."""
        
        N = M //2
        matrix = []
        # We will generate each row of the matrix
        for i in range(N):
            # For each row we select a different root
            root = xi ** (2 * i + 1)
            row = []

            # Then we store its powers
            for j in range(N):
                row.append(root ** j)
            matrix.append(row)
        return matrix
    
    def sigma_inverse(self, b: np.array) -> Polynomial:
        """Encodes the vector b in a polynomial using an M-th root of unity."""

        # First we create the Vandermonde matrix
        A = CKKSEncoder.vandermonde(self.xi, M)

        # Then we solve the system
        coeffs = np.linalg.solve(A, b)

        # Finally we output the polynomial
        p = Polynomial(coeffs)
        return p

    def sigma(self, p: Polynomial) -> np.array:
        """Decodes a polynomial by applying it to the M-th roots of unity."""

        outputs = []
        N = self.M //2

        # We simply apply the polynomial on the roots
        for i in range(N):
            root = self.xi ** (2 * i + 1)
            output = p(root)
            outputs.append(output)
        return np.array(outputs)

    def pi(self, z: np.array) -> np.array:
        """Projects a vector of H into C^{N/2}."""
        
        N = self.M // 4
        return z[:N]

    def pi_inverse(self, z: np.array) -> np.array:
        """Expands a vector of C^{N/2} by expanding it with its
        complex conjugate."""
        
        z_conjugate = z[::-1]
        z_conjugate = [np.conjugate(x) for x in z_conjugate]
        return np.concatenate([z, z_conjugate])

    def create_sigma_R_basis(self):
        """Creates the basis (sigma(1), sigma(X), ..., sigma(X** N-1))."""

        self.sigma_R_basis = np.array(self.vandermonde(self.xi, self.M)).T

    def compute_basis_coordinates(self, z):
        """Computes the coordinates of a vector with respect to the orthogonal lattice basis."""
        output = np.array([np.real(np.vdot(z, b) / np.vdot(b,b)) for b in self.sigma_R_basis])
        return output

    def sigma_R_discretization(self, z):
        """Projects a vector on the lattice using coordinate wise random rounding."""
        coordinates = self.compute_basis_coordinates(z)
        y = coordinate_wise_random_rounding(coordinates)
        return y

    def encode(self, z: np.array) -> np.array:
        """Encodes a vector by expanding it first to H,
        scale it, project it on the lattice of sigma(R), and performs
        sigma inverse.
        """
        pi_z = self.pi_inverse(z)
        scaled_pi_z = self.scale * pi_z
        p = self.sigma_R_discretization(scaled_pi_z)
       
        return p

    def decode(self, p: np.array) -> np.array:
        """Decodes a polynomial by removing the scale, 
        evaluating on the roots, and project it on C^(N/2)"""
        rescaled_p = Polynomial(p) / self.scale
        z = self.sigma(rescaled_p)
        pi_z = self.pi(z)
        return pi_z

In [None]:
encoder = CKKSEncoder(M, scale)
z = np.array([1+1j for n in range(N//2)])
print("z: ",z[0:5])
p = encoder.encode(z)
print("p: ",p[0:5])
d = encoder.decode(p)
print("d: ",d[0:5])

## Encrypter

In [None]:
class CKKSEncrypter:
    '''CKKS encrypter'''

    def key_gen() -> tuple:
        s = np.random.randint(q, size=(N))
        a = np.random.randint(q, size=(N))
        # e = np.zeros(N)
        e = np.random.randint(2, size=(N))
        p = (poly_mult((-a)%q, s, N, q) + e, a)
        
        return (s,p)

    def encrypt(msg: np.array, p: tuple) -> tuple:
        return (np.add(p[0],msg)%q, p[1])
    
    def decrypt(c: tuple, s: np.array) -> np.array:
        return np.add(c[0], poly_mult(c[1], s, N, q))%q

## Homomorphic operations

### Addition

In [None]:
def hom_add(c1: tuple, c2: tuple) -> tuple:
    return ((c1[0]+c2[0])%q, (c1[1]+c2[1])%q)

## Trial

In [None]:
encoder = CKKSEncoder(M, scale)
z1 = np.array([1+1j for n in range(N//2)])

print("z1: ",z1[0:5])
z2 = np.array([2+2j for n in range(N//2)])
print("z2: ",z2[0:5])
p1 = encoder.encode(z1)
p2 = encoder.encode(z2)
#print("p: ",p)

sk, pk = CKKSEncrypter.key_gen()

#print("sk: ",sk)
#print("pk: ",pk)

c1 = CKKSEncrypter.encrypt(p1,pk)
p1 =  CKKSEncrypter.decrypt(c1, sk)
d1 = encoder.decode(p1)
print("z1: ",d1[0:5])

c2 = CKKSEncrypter.encrypt(p2,pk)

print("c: ",c2[0:5])

cres = hom_add(c1,c2)

p = CKKSEncrypter.decrypt(cres, sk)

print("p: ",p[0:5])

d = encoder.decode(p)
print("zres: ",d[0:5])