## Plantard arithmetic
Input: $A,B, P, R, n$ with $0 \leq A,B \leq P$ and $R = P^{−1} \bmod 2^{2n}$

Output: $C$ with $0 \leq C < P$ and $C = AB(−2^{−2n}) \bmod P$

$C ← [([[ABR]_{2n}]^n + 1)P]^n$

if $C=P$ return 0

else return $C$

In [2]:
import math
import random
def EX_GCD(a,b,arr): #扩展欧几里得
    if b == 0:
        arr[0] = 1
        arr[1] = 0
        return a
    g = EX_GCD(b, a % b, arr)
    t = arr[0]
    arr[0] = arr[1]
    arr[1] = t - int(a / b) * arr[1]
    return g
def ModReverse(a,n): #ax=1(mod n) 求a模n的乘法逆x
    arr = [0,1,]
    gcd = EX_GCD(a,n,arr)
    if gcd == 1:
        return (arr[0] % n + n) % n
    else:
        return -1

# Constants for Plantard arithmetic

In [21]:
# NTTRU Parameters
m=7681
n=16
alpha=2
w=20

m2=m**2 # m^2
alpha2=2**alpha # 2^(\alpha)
alpha4=alpha2**2 # 2^(2\alpha)
R=2**(2*n)
r=2**n
minv=ModReverse(m,R) # m^-1 mod R
minv_mont=ModReverse(m,r)
winv=ModReverse(w,m)
print("w^-1 mod m",winv)
# minv1=ModReverse(-m,r)
PLANT_CONST=(-R)%m # Like Mont, Plant_const = -(2^2n) mod m
Mont_Const=(r)%m
pm=(PLANT_CONST*minv)%R
PLANT_CONST2=(PLANT_CONST**2)%m #Use to calculate a*PLANT_CONST%m=Plant_Mul(a,PLANT_CONST2), in NTT precomputation.
from_plant_cont=(PLANT_CONST2*minv)%R
print("p",m, hex(m))
print("p^-1",minv, hex(minv))
print("p^-1 mont",minv_mont, hex(minv_mont))
# print("-p^-1",minv1, hex(minv))
print("Mont_Const 2^n mod p",Mont_Const,hex(Mont_Const))
print("PLANT_CONST -2^2n mod p",PLANT_CONST,hex(PLANT_CONST))
print("2^n mod p",Mont_Const,hex(Mont_Const))
print("-2^2n*p^-1 mod R",pm,hex(pm))
print("PLANT_CONST2",PLANT_CONST2,hex(PLANT_CONST2))
print("from_plant_cont",from_plant_cont,hex(from_plant_cont))
print(ModReverse(2868,m))

w^-1 mod m 7297
p 7681 0x1e01
p^-1 2340676097 0x8b83e201
p^-1 mont 57857 0xe201
Mont_Const 2^n mod p 4088 0xff8
PLANT_CONST -2^2n mod p 2112 0x840
2^n mod p 4088 0xff8
-2^2n*p^-1 mod R 559168 0x88840
PLANT_CONST2 5564 0x15bc
from_plant_cont 1180962236 0x46640dbc
5043


In [4]:
# Montgomery things.
Mont_minv=ModReverse(-m,r) #-q^-1

print("p^-1 mod 2^16",Mont_minv, hex(Mont_minv))
print(m*Mont_minv%r)

print(r%m)
print(r*r%m)
# Barrett_const=log(m)-1



p^-1 mod 2^16 7679 0x1dff
65535
4088
5569


# Modular Multiplication;
Input: $a,b\in [-2^{\alpha}m,2^{\alpha}m]$
Output: $ab(-2^{-2n}) \bmod m$

In [5]:
def Plant_Mul(a,b):
	c=(((((a*b*minv)%R)//r)+alpha2)*m//r)
	return c

# Modular Reduction
Input: $a\in [-2^{2\alpha}m^2,2^{2\alpha}m^2]$; This range is much bigger than Montgomery reduction which is: $2^{n-1}*p$;
6.5 times bigger than Montgomery reduction in Kyber. Maybe bigger in Dilithium. $2^{2*8}*m^2$compared to $2^{31}*m$; 255 times bigger.

Output: $a(-2^{-2n}) \bmod m$

In [6]:
def Plant_Red(a):
	c=(((((a*minv)%R)//r)+alpha2)*m//r)
	return c

# Modular Reduction with normal input and normal output.
Input: $a\in [-2^{2\alpha}m,2^{2\alpha}m]$; 

Output: $a \bmod m$

This is achieved by first multiply a term:$pm=(PLANT\_CONST*minv)\bmod R$

In [7]:
def Plant_Red_normal(a):
	c=(((((a*pm)%R)//r)+alpha2)*m//r)
	return c

# Test if the Plant_mul algorithm is right.

In [8]:
def generateMulTest(num):
    for i in range(num):
        [a,b]=random.sample(range(-alpha2*m,alpha2*m),2)
        c1=(a*b)%m
        c2=Plant_Mul(a,b)
        c2=(c2*PLANT_CONST)%m #AB
        if c1 != c2:
            print("ERROR: {0} * {1} mod {2} = {3} BUT GET {4} INSTEAD." .format(a,b,m,c1,c2))
        if (i % 100000)==99999:
            print("ROUND {0} TEST FINISHED.".format(i+1))

generateMulTest(100000)

ROUND 100000 TEST FINISHED.


# Test if Plantard Reduction is right

In [9]:
def generateRedTest(num):
    
    for i in range(num):
        [a]=random.sample(range(-alpha4*m2,alpha4*m2),1)
        c1=(a)%m
        c2=Plant_Red(a)
        c2=(c2*PLANT_CONST)%m #AB
        if c1 != c2:
            print("ERROR: {0}(-2^(-2n)) mod {1} = {2} BUT GET {3} INSTEAD." .format(a,m,c1,c2))
        if (a== alpha4*m2-1):
            print("ROUND {0} TEST FINISHED.".format(i+1))
        # a=a+1

generateRedTest(100000)

In [10]:
def generateRed_normalTest(num):
    
    for i in range(num):
        [a]=random.sample(range(-alpha4*m,alpha4*m),1)
        c1=(a)%m
        c2=Plant_Red_normal(a)
        if c1 != c2:
            print("ERROR: {0} mod {1} = {2} BUT GET {3} INSTEAD." .format(a,m,c1,c2))
        if (a== alpha4*m2-1):
            print("ROUND {0} TEST FINISHED.".format(i+1))
        # a=a+1

generateRed_normalTest(100000)

ERROR: -61448 mod 7681 = 0 BUT GET 7681 INSTEAD.
ERROR: -122896 mod 7681 = 0 BUT GET 7681 INSTEAD.
ERROR: -15362 mod 7681 = 0 BUT GET 7681 INSTEAD.
ERROR: -122896 mod 7681 = 0 BUT GET 7681 INSTEAD.
ERROR: -15362 mod 7681 = 0 BUT GET 7681 INSTEAD.
ERROR: -7681 mod 7681 = 0 BUT GET 7681 INSTEAD.


# Test NTT butterfly computation.

In [11]:
# testing NTT mul scheme: a * btp^-1 =ab
def Butterfly_Mul(a,twiddle):
	c=(((((a*twiddle)%R)//r)+alpha2)*m//r)
	return c

def generateNTTTest(num):
    for i in range(num):
        [b]=random.sample(range(-alpha2*m,alpha2*m),1)
        [a]=random.sample(range(-2**(n-1),2**(n-1)),1)
        w=(((b*PLANT_CONST)%m)*minv)%R # Twiddle factor*PLANT_CONST*minv
        # w=Plant_Mul(b,PLANT_CONST2) # b*PLANT_CONST %m
        # w=(w*minv)%R
        c1=Butterfly_Mul(a,w) #
        c2=(a*b)%m #AB
        if c1 != c2:
            print("ERROR: {0} * {1} mod {2} = {3} BUT GET {4} INSTEAD." .format(a,b,m,c2,c1))
        if (i % 100000)==99999:
            print("ROUND {0} TEST FINISHED.".format(i+1))

generateNTTTest(100000)


ERROR: -7681 * 7122 mod 7681 = 0 BUT GET 7681 INSTEAD.
ERROR: -30724 * 14466 mod 7681 = 0 BUT GET 7681 INSTEAD.
ERROR: -7681 * 29614 mod 7681 = 0 BUT GET 7681 INSTEAD.
ROUND 100000 TEST FINISHED.


# Test from Plantard

In [12]:
# testing NTT mul scheme: a * btp^-1 =ab

def from_plant(a):
	c=(((((a*from_plant_cont)%R)//r)+alpha2)*m//r)
	return c

def generateNTTTest(num):
    for i in range(num):
        [a]=random.sample(range(-2**(n-1),2**(n-1)),1)
        c2=Plant_Mul(a,1)
        c1=from_plant(c2)
        if (a%m) != c1:
            print("ERROR: {0} * {1} = {2} BUT GET {3} INSTEAD." .format(a,from_plant_cont,a,c1))
        if (i % 100000)==99999:
            print("ROUND {0} TEST FINISHED.".format(i+1))

generateNTTTest(100000)

ROUND 100000 TEST FINISHED.


# Test Final reduction to reduce 16-bit integer back to positive integers.
input: $a\in[-2^{n-1},2^{n-1}-1]$
output: $c\in [0,2^n]$?

In [13]:
def Plant_Red_Positive(a):
	c=((((((a+10*m)*minv)%R)//r)+alpha2)*m//r)
	return c

# Test if there exists ab, s.t. [[[ab]_2n p']_2n]^n+2^alpha>2^n-1

In [14]:
def generateBoundTest():
	a=-R
	while(a<R):
		# [a]=random.sample(range(-R,R),1)
		c=(((a*minv)%R)//r)
		t=(c>>3) & 1
		c=c>>3
		i=3
		while(c>0 and t!=0):
			t=t&(c>>1)
			c=c>>1
			i=i+1
		if(t==1 and i==15):
			print("Error: {0}*{1}+alpha2={2}" .format(hex(a),hex(minv),hex(c)))
		a=a+1
# generateBoundTest()


In [15]:
# ab=a*b%m
# print("AB",c*PLANT_CONST%m)
# print("AB",ab)
# print("A=AB/B",(ab*ModReverse(b,m))%m)
# print("B=AB/A",(ab*ModReverse(a,m))%m)
minus2_ninv=ModReverse(-R,m)
# print("(-2^-2n) mod p",hex(minus2_ninv))
# print("R*R-1", R*minus2_ninv%m)
# print("c=AB(-2^-2n) mod p",c)
# test=a*b*minus2_ninv%m
# print("AB(-2^-2n) mod p",test)

# # test if ((ab mod q)c) mod q== (abc) mod q
# print("test_c",test_c%m)
# c1=c*PLANT_CONST%m # AB
# print("AB",c1)
# c1=(((((c1*test_c*minv)%R)//r)+1)*m//r)
# product=((a*b%R)*test_c)%R
# c2=(((((product*minv)%R)//r)+1)*m//r)
# print("((ab mod q)c) mod q",c)
# print("(abc) mod q",c2)

In [16]:
# Test if negative inputs A,B work?
i=4*(m**2)
count=0
a=-3328
b=453
# [a,b]=random.sample(range(-m,m),2)
c=((((((a*b*minv)%R)//r)+1)*m//r))%m
test=(a*b*minus2_ninv)
print("test",test,hex(test),hex(test%m))
print("abP^-1",a*b*minv,hex(a*b*minv),hex(a*b*minv%R),hex(((a*b*minv%R)>>16)+1))

abpinv=a*b*minv//R+1
print("abpinv/R",abpinv)
tmp=a*b*minv-abpinv*R
print(hex(tmp))

k=((tmp)>>16)+1
print("k*P",hex(k*m),hex((k*m)>>16),hex(((k*m)>>16)%m))
a=-3328
b=454
# [a,b]=random.sample(range(-m,m),2)
c=((((((a*b*minv)%R)//r)+1)*m//r))%m
test=(a*b*minus2_ninv)%m
print("test",test,hex(test),hex(test%m))
print("abP^-1",a*b*minv,hex(a*b*minv),hex(a*b*minv%R),hex(((a*b*minv%R)>>16)+1))
abpinv=a*b*minv
print("abpinv/R",abpinv/R)
print("abpinv//R",abpinv//R)
k=((a*b*minv%R)>>16)+1
print("k*P",hex(k*m),hex((k*m)>>16),hex(((k*m)>>16)%m))


test -6310746624 -0x178265a00 0x565
abP^-1 -3528765833019648 -0xc8964d1f90100 0x2e06ff00 0x2e07
abpinv/R -821604
-0xd1f90100
k*P -0x189bfff9 -0x189c 0x565
test 3707 0xe7b 0xe7b
abP^-1 -3536555603070464 -0xc907a84730e00 0x7b8cf200 0x7b8d
abpinv/R -823418.5173805952
abpinv//R -823419
k*P 0xe7b018d 0xe7b 0xe7b


In [17]:
# New Reduction to replace Barrett reduction
import math
const=(PLANT_CONST*minv)%R #R^(-1)*(-2^2n mod P)
print(hex(const))
def red(c):
	d=(((((c*const)%R)//r)+8)*m//r)
	if(d==m):
		return d-m
	else:
		return d

k=3
n=16
c=-2**(n)
while(c<2**(n)):
	# [c]=random.sample(range(-2**(n-1),2**(n-1)),1)
	c1=red(c)
	c2=c%m
	if(c1!=c2):
		print("false; Expect:{0}, get:{1}" .format(c2,c1))
		break
	c=c+1
if(c==2**n):
	print("true")

0x88840
true


1. double new reduction on Cortex-M4
<!-- const=R^(-1)*(-2^(2n) mod p) mod 2^(2n) -->
```
smlawb d1, const, c, (8<<16)
smlawt d2, const, c, (8<<16)
smlawb d1, p, d1, 0
smlawb d2, p, d2, 0
pkhbt c, d1, d2, lsr #16
```
Compared to the 8-instruction Barrett reduction, we only need 5 instructions. We can replace Barrett reduction in LBC.

1. New modular multiplication on NTT 

In [18]:
# # Test negative modulo 2^2n: input: (-2^n,0)
# def twos_comp(val, bits=32):
# 	"""compute the 2's complement of int value val"""
# 	if (val & (1 << (bits - 1))) != 0:	# if sign bit is set e.g., 8bit: 128-255
# 		val = val - (1 << bits)			# compute negative value
# 	return val							# return positive value as is

# c=-R
# while(c<0):
# 	print("c=",bin(c))
# 	r=c%r
# 	print("c%m=",bin(r))
# 	c=c+1

Test for Base mul

In [19]:
#test if (ab mod q+cd mod q)=(ab+cd) mod q
[a,b,c,d]=random.sample(range(0,m),4)
ab1=(((((a*b*minv)%R)//r)+1)*m//r)
cd1=(((((c*d*minv)%R)//r)+1)*m//r)
ab1=(ab1+cd1)%m
ab2=a*b+c*d
ab2=(((((ab2*minv)%R)//r)+1)*m//r)
print("ab1",ab1)
print("ab2",ab2)

ab1 6245
ab2 6245
