# This is used to generate twiddle factors for Kyber.

In [2]:
import math
import random
def EX_GCD(a,b,arr): #扩展欧几里得
    if b == 0:
        arr[0] = 1
        arr[1] = 0
        return a
    g = EX_GCD(b, a % b, arr)
    t = arr[0]
    arr[0] = arr[1]
    arr[1] = t - int(a / b) * arr[1]
    return g
def ModReverse(a,n): #ax=1(mod n) 求a模n的乘法逆x
    arr = [0,1,]
    gcd = EX_GCD(a,n,arr)
    if gcd == 1:
        return (arr[0] % n + n) % n
    else:
        return -1

# Constants for Kyber

In [3]:
M=3329
n=16
N=32
r=2**n
R=2**N
alpha=3
alpha2=2**alpha
Mprime_mont=ModReverse(-M,r)
Mprime_plant=ModReverse(M,R)

Mont_const=r%M
Plant_const=(-R)%M

root_of_unity=17

and32=(1>>32)-1
and64=(1>>64)-1
print("q",M)
print("qa",M*alpha2)
print("qinv_plant=q^-1 mod 2^2n", Mprime_plant, hex(Mprime_plant))
print("qinv_mont=-q^-1 mod 2^n", Mprime_mont)
print("Mont_const==2^n mod q",Mont_const)
print("Plant_const=-2^2n mod q",Plant_const) # 
print("plant*qinv_plant= q^-1 (-2^2n) mod 2^2n",Plant_const*Mprime_plant%R)

q 3329
qa 26632
qinv_plant=q^-1 mod 2^2n 1806234369 0x6ba8f301
qinv_mont=-q^-1 mod 2^n 3327
Mont_const==2^n mod q 2285
Plant_const=-2^2n mod q 1976
plant*qinv_plant= q^-1 (-2^2n) mod 2^2n 1290168


# Basic modular arithmetic for computing NTT twiddle factors

In [4]:

# Plantard reduction
def plant_red(a):
	t=(((((a*Mprime_plant)%R)//r)+alpha2)*M//r)
	return t
def plant_mul(a,b):
	return plant_red(a*b)
def plant_mul_ntt(a,zeta):
	t=(((((a*zeta)%R)//r)+alpha2)*M//r)
	return t

# Generation tree for C-version NTT and INTT

In [5]:
# For C-version NTT
# tree = [
#   0, 64, 32, 96, 16, 80, 48, 112, 8, 72, 40, 104, 24, 88, 56, 120,
#   4, 68, 36, 100, 20, 84, 52, 116, 12, 76, 44, 108, 28, 92, 60, 124,
#   2, 66, 34, 98, 18, 82, 50, 114, 10, 74, 42, 106, 26, 90, 58, 122,
#   6, 70, 38, 102, 22, 86, 54, 118, 14, 78, 46, 110, 30, 94, 62, 126,
#   1, 65, 33, 97, 17, 81, 49, 113, 9, 73, 41, 105, 25, 89, 57, 121,
#   5, 69, 37, 101, 21, 85, 53, 117, 13, 77, 45, 109, 29, 93, 61, 125,
#   3, 67, 35, 99, 19, 83, 51, 115, 11, 75, 43, 107, 27, 91, 59, 123,
#   7, 71, 39, 103, 23, 87, 55, 119, 15, 79, 47, 111, 31, 95, 63, 127]
tree= [0, 64, 32, 96, 16, 80, 48, 112, 8, 4, 68, 2, 66, 34, 98, 72, 36, 100, 18, 82, 50, 114, 40, 20, 84, 10, 74, 42, 106, 104, 52, 116, 26, 90, 58, 122, 24, 12, 76, 6, 70, 38, 102, 88, 44, 108, 22, 86, 54, 118, 56, 28, 92, 14, 78, 46, 110, 120, 60, 124, 30, 94, 62, 126, 1, 65, 33, 97, 17, 81, 49, 113, 9, 73, 41, 105, 25, 89, 57, 121, 5, 69, 37, 101, 21, 85, 53, 117, 13, 77, 45, 109, 29, 93, 61, 125, 3, 67, 35, 99, 19, 83, 51, 115, 11, 75, 43, 107, 27, 91, 59, 123, 7, 71, 39, 103, 23, 87, 55, 119, 15, 79, 47, 111, 31, 95, 63, 127]
# Original precomputed zetas and inverse zetas for Kyber in Mont domain. 
zetas_asm = [
    # 7 & 6 & 5 & 4 layers
    2571, 2970, 1812, 1493, 1422, 287, 202, 3158, 622, 1577, 182, 962, 2127, 1855, 1468, 
    # 1st loop of 3 & 2 & 1 layers
    573, 1223, 652, 2226, 430, 555, 843, 2004, 2777, 1015, 2078, 871, 1550, 105,
    # 2nd loop of 3 & 2 & 1 layers
    264, 2036, 1491, 422, 587, 177, 3094, 383, 3047, 1785, 3038, 2869, 1574, 1653,
    # 3rd loop of 3 & 2 & 1 layers
    2500, 516, 3321, 3083, 778, 1159, 3182, 1458, 3009, 2663, 2552, 1483, 2727, 1119,
    # 4th loop of 3 & 2 & 1 layers
    1727, 1711, 2167, 1739, 644, 2457, 349, 3199, 126, 1469, 418, 329, 3173, 3254,
    # 5th loop of 3 & 2 & 1 layers
    2648, 2476, 3239, 817, 1097, 603, 610, 1017, 3058, 830, 1322, 2044, 1864, 384,
    # 6th loop of 3 & 2 & 1 layers
    732, 107, 1908, 2114, 3193, 1218, 1994, 608, 3082, 2378, 2455, 220, 2142, 1670,
    # 7th loop of 3 & 2 & 1 layers
    1787, 2931, 961, 2144, 1799, 2051, 794, 411, 1821, 2604, 1819, 2475, 2459, 478,
    # 8th loop of 3 & 2 & 1 layers
    3124, 448, 2264, 3221, 3021, 996, 991, 1758, 677, 2054, 958, 1869, 1522, 1628, 0
]

def init_ntt_c_plant_from_mont(zetas,type):
	plant_ntt_asm=[]
	plant_ntt_asm_prime=[]
	Mont_inv=ModReverse(Mont_const,M)
	for zeta in zetas:
		t=Plant_const*zeta*Mont_inv%M
		plant_ntt_asm.append(t)
	if type==1: 
		# for inverse NTT. Finalize the final two twiddle factors in the final layer as (Plant^2*128^-1)%M 
		plant_ntt_asm[-2]=plant_ntt_asm[-2]*Plant_const*Mont_inv%M
		plant_ntt_asm[-1]=plant_ntt_asm[-1]*Plant_const*Mont_inv%M
	for zeta in plant_ntt_asm: 
		# zeta*Mprime_plant %R
		t=(zeta*Mprime_plant)%R
		plant_ntt_asm_prime.append(t)
	if type==0:
		print("plant_ntt_asm",plant_ntt_asm_prime)
	else:
		print("plant_inv_ntt_asm",plant_ntt_asm_prime)
		
init_ntt_c_plant_from_mont(zetas_asm)