# This is used to generate twiddle factors for Kyber.

In [1]:
import math
import random
def EX_GCD(a,b,arr): #扩展欧几里得
    if b == 0:
        arr[0] = 1
        arr[1] = 0
        return a
    g = EX_GCD(b, a % b, arr)
    t = arr[0]
    arr[0] = arr[1]
    arr[1] = t - int(a / b) * arr[1]
    return g
def ModReverse(a,n): #ax=1(mod n) 求a模n的乘法逆x
    arr = [0,1,]
    gcd = EX_GCD(a,n,arr)
    if gcd == 1:
        return (arr[0] % n + n) % n
    else:
        return -1

# Constants for Kyber

In [2]:
M=3329
n=16
N=32
r=2**n
R=2**N
alpha=3
alpha2=2**alpha
Mprime_mont=3327
Mprime_plant=1806234369
Mont_const=r%M
Mont_const2=(r**2)%M
Plant_const=(-R)%M
Plant_const2=((-R)**2)%M
root_of_unity=17
and32=(1>>32)-1
and64=(1>>64)-1
Mont_const_inv=ModReverse(Mont_const,M)
Plant_const_inv=ModReverse(Plant_const,M)
print("Mont_const_inv",Mont_const_inv)
print("Plant_const",Plant_const)
print("Mont_const",Mont_const)
print("Plant_const2",Plant_const2)
print("Mont_const*Mont_const*root_of_unity%M",Mont_const*Mont_const*root_of_unity%M)
print("128^-1*Mont^2 %M",(ModReverse(128,M)*Mont_const2)%M)
print("128^-1*Plant^2 %M",(ModReverse(128,M)*Plant_const2)%M)
print("plant*Mprime_plant",Plant_const*Mprime_plant%R)

Mont_const_inv 169
Plant_const 1976
Mont_const 2285
Plant_const2 2988
Mont_const*Mont_const*root_of_unity%M 3027
128^-1*Mont^2 %M 1441
128^-1*Plant^2 %M 2208
plant*Mprime_plant 1290168


# Basic modular arithmetic for computing NTT twiddle factors

In [3]:

# Plantard reduction
def plant_red(a):
	t=(((((a*Mprime_plant)%R)//r)+alpha2)*M//r)
	return t
def plant_mul(a,b):
	return plant_red(a*b)
def plant_mul_ntt(a,zeta):
	t=(((((a*zeta)%R)//r)+alpha2)*M//r)
	return t

# Generation tree for C-version NTT and INTT

In [4]:
# For C-version NTT
# tree = [
#   0, 64, 32, 96, 16, 80, 48, 112, 8, 72, 40, 104, 24, 88, 56, 120,
#   4, 68, 36, 100, 20, 84, 52, 116, 12, 76, 44, 108, 28, 92, 60, 124,
#   2, 66, 34, 98, 18, 82, 50, 114, 10, 74, 42, 106, 26, 90, 58, 122,
#   6, 70, 38, 102, 22, 86, 54, 118, 14, 78, 46, 110, 30, 94, 62, 126,
#   1, 65, 33, 97, 17, 81, 49, 113, 9, 73, 41, 105, 25, 89, 57, 121,
#   5, 69, 37, 101, 21, 85, 53, 117, 13, 77, 45, 109, 29, 93, 61, 125,
#   3, 67, 35, 99, 19, 83, 51, 115, 11, 75, 43, 107, 27, 91, 59, 123,
#   7, 71, 39, 103, 23, 87, 55, 119, 15, 79, 47, 111, 31, 95, 63, 127]
tree= [0, 64, 32, 96, 16, 80, 48, 112, 8, 4, 68, 2, 66, 34, 98, 72, 36, 100, 18, 82, 50, 114, 40, 20, 84, 10, 74, 42, 106, 104, 52, 116, 26, 90, 58, 122, 24, 12, 76, 6, 70, 38, 102, 88, 44, 108, 22, 86, 54, 118, 56, 28, 92, 14, 78, 46, 110, 120, 60, 124, 30, 94, 62, 126, 1, 65, 33, 97, 17, 81, 49, 113, 9, 73, 41, 105, 25, 89, 57, 121, 5, 69, 37, 101, 21, 85, 53, 117, 13, 77, 45, 109, 29, 93, 61, 125, 3, 67, 35, 99, 19, 83, 51, 115, 11, 75, 43, 107, 27, 91, 59, 123, 7, 71, 39, 103, 23, 87, 55, 119, 15, 79, 47, 111, 31, 95, 63, 127]
# Original precomputed zetas and inverse zetas for Kyber in Mont domain. 
zetas_asm = [
    # 7 & 6 & 5 & 4 layers
    2571, 2970, 1812, 1493, 1422, 287, 202, 3158, 622, 1577, 182, 962, 2127, 1855, 1468, 
    # 1st loop of 3 & 2 & 1 layers
    573, 1223, 652, 2226, 430, 555, 843, 2004, 2777, 1015, 2078, 871, 1550, 105,
    # 2nd loop of 3 & 2 & 1 layers
    264, 2036, 1491, 422, 587, 177, 3094, 383, 3047, 1785, 3038, 2869, 1574, 1653,
    # 3rd loop of 3 & 2 & 1 layers
    2500, 516, 3321, 3083, 778, 1159, 3182, 1458, 3009, 2663, 2552, 1483, 2727, 1119,
    # 4th loop of 3 & 2 & 1 layers
    1727, 1711, 2167, 1739, 644, 2457, 349, 3199, 126, 1469, 418, 329, 3173, 3254,
    # 5th loop of 3 & 2 & 1 layers
    2648, 2476, 3239, 817, 1097, 603, 610, 1017, 3058, 830, 1322, 2044, 1864, 384,
    # 6th loop of 3 & 2 & 1 layers
    732, 107, 1908, 2114, 3193, 1218, 1994, 608, 3082, 2378, 2455, 220, 2142, 1670,
    # 7th loop of 3 & 2 & 1 layers
    1787, 2931, 961, 2144, 1799, 2051, 794, 411, 1821, 2604, 1819, 2475, 2459, 478,
    # 8th loop of 3 & 2 & 1 layers
    3124, 448, 2264, 3221, 3021, 996, 991, 1758, 677, 2054, 958, 1869, 1522, 1628, 0
]
# Original precomputed zetas and inverse zetas for Kyber in Mont domain.
zetas_inv_CT_asm =[
    # pad + LAYER 7+6+5+4
    2285, 2285, 758, 2285, 1517, 758, 359, 2285, 3127, 1517, 1907, 758, 3042, 359, 1836,
    # removed first "2285" + LAYER 3+2+1 - 1 - butterfly
    2285, 758, 2285, 1517, 758, 359,
    # LAYER 3+2+1 - 1 - twist
    1441, 3104, 1047, 987, 1932, 2861, 713, 1254,

    # LAYER 3+2+1 - 2 - butterfly
    1861, 1571, 205, 1275, 1065, 2652, 2881,
    # LAYER 3+2+1 - 2 - twist
    2043, 1945, 1824, 1233, 3051, 2714, 2196, 2032,

    # LAYER 3+2+1 - 3 - butterfly
    3127, 1861, 1474, 1571, 2918, 205, 1542,
    # LAYER 3+2+1 - 3 - twist
    316, 1681, 2653, 660, 2921, 3097, 325, 707,

    # LAYER 3+2+1 - 4 - butterfly
    3147, 130, 1602, 1860, 1162, 3203, 1618,
    # LAYER 3+2+1 - 4 - twist
    1781, 1078, 1331, 3172, 3305, 378, 2369, 1804,

    # LAYER 3+2+1 - 5 - butterfly
    1517, 3127, 3042, 1861, 1202, 1474, 2367,
    # LAYER 3+2+1 - 5 - twist
    2063, 1630, 2624, 1949, 1761, 1393, 531, 2456,

    # LAYER 3+2+1 - 6 - butterfly
    1202, 2721, 2597, 951, 1421, 247, 3222,
    # LAYER 3+2+1 - 6 - twist
    513, 1075, 546, 3052, 1866, 2236, 1402, 2886,

    # LAYER 3+2+1 - 7 - butterfly
    1907, 3147, 1752, 130, 1871, 1602, 829,
    # LAYER 3+2+1 - 7 - twist
    226, 1434, 2382, 767, 2068, 719, 2824, 2128,

    # LAYER 3+2+1 - 8 - butterfly
    2707, 2946, 3065, 1544, 1838, 282, 1293,
    # LAYER 3+2+1 - 8 - twist
    2559, 476, 2490, 2395, 3059, 2588, 2516, 321,

    # LAYER 3+2+1 - 9 - butterfly
    758, 1517, 359, 3127, 1907, 3042, 1836,
    # LAYER 3+2+1 - 9 - twist
    738, 28, 2888, 1120, 2334, 1523, 148, 998,

    # LAYER 3+2+1 - 10 - butterfly
    1474, 2918, 1542, 725, 2368, 1508, 398,
    # LAYER 3+2+1 - 10 - twist
    1610, 2939, 1149, 1045, 2683, 1852, 792, 842,

    # LAYER 3+2+1 - 11 - butterfly
    3042, 1202, 2367, 2721, 2312, 2597, 681,
    # LAYER 3+2+1 - 11 - twist
    878, 1152, 1830, 2803, 3291, 2263, 1809, 637,

    # LAYER 3+2+1 - 12 - butterfly
    1752, 1871, 829, 666, 8, 320, 2813,
    # LAYER 3+2+1 - 12 - twist
    2989, 2026, 3045, 1144, 1956, 2483, 1673, 2779,

    # LAYER 3+2+1 - 13 - butterfly
    359, 1907, 1836, 3147, 2707, 1752, 171,
    # LAYER 3+2+1 - 13 - twist
    3309, 315, 2529, 2613, 1290, 1321, 1665, 2905,

    # LAYER 3+2+1 - 14 - butterfly
    2367, 2312, 681, 2499, 90, 271, 853,
    # LAYER 3+2+1 - 14 - twist
    3132, 606, 2107, 937, 1055, 861, 2252, 1150,

    # LAYER 3+2+1 - 15 - butterfly
    1836, 2707, 171, 2946, 1325, 3065, 2756,
    # LAYER 3+2+1 - 15 - twist
    1555, 2973, 2278, 2405, 1237, 2988, 2874, 3005,

    # LAYER 3+2+1 - 16 - butterfly
    171, 1325, 2756, 2314, 2677, 552, 2106, 2833,
    # LAYER 3+2+1 - 16 - twist
    1154, 134, 2883, 2031, 2134, 1344, 2135,0
]


# Init_ntt for C-version NTT and INTT

In [5]:
# Test if our generation match the original one
def init_ntt_c():
	tmp=[0]
	zetas=[]
	zetas_idx=[]
	zetas_inv=[]
	zetas_inv_idx = []
	tmp[0]=Mont_const
	for i in range(1,128):
		tmp.append((tmp[i-1]*root_of_unity)%M)
	# print("tmp_mont:",tmp)
	for i in range(0,128):
		zetas.append(tmp[tree[i]])
		zetas_idx.append(tree[i])
	print("zetas_mont:",zetas)
	print("zetas_mont_idx:", zetas_idx)
	i=64
	while i>=1:
		for j in range(i,2*i):
			zetas_inv.append(-tmp[128-tree[j]]%M)
			zetas_inv_idx.append(128-tree[j])
		i=i>>1
	
	final=Mont_const*(zetas_inv[-1]*(M-1)*((M-1)/128)%M)%M
	del zetas_inv[-1]
	zetas_inv.append(int(final))
	final=Mont_const*(Mont_const*(M-1)*((M-1)/128)%M)%M
	zetas_inv.append(int(final))
	print("zetas_inv_mont:",zetas_inv)
	print("zetas_inv_mont_idx:", zetas_inv_idx)
	return zetas, zetas_inv

def init_ntt_c_plant():
	tmp=[0]
	zetas=[]
	zetas_inv=[]
	tmp[0]=Plant_const
	for i in range(1,128):
		tmp.append((tmp[i-1]*root_of_unity)%M)
	# print("tmp_plant:",tmp)
	for i in range(0,128):
		zetas.append(tmp[tree[i]])
	print("zetas_plant:",zetas)
	i=64
	while i>=1:
		for j in range(i,2*i):
			zetas_inv.append(-tmp[128-tree[j]]%M)
		i=i>>1
	final=Plant_const*(zetas_inv[-1]*(M-1)*((M-1)/128)%M)%M
	del zetas_inv[-1]
	zetas_inv.append(int(final))
	final=Plant_const*(Plant_const*(M-1)*((M-1)/128)%M)%M
	zetas_inv.append(int(final))
	print("zetas_inv_plant:",zetas_inv)

def init_ntt_c_plant_from_mont(zetas,type):
	plant_ntt_asm=[]
	plant_ntt_asm_prime=[]
	Mont_inv=ModReverse(Mont_const,M)
	for zeta in zetas:
		t=Plant_const*zeta*Mont_inv%M
		plant_ntt_asm.append(t)
	if type==1: 
		# for inverse NTT. Finalize the final two twiddle factors in the final layer as (Plant^2*128^-1)%M 
		plant_ntt_asm[-2]=plant_ntt_asm[-2]*Plant_const*Mont_inv%M
		plant_ntt_asm[-1]=plant_ntt_asm[-1]*Plant_const*Mont_inv%M
	for zeta in plant_ntt_asm: 
		# zeta*Mprime_plant %R
		t=(zeta*Mprime_plant)%R
		plant_ntt_asm_prime.append(t)
	if type==0:
		print("plant_ntt_asm",plant_ntt_asm_prime)
	else:
		print("plant_inv_ntt_asm",plant_ntt_asm_prime)
# Success
def test_ntt_twiddle():
	# Mont ntt
	a=1243
	i=0
	for zeta in zetas_asm:
		b=zeta*a*Mont_const_inv%M # a*zeta
		d=plant_mul_ntt(a,plant_ntt_asm[i])# a*zeta
		if b!=d:
			print("Inequality:",i,b,d,zeta)
		i=i+1
# Success
def test_intt_twiddle():
	a=1243
	i=0
	for zeta in zetas_inv_asm:
		b=zeta*a*Mont_const_inv%M # a*zeta
		d=plant_mul_ntt(a,plant_inv_ntt_asm[i])# a*zeta
		if b!=d:
			if i==126 or i==127:
				b=b*Mont_const_inv%M
				d=d*Plant_const_inv%M
				print(b,d)
			else:
				print("Inequality:",i,b,d,zeta)
		i=i+1

# input: twiddle list, the total number of layers, number of the merging layer 
def permute_ntt_to_asm(twiddle, layer, merge_mode):
	step=0
	for i in range(0,merge_mode):
		step=step+2**i
	num=[]
	n=0
	for i in range(0,int(layer/merge_mode)):
		n = 2**(merge_mode*i)
		step = n*step
		num.append(step)		
	last_merge=layer%merge_mode
	step=0
	for i in range(0,last_merge):
		step=step+2**i
	n=2**(merge_mode*2)
	step=step*n
	num.append(step)

	# step shows the number of twiddle factors used at each butterfly
	print("Number of Steps:", num)
	i=0
	new_twiddle=[]
	while i<len(twiddle):
		# The first merging layer
		if(i<num[0]+1):
			new_twiddle.append(twiddle[i])
			i=i+1
		elif(i<(num[0]+num[1]+1)):
			n1 = 2**merge_mode
			n2=2*n1
			for j in range(0, n1):
				for k in range(0, merge_mode):
					for k1 in range(0,2**k):
						index = i+(2**k)*j+int((k+1)/2)*n1+int((k)/2)*n2+k1
						# print(index)
						new_twiddle.append(twiddle[index])
			i=i+num[1]
		elif(i<len(twiddle)):
			print("i",i)
			n1 = 2**(merge_mode*2)
			for j in range(0, n1):
				for k in range(0, last_merge):
					for k1 in range(0, 2**k):
						index = i+(2**k)*j+int((k+1)/2)*n1+k1
						print(index)
						new_twiddle.append(twiddle[index])
			break
	print("new_twiddle:", new_twiddle)


def permute_invntt_to_asm(twiddle, layer, merge_mode):
	step = 0
	for i in range(0, merge_mode):
		step = step+2**i
	num = []
	n = 0
	for i in range(0, int(layer/merge_mode)):
		n = 2**(merge_mode*i)
		step = n*step
		num.append(step)
	last_merge = layer % merge_mode
	step = 0
	for i in range(0, last_merge):
		step = step+2**i
	n = 2**(merge_mode*2)
	step = step*n
	num.append(step)
	print("num",num)
	# 7, 56, 192

	# step shows the number of twiddle factors used at each merge_mode of butterflies
	new_twiddle = []
	i = 0
	while i < len(twiddle):
		if(i < num[2]):
			n1 = 2**(merge_mode*2+1)
			for j in range(0, num[2]):
				for k in range(0, last_merge):
					k2 = last_merge-k-1
					for k1 in range(0, 2**(k2)):
						index = i+(2**(k2))*j+int((k+1)/2)*n1+k1
						print("index", index)
						new_twiddle.append(twiddle[index])
			print("end of the first two layers")
			i = i+num[2]
		elif i < (num[2]+num[1]):
			n1 = 2**(merge_mode+1)#16
			n2 = 2*n1#32
			for j in range(0, int(num[1]/7)):
				for k in range(0, merge_mode):
					k2 = merge_mode-k-1
					for k1 in range(0, 2**(k2)):
						index = i+(2**(k2))*j+int((k+1)/2)*n2+int((k)/2)*n1+k1
						print("index",index)
						new_twiddle.append(twiddle[index])
			print("end of the second two layers")
			i = i+num[1]
		elif i < len(twiddle):
			print("index", i)
			new_twiddle.append(twiddle[i])
			i = i+1
	print("new_twiddle:", new_twiddle)

# init_ntt_c()
# init_ntt_c_plant()
init_ntt_c_plant_from_mont(zetas_asm,0)
# init_ntt_c_plant_from_mont(zetas_inv_asm,1)
# zetas, zetas_inv=init_ntt_c()
# print("zetas:",zetas)
# print("zetas_inv", zetas_inv)
# permute_ntt_to_asm(zetas, 7,3)
# permute_invntt_to_asm(zetas_inv, 7, 3)

# test_ntt_twiddle()
# test_intt_twiddle()

plant_ntt_asm [2230699446, 3328631909, 4243360600, 3408622288, 812805467, 2447447570, 1094061961, 1370157786, 2475831253, 249002310, 1028263423, 3594406395, 4205945745, 734105255, 2252632292, 381889553, 372858381, 427045412, 21932846, 3562152210, 752167598, 3417653460, 3157039644, 4196914574, 2265533966, 2112004045, 932791035, 2951903026, 1419184148, 1727534158, 1544330386, 2972545705, 1817845876, 3434425636, 4233039261, 300609006, 1904287092, 2937711185, 2651294021, 975366560, 2781600929, 3889854731, 3935010590, 3929849920, 838608815, 2550660963, 2197155094, 2130066389, 3598276897, 2308109491, 72249375, 3242190693, 815385801, 2382939200, 1228239371, 1884934581, 3466679822, 2889974991, 3696329620, 42575525, 1211467195, 2977706375, 3144137970, 3080919767, 1719793153, 1703020977, 2470670584, 945692709, 3015121229, 345764865, 826997308, 1839778722, 2991898216, 1851390229, 2043625172, 2964804700, 2628071007, 4154339049, 2701610550, 1041165097, 583155668, 483812778, 3288636719, 2696449880, 

In [6]:
# Useful for learning how to reschedule its schedule
tmp=[2285, 2226, 1223, 817, 573, 3083, 2476, 2144, 3158, 422, 516, 2114, 2648, 1739, 2931, 3221, 1493, 2078, 2036, 1322, 2500, 2552, 107, 1819, 962, 3038, 1711, 2455, 1787, 418, 448, 958, 2970, 555, 2777, 603, 264, 1159, 3058, 2051, 1577, 177, 3009, 1218, 732, 2457, 1821, 996, 287, 1550, 3047, 1864, 1727, 2727, 3082, 2459, 1855, 1574, 126, 2142, 3124, 3173, 677, 1522, 2571, 430, 652, 1097, 2004, 778, 3239, 1799, 622, 587, 3321, 3193, 1017, 644, 961, 3021, 1422, 871, 1491, 2044, 1458, 1483, 1908, 2475, 2127, 2869, 2167, 220, 411, 329, 2264, 1869, 1812, 843, 1015, 610, 383, 3182, 830, 794, 182, 3094, 2663, 1994, 608, 349, 2604, 991, 202, 105, 1785, 384, 3199, 1119, 2378, 478, 1468, 1653, 1469, 1670, 1758, 3254, 2054, 1628]
tmp=[1976, 302, 1805, 724, 2321, 2838, 1640, 1248, 1242, 1140, 2735, 3218, 1442, 1211, 613, 434, 720, 2253, 1682, 1962, 64, 1088, 1851, 1506, 2299, 2464, 1940, 3019, 1388, 293, 1652, 1452, 1381, 174, 2958, 351, 2638, 1569, 41, 697, 1862, 1693, 2149, 3243, 1867, 1778, 265, 1176, 18, 306, 1873, 1880, 1999, 693, 1794, 537, 2471, 2059, 1713, 2489, 2365, 257, 1040, 1035, 950, 2834, 1572, 92, 1564, 3285, 2581, 600, 213, 292, 1635, 1163, 3126, 3207, 1255, 1361, 3163, 507, 1961, 47, 799, 267, 1210, 596, 145, 2465, 1957, 3308, 2972, 589, 26, 442, 856, 1236, 1038, 1001, 372, 2995, 980, 15, 255, 1006, 457, 1111, 2242, 1495, 2112, 2614, 1161, 3092, 2629, 1416, 769, 3086, 2527, 3011, 1252, 1310, 2296, 2413, 1073, 1596, 500, 1842]


In [7]:
def test_base_mul():
	i=0
	while i<10000:
		[a,b,zeta]=random.sample(range(-M,M),3)
		c=((a*b%M)*zeta%M)
		d=(a*(b*zeta%M)%M)
		if c!=d:
			print("Inequal")
			break
		i=i+1
	if i==10000:
		print("Equal")
test_base_mul()

Equal
