In [1]:
import numpy as np

## Definitions:

### Primitive root of unity
Any complex number that yields 1 when raised to some positive integer power n.
1. $ \zeta^n = 1 $
2. $ \zeta^k \neq 1 $ for all integers $ k $ such that $ 0 < k < n $.

$\zeta^n = e^{i(2\pi \frac{k}{n})}$

For 2N-th roots of unity, this means any number $\zeta$ that satisfies $\zeta^{2N}=1$

### Cyclotomic Polynomial
A **cyclotomic polynomial** is a special type of polynomial defined as the unique irreducible polynomial with integer coefficients whose roots are the **primitive roots of unity** $e^{i(2\pi\frac{k}{n})}$, where k runs over the positive integers less than n and coprime to n. The formula is equal to:

$$
\Phi_{n}(x) = \prod_{\substack{1 \leq k \leq n \\ \gcd(k, n) = 1}} \left(x - e^{2i\pi \frac{k}{n}}\right)
$$

**Special Case**: $\Phi_M(X) = X^N + 1$

The specific form $\Phi_M(X) = X^N + 1$ arises when:

- $M = 2N$, i.e., $M$ is an even number and M is a power of 2.
- The roots of $\Phi_{2N}(X)$ are the primitive $2N$-th roots of unity.


#### 2.1. Roots of $X^N + 1$

The equation $X^N + 1 = 0$ has $N$ distinct roots in the complex plane. These roots are the **odd powers** of $\xi$, a primitive $2N$-th root of unity:

$$
\xi = e^{2\pi i / (2N)}, \quad \text{so } \xi^{2k+1} \text{ for } k = 0, 1, \dots, N-1.
$$



#### 2.2. Factorization

The polynomial $X^N + 1$ factors as:

$$
X^N + 1 = \prod_{k=0}^{N-1} \left( X - \xi^{2k+1} \right),
$$

where $\xi^{2k+1}$ are the odd powers of the primitive $2N$-th root of unity $\xi$.



#### 2.3. Minimal Polynomial

When $M = 2N$, the primitive $2N$-th roots of unity are precisely the roots of $X^N + 1$. Hence, the minimal polynomial of the primitive $2N$-th roots of unity is:

$$
\Phi_{2N}(X) = X^N + 1.
$$


### Isomorphism

An isomorphism between two algebraic structures is a bijective homomorphism.

**Addition Preservation**:
$\sigma(m1 + m2) = \sigma(m1) + \sigma(m2)$

**Multiplication Preservation**:
$\sigma(m1 * m2) = \sigma(m1) \circ \sigma(m2)$

- Injectivity: No two distinct polynomials map to the same vector under σ.
- Surjectivity: Every vector in ${C^N}$ (satisfying the necessary properties) is the image of some polynomial under σσ.




In [296]:
from math import sin,cos,pi
import numpy as np

class CKKSEncoder:
    """Basic CKKS encoder to encode complex vectors into polynomials."""
    
    def __init__(self, M: int):
        """Initialization of the encoder for M a power of 2. 
        
        xi, which is an M-th root of unity will, be used as a basis for our computations.
        """
        # self.xi = np.exp(2 * np.pi * 1j / M)
        self.M = M
        self.N = M // 2


    def pi_inverse(self, z: np.array) -> np.array:
        """Expands a vector of C^{N/2} by expanding it with its
        complex conjugate."""
        
        z_conjugate = z[::-1]
        z_conjugate = [np.conjugate(x) for x in z_conjugate]
        return np.concatenate([z, z_conjugate])



    def ifft(self, z):
    
        xi_inv = np.exp(2 * np.pi * 1j / self.N)

        a = self.fft(z, xi_inv)
    
        for j in range(self.N):
            a[j] /= self.N
    
        return a

    # Recursive function of FFT
    def fft(self, a, xi):

        n = len(a)
          
        # if input contains just one element
        if n==1:
            return [a[0]]
        
        # Separe coefficients
        Aeven = a[0::2]
        Aodd  = a[1::2]
     
        # Recursive call for even indexed coefficients
        Yeven = self.fft(Aeven, xi**2) 
     
        # Recursive call for odd indexed coefficients
        Yodd = self.fft(Aodd, xi**2)
     
        # for storing values of y0, y1, y2, ..., yn-1.
        Y = [0]*n
    
        for k in range(n//2):
            twiddle = xi**k
             
            Y[k] =  Yeven[k] + twiddle *  Yodd[k]
            Y[k + n//2] =  Yeven[k]  -  twiddle * Yodd[k]
         
        return Y

In [313]:
# EXAMPLE

import numpy as np
from numpy.polynomial import Polynomial
from sympy import ntt

M = 8
N = M // 2



def sigma(p: Polynomial, xi) -> np.array:
    """Decodes a polynomial by applying it to the M-th roots of unity."""

    outputs = []
    N = 4

    # We simply apply the polynomial on the roots
    for i in range(N):
        root = xi ** (2*i + 1)
        output = p(root)
        outputs.append(output)
    return np.array(outputs)

# # pol = np.array([2.5+4.440892098500626e-16j,-4.996003610813204e-16+0.7071067811865479j,-3.4694469519536176e-16+0.5000000000000003j, -8.326672684688674e-16+0.7071067811865472j])

# prime = 5

encoder = CKKSEncoder(M)

# print(pol)
xi = np.exp(-2 * np.pi * 1j / M)
pol = np.array([2.5, 1.421875, 2.5, 0.703125])
print(encoder.fft(pol, xi))

# transform = np.fft.fft(pol)
# print ("FFT : ", transform)

# pol = Polynomial([2.5, 1.421875, 2.5, 0.703125])
# xi = np.exp(-2 * np.pi * 1j / M)
# res = sigma(pol, xi)
# print("new: ", res)


# p = encoder.sigma_inverse(b)
# p
# b_reconstructed = encoder.sigma(p)
# b


# encoder.ifft(pol)




[np.complex128(7.125+0j), np.complex128(0.508232998977831-0.508232998977831j), np.complex128(2.875+0j), np.complex128(-0.508232998977831+0.508232998977831j)]


In [286]:
# M = 8
# N = M // 2


# # b = np.array([1, 2, 3, 4])

# delta = 64

# # input vector
# inp = np.array([3+4j, 2-1j])


# # append the conjugate since N=4
# inp = encoder.pi_inverse(inp)

# # scaling 
# inp = inp * delta



# # p = encoder.ifft(inp)
# # print(p)
# # np.dot(delta, p)



# # p
# # xi = np.exp(2 * np.pi * 1j / N)
# # d = encoder.fft(p, xi)
# # d

In [377]:
def bitReverse(vals, size):
    """Rearrange elements of vals based on bit-reversed indices."""
    j = 0
    for i in range(1, size):
        bit = size >> 1
        while j >= bit:
            j -= bit
            bit >>= 1
        j += bit
        if i < j:
            vals[i], vals[j] = vals[j], vals[i]
 

In [353]:
def scale_down_to_real(x, logp):
    """
    Scale down a large integer `x` by reducing its precision by `logp`.

    Parameters:
        x (int): The large integer value to scale down.
        logp (int): The logarithm of the precision to scale down by.

    Returns:
        float: The scaled-down floating-point value.
    """
    from decimal import Decimal  # For handling large numbers precisely

    xp = Decimal(x)  # Convert the integer to a high-precision decimal
    xp /= Decimal(2 ** logp)  # Scale down by shifting `logp` bits
    return float(xp)  # Convert back to a double (floating-point)


In [465]:
from decimal import Decimal  # For handling large numbers precisely
def scale_up_to_zz(x, logp):
    """
    Scales up a floating-point value x by left-shifting it by logp.
    
    Parameters:
        x (float): The input floating-point value.
        logp (int): The power of 2 for scaling (precision).
    
    Returns:
        int: The scaled-up integer value.
    """
    # Convert x to high precision using Decimal for better accuracy
    scaled_value = Decimal(x) * (2 ** logp)
    
    # Convert to integer and return
    return int(scaled_value)

In [429]:
import random

def random_complex_array(size, bound=1.0):
    """
    Generates an array of random complex numbers with real and imaginary parts 
    in the range (0, bound).
    
    Parameters:
        size (int): Number of complex numbers to generate.
        bound (float): Upper bound for the real and imaginary parts (default: 1.0).
    
    Returns:
        list of complex: An array of random complex numbers.
    """
    return [complex(random.uniform(0, bound), random.uniform(0, bound)) for _ in range(size)]


In [556]:
class Enc:
    def __init__(self, logN, logQ):
        self.N = 1 << logN # N is a power-of-two that corresponds to the ring Z[X]/(X^N + 1)
        self.Nh = self.N >> 1 # Nh = N / 2
        self.logHh = logN - 1 # logHn = logN - 1
        self.M = self.N << 1
        self.logQQ = logQ << 1 # logQQ = log of PQ
        self.Q = 2**(logQ) # the highest modulous
        self.QQ = 2**(self.logQQ) # PQ = Q*Q
        self.rotGroups = []
        self.ksiPows = []

        fivePows = 1
        for i in range(self.Nh):
            self.rotGroups.append(fivePows)
            fivePows *= 5
            fivePows %= self.M

        for j in range(self.M):
            angle = 2.0 * np.pi * 1j / self.M
            self.ksiPows.append(np.exp(1j * angle))

        self.ksiPows.append(self.ksiPows[0]) # Mth element

        
    
    def fftSpecial(self, vals):
        siz = len(vals)
        
        bitReverse(vals, siz)
    
        len_ = 2
        i = 0
        while len_ <= siz:
            while i  < siz:
                lenh = len_ >> 1
                lenq = len_ << 2
                for j in range(lenh):
                    idx = ((self.rotGroups[j] % lenq)) * self.M // lenq
                    u = vals[i + j]
                    v = vals[i + j + lenh] * self.ksiPows[idx]
                    vals[i+j] = u + v
                    vals[i+j+lenh] = u-v

                i += len_
            len_ <<= 1
                    
        return vals

    def fftSpecialInvLazy(self, vals):
        siz = 4 # TODO: CHANGE
        
        i = 0
        len_ = siz
        while len_ >= 1:
            while i < siz:
                lenh = len_ >> 1
                lenq = len_ << 2
                for j in range(lenh):
                    idx = (lenq - (self.rotGroups[j] % lenq))* self.M // lenq
                    u = vals[i+j] + vals[i+j+lenh]
                    v = vals[i+j] - vals[i+j+lenh]
            
                    v *= self.ksiPows[idx]
                    vals[i+j] = u
                    vals[i+j+lenh] = v
                
                i += len_
            len_ >>= 1
        bitReverse(vals, siz)

        return vals

    def fftSpecialInv(self, vals):
        vals = self.fftSpecialInvLazy(vals)
        siz_ = len(vals)
        for i in range(siz_):
            vals[i] /= siz_
        return vals
    
    def decode(self, mx, slots, logp, logq):
        q = 2**(logq)
        gap = self.Nh / slots
        idx = 0
        res = [0] * slots  # Initialize the result array

        # print(self.Nh, slots)

        # print("gap", gap)
        
        for i in range(slots):
            idx = int(i * gap)
            tmp = mx[idx] % q
        
            if tmp.bit_length() == logq:  # Check NumBits(tmp) == logq
                tmp -= q

            real_part = scale_down_to_real(tmp, logp)

            tmp = mx[idx + Nh] % q  # rem(tmp, mx[idx + Nh], q)
            if tmp.bit_length() == logq:  # Check NumBits(tmp) == logq
                tmp -= q
            imag_part = scale_down_to_real(tmp, logp)

            res[i] = complex(real_part, imag_part)

        return self.fftSpecial(res)
        return self.fftSpecial(mx)

    def encode(self, vals, slots, logp):

        uvals = vals[:slots] + [0] * (slots - len(vals))
        uvals = self.fftSpecialInv(uvals)
        print(uvals)

        
        mx = [0] * N
        gap = self.Nh // slots

        for i in range(slots):
            idx = i * gap
            jdx = self.Nh + idx
            
            # Real and imaginary parts
            mx[idx] = scale_up_to_zz(uvals[i].real, logp)
            mx[jdx] = scale_up_to_zz(uvals[i].imag, logp)

        
        
        return mx
        
        
            
    

In [557]:
# mx = np.array([160, 91, 160, 45])


logN = 2  # N = 4
logQ = 65  # Q = 2^65
logp = 30  # Precision of scaling
slots = 1 << logN


bound = 1.0
# random_array1 = random_complex_array(slots, bound)

random_array = [complex(3+4j), complex(2-1j), complex(3-4j), complex(2+1j)]

print(random_array)
# print(random_array1)
print("\n")
# print(len(random_array))

enc = Enc(logN, logQ)
encoded = enc.encode(random_array, slots, logp)
print(encoded)
enc.decode(encoded, slots, logp, logQ)



# Mock the mx array (encoded data)
# mx is of size Nh = N / 2, containing values encoded in frequency domain



# decoded_data = enc.decode(mx, slots, logp, logq)

# print("Decoded Data:")
# for i, val in enumerate(decoded_data):
    # print(f"Slot {i}: {val}")

[(3+4j), (2-1j), (3-4j), (2+1j)]


[(1.5+0j), np.complex128(0.9118762555319925j), (1+0j), np.complex128(-0.22796906388299812j)]
[0, 0, -244779918, 0]


[np.complex128(-0.10393978797610298j),
 np.complex128(0.10393978797610298j),
 np.complex128(-0.10393978797610298j),
 np.complex128(0.10393978797610298j)]