In [20]:
import random

In [21]:
def modinv(a, mod):
    return pow(a, mod-2, mod)

In [22]:
def prime_factors(n):
    i = 2
    factors = set()
    while i * i <= n:
        if n % i == 0:
            factors.add(i)
            while n % i == 0:
                n //= i
        i += 1
    if n > 1:
        factors.add(n)
    return factors

In [23]:
def find_generator(mod):
    phi = mod - 1
    factors = prime_factors(phi)
    for g in range(2, mod):
        ok = True
        for p in factors:
            if pow(g, phi//p, mod) == 1:
                ok = False
                break
        if ok:
            return g
    raise ValueError(f"No generator found for mod={mod}")

In [None]:
def bit_reverse(a):     # b1,b2,b3,b4 -> b4,b3,b2,b1
    n = len(a)
    j = 0
    for i in range(1, n):
        bit = n >> 1
        while j & bit:
            j ^= bit
            bit >>= 1
        j |= bit
        if i < j:
            a[i], a[j] = a[j], a[i]

In [25]:
def ntt(poly, root, mod):
    a = poly.copy()
    n = len(a)
    bit_reverse(a)
    print(f"Bit-reversed: {a}")
    length = 2
    while length <= n:
        wlen = pow(root, n // length, mod)
        print(f"Length = {length}, wlen = {wlen}")
        for start in range(0, n, length):
            w = 1
            for i in range(start, start + length // 2):
                u = a[i]
                v = (a[i + length//2] * w) % mod
                a[i] = (u + v) % mod
                a[i + length//2] = (u - v + mod) % mod
                w = (w * wlen) % mod
        length <<= 1
    return a

In [26]:
def intt(poly, inv_root, mod):
    n = len(poly)
    a = ntt(poly, inv_root, mod)
    inv_n = modinv(n, mod)
    return [(x * inv_n) % mod for x in a]

In [27]:
def naive_cyclic_convolution(a, b, mod):
    n = len(a)
    c = [0] * n
    for i in range(n):
        for j in range(n):
            c[(i + j) % n] = (c[(i + j) % n] + a[i] * b[j]) % mod
    return c

In [None]:
def generate_test_vectors(N, mod, num_tests=1):
    assert (mod - 1) % N == 0, "mod-1 must be divisible by N for an N-th root of unity"
    print(f"Generating {num_tests} test vectors (N={N}, mod={mod})")
    gen = find_generator(mod)
    print(f"Generator = {gen}")
    root = pow(gen, (mod-1)//N, mod)
    print(f'Root of unity = {root} (mod {mod})')
    inv_root = modinv(root, mod)
    print(f'Inverse root = {inv_root} (mod {mod})')
    for t in range(num_tests):
        a = [random.randrange(mod) for _ in range(N)]
        print(f'a = {a}')
        b = [random.randrange(mod) for _ in range(N)]
        print(f'b = {b}')
        A = ntt(a, root, mod)
        print(f'NTT(a) = {A}')
        B = ntt(b, root, mod)
        print(f'NTT(b) = {B}')
        C = [(A[i] * B[i]) % mod for i in range(N)]
        c = intt(C, inv_root, mod)
        print(f'Point-wise product C = {C}')
        print(f'INTT(C) = {c}')
        c_naive = naive_cyclic_convolution(a, b, mod)
        print(f'Naive cyclic convolution = {c_naive}')
        assert c == c_naive, f"Mismatch on test {t}"
        

In [38]:
N = 512       # polynomial degree (power of 2)
mod = 7681     # prime modulus where (mod-1) % N == 0
num_tests = 1 # number of test vectors to generate

generate_test_vectors(N, mod, num_tests)

Generating 1 test vectors (N=512, mod=7681)
Generator = 17
Root of unity = 7146 (mod 7681)
Inverse root = 7480 (mod 7681)
a = [2121, 4231, 2210, 2214, 6330, 1953, 4423, 2105, 4367, 1859, 758, 7178, 4983, 7038, 7145, 5896, 1994, 2826, 2159, 1834, 3923, 310, 3469, 756, 205, 4098, 7312, 3654, 3430, 1598, 1420, 7191, 167, 2693, 6952, 409, 4757, 2249, 3141, 6158, 3417, 2383, 7217, 1127, 1118, 1549, 7601, 2130, 7073, 2017, 2327, 395, 784, 1886, 3522, 3889, 1679, 1715, 5838, 2017, 7239, 1088, 6644, 5727, 4881, 5080, 6720, 4376, 1278, 551, 4985, 3530, 1228, 2070, 823, 3446, 3907, 272, 5197, 2859, 5042, 519, 4691, 183, 5489, 5533, 4980, 4350, 4069, 244, 1121, 1127, 60, 1264, 5285, 2779, 5907, 4676, 2113, 7635, 2555, 3288, 4438, 3030, 4124, 5221, 3865, 5818, 5971, 6288, 546, 3493, 4942, 573, 237, 5719, 4779, 1730, 1763, 3291, 6565, 648, 4222, 2426, 3463, 1152, 2561, 2719, 2279, 2997, 912, 4535, 5479, 2818, 3763, 3020, 33, 30, 1622, 406, 2287, 794, 942, 3872, 6907, 3375, 6837, 6277, 6704, 313, 20

In [None]:
# Some known NTT test vectors:# N = 16, mod = 17
# a = [8, 11, 3, 1, 1, 11, 13, 7, 10, 7, 15, 1, 14, 9, 11, 11]
# NTT(a) =[14, 1, 7, 3, 4, 0, 9, 7, 0, 11, 13, 7, 12, 1, 0, 5]

# b = [10, 12, 5, 9, 7, 14, 6, 2, 4, 6, 13, 15, 0, 16, 11, 16]
# NTT(b) =[10, 16, 2, 1, 13, 2, 4, 3, 0, 3, 4, 9, 10, 10, 1, 4]

# I used this link to verify: https://www.nayuki.io/page/number-theoretic-transform-integer-dft (leave M and w blank)