In [1]:
import numpy as np
import parameters as p
from sympy.ntheory import n_order
from numpy import int32

def n_order_verification():
    print(f"order is {n_order(p.n_root, p.prime)}, prime is {p.prime}")
    return n_order(p.n_root, p.prime) == p.prime-1

def calculate_omega():
    return int(p.n_root**((p.prime-1)/p.N) % p.prime)

# Generate ntt vector for a N-degree polynomial.    
def N_roots_of_unit(omega):
    A = np.ones(p.N, dtype='int32')
    for i in range(1,p.N):
        A[i] = A[i-1]*omega % p.prime
    return A

# Generate ntt matrix F(x,y) = omega^(xy); x,y in [0, N-1]
def ntt_mat(omega):
    F = np.ones((p.N,p.N), dtype='int32')
    for i in range(1, p.N):
        F[i][1] = F[i-1][1]*omega % p.prime 
        for j in range(1, p.N):
            F[i][j] = F[i][j-1]*F[i][1] % p.prime 
    return F

# Generate the multilinear extension A_F(r,y) of ntt matrix F(x,y)       
def Initilization(A, r):
    A_F = np.ones(p.N, dtype='int32')
    for i in range(p.bitwise):
        #inverse loop
        for j in range(2**(i+1)-1, -1, -1):   
            A_F[j] = A_F[j % (2**i)] * ((1-r[i])+r[i]*A[j*(p.N>>(i+1))]) % p.prime
    return A_F

In [2]:
def test_A_F(F, r):
    A_F = np.zeros(p.N, dtype='int32')
    for j in range(p.N):
        for i in range(p.N):
            v = np.array([int(char) for char in np.binary_repr(i, width=p.bitwise)])
            for idx,k in enumerate(v):
                F[i][j] = F[i][j] * ((1-k)*(1-r[idx]) + k*r[idx]) % p.prime
            A_F[j] = (A_F[j] + F[i][j]) % p.prime
    return A_F

In [3]:
def main():
    r = np.random.randint(30, size=p.bitwise)

    n_order_verification()
    omega = calculate_omega()
    print(f"polynomial degree is {p.N}, primitive root omega is {omega}")
    A = N_roots_of_unit(omega)
    print(f"ntt roots of unit vector is \n{A}")
    F = ntt_mat(omega)
    print(f"ntt roots of unit matrix is \n{F}")
    
    A_F = Initilization(A, r)
    print(f"our result is \n{A_F}")
    A_F_t = test_A_F(F, r)
    print(f"reference result is \n{A_F_t}")

    try:
        if not (A_F==A_F_t).all():
            raise ValueError('test failed, not equal !!!')
    except ValueError as e:
        print(str(e))

if __name__ == "__main__":
    main()


order is 768, prime is 769
polynomial degree is 256, primitive root omega is 562
ntt roots of unit vector is 
[  1 562 554 672  85  92 181 214 304 130   5 503 463 284 425 460 136 301
 751 650  25 208   8 651 587 762 680 736 679 174 125 271  40 179 628 734
 324 604 319 101 625 586 200 126  64 594  82 713  57 505  49 623 231 630
 320 663 410 489 285 218 245  39 386  74  62 239 512 138 656 321 456 195
 392 370 310 426 253 690 204  67 742 206 422 312  12 592 496 374 251 335
 634 261 572  22  60 653 173 332 486 137  94 536 553 110 300 189  96 122
 123 685 470 373 458 550 731 176 480 610 615 349  43 327 752 443 579 111
  93 743 768 207 215  97 684 677 588 555 465 639 764 266 306 485 344 309
 633 468  18 119 744 561 761 118 182   7  89  33  90 595 644 498 729 590
 141  35 445 165 450 668 144 183 569 643 705 175 687  56 712 264 720 146
 538 139 449 106 359 280 484 551 524 730 383 695 707 530 257 631 113 448
 313 574 377 399 459 343 516  79 565 702  27 563 347 457 757 177 273 395
 518 434 135 5