In [8]:
import itertools, numbers, math
import random, unittest

In [9]:
# Transform an integer number to a coefficient vector in radix-base
def int2vec(number, base):
    vec = []
    while number > 0:
        number, digit = divmod(number, base)
        vec.append(int(digit))
    return vec

In [10]:
# Returns silently if the given value is an integer, otherwise raises a TypeError.
def check_int(n):
    if not isinstance(n, numbers.Integral):
        raise TypeError()

In [11]:
# Returns floor(sqrt(n)) for the given integer n >= 0.
def sqrt(n):
    check_int(n)
    if n < 0:
        raise ValueError()
    i = 1
    while i * i <= n:
        i *= 2
    result = 0
    while i > 0:
        if (result + i)**2 <= n:
            result += i
        i //= 2
    return result

In [12]:
# Tests whether the given integer n >= 2 is a prime number.
def is_prime(n):
    check_int(n)
    if n <= 1:
        raise ValueError()
    return all((n % i != 0) for i in range(2, sqrt(n) + 1))

In [13]:
# Returns a list of unique prime factors of the given integer in
# ascending order. For example, unique_prime_factors(60) = [2, 3, 5].
def unique_prime_factors(n):
    check_int(n)
    if n < 1:
        raise ValueError()
    result = []
    i = 2
    end = sqrt(n)
    while i <= end:
        if n % i == 0:
            n //= i
            result.append(i)
            while n % i == 0:
                n //= i
            end = sqrt(n)
        i += 1
    if n > 1:
        result.append(n)
    return result

In [14]:
# Returns the multiplicative inverse of n modulo mod. The inverse x has the property that
# 0 <= x < mod and (x * n) % mod = 1. The inverse exists if and only if gcd(n, mod) = 1.
def reciprocal(n, mod):
    check_int(n)
    check_int(mod)
    if not (0 <= n < mod):
        raise ValueError()
    x, y = mod, n
    a, b = 0, 1
    while y != 0:
        a, b = b, a - x // y * b
        x, y = y, x % y
    if x == 1:
        return a % mod
    else:
        raise ValueError("Reciprocal does not exist")

In [15]:
# Returns a vector with zero padding on the higher indexes of the input
# vector, which satisfies that its length is power of 2.
# Example: If the length of vec is 6, we need to pad 2 zeroes to make the length
# to 8=2^3, and then double the length by zero padding agin to 16 in order to 
# avoid overflow; If the length of vec is 8, we just need to pad 8 zeroes.
def zero_padding(vec):
    #---------------------------------
    # please provide your codes here
    #---------------------------------
    i = 0
    v = len(vec)
    while(2**i<=v):
        i+=1
        
    padding_len = 2**i * 2 - v
    for i in range(padding_len):
        vec.append(0)
    return vec

In [16]:
#Return the minimum modulus according to the input vectors vec0 and vec1. Assume vec0 and vec1 have the same length.
#When we have convolution on two vectors of length 𝑛 where each input coefficient value is at most 𝑚 (i.e. 0≤"x(t),y(t)"<"m"), 
#the upper bound on each output value is 𝑛𝑚^2. Choosing a minimum modulus 𝑀=𝑛𝑚^2+1.
def find_minmod(vec0, vec1):
    n = len(vec0)
    maxval = max(val for val in itertools.chain(vec0, vec1))
    minmod = maxval**2 * n + 1
    return minmod

In [17]:
# Returns the smallest modulus n such that n = i * veclen + 1
# for some integer i >= 1, n > veclen, and n >= minmod is prime.
# Although the loop might run for a long time and create arbitrarily large numbers,
# Dirichlet's theorem guarantees that such a prime number must exist.
def find_modulus(veclen, minmod):
    check_int(veclen)
    check_int(minmod)
    if veclen < 1 or minmod < 1:
        raise ValueError()
    #---------------------------------
    # please provide your codes here
    #---------------------------------
    i = 1
    n = i*veclen + 1
    while n<minmod or not is_prime(n):
        i+=1
        n = i*veclen + 1
    return n

In [18]:
# Tests whether val generates the multiplicative group of integers modulo mod. totient
# must equal the Euler phi function of mod. In other words, the set of numbers
# {val^1 % mod, val^2 % mod, ..., val^totient % mod} is equal to the set of all
# numbers in the range (0, mod) that are coprime to mod. If mod is prime, then
# totient = mod - 1, and powers of a generator produces all integers in the range [1, mod).
# Return True or False.
def is_generator(val, totient, mod):
    check_int(val)
    check_int(totient)
    check_int(mod)
    if not (0 <= val < mod):
        raise ValueError()
    if not (1 <= totient < mod):
        raise ValueError()
    #---------------------------------
    # please provide your codes here
    #---------------------------------
    ps =  unique_prime_factors(totient)
    is_gene = 1
    if val**totient % mod == 1:
        for p in ps:
            if val**(totient//p) % mod == 1:
                is_gene = 0
                break
    else:
        is_gene = 0
    return True if is_gene else False

In [19]:
# Returns an arbitrary generator of the multiplicative group of integers modulo mod.
# totient must equal the Euler phi function of mod. If mod is prime, an answer must exist.
def find_generator(totient, mod):
    check_int(totient)
    check_int(mod)
    if not (1 <= totient < mod):
        raise ValueError()
    for i in range(1, mod):
        if is_generator(i, totient, mod):
            return i
    raise ValueError("No generator exists")


In [20]:
# Returns an arbitrary primitive degree-th root of unity modulo mod.
# totient must be a multiple of degree. If mod is prime, an answer must exist.
# Hint: function find_generator() is required.
def find_primitive_root(degree, totient, mod):
    print ('degree, totient, mod',degree, totient, mod)
    check_int(degree)
    check_int(totient)
    check_int(mod)
    if not (1 <= degree <= totient < mod):
        raise ValueError()
    if totient % degree != 0:
        raise ValueError()
    #---------------------------------
    # please provide your codes here
    #---------------------------------
    g = find_generator(totient,mod)
    c = totient // degree
    root = g**c % mod
    return root

In [21]:
# Returns the forward number-theoretic transform of the given vector with
# respect to the given primitive nth root of unity under the given modulus.
def dft_ntt(invec, root, mod):
    check_int(root)
    check_int(mod)
    if len(invec) >= mod:
        raise ValueError()
    if not all((0 <= val < mod) for val in invec):
        raise ValueError()
    if not (1 <= root < mod):
        raise ValueError()
    
    outvec = []
    #---------------------------------
    # please provide your codes here
    #---------------------------------
    for k in range(len(invec)):
        temp = 0
        for t in range(len(invec)):
            temp += invec[t]*root**(t*k)
        outvec.append(temp%mod)
    return outvec

In [22]:
# Returns the inverse number-theoretic transform of the given vector with
# respect to the given primitive nth root of unity under the given modulus.
def idft_ntt(invec, root, mod):
    outvec = dft_ntt(invec, reciprocal(root, mod), mod)
    scaler = reciprocal(len(invec), mod)
    return [(val * scaler % mod) for val in outvec]

In [23]:
def bit_reverse(x, bits):
    y = 0
    for i in range(bits):
        y = (y<<1) | (x&1)
        x >>= 1
    return y

In [24]:
a = [0,1,2,3,4,5,6,7]
[bit_reverse(i,3) for i in a]

[0, 4, 2, 6, 1, 5, 3, 7]

In [25]:
# Computes the forward number-theoretic transform in in-place FFT structure of the given vector,
# with respect to the given primitive nth root of unity under the given modulus.
# The length of the vector must be a power of 2.
def dft_fft(vector, root, mod):
    n = len(vector)
    levels = n.bit_length() - 1
    if 1 << levels != n:
        raise ValueError("Length is not a power of 2")
    
    #---------------------------------
    # please provide your codes here
    #---------------------------------
    idx = [bit_reverse(i, levels) for i in range(n)]
    x = [vector[i] for i in idx]
    lgn = math.log(n,2)
    X =  [0] * n
    
    for j in range(int(lgn)):
        for k in range(int(n/2)):
            p = math.floor(k/(2**(lgn-1-j)))*(2**(lgn-1-j))
            X[k] = (x[2*k] + x[2*k+1]*root**p) % mod
            X[k+int(n/2)] = (x[2*k] - x[2*k+1]*root**p) % mod
        if j != lgn - 1:
            for k in range(n):
                x[k] = X[k]
                
    return X

In [26]:
# Returns the inverse number-theoretic transform (FFT structure) of the given vector with
# respect to the given primitive nth root of unity under the given modulus.
def idft_fft(invec, root, mod):
    outvec = dft_fft(invec, reciprocal(root, mod), mod)
    scaler = reciprocal(len(invec), mod)
    return [(val * scaler % mod) for val in outvec]

In [52]:
def fft_mult(in0, in1):
    
    #transform integer to vector
    base = 10
    vec0 = int2vec(in0,base)
    vec1 = int2vec(in1,base)

    #input validation check
    if not (0 < len(vec0) == len(vec1)):
        raise ValueError()
    if any((val < 0) for val in itertools.chain(vec0, vec1)):
        raise ValueError()

    #zero padding the vectors to length of power of 2
    vec0 = zero_padding(vec0);
    vec1 = zero_padding(vec1);

    #parameter selction
    n = len(vec0)
    minmod = find_minmod(vec0, vec1)
    mod = find_modulus(n, minmod)
    root = find_primitive_root(n, mod-1, mod)

    #forward transforms
    vec0 = dft_ntt(vec0, root, mod)
    vec1 = dft_ntt(vec1, root, mod)

    #point-wise multiplication in frequency domain
    vec2 = [(x * y % mod) for (x, y) in zip(vec0, vec1)]
    print(vec2)

    #inverse transform
    result = idft_ntt(vec2, root, mod)

    #transform vector to integer
    product = 0
    for i in range(len(result)):
        product += result[i] * base**i
    return product

#### Test find_modulus() function: 

In [28]:
#test set 1
veclen = 32
minimum = 2593
expect = 2593
actual = find_modulus(veclen, minimum)
print ('test find modulus pass?', actual == expect)

test find modulus pass? True


In [29]:
#test set 2
veclen = 16
minimum = 785
expect = 881
actual = find_modulus(veclen, minimum)
print ('test find modulus pass?', actual == expect)

test find modulus pass? True


In [30]:
#test set 3
veclen = 8
minimum = 393
expect = 401
actual = find_modulus(veclen, minimum)
print ('test find modulus pass?', actual == expect)

test find modulus pass? True


#### Test `is_generator()` function: 

In [31]:
#test set 1
val = 2
totient = 400
mod = 401
expect = False
actual = is_generator(val, totient, mod)
print ('test is generator pass?', actual == expect)

test is generator pass? True


In [32]:
#test set 2
val = 3
totient = 400
mod = 401
expect = True
actual = is_generator(val, totient, mod)
print ('test is generator pass?', actual == expect)

test is generator pass? True


In [33]:
#test set 3
val = 10
totient = 1296
mod = 1297
expect = True
actual = is_generator(val, totient, mod)
print ('test is generator pass?', actual == expect)

test is generator pass? True


#### Test `find_primitive_root()` function: 

In [34]:
#test set 1
degree = 16
totient = 1296
mod = 1297
expect = 355
actual = find_primitive_root(degree, totient, mod)
print ('test find primitive root pass?', actual == expect)

degree, totient, mod 16 1296 1297
test find primitive root pass? True


In [35]:
#test set 2
degree = 32
totient = 2592
mod = 2593
expect = 1997
actual = find_primitive_root(degree, totient, mod)
print ('test find primitive root pass?', actual == expect)

degree, totient, mod 32 2592 2593
test find primitive root pass? True


In [36]:
#test set 3
degree = 64
totient = 5440
mod = 5441
expect = 1638
actual = find_primitive_root(degree, totient, mod)
print ('test find primitive root pass?', actual == expect)

degree, totient, mod 64 5440 5441
test find primitive root pass? True


#### Test `forward_transform()` function: 

In [46]:
#actual = dft_ntt([6, 0, 10, 7, 2], 3, 11)
#expect = [3, 7, 0, 5, 4]
actual = dft_fft([1,3,5,6,0,0,0,0], 609, 673)
expect = [15, 36, 495, 305, 670, 546, 170, 463]
print ('test forward transform pass?', actual == expect)

test forward transform pass? True


#### Test `inverse_transform()` function: 

In [47]:
#actual = idft_ntt([3, 7, 0, 5, 4], 3, 11)
#expect = [6, 0, 10, 7, 2]
actual = idft_fft([15, 36, 495, 305, 670, 546, 170, 463], 609, 673)
expect = [1,3,5,6,0,0,0,0]
print ('test inverse transform pass?',  actual == expect)

test inverse transform pass? True


#### Test `fft_mult()` function for FFT: 

In [56]:
def fft_mult(in0, in1):
    
    #transform integer to vector
    base = 10
    vec0 = int2vec(in0,base)
    vec1 = int2vec(in1,base)

    #input validation check
    if not (0 < len(vec0) == len(vec1)):
        raise ValueError()
    if any((val < 0) for val in itertools.chain(vec0, vec1)):
        raise ValueError()

    #zero padding the vectors to length of power of 2
    vec0 = zero_padding(vec0);
    vec1 = zero_padding(vec1);

    #parameter selction
    n = len(vec0)
    minmod = find_minmod(vec0, vec1)
    mod = find_modulus(n, minmod)
    root = find_primitive_root(n, mod-1, mod)

    #forward transforms
    vec0 = dft_fft(vec0, root, mod)
    vec1 = dft_fft(vec1, root, mod)

    #point-wise multiplication in frequency domain
    vec2 = [int(x * y % mod) for (x, y) in zip(vec0, vec1)]
    print(vec2)

    #inverse transform
    result = idft_fft(vec2, root, mod)

    #transform vector to integer
    product = 0
    for i in range(len(result)):
        product += result[i] * base**i
    return product

In [57]:
def main():
    print ('test_main')
    input0 = 1276582236958546324
    input1 = 1475489236589542321
    result = fft_mult(input0, input1)
    print (result)
    print (input0 * input1)
    print ('test FFT multiplication pass?', result == input0 * input1)

main()

test_main
degree, totient, mod 64 5440 5441
[2303, 4387, 5344, 180, 3924, 1809, 2574, 4580, 5281, 1097, 4730, 3306, 3353, 1250, 1682, 1171, 2725, 4599, 2032, 4136, 3263, 58, 1970, 153, 1926, 4483, 703, 1289, 4859, 2783, 2917, 806, 16, 3549, 764, 1353, 4377, 2666, 2574, 4580, 5281, 1097, 4730, 3306, 3353, 1250, 1682, 1171, 2725, 4599, 2032, 4136, 3263, 58, 1970, 153, 1926, 4483, 703, 1289, 4859, 2783, 2917, 806]
1.5001644288678464e+65
1883583350253735734193281844636978004
test FFT multiplication pass? False
