# Dependencies

In [42]:
import math
import time
import os
import random
import itertools
import numpy as np
from numpy.fft import rfft, irfft
from numpy import multiply, prod
from sympy import GF, invert
from sympy.abc import x
from sympy import ZZ, QQ, FF, Poly
from sympy import intt, ntt, fft, ifft, simplify
from sympy.polys import polytools
import pickle

# Number theory functions

In [43]:
def is_prime(n):
    for i in range(2, int(n ** 0.5) + 1):
        if n % i == 0:
            return False
    return True

def gcd(p, q):
    if p == 0:
        return q
    return gcd(q%p,p)

def find_inverse(p, m):
    if p % m == 1:
        return 1
    for candidate in range(2, m):
        # print((p*candidate) % m)
        if (p*candidate) % m == 1:
            return candidate
    raise Exception("no inverse found")

def get_coprime(p_range, q_range):
    for p_candidate in p_range:
        for q_candidate in q_range:
            print(p_candidate, q_candidate, gcd(p_candidate, q_candidate))
            if gcd(p_candidate, q_candidate) == 1:
                yield p_candidate, q_candidate

def SieveOfEratosthenes(n, path="./primes.pkl"): # Taken from https://www.geeksforgeeks.org/python-program-for-sieve-of-eratosthenes/
        # Create a boolean array "prime[0..n]" and initialize 
        # all entries it as true. A value in prime[i] will 
        # finally be false if i is Not a prime, else true. 
        prime = [True for i in range(n + 1)] 
        p = 2
        while (p * p <= n): 

            # If prime[p] is not changed, then it is a prime 
            if (prime[p] == True): 

                # Update all multiples of p 
                for i in range(p * 2, n + 1, p): 
                    prime[i] = False
            p += 1
        prime[0]= False
        prime[1]= False

        f = open(path, "wb")
        pickle.dump(prime, f)

def generate_primes(start_range, end_range, path="./primes.pkl"):
    f = open(path, "rb")
    primes = np.array(pickle.load(f))
    if end_range>=len(primes):
        print("checkpoint")
        SieveOfEratosthenes(end_range)
        f = open(path, "rb")
        primes = np.array(pickle.load(f))
    prime_subset = primes[start_range:end_range]
    final_set = np.array((prime_subset == True).nonzero()) + start_range
    return final_set

### Precompute primes

In [44]:
# precompute a few primes
SieveOfEratosthenes(10000000)

# FFT function

In [45]:
def fft_polynomial_multiply(expression_1,expression_2):  #fft based real-valued polynomial multiplication
    
    def next_power_of_two(x):
        return int(math.ceil(math.log(x) / math.log(2)))
    
    expression_1_coeffs = expression_1.all_coeffs()
    expression_2_coeffs = expression_2.all_coeffs()
    expression_1_coeffs.reverse()
    expression_2_coeffs.reverse()
    
    print("EXPRESSION COEFFS")
    print(expression_1_coeffs)
    print(expression_2_coeffs)
    
    highest_coeff_power = max(len(expression_1_coeffs), len(expression_2_coeffs))
    
    fft_length = 2**(next_power_of_two(highest_coeff_power) + 1)
    
    print("FFT LENGTH")
    print(fft_length)
    
    padded_expression_1_coeffs = expression_1_coeffs + [0]*(fft_length - len(expression_1_coeffs))
    padded_expression_2_coeffs = expression_2_coeffs + [0]*(fft_length - len(expression_2_coeffs))
    
    print("PADDED COEFFS")
    print(padded_expression_1_coeffs)
    print(padded_expression_2_coeffs)
    
    point_value_representation_1=fft(padded_expression_1_coeffs)
    point_value_representation_2=fft(padded_expression_2_coeffs)

    print("POINT VAL REP, PADDED")
    print(point_value_representation_1)
    print(point_value_representation_2)
    multiplied_point_value_representation = np.multiply(point_value_representation_1, point_value_representation_2)
    
    multiplied_coefficients = ifft(multiplied_point_value_representation)
    multiplied_coefficients.reverse()
    
    print("IFFT Applied")
    print(multiplied_coefficients)
    
    return Poly(multiplied_coefficients, x)

### FFT Tester

In [46]:
j, k = (Poly([3842,-23,-23525,2],x), Poly([-129,1,-1598918,-7],x))

print(fft_polynomial_multiply(j,k))
print(j*k)

EXPRESSION COEFFS
[2, -23525, -23, 3842]
[-7, -1598918, 1, -129]
FFT LENGTH
8
PADDED COEFFS
[2, -23525, -23, 3842, 0, 0, 0, 0]
[-7, -1598918, 1, -129, 0, 0, 0, 0]
POINT VAL REP, PADDED
[-19704, -27367*sqrt(2)/2 + 2 - 19683*sqrt(2)*I/2 - 23*I, 25 - 27367*I, 2 + 27367*sqrt(2)/2 - 19683*sqrt(2)*I/2 + 23*I, 19662, 2 + 27367*sqrt(2)/2 - 23*I + 19683*sqrt(2)*I/2, 25 + 27367*I, -27367*sqrt(2)/2 + 2 + 23*I + 19683*sqrt(2)*I/2]
[-1599053, -1598789*sqrt(2)/2 - 7 - 1599047*sqrt(2)*I/2 + I, -8 - 1598789*I, -7 + 1598789*sqrt(2)/2 - 1599047*sqrt(2)*I/2 - I, 1599041, -7 + 1598789*sqrt(2)/2 + I + 1599047*sqrt(2)*I/2, -8 + 1598789*I, -1598789*sqrt(2)/2 - 7 - I + 1599047*sqrt(2)*I/2]
IFFT Applied
[0, -495618, 6809, -6140008254, 36724437, 37614546113, -3033161, -14]
Poly(-495618*x**6 + 6809*x**5 - 6140008254*x**4 + 36724437*x**3 + 37614546113*x**2 - 3033161*x - 14, x, domain='ZZ')
Poly(-495618*x**6 + 6809*x**5 - 6140008254*x**4 + 36724437*x**3 + 37614546113*x**2 - 3033161*x - 14, x, domain='ZZ')


# CRT function

In [47]:

def crt_multiply(expression_1, expression_2, p_list, multiplication_function = None):
    # d is number of factors
    M_over_mi_dictionary = {}
    M_over_mi_inverse_dictionary = {}
    component_dictionary = {}
    M = prod(p_list)
    accumulator = 0
    
    # Precompute constants
    for prime_factor in p_list:
        M_over_mi_dictionary[prime_factor] = (M//prime_factor)%prime_factor
        M_over_mi_inverse_dictionary[prime_factor] = find_inverse(M_over_mi_dictionary[prime_factor], prime_factor)
        # print(M_over_mi_inverse_dictionary[prime_factor])
        component_dictionary[prime_factor] = (expression_1.trunc(prime_factor), expression_2.trunc(prime_factor))    
        
    for prime_factor in p_list:
        if multiplication_function is None:
            accumulator += component_dictionary[prime_factor][0]*component_dictionary[prime_factor][1]*M_over_mi_inverse_dictionary[prime_factor]*M/prime_factor
        else:
            current_product = multiplication_function(component_dictionary[prime_factor][0],component_dictionary[prime_factor][1])
            print("CURRENT CRT INNER PRODUCT")
            print(current_product)
            accumulator += current_product*M_over_mi_inverse_dictionary[prime_factor]*M/prime_factor
    # print(M_over_mi_dictionary)
    # print(M_over_mi_inverse_dictionary)
    return polytools.trunc(accumulator, M)
    

### CRT Tester

In [48]:
a = Poly(x**20 + x**12 + x**11 + 5069*x**10-34*x**7+x**4+54*x**2+x-1, x)
b = Poly(x**23 + x**11 + 5038*x**9-26*x**7+x**6-4*x**5+x**3-x+69, x)

c = Poly(x**4+x**2-x+1, x)
d = Poly(x**4-x**2+x-1, x)
primes = [2**4+1, 2**8+1]
primes = [18223, 18229]
M = prod(primes)

print("RESULTS")
print(crt_multiply(a, b, primes))

print("ACTUAL")
print(a*b.trunc(M))

RESULTS
x**43 + x**35 + x**34 + 5069*x**33 + x**31 - 34*x**30 + 5038*x**29 - 25*x**27 + x**26 + 50*x**25 + x**24 + x**23 + x**22 + 10106*x**21 + 5107*x**20 + 25537596*x**19 - 59*x**18 - 131797*x**17 - 166227*x**16 - 20274*x**15 + 885*x**14 + 10126*x**13 + 205*x**12 + 267025*x**11 + 354766*x**10 - 6446*x**9 + 62*x**8 - 2534*x**7 - 5*x**6 + 57*x**5 + 70*x**4 - 55*x**3 + 3725*x**2 + 70*x - 69
ACTUAL
Poly(x**43 + x**35 + x**34 + 5069*x**33 + x**31 - 34*x**30 + 5038*x**29 - 25*x**27 + x**26 + 50*x**25 + x**24 + x**23 + x**22 + 10106*x**21 + 5107*x**20 + 25537596*x**19 - 59*x**18 - 131797*x**17 - 166227*x**16 - 20274*x**15 + 885*x**14 + 10126*x**13 + 205*x**12 + 267025*x**11 + 354766*x**10 - 6446*x**9 + 62*x**8 - 2534*x**7 - 5*x**6 + 57*x**5 + 70*x**4 - 55*x**3 + 3725*x**2 + 70*x - 69, x, domain='ZZ')


# Parameter generation methods

In [49]:
### Parameter generation
# Generate f, g
# Generate N, p, q
# Check invertibility of f mod p, q
# Create h from p, f, g

def apply_dual_modulus(expression, q, N):
    return expression.trunc(q) % Poly(x**N-1,x).set_domain(ZZ)

def generate_polynomial(terms, ones=None, zeros=None):
    # Note terms = degree - 1
    if ones is None and zeros is None:
        ones = terms//3
        zeros = terms//3
    elif ones is None:
        ones = (terms - zeros)//2
    elif zeros is None:
        zeros = (terms - ones)//2
    
    if terms < ones + zeros:
        raise Exception("N must be greater than total ones and zeros")
    
    ones_vector = np.ones(ones)
    zeros_vector = np.zeros(zeros)
    negative_ones_vector = -np.ones(terms-ones-zeros)
    unordered_polynomial_coefficients = np.concatenate((ones_vector, zeros_vector, negative_ones_vector))
    return Poly(np.random.permutation(unordered_polynomial_coefficients),x).set_domain(ZZ)

def invert_polynomial(f_polynomial, N, p):
    modulus_polynomial=Poly(x**N-1,x).set_domain(ZZ)
    return invert(f_polynomial, modulus_polynomial, domain=GF(p))

def invert_poly(f_poly, N, p):
    R_poly = Poly(x**N-1,x).set_domain(ZZ)
    inv_poly = None
    if is_prime(p):
        log.debug("Inverting as p={} is prime".format(p))
        inv_poly = invert(f_poly, R_poly, domain=GF(p))
    elif is_2_power(p):
        log.debug("Inverting as p={} is 2 power".format(p))
        inv_poly = invert(f_poly, R_poly, domain=GF(2))
        e = int(math.log(p, 2))
        for i in range(1, e):
            log.debug("Inversion({}): {}".format(i, inv_poly))
            inv_poly = ((2 * inv_poly - f_poly * inv_poly ** 2) % R_poly).trunc(p)

def generate_parameters(N_range=range(250, 2500), p_range=range(250, 2500), q_range=range(3, 4)):
    # Generate p, q, N such that
    # N is a a large prime representing the degree of our modulo for our residue class
    # p is a small prime modulus coprime to q
    # q is a large modulus coprime to p
    
    # Todo: enforce q >> p

    ATTEMPTS = 100

    def sample_range(number_range, looking_for_prime=True):
        if looking_for_prime:
            for candidate_prime in number_range:
                # print(candidate_prime)
                if is_prime(candidate_prime):
                    return candidate_prime
            raise ValueError("no primes found in defined range")
        else:
            return random.choice(number_range)

    def get_coprime(p_range, q_range):
        for p_candidate in p_range:
            for q_candidate in q_range:
                if gcd(p_candidate, q_candidate) == 1:
                    # print(p_candidate, q_candidate, gcd(p_candidate, q_candidate))
                    yield p_candidate, q_candidate
    
    def get_f_g_h(p, q, N):
        current_attempt=0
        for current_attempt in range(ATTEMPTS):
            f = generate_polynomial(N)
#             print("Candidate")
#             print(f)
            try:
                f_p = invert_polynomial(f, N, p) # TODO: Replace to utilize crt_multiply
            except:
                continue
            try:
                f_q = invert_polynomial(f, N, q) # TODO: Replace to utilize crt_multiply
            except:
                continue
            g = generate_polynomial(N)
            h = (p*f_q*g).trunc(q)
            return f, f_p, f_q, g, h 
        raise Exception("Cannot find invertible f")

    N = sample_range(N_range, looking_for_prime=True)
    print("checkpoint 1")
    f = None
    for p, q in get_coprime(p_range, q_range):
        # print("Got coprime p, q: ", p, q)
        try:
            f, f_p, f_q, g, h = get_f_g_h(p, q, N)
            
            if f is not None:
                return p, q, f, f_p, f_q, g, h
        except:
            pass
    raise Exception("invertible f not found")

# End-to-end example

In [50]:
### Generate message
# enforce r must be small
N = 401
m = generate_polynomial(N)
print(m)
r = generate_polynomial(N//4)
print(r)

Poly(-x**399 + x**398 - x**397 - x**396 - x**394 - x**392 + x**391 - x**390 + x**389 - x**388 - x**386 + x**384 - x**382 - x**381 + x**380 + x**379 + x**377 + x**376 + x**375 + x**372 - x**371 - x**370 - x**369 + x**367 + x**365 - x**364 + x**362 + x**361 + x**359 - x**358 - x**357 + x**356 + x**355 - x**354 - x**353 - x**352 + x**351 + x**349 - x**347 + x**346 + x**345 + x**344 - x**343 + x**341 + x**340 - x**339 + x**338 + x**335 + x**334 - x**332 - x**331 + x**330 - x**326 + x**324 - x**322 + x**321 - x**320 - x**318 + x**317 + x**316 - x**314 - x**312 - x**311 + x**310 - x**309 - x**308 + x**307 + x**306 + x**305 + x**304 + x**302 - x**299 + x**298 + x**297 - x**296 + x**295 + x**294 + x**293 - x**292 + x**291 - x**290 + x**289 - x**285 - x**282 + x**280 - x**279 - x**278 - x**277 + x**275 + x**274 - x**273 + x**271 + x**270 + x**269 + x**268 + x**263 - x**262 - x**261 - x**259 + x**257 - x**256 + x**253 + x**251 + x**249 - x**247 + x**245 - x**244 - x**243 - x**242 - x**240 - x**2

### KeyGen

In [27]:
# q > (6d + 1) p
# N > 2n^2p

N = 401
p, q, f, f_p, f_q, g, h = generate_parameters(range(N, N+1), range(3, 4), range(2001, 2997))



checkpoint 1


### Generate primes for P-List

In [None]:
p_list_all = generate_primes(10000000000, 10000010000)
p_list = random.sample(p_list_all, 3)

### Encrypt

In [57]:
# compute e=r*h+m mod q
# IEEE STANDARD?
# Ideas: run on 3 systems to compare performance
# Write script to create test cases with bins of cases based on difficulty.

def encrypt(r, h, q, m, N, use_fft_crt=False, p_list=[257,569]):
    if use_fft_crt:
        print("# Checkpoint 1")
        product = crt_multiply(r, h, p_list=[], multiplication_function = fft_polynomial_multiply)
        print("# Checkpoint 2")
        return (((r*h).trunc(q)+m).trunc(q) % Poly(x**N-1,x)).trunc(q)
    else:
        return (((r*h).trunc(q)+m).trunc(q) % Poly(x**N-1,x)).trunc(q)

p_list = [3, 5, 7]

print("r", r)
print("h", h)
r*h

cipher_text = encrypt(r, h, q, m, N, False, p_list)

### Decrypt

In [58]:
# compute a=f*e

def decrypt(cipher_text, f, f_p, N, p, q, use_fft_crt=False):
    if use_fft_crt:
        pass
    else:
        a = ((f * cipher_text) % Poly(x**N-1,x)).trunc(q)
        b = a.trunc(p)
        return ((f_p * b)% Poly(x**N-1,x)).trunc(p)

decrypted_text = decrypt(cipher_text, f, f_p, N, p, q)
print(decrypted_text)

### Verify Message

In [None]:
print(m)

# Execution time testing

In [54]:
execution_times = {} # list of execution times

In [26]:
# change to create and store examples in different files
# run once for every paramset you want to store multiple examples for
examples_path = "./112_bit_examples.pkl" 
if os.path.isfile(examples_path):
    f = open(examples_path, "rb")
    parameter_sets = pickle.load(f)
else:
    parameter_sets = []

In [37]:
### KeyGen
# q > (6d + 1) p
# N > 2n^2p

N = 401
p, q, f, f_p, f_q, g, h = generate_parameters(range(N, N+1), range(3, 4), range(2001, 2997))

checkpoint 1


In [38]:
print(p, q, f, f_p, f_q, g, h)

parameter_set = {
    'p': p,
    'q': q,
    'f': f,
    'f_p': f_p,
    'f_q': f_q,
    'g': g,
    'h': h
}

parameter_sets.append(parameter_set)

3 2003 Poly(-x**400 + x**398 + x**397 - x**395 + x**394 - x**392 - x**391 + x**390 - x**389 - x**388 - x**387 + x**385 - x**383 + x**381 - x**380 - x**379 + x**378 + x**377 - x**376 + x**375 - x**373 - x**372 - x**370 + x**369 + x**367 - x**364 - x**362 - x**361 + x**360 + x**359 + x**358 - x**357 + x**356 - x**355 - x**353 + x**352 + x**351 - x**350 - x**348 - x**346 + x**345 + x**344 + x**343 - x**341 + x**340 + x**339 - x**338 + x**337 - x**335 - x**334 - x**333 - x**332 + x**330 - x**329 + x**327 - x**325 - x**323 + x**320 - x**319 + x**318 + x**316 - x**313 - x**312 - x**310 + x**309 + x**306 - x**305 + x**303 + x**302 - x**299 + x**298 + x**297 - x**296 - x**295 + x**292 + x**291 + x**289 - x**288 - x**287 + x**286 - x**282 - x**280 - x**279 - x**278 + x**277 + x**276 + x**275 + x**274 - x**273 + x**272 - x**269 + x**268 + x**267 - x**265 - x**264 + x**263 + x**261 - x**260 + x**259 - x**257 - x**256 - x**253 - x**252 + x**251 - x**249 + x**248 - x**246 + x**245 + x**244 - x**243

In [39]:
f = open(examples_path, "wb")
pickle.dump(parameter_sets, f)

In [55]:
execution_times[examples_path] = []

In [61]:
for parameter_set_example in parameter_sets:
    p = parameter_set_example['p']
    q = parameter_set_example['q']
    f = parameter_set_example['f']
    f_p = parameter_set_example['f_p']
    f_q = parameter_set_example['f_q']
    g = parameter_set_example['g']
    h = parameter_set_example['h']
    
    
    ### Encrypt
    # compute e=r*h+m mod q
    print("ENCRYPT STARTING")
    encrypt_start_time = time.time()
    cipher_text = encrypt(r, h, q, m, N, False, [3,5,7,11])
    encrypt_execution_time = time.time() - encrypt_start_time
    print("ENCRYPT EXECUTION TIME: ", encrypt_execution_time)

    ### Decrypt
    # compute a=f*e
    print("DECRYPT STARTING")
    decrypt_start_time = time.time()
    print(decrypt(cipher_text, f, f_p, N, p, q))
    decrypt_execution_time = time.time() - decrypt_start_time
    print("DECRYPT EXECUTION TIME: ", decrypt_execution_time)
    
    execution_times[examples_path].append({'encrypt': encrypt_execution_time, 'decrypt': decrypt_execution_time})
    


ENCRYPT STARTING
ENCRYPT EXECUTION TIME:  0.11510705947875977
DECRYPT STARTING
Poly(-x**399 + x**398 - x**397 - x**396 - x**394 - x**392 + x**391 - x**390 + x**389 - x**388 - x**386 + x**384 - x**382 - x**381 + x**380 + x**379 + x**377 + x**376 + x**375 + x**372 - x**371 - x**370 - x**369 + x**367 + x**365 - x**364 + x**362 + x**361 + x**359 - x**358 - x**357 + x**356 + x**355 - x**354 - x**353 - x**352 + x**351 + x**349 - x**347 + x**346 + x**345 + x**344 - x**343 + x**341 + x**340 - x**339 + x**338 + x**335 + x**334 - x**332 - x**331 + x**330 - x**326 + x**324 - x**322 + x**321 - x**320 - x**318 + x**317 + x**316 - x**314 - x**312 - x**311 + x**310 - x**309 - x**308 + x**307 + x**306 + x**305 + x**304 + x**302 - x**299 + x**298 + x**297 - x**296 + x**295 + x**294 + x**293 - x**292 + x**291 - x**290 + x**289 - x**285 - x**282 + x**280 - x**279 - x**278 - x**277 + x**275 + x**274 - x**273 + x**271 + x**270 + x**269 + x**268 + x**263 - x**262 - x**261 - x**259 + x**257 - x**256 + x**253

Poly(-x**399 + x**398 - x**397 - x**396 - x**394 - x**392 + x**391 - x**390 + x**389 - x**388 - x**386 + x**384 - x**382 - x**381 + x**380 + x**379 + x**377 + x**376 + x**375 + x**372 - x**371 - x**370 - x**369 + x**367 + x**365 - x**364 + x**362 + x**361 + x**359 - x**358 - x**357 + x**356 + x**355 - x**354 - x**353 - x**352 + x**351 + x**349 - x**347 + x**346 + x**345 + x**344 - x**343 + x**341 + x**340 - x**339 + x**338 + x**335 + x**334 - x**332 - x**331 + x**330 - x**326 + x**324 - x**322 + x**321 - x**320 - x**318 + x**317 + x**316 - x**314 - x**312 - x**311 + x**310 - x**309 - x**308 + x**307 + x**306 + x**305 + x**304 + x**302 - x**299 + x**298 + x**297 - x**296 + x**295 + x**294 + x**293 - x**292 + x**291 - x**290 + x**289 - x**285 - x**282 + x**280 - x**279 - x**278 - x**277 + x**275 + x**274 - x**273 + x**271 + x**270 + x**269 + x**268 + x**263 - x**262 - x**261 - x**259 + x**257 - x**256 + x**253 + x**251 + x**249 - x**247 + x**245 - x**244 - x**243 - x**242 - x**240 - x**2

In [63]:
parameter_sets

[{'p': 3,
  'q': 2003,
  'f': Poly(-x**400 - x**399 - x**398 + x**397 + x**396 + x**395 - x**394 + x**393 + x**390 - x**388 - x**387 - x**385 - x**384 - x**380 - x**379 + x**378 - x**377 + x**376 + x**373 + x**372 - x**371 + x**369 + x**367 - x**364 - x**363 + x**361 + x**359 - x**357 - x**355 - x**354 + x**353 + x**352 - x**350 + x**348 + x**346 - x**344 - x**343 + x**342 + x**340 + x**339 + x**336 - x**333 - x**332 - x**331 - x**329 - x**327 + x**325 - x**324 - x**323 + x**322 - x**321 - x**320 + x**319 + x**317 + x**316 + x**315 + x**314 + x**312 - x**310 + x**308 + x**306 + x**305 - x**304 + x**302 - x**299 + x**298 - x**296 - x**292 + x**291 - x**288 - x**287 + x**284 - x**283 + x**282 + x**281 - x**280 + x**279 - x**278 + x**275 - x**274 - x**273 + x**272 + x**270 - x**269 - x**268 - x**267 + x**265 - x**264 + x**263 + x**262 - x**261 + x**260 + x**259 - x**257 + x**255 - x**251 + x**249 - x**247 + x**245 - x**243 - x**242 + x**241 - x**238 + x**236 - x**235 - x**234 + x**233 - x

In [None]:
execution_times