## Roots of Unity

In [154]:
def get_primitive_root_of_unity(F, n):
    
    # To check that candidate root is a primitive root we need to verify that:
    # For any a | n, candidate ^ a != 1
    # This is equivelent to:
    # For a, b: a * b = n, candidate ^ a != 1 and candidate ^ b != 1
    #
    # Therefor we will look for a first a | n, and then evaluate:
    # candidate ^ a != 1 and candidate ^ (n / a) != 1
    # If the checks pass, this is the primitive root

    # For power we assume to use quick power in O(log(n))
    # Here we assume to use number sieve algorythm 
    # And we look for a devider of n
    divider = n
    for i in range(2, n):
        if n % i == 0:
            divider = i

    
    
    # Get order of field F
    p = F.order()
    
    while True:
        
        # Generate a random x until we find primitive root
        x = F.random_element()
        
        # Root as: candidate ^ n = (x ^ ((p - 1) / n)) ^ n = x ^ (p - 1) = 1
        candidate = x ^ ((p - 1) / n)
        
        # We need to check if it is primitive
        if candidate == 1:
            continue
            
        
        # If n is prime, then any root is primitive
        if divider == n:
            return candidate
        
        if candidate ^ divider != 1 and candidate ^ (n / divider) != 1:
            return candidate

In [181]:
def get_nth_roots_of_unity(n, primitive_w):
    
    # We assume to have primitive root of unity 
    # To get the other n - 1 roots we just exponenciate the root:
    
    w_arr = [1]
    
    cur_w = 1
    for i in range(n - 1):
        cur_w *= primitive_w
        w_arr += [cur_w]
        
    return w_arr
        

In [182]:
q = 41
F = GF(q)
n = 4

In [183]:
primitive_w = get_primitive_root_of_unity(F, n)
print("primitive_w:", primitive_w)

primitive_w: 32


In [184]:
w_arr = get_nth_roots_of_unity(n, primitive_w)
print(f"{n}th roots of unity: {w_arr}")

4th roots of unity: [1, 32, 40, 9]


## Fast Fourier Transform

In [185]:
w = get_primitive_root_of_unity(F, n)

In [205]:
def ft_setup(n, w):
    
    w_arr = get_nth_roots_of_unity(n, w)

    
    acc = [1 for i in range(n)]
    matrix_elms = [acc]
    
    for i in range(n - 1):
        acc = [acc[i] * w_arr[i] for i in range(n)]
        matrix_elms += [acc]
        
    F = matrix(matrix_elms)
    
    w_inv = ~w
    w_inv_arr = get_nth_roots_of_unity(n, w_inv)
    
    acc = [1 for i in range(n)]
    inv_matrix_elms = [acc]
    
    for i in range(n - 1):
        acc = [acc[i] * w_inv_arr[i] for i in range(n)]
        inv_matrix_elms += [acc]
        
        
    G = matrix(inv_matrix_elms) / n
    assert(G == ~F * n, "Failed in calculating F matrix inverse.")
    
    return w_arr, F, G

In [206]:
w_arr, FT, IFT = ft_setup(n, w)

In [207]:
print("nth roots of unity:", w_arr)
print("Vandermonde matrix:")
print(FT)
print("Inverse Vandermonde matrix:")
print(IFT)

nth roots of unity: [1, 9, 40, 32]
Vandermonde matrix:
[ 1  1  1  1]
[ 1  9 40 32]
[ 1 40  1 40]
[ 1 32 40  9]
Inverse Vandermonde matrix:
[31 31 31 31]
[31  8 10 33]
[31 10 31 10]
[31 33 10  8]


In [208]:
def ft(fa_coef, w):
    
    n = len(fa_coef)
    
    w_arr = get_nth_roots_of_unity(n, w)

    
    acc = [1 for i in range(n)]
    matrix_elms = [acc]
    
    for i in range(n - 1):
        acc = [acc[i] * w_arr[i] for i in range(n)]
        matrix_elms += [acc]
        
    F = matrix(matrix_elms)
        
    return F * fa_coef


def ft_inv(fa_eval, w):
    
    n = len(fa_eval)
    
    # We do matrix inversion manually to save computation 
    # As well as because the formula of the inverse is availble
    w_inv = ~w
    w_inv_arr = get_nth_roots_of_unity(n, w_inv)
    
    acc = [1 for i in range(n)]
    inv_matrix_elms = [acc]
    
    for i in range(n - 1):
        acc = [acc[i] * w_inv_arr[i] for i in range(n)]
        inv_matrix_elms += [acc]
        
        
    G = matrix(inv_matrix_elms) / n
    
        
    return G * fa_eval

In [209]:
# We define a polynomial by its evaluations:
fa_eval = vector([9,4,5,9])
print("fa_eval:", fa_eval)

fa_eval: (9, 4, 5, 9)


In [210]:
# we can interpolate it to the coefficients of f_a(x) by
# multiplying it by ft_inv:
fa_coef = ft_inv(fa_eval, w)

In [211]:
# we can now create a polynomial from fa coefficients:
P.<x> = PolynomialRing(F)
fa = P(list(fa_coef))
# and if we print fa we get: 19*x^3 + 9*x^2 + 21*x + 36
print("f_a(x):", fa)

f_a(x): 31*x^2 + 2*x + 17


In [212]:
# we can check that evaluating fa(x) at the roots of
# unity returns the expected values of a
for i in range(len(fa_eval)):
    assert fa(w_arr[i]) == fa_eval[i]
print("Interpolated polynomial evaluation check passed.")

Interpolated polynomial evaluation check passed.


In [213]:
# We can also go in the opposite direction: from coefficient form
# to evaluation form (from fa(x) to fa_eval):
fa_eval_rec = ft(fa_coef, w)

In [214]:
# and it should get us to the same values than our original fa_eval vector:
assert fa_eval == fa_eval_rec
print("Interpolated polynomial coeficient check passed.")

Interpolated polynomial coeficient check passed.


In [280]:
def fft(coef, w):
    
    n = len(coef)
    
    if n == 1:
        return vector(coef)
    
    if n == 2:
        return vector([coef[0] + coef[1], coef[0] - coef[1]])
    
    if n % 2 == 0:
        odd_coef = coef[1::2]
        even_coef = coef[0::2]
        
        odd_ev = fft(odd_coef, w ^ 2)
        even_ev = fft(even_coef, w ^ 2)
        
        w_parr = vector(get_nth_roots_of_unity(n / 2, w))
        
        w_kd2 = w_parr[-1] * w
        
        p_one = even_ev + w_parr.pairwise_product(odd_ev)
        p_two = even_ev + (vector(w_parr) * w_kd2).pairwise_product(odd_ev)
        
        return vector(list(p_one) + list(p_two))
    
    if n % 2 != 0:
        
        return vector(ft(coef, w))
    
def fft_inv(coef, w):
    
    # HOW TO IMPLEMENT
    
#     n = len(coef)
    
#     if n == 1:
#         return vector(coef)
    
#     if n == 2:
#         return vector([coef[0] + coef[1], coef[0] - coef[1]])
    
#     if n % 2 == 0:
#         odd_coef = coef[1::2]
#         even_coef = coef[0::2]
        
#         odd_ev = fft(odd_coef, w ^ 2)
#         even_ev = fft(even_coef, w ^ 2)
        
#         w_parr = vector(get_nth_roots_of_unity(n / 2, ~w))
        
#         w_kd2 = w_parr[-1] * (~w)
        
#         p_one = even_ev - w_parr.pairwise_product(odd_ev)
#         p_two = even_ev - (vector(w_parr) * w_kd2).pairwise_product(odd_ev)
        
#         return vector(list(p_one) + list(p_two))
    
#     if n % 2 != 0:
        
#         return vector(ft(coef, w))

In [281]:
fa_coef

(17, 2, 31, 0)

In [282]:
fft_inv(fft(fa_coef, w), w)

(1, 0, 27, 8)

In [253]:
fa_eval

(9, 4, 5, 9)

In [243]:
a = [1, 2, 3, 4]
a_e = a[1::2]
a_e

[2, 4]

In [131]:
get_nth_roots_of_unity(2, primitive_w)

[1, 9]