In [None]:
# Mixed radix recursion
import itertools as itt
from collections import deque

def MixedRadixPrecomputation(N, modulus_list):
    res = []
    factorization = N.factor()
    radix_list = [ item[0] for item in factorization.radical() ]
    radical = factorization.radical_value()
    coprime_list = list( itt.starmap(pow, factorization))
    idem_vector = vector(Zmod(N), CRT_basis(coprime_list) )
    invRoot_list = [ inverse_mod(radix, modulus_list[-1]) for radix in radix_list ]
    invRad_list = [ vector([inverse_mod(radix1, radix2) if radix1 != radix2 else 1
                   for radix2 in radix_list ]) for radix1 in radix_list ]
    
    return (radix_list, radical, coprime_list, idem_vector, invRoot_list, invRad_list)

    return res

def MixedRadixFourierPrecomputation(p_list, rt_list, inv_list):
    config_list = deque()
    rev_inv_config_list = deque()
    p_length = len(p_list)
    
    L = prod(p_list)
    Perm = None
    inv_Perm = None

    N_list = [1] * (p_length + 1)
    for i in range(p_length):
        N_list[i+1] = N_list[i] * p_list[i]

    M_list = [1] * (p_length + 1)
    for i in range(1, p_length + 1):
        M_list[i] = M_list[i-1] * p_list[-i]
    M_list.reverse()

    match p_length:
        case 0:
            raise ValueError("The list of primes should not be empty")
        case 1:
            raise NotImplemented("Trivial case")
        case 2:
            raise NotImplemented("Trivial case")
        case _:
            Q_cur = range(p_list[0])
            Q_cur_inv = range(p_list[0])
            Perm = range(L)
            Perm_inv = range(L)
            
            Q_next, Q_next_inv = TensorInterchange(p_list[1], N_list[1])
            cur_node = 1
            
            for ptr in range(p_length - 1):
                
                inv_i = inv_list[ptr]
                p_i = p_list[ptr]
                N_i = N_list[ptr]
                M_i = M_list[ptr]
                N_ii = N_list[ptr + 1]
                M_ii = M_list[ptr + 1]
                M_iii = M_list[ptr + 2]
                N_iii = N_list[ptr + 2]
                
                E1, E2, P, P_inv = ThomasGoodPerm( M_ii, p_i)
                
                index = cur_node * E2 % L                
                Vand = matrix( p_i, p_i, lambda i,j: rt_list[ i * j * index % L] )
                Vand_inv = matrix( p_i, p_i, lambda i,j: inv_i * rt_list[ (- i * j * index) % L] )

                config_list.append( ('v', p_i, Vand) )
                rev_inv_config_list.append( ('v', p_i, Vand_inv) )


                Q = CompositionPerm(L,
                                    TensorProdPerm(M_iii, N_iii, range(M_iii), Q_next),
                                    TensorProdPerm(M_ii, N_ii, range(M_ii), Q_cur_inv) )
                Q_inv = InversePerm(L, Q)
                config_list.append( ('p', Q) )
                rev_inv_config_list.append( ('p', Q_inv) )

                # update permutations
                Perm = CompositionPerm(L, TensorProdPerm(M_i, N_i, P, range(N_i)), 
                                       Perm)
                Perm_inv = CompositionPerm(L, Perm_inv, 
                                           TensorProdPerm(M_i, N_i, P_inv, range(N_i)) )
                cur_node = cur_node * E1 % L
                Q_cur = Q_next
                Q_cur_inv = Q_next_inv
                if ptr != p_length - 2:
                    Q_next, Q_next_inv = TensorInterchange(p_list[ptr + 2],  N_iii)
                    
            p_i = p_list[-1]
            inv_i = inv_list[-1]
            Vand = matrix( p_i, p_i, lambda i,j: rt_list[ i * j * cur_node % L] )
            Vand_inv = matrix( p_i, p_i, lambda i,j: inv_i * rt_list[ (- i * j * cur_node) % L] )
            config_list.append( ('v', p_i, Vand) )
            config_list.append( ('p', Q_cur_inv) )
            config_list.appendleft( ('p', Perm) )
            config_list.append(('mp'))
            
            rev_inv_config_list.append( ('v', p_i, Vand_inv) )
            rev_inv_config_list.append( ('p', Q_cur) )
            rev_inv_config_list.appendleft( ('p', Perm_inv) )
            rev_inv_config_list.reverse()

            config_list.extend(rev_inv_config_list)
            return config_list    

def MultGeneric2(conf, vec, modulus_list, lazy=False):
    length = conf[0]
    M = conf[1]
    res = []
    for vec_slice in grouper(vec, length):
        tmp = M @ np.array(vec_slice, dtype=object)
        if not lazy:
            for modulus in modulus_list:
                tmp = tmp % modulus
        res.extend(list(tmp))
    return res
            

def MixedRadixBaseCase2(row, vec, total_length, modulus_list, conf_list, lazy=False):
    col = [ row[-i] for i in range(total_length) ]
    merged = False
    # Implement strategy in conf_list
    for conf in conf_list:
        match conf[0]:
            case 'p':
                vec = [ vec[ conf[1][i] ] for i in range(total_length) ]
                if not merged:
                    col = [ col[ conf[1][i] ] for i in range(total_length) ]
            case 'v':
                vec = MultGeneric2(conf[1:], vec, modulus_list, lazy=lazy)
                if not merged:
                    col = MultGeneric2(conf[1:], col, modulus_list, lazy=lazy)
            case 'mp':
                for i in range(total_length):
                    vec[i] = vec[i] * col[i]
                    if not lazy:
                        for modulus in modulus_list:
                            vec[i] = vec[i] % modulus
                merged = True
            case _:
                raise NotImplementedError("Configuration mode not implemented")
    if lazy:
        for i in range(total_length):
            for modulus in modulus_list:
                vec[i] = vec[i] % modulus
    return vec

def MixedRadixBaseCase1(row, vec, length, f, modulus_list,aux):
    assert len(vec) == length

    res = []
    for m in range(length):
        acc = 0
        for n in range(length):
            tmp = row[n - m] * vec[n]
            for modulus in modulus_list[:-1]:
                tmp = tmp % modulus
            if m > n:
                tmp = tmp * f
                for modulus in modulus_list[:-1]:
                    tmp = tmp % modulus
            acc += tmp
        acc = acc % modulus_list[-1]
        res.append(acc)
    return res
    
    r'''
    M = matrix(length, length, lambda i,j: row[j-i] if i <= j else f * first_row[length + j - i])
    res = M @ np.array(vec)
    for modulus in modulus_list:
        res = res % modulus
    return list(res)
    '''
    
def MixedRadixRecursion( 
    first_row, vec, modulus_list, N, rt_list, aux1, aux2, cur_index=0, cur_inv_vector=None,nodeVal_vector = None, length=None,
    upperMost=True, lazy=False):

    radix_list, radical, coprime_list, idem_vector, invRoot_list, invRad_list = aux
    radix_list_len = len(radix_list)

    #r'''
    if upperMost and length is None:
        length = N
        nodeVal_vector = vector([0] * radix_list_len)
        cur_inv_vector = invRad_list[cur_index]

    # Base Case:
    if length == radical:
        f = rt_list[ idem_vector * nodeVal_vector ]
        return MixedRadixBaseCase1(first_row, vec, length, f, modulus_list, [])
    # Recursion step

    else:
        radix = radix_list[cur_index]
        while length % radix != 0:
            cur_index += 1
            radix = radix_list[cur_index]
            cur_inv_vector = invRad_list[cur_index]
            
        stride = length // radix
        slice_row = [ first_row[ i * stride: (i+1) * stride] for i in range(radix) ]
        slice_vec = [ vec[ i * stride: (i+1) * stride] for i in range(radix) ]

        recur_list = []
        root = cur_inv_vector.pairwise_product(nodeVal_vector)
        root[cur_index] = root[cur_index] // radix
        zeta = vector([0] * radix_list_len)
        zeta[cur_index] = coprime_list[cur_index] // radix

        for m in range(radix):
            new_nodeVal_vector = root + m * zeta
            
            new_row = [0] * stride
            new_vec = [0] * stride

            for n in range(radix):
                index1 = (n * root + m * n * zeta) * idem_vector
                index2 = ((radix - n - 1) * root - m * n * zeta) * idem_vector

                for i in range(stride):
                    new_row[i] = new_row[i] + slice_row[n][i] * rt_list[index1]
                    new_vec[i] = new_vec[i] + slice_vec[n][i] * rt_list[index2]
                    if not lazy:
                        for modulus in modulus_list:
                            new_row[i] = new_row[i] % modulus
                            new_vec[i] = new_vec[i] % modulus
            
            recur_list.append( MixedRadixRecursion( 
                            new_row, new_vec, modulus_list, N, rt_list, aux1, aux2, cur_index=cur_index, 
                            cur_inv_vector=cur_inv_vector, nodeVal_vector = new_nodeVal_vector, length=stride,
                            upperMost=False, lazy=False) )
            
        inv = invRoot_list[cur_index]
        res = []
        for m in range(radix):
            res_tmp = [0] * stride
            
            for n in range(radix):
                index3 = (m * n * zeta - root * (radix - m - 1) ) * idem_vector

                for i in range(stride):
                    res_tmp[i] = res_tmp[i] + recur_list[n][i] * inv * rt_list[index3]
                    if not lazy:
                        for modulus in modulus_list:
                            res_tmp[i] = res_tmp[i] % modulus
            res.extend(res_tmp)
        return res  
    #'''

    r'''
    # Some precomputations
    if upperMost:
        vec = np.array(vec, dtype=object)
        nodeVal_vector = vector([0] * radix_list_len)
        length = N
        cur_inv_vector = vector([ inverse_mod(radix_list[cur_index], radix_list[i]) if i != cur_index
                                 else 1 for i in range(radix_list_len) ] )
        first_row = np.array(first_row, dtype=object)
            
    # Base Case: Simply do a matrix vector product
    if length == radical:
        f = rt_list[ idem_vector * nodeVal_vector ]
        M = matrix(length, length, lambda i,j: first_row[j-i] if i <= j else f * first_row[length + j - i])
        res = M @ vec
        if not lazy or upperMost:
            for modulus in modulus_list:
                res = res % modulus
        if upperMost:
            return list(res)
        else:
            return res
    # Do the recursion
    else:
        radix = radix_list[cur_index]
        while length % radix != 0:
            cur_index += 1
            radix = radix_list[cur_index]
            cur_inv_vector = [ inverse_mod(radix_list[cur_index], radix_list[i]) if i != cur_index
                                 else 1 for i in range(radix_list_len) ]
            
        stride = length // radix
        binned_row = np.reshape(first_row, (radix, stride))
        binned_vec = np.reshape(vec, (radix, stride))

        root = vector([ cur_inv_vector[i] * nodeVal_vector[i] if i != cur_index 
                          else ZZ(nodeVal_vector[i] / radix) for i in range(radix_list_len)])
        zeta = vector([ coprime_list[i] // radix if i == cur_index else 0 for i in range(radix_list_len) ])
        
        M1 = matrix(radix, radix, lambda m, n: rt_list[
                    (n * root + m * n * zeta).inner_product(idem_vector)] )
        M2 = matrix(radix, radix, lambda m,n: rt_list[
                    ((radix - n - 1) * root - m * n * zeta).inner_product(idem_vector)] )

        iterator = zip( range(radix), M1 @ binned_row, M2 @ binned_vec)
        for m, new_row, new_vec in iterator:
            new_nodeVal_vector = root + m * zeta
            if not lazy:
                for modulus in modulus_list:
                    new_row = new_row % modulus
                    new_vec = new_vec % modulus
            tmp_res = np.reshape( MixedRadixRecursion( 
                            new_row, new_vec, modulus_list, N, rt_list, aux, cur_index=cur_index, 
                            cur_inv_vector=cur_inv_vector, nodeVal_vector = new_nodeVal_vector, length=stride,
                            upperMost=False, lazy=False),
                (1, stride) )
            if m == 0:
                inter = tmp_res
            else:
                inter = np.append(inter, tmp_res, axis=0)
                
        inv = invRoot_list[cur_index]
        M3 = matrix(radix, radix, lambda m, n: inv * rt_list[
                    (m * n * zeta - root * (radix - m - 1)).inner_product(idem_vector)] )
        res = np.reshape(M3 @ inter, -1)
        if not lazy or upperMost:
            for modulus in modulus_list:
                res = res % modulus
        if upperMost:
            return list(res)
        else:
            return res 
    '''

### Unit (correctness) Test

In [None]:
# Preprocesing for correctness & timing test
orig_mod = 2^8
orig_length = 1500
prime = 2
prime_exponent = 8
eff_degree = 18
eff_length = 3^3 * 7 * 19
R = Zmod(2^8)

F22 = GF(prime^eff_degree, name='a', modulus='primitive')
prim_poly = ZZx( F22.modulus() )
mod_poly = HenselLiftPrim(prime, prim_poly, prime_exponent)
modulus_list = [ mod_poly, orig_mod]
length = 189 * 19

rt = RaiseToPowerMod_general(x, 73, modulus_list) 
rt_list = BuildRoot_general(rt, modulus_list, length)
aux = MixedRadixPrecomputation(length, modulus_list)

print("Done")

In [None]:
# Executing correctness test
row = [ randint(0, orig_mod - 1) for _ in range(eff_length) ]
vec = [ randint(0, orig_mod - 1) for _ in range(eff_length) ]
M = matrix.circulant(row)

res1 = [ ZZ(item) for item in M * vector(R, vec) ]
res2 = MixedRadixRecursion(row, vec, modulus_list, length, rt_list, aux)
print("Unit test passed: {}".format(res1 == res2) )

### Batch timing test

In [None]:
# Executing timing test
total = 100

row = [ randint(0, orig_mod - 1) for _ in range(orig_length) ]
padded_row = PadCircMat (row, orig_length, eff_length)
r'''
M = matrix.circulant(row)
test_list = [ [randint(0, orig_mod - 1) for _ in range(orig_length)] for _ in range(total)]
# Timing for the naive method
start1 = time.time()
for i in range(total):
    res1 = [ZZ(item) for item in M * vector(R, test_list[i])]
end1 = time.time()
'''
# Timing for the Mixed radix method
start2 = time.time()
for i in range(total): 
    new_vec = PadInput(test_list[i], orig_length, eff_length)
    res2 = MixedRadixRecursion(padded_row, new_vec, modulus_list, eff_length, rt_list, aux)[: orig_length]
end2 = time.time()

print ( "Naive method, avg time: {:.10} ms".format(float( ( end1 - start1) / total * 1000)))
print ( "Mixed-radix recursion method, avg time: {:.10} ms".format(float( ( end2 - start2) / total * 1000)))

# Our Generalized NTT method
Requires Sagemath

Handles length \\( L = q_1^e * q_2 * q_3 \dots \\) where \\(q_1, q_2 \dots \\) are small primes

In [1]:
# Some preprocessing function
%display latex
import itertools as itt
import numpy as np
from collections import deque
ZZx.<x> = PolynomialRing(ZZ, 'x', implementation='FLINT')

def grouper(iterable, n, *, incomplete='strict', fillvalue=None):
    "Collect data into non-overlapping fixed-length chunks or blocks."
    # grouper('ABCDEFG', 3, fillvalue='x') → ABC DEF Gxx
    # grouper('ABCDEFG', 3, incomplete='strict') → ABC DEF ValueError
    # grouper('ABCDEFG', 3, incomplete='ignore') → ABC DEF
    iterators = [iter(iterable)] * n
    match incomplete:
        case 'fill':
            return itt.zip_longest(*iterators, fillvalue=fillvalue)
        case 'strict':
            return zip(*iterators, strict=True)
        case 'ignore':
            return zip(*iterators)
        case _:
            raise ValueError('Expected fill, strict, or ignore')

def HenselLiftPrim(p, poly, exponent):
    # For a reference:
    # https://www.irishmathsoc.org/bull47/R4701.pdf
    # Also equivalent to adams_operator_on_roots method in Sagemath
    # Essentially, we compute the resultant R_y ( y - x^n, poly(x) )
    # whose roots are the nth power of poly(x) in ZZ[x]
    '''
    Given poly, an irreducible factor of x^k - 1 mod p
    Hensel lift poly to p^exponent, such that the result satisfies:
    1. result = poly (mod p)
    2. result is an  irreducible factor of x^k - 1 mod p^exponent
    '''
    if not is_prime(p):
        raise ValueError("Please provide a prime p")
        return -1

    if p == 2:
        p_recur = p
        poly_recur = poly
        for _ in range(exponent - 1):
            p_recur *= p
            tmp = poly_recur * poly_recur.subs(x = -x)  # form f(x) = poly(x) * poly(-x), return f(x^1/2)
            tmp = tmp * tmp.lc()
            poly_recur = ZZx ( [ item % p_recur for item in list(tmp)[::p] ] )
        return poly_recur
    else:
        p_recur = p
        for _ in range(exponent - 1):
            p_recur *= p
            poly = poly.adams_operator_on_roots(p, monic=True)
            poly = poly.map_coefficients(lambda i : i % p_recur)
        return poly

# Some basic operations on Permutations
def TensorInterchange(M, N):
    '''
    return the permutation P and P^-1 such that
    if dim(A) = M, dim(B) = N are 2 square matrices,
    then A tensor B = Q^-1 (B tensor A) Q
    '''
    pair = [ [i + j * M, j + i * N] for j in range(N) for i in range(M)]
    Q = [ item[1] for item in pair]
    inv_Q = [ item[0] for item in sorted(pair, key=lambda i: i[1]) ]
    return (Q, inv_Q)

def TensorProdPerm(M, N, perm1, perm2):
    '''
    return the list representation of the permutation perm1 tensor perm2
    perm1 has length M, perm2 has length N
    '''
    if len(perm1) != M:
        raise ValueError("perm1 has incompatible length")
        return -1
    if len(perm2) != N:
        raise ValueError("perm2 has incompatible length")
        return -1
    return [ perm1[i] * N + perm2[j] for i in range(M) for j in range(N)]

def CompositionPerm(M, perm1, perm2):
    '''
    return the composition perm2 * perm1, both having length M
    '''
    if len(perm1) != M or len(perm2) != M:
        raise ValueError("Input lengths are incompatible")
        return -1
    return[ perm2[ perm1[i] ] for i in range(M) ]

def InversePerm(M, perm):
    '''
    return the inverse permutation of perm, having length M
    '''
    if len(perm) != M:
        raise ValueError("Input lengths are incompatible")
        return -1
    tmp = zip(perm, range(M))
    return [ item[1] for item in sorted(tmp, key=lambda i: i[0]) ]

# Calculate the permutation matrix associated with Thomas-Good Prime Factorization FFT
def ThomasGoodPerm(M, N):
    '''
    N is the number of block and M is the block-length
    '''
    if gcd(M, N) != 1:
        raise ValueError("M and N must be relatively prime")
        return -1
    prod = M*N
    # E1 =1 mod M, E1 = 0 mod N, similarly for E2
    E1, E2 = crt_basis([M, N])
    # Make every idenpotent positive
    if E1 < 0:
        E1 = prod + E1
    if E2 < 0:
        E2 = prod + E2
    pair = [ (N * i + j, (i*E1 + j*E2) % prod) for i in range(M) for j in range(N) ]
    Q = [ item[1] for item in pair]
    inv_Q = [ item[0] for item in sorted(pair, key=lambda i: i[1]) ]
    return (E1, E2, Q, inv_Q)


# Some useful functions related to the O(NlogN) recursion algorithm
def RaiseToPowerMod_general(base, exponent, modulus_list):
    '''
    Babystep-Giantstep implementation of raising a base to exponent power,
    quotienting all modulus in modulus_list
    Pay attention: items in modulus_list must appear in the right order: e.x. [poly, prime] instead of [prime, poly]
    '''
    if len(modulus_list) < 1:
        raise IndexError("modulus_list must have length > 1")
        return False
    # bit-decompose the exponent into binary
    bin2List = exponent.digits(2)

    # Repeatedly square the base
    babyGiantList = [base]
    for _ in range(len(bin2List) - 1):
        tmp = babyGiantList[-1]^2
        for modulus in modulus_list:
            tmp = tmp % modulus
        babyGiantList.append(tmp)

    # Finally, accumulate the result according to the binary decomposition of exponent
    res = 1
    for i in range(len(bin2List)):
        if bin2List[i]:
            res = res * babyGiantList[i]
            for modulus in modulus_list:
                res = res % modulus
    return res

def BuildRoot_general(rt, modulus_list, L):
    '''
    Input:
        1. rt: a primitive Lth root of unity
        2. modulus_list: a list of reduction modulus
    Output the list of all the Lth root of unity
    '''
    res = [1] * L
    for i in range(1, L):
        tmp =  res[i - 1] * rt
        for modulus in modulus_list:
            tmp = tmp % modulus
        res[i] = tmp
    return res

def PrimePowerRecursion( 
    first_row, vec, modulus_list, N, radix, inv, rt_lst, *,
    is_col = False, length=None, node_value = 0, upperMost = True, lazy=False):
    '''
    The core recursion algorithm for prime power length L=q^m
    Input:
    1. First_row, vec: First row of the circulant matrix and input vector
    2. modulus_list: a list of reduction modulus
    3. N, radix: length of input vector, which is a power of radix
    4. inv: radix^(-1)
    5. rt_lst: a list of powers of the Nth principal roots of unity
    6. lazy: Whether or not the algorithm reduce mod after each arithmetic operation
    7. The rest: For recursion only
    Output:
    Matrix vector product
    '''

    if upperMost and length is None:
        length = N
    if is_col:
        first_row = [first_row[-i] for i in range(N) ]

    # Base Case: Simply do a matrix vector product
    if length == radix:
        f = rt_lst[node_value]
        res = [0] * length
        for m in range(length):
            res[m] = sum( first_row[n-m] * vec[n] if m<= n else
                         f * first_row[n-m+length] * vec[n] for n in range(length) )
            if not lazy or upperMost:
                for modulus in modulus_list:
                    res[m] = res[m] % modulus
        return res
    # Recursion step
    else:   
        stride = length // radix
        slice_row = [ first_row[ i * stride: (i+1) * stride] for i in range(radix) ]
        slice_vec = [ vec[ i * stride: (i+1) * stride] for i in range(radix) ]

        recur_list = []
        for m in range(radix):
            new_node = (node_value + m * N) // radix

            index1 = -new_node
            index2 = node_value + m * N // radix
            new_row = [0] * stride
            new_vec = [0] * stride

            for n in range(radix):
                index1 = (index1 + new_node) % N
                index2 = (index2 - new_node) % N

                for i in range(stride):
                    new_row[i] = new_row[i] + slice_row[n][i] * rt_lst[index1]
                    new_vec[i] = new_vec[i] + slice_vec[n][i] * rt_lst[index2]
                    if not lazy:
                        for modulus in modulus_list:
                            new_row[i] = new_row[i] % modulus
                            new_vec[i] = new_vec[i] % modulus
            
            recur_list.append( PrimePowerRecursion
            (new_row, new_vec, modulus_list, N, radix, inv, rt_lst, 
             is_col=False, length=stride, node_value=new_node, upperMost=False, lazy=lazy) )
        
        res = []
        for m in range(radix):
            index3 = - (m * N + node_value * (radix - m - 1)) // radix
            incr3 = m * N // radix
            res_tmp = [0] * stride
            
            for n in range(radix):
                index3 = (index3 + incr3) % N

                for i in range(stride):
                    res_tmp[i] = res_tmp[i] + recur_list[n][i] * inv * rt_lst[index3]
                    if not lazy:
                        for modulus in modulus_list:
                            res_tmp[i] = res_tmp[i] % modulus
            res.extend(res_tmp)
        return res

# pseudo-vectorized implementation of matrix vector product, where the matrix is the vandermonde matrix
# of a suitable root of unity

def MultGeneric3(conf, vec, modulus_list, lazy=False):
    '''
    Input:
    1. conf: a tuple of precomputation data, consisting of 
        a) inverse (True/False), whether we are doing FFT or inverse FFT
        b) rt_lst, a list of roots of unity
        c) (length, stride) : length is the FFT dimension, the input vec consists of a 2 dimensional array
            with shape (length, stride)
    2. input vec and the list of reduction modulus
    Output:
    FFT (Inv-FFT) matrix times the input vec
    '''
    Mat = conf[0]
    vec = np.array(vec, dtype=object)
    vec = np.reshape(vec, conf[1])
    res = (Mat @ vec).flatten()

    if not lazy:
        for modulus in modulus_list:
            res = res % modulus

    return list(res)

In [2]:
# Generate precomputation data
def FFT_precomputation_general(factorization, total_length, rt_list, modulus_list):
    '''
    Generate precomputation configuration data
    Input:
    1. Factorization: a tuple ((q_1,e), q_2, q_3, ...)
    2. total_lengt: The length of the input N
    3. rt_list: Powers of the Nth root of unity
    4. modulus_list: a list of reduction modulus
    Output:
    Precomputation configuration data
    '''
    q_1, e_1 = factorization[0]
    partial_length = q_1 ^ e_1
       
    q_list = list(reversed(factorization[1:]))
    inv_list = [ inverse_mod(q, modulus_list[-1]) for q in q_list]

    config_list = deque()
    rev_inv_config_list = deque()
    q_length = len(q_list)

    N_list = [1] * (q_length + 1)
    for i in range(q_length):
        N_list[i+1] = N_list[i] * q_list[i]

    M_list = [partial_length] * (q_length + 1)
    for i in range(1, q_length + 1):
        M_list[i] = M_list[i-1] * q_list[-i]
    M_list.reverse()

    match q_length:
        case 0:
            config_list.append(('r', rt_list, (q_1, e_1)))
            return config_list
        case 1:         
            raise NotImplemented("See FFT1")
        case 2:
            raise NotImplemented("See FFT2")
        case _:
            Q_cur = range(1)
            Q_cur_inv = range(1)
            Perm = range(total_length)
            
            Q_next, Q_next_inv = TensorInterchange(N_list[1], q_list[1])
            cur_node = 1
            
            for ptr in range(q_length - 1):

                q_i = q_list[ptr]
                M_i = M_list[ptr]
                
                E1, E2, P, _ = ThomasGoodPerm( q_list[ptr], M_list[ptr+1])
                
                index = cur_node * E1 % total_length               
                Vand = matrix( q_list[ptr], q_list[ptr], 
                              lambda i,j: rt_list[ i * j * index % total_length] )
                Vand_inv = matrix( q_list[ptr], q_list[ptr], 
                                  lambda i,j: inv_list[ptr] * rt_list[ (- i * j * index) % total_length] )

                #assert Vand * Vand_inv % modulus_list[-1] == matrix.identity(q_list[ptr])
                config_list.append( ('v', Vand, (q_list[ptr], total_length // q_list[ptr])) )
                rev_inv_config_list.append( ('v', Vand_inv, (q_list[ptr], total_length // q_list[ptr])) )

                Q = TensorProdPerm(N_list[ptr + 2], M_list[ptr+2] , Q_next, range(M_list[ptr + 2]))
                if ptr != 0:
                    Q = CompositionPerm(total_length, Q, 
                                        TensorProdPerm(N_list[ptr + 1] ,M_list[ptr + 1], Q_cur_inv, range(M_list[ptr + 1])))

                Q_inv = InversePerm(total_length, Q)
                
                config_list.append( ('p', Q) )
                rev_inv_config_list.append( ('p', Q_inv) )

                # update permutations
                Perm = CompositionPerm(total_length, TensorProdPerm(N_list[ptr], M_list[ptr], range(N_list[ptr]), P), 
                                       Perm)
                cur_node = cur_node * E2 % total_length
                Q_cur = Q_next
                Q_cur_inv = Q_next_inv
                if ptr != q_length - 2:
                    Q_next, Q_next_inv = TensorInterchange(N_list[ptr + 2], q_list[ptr + 2])

            ptr += 1
            Q_cur = TensorProdPerm(total_length // partial_length, partial_length,
                                   Q_cur, range(partial_length))
            Q_cur_inv = InversePerm(total_length, Q_cur)
            q_i = q_list[-1]
            inv_i = inv_list[-1]

            E1, E2, P, _ = ThomasGoodPerm( q_list[ptr], M_list[ptr+1])
            recur_rt_index = cur_node * E2 % total_length
            index = cur_node * E1 % total_length

            Perm = CompositionPerm(total_length, TensorProdPerm(N_list[ptr], M_list[ptr], range(N_list[ptr]), P), 
                                       Perm)
            Perm_inv = InversePerm(total_length, Perm)
            
            Vand = matrix( q_i, q_i, lambda i,j: rt_list[ i * j * index % total_length] )
            Vand_inv = matrix( q_i, q_i, lambda i,j: inv_i * rt_list[ (- i * j * index) % total_length] )
            #assert Vand * Vand_inv % modulus_list[-1] == matrix.identity(q_i)
            
            config_list.append( ('v', Vand, (q_i, total_length // q_i) ) )
            config_list.append( ('p', Q_cur_inv) )
            config_list.appendleft( ('p', Perm) )
            new_rt_list = [ rt_list[recur_rt_index * i % total_length] for i in range(partial_length) ]
            
            #assert new_rt_list[1] * new_rt_list[-1] % modulus_list[0] == 1
            config_list.append(('r', new_rt_list, (q_1, e_1)))
            rev_inv_config_list.append( ('v', Vand_inv, (q_i, total_length // q_i)) )
            rev_inv_config_list.append( ('p', Q_cur) )
            rev_inv_config_list.appendleft( ('p', Perm_inv) )
            rev_inv_config_list.reverse()
            config_list.extend(rev_inv_config_list)
            
            return config_list

In [3]:
# FFT execution
def FFT_execution2(first_col, vec, total_length, modulus_list, config, lazy=False):
    '''
    Execution of general NTT method
    Input:
    1. first_col, vec: The 2 input vectors
    2. total_length: length of the inputs N
    3. modulus_list: list of reduction modulus
    4. config: configuration data from FFT_precomputation_general
    5. lazy: Whether or not the algorithm reduce mod after each arithmetic operation
    Output:
    The circular convolution
    '''
    merged = False
    # Implement FFT strategy based on the config tuple
    for conf in config:
        match conf[0]:
            case 'p':
                vec = [ vec[ conf[1][i] ] for i in range(total_length) ]
                if not merged:
                    first_col = [ first_col[ conf[1][i] ] for i in range(total_length) ]
            case 'v':
                vec = MultGeneric3(conf[1:], vec, modulus_list, lazy=lazy)
                if not merged:
                    first_col = MultGeneric3(conf[1:], first_col, modulus_list, lazy=lazy)
            case 'r':
                #return first_col
                new_rt_lst = conf[1]
                radix, exponent = conf[2]
                length = radix^exponent
                inv = inverse_mod(radix, modulus_list[-1])
    
                # Recursion Algorithm
                res = []
                iterate = zip( grouper(first_col, length), grouper(vec, length) )
    
                res = []
                for col_slice, vec_slice in iterate:
                    res.extend( PrimePowerRecursion(
                        col_slice, vec_slice, modulus_list, length, radix, inv, new_rt_lst, is_col=True, lazy=lazy) )
                vec = res
                merged = True
            case _:
                raise NotImplementedError("Configuration mode not implemented")
    if lazy:
        for i in range(total_length):
            for modulus in modulus_list:
                vec[i] = vec[i] % modulus
    return vec

## Unit (correctness test)

In [None]:
# An example
# Preprocessing for correctness & timing test
orig_mod = 2^32
orig_length = 1500
prime = 2
prime_exponent = 32
eff_degree = 24
eff_length = 3^2 * 5 * 7 * 13 * 17
R = Zmod(orig_mod)

ZZx.<x> = ZZ[]
F22 = GF(prime^eff_degree, name='a', modulus='primitive')
prim_poly = ZZx( F22.modulus() )
mod_poly = HenselLiftPrim(prime, prim_poly, prime_exponent)
modulus_list = [ mod_poly, orig_mod]
rt = RaiseToPowerMod_general(x, 241, modulus_list)
rt_list = BuildRoot_general(rt, modulus_list, eff_length)
assert rt_list[1] * rt_list[-1] % mod_poly % orig_mod == 1

config_list = FFT_precomputation_general(((3,2),5,7,13,17), eff_length, rt_list, modulus_list)

print("Done")

In [None]:
# Executing correctness test
col = [randint(0, orig_mod - 1) for i in range(eff_length)]
row = [ col[-i] for i in range(eff_length) ]
vec = [ randint(0, orig_mod - 1) for _ in range(eff_length) ]

M = matrix.circulant(row)
%time res1 = res1 = list( M * vector(vec) % orig_mod)
%time res2 = FFT_execution2(col, vec, eff_length, modulus_list, config_list)

assert res1 == res2

# Timing Test

In [8]:
# Some preprocessing function for multimodular NTT
import itertools as itt
from sympy import ntt, intt, convolution

def FindEffLen (N, radix = 2):
    '''
    Find a suitable power of radix greater than N
    '''
    target = 2 * N - 1
    factor = ceil( log( RDF(target), radix) )
    return radix^factor

def PadCircMat (first_row, org_len, eff_len):
    '''
    Pad the circulant matrix to dimension eff_len
    '''
    pad0 = eff_len - 2 * org_len + 1
    return first_row + [ 0 ] * pad0 + first_row[1:]

def PadInput (vec, org_len, eff_len):
    '''
    Pad input vector to length eff_len
    '''
    pad0 = eff_len - org_len
    return vec + [ 0 ] * pad0

def ModListGen (orig_mod, eff_len, radix=2, full=False):
    '''
    Generate a list of modulus for the multimodular recursion
    '''
    lower = orig_mod^2 * eff_len
    prime_list = []
    if full:
        root_list = []
        inv_list = []
        
    modulus = 1
    k = 0
    candidate = 1

    while modulus < lower:
        k += 1
        candidate += eff_len
        if is_prime(candidate):
            prime_list.append(candidate)
            modulus *= candidate
            if full:
                gen = Integers(candidate).multiplicative_generator()
                root = gen^k           
                root_list.append( list(map(lambda i: ZZ(i), root.powers(eff_len // 2))) )
                inv_list.append( inverse_mod(radix, candidate) )

    if full:
        return (prime_list, root_list, inv_list)
    else:
        return prime_list

In [4]:
# Search for a big enough smooth divisor for p^m - 1
radix = 2
for i in range(1, 51):
    print("{}: {}".format(i, (radix^i - 1).factor() ) )

1: 1
2: 3
3: 7
4: 3 * 5
5: 31
6: 3^2 * 7
7: 127
8: 3 * 5 * 17
9: 7 * 73
10: 3 * 11 * 31
11: 23 * 89
12: 3^2 * 5 * 7 * 13
13: 8191
14: 3 * 43 * 127
15: 7 * 31 * 151
16: 3 * 5 * 17 * 257
17: 131071
18: 3^3 * 7 * 19 * 73
19: 524287
20: 3 * 5^2 * 11 * 31 * 41
21: 7^2 * 127 * 337
22: 3 * 23 * 89 * 683
23: 47 * 178481
24: 3^2 * 5 * 7 * 13 * 17 * 241
25: 31 * 601 * 1801
26: 3 * 2731 * 8191
27: 7 * 73 * 262657
28: 3 * 5 * 29 * 43 * 113 * 127
29: 233 * 1103 * 2089
30: 3^2 * 7 * 11 * 31 * 151 * 331
31: 2147483647
32: 3 * 5 * 17 * 257 * 65537
33: 7 * 23 * 89 * 599479
34: 3 * 43691 * 131071
35: 31 * 71 * 127 * 122921
36: 3^3 * 5 * 7 * 13 * 19 * 37 * 73 * 109
37: 223 * 616318177
38: 3 * 174763 * 524287
39: 7 * 79 * 8191 * 121369
40: 3 * 5^2 * 11 * 17 * 31 * 41 * 61681
41: 13367 * 164511353
42: 3^2 * 7^2 * 43 * 127 * 337 * 5419
43: 431 * 9719 * 2099863
44: 3 * 5 * 23 * 89 * 397 * 683 * 2113
45: 7 * 31 * 73 * 151 * 631 * 23311
46: 3 * 47 * 178481 * 2796203
47: 2351 * 4513 * 13264529
48: 3^2 * 5 * 7 *

In [13]:
# Timing Test
import time
# Need to manually change parameters based on setups
radix = 2
orig_mod = 2^32
orig_length = 2000

time1 = 0
time2 = 0
time3 = 0
time4 = 0
Total = 50

# Need to manually select a smooth divisor eff_length1 of radix^eff_degree - 1
exponent = 32
eff_degree = 12
eff_length1 = 3^2 * 5 * 7 * 13
rt_power = (radix^eff_degree - 1) // eff_length1
R = Zmod(orig_mod)

ZZx.<x> = ZZ[]
R1.<y> = PolynomialRing(Integers(orig_mod),'y', implementation="generic")
R2.<z> = PolynomialRing(Integers(orig_mod), 'z')

F22 = GF(radix^eff_degree, name='a', modulus='primitive')
prim_poly = ZZx( F22.modulus() )
mod_poly = HenselLiftPrim(radix, prim_poly, exponent)
modulus_list = [ mod_poly, orig_mod]

rt = RaiseToPowerMod_general(x, rt_power, modulus_list)
rt_list = BuildRoot_general(rt, modulus_list, eff_length1)
assert rt_list[1] * rt_list[-1] % mod_poly % orig_mod in (1, 1 - orig_mod)

# Need to manually input the first argument into the factorization form
config_list = FFT_precomputation_general(((3,2),5, 7, 13), eff_length1, rt_list, modulus_list)

eff_length2 = FindEffLen (orig_length, radix = 2)
mod_list = ModListGen(orig_mod, eff_length2)

for _ in range(Total):
    first_col = [randint(0, orig_mod - 1) for _ in range(orig_length)]
    test = [randint(0, orig_mod - 1) for _ in range(orig_length)]

    col1 = PadCircMat (first_col, orig_length, eff_length1)
    vec1 = PadInput (test, orig_length, eff_length1)

    # Direct method
    tmp = time.time()
    res1 = list(R1(first_col) * R1(test) % (y^orig_length - 1))
    time1 += time.time() - tmp

    # Our generalized NTT method
    tmp = time.time()
    res2 = FFT_execution2(col1, vec1, eff_length1, modulus_list, config_list)
    time2 += time.time() - tmp

    # Multimodular NTT method
    tmp = time.time()
    res3 = Multimodulus_Mult2(first_col, test, mod_list, orig_length, orig_mod)
    time3 += time.time() - tmp

    # Using Flint engine
    tmp = time.time()
    res4 = list(R2(first_col) * R2(test) % (z^orig_length - 1))
    time4 += time.time() - tmp

print("Direct: Avg {} ms".format( time1 * 1000 / Total))
print("Our method: Avg {} ms".format(time2 * 1000 / Total))
print("Multimodular NTT: Avg {} ms".format(time3 * 1000 / Total))
print("Flint: Avg {} ms".format(time4 * 1000 / Total))

Direct: Avg 608.2621049880981 ms
Our: Avg 971.5304327011108 ms
MultiFFT: Avg 112.38384246826172 ms
Flint: Avg 6.259965896606445 ms
