In [60]:
from sage.all import *
from fpylll import IntegerMatrix

#file generation

def gen_small(s, n):
	"""
	s+1 entries of 1s and s entries of -1s
	"""
	deg = n
	coeff_vector = deg*[0]
	coeff_vector[deg-1] = 1
	coeff_vector[0] = 1
	index_set = set({0,deg-1})
	for i in range(s-2):
	# add 1's
		while True:
			index1 = ZZ.random_element(1,deg-1)
			if not index1 in index_set:
				coeff_vector[index1] = 1
				index_set = index_set.union({index1})
				break
	# add -1's
	for i in range(s):
		while True:
			index2 = ZZ.random_element(1,deg-1)
			if not index2 in index_set:
				coeff_vector[index2] = -1
				index_set = index_set.union({index2})
				break
	return coeff_vector



def print_ntru(q, h, variable_x, filename):
	n = len(list(h))
	f = open(filename, 'w')
	f.write(str(q)+'\n')
	#f.write('[')
	HMat = [0]*n
	for i in range(n):
		hvector = list(h* variable_x**i)
		HMat[i] = hvector
		f.write( str(hvector).replace(',','') +'\n')
		#f.write( str(hvector)+'\n')
	#f.write(']')
	f.close()

	return HMat

def all_rotations(g, variable_x, q):
	n = len(list(g))
	rotations = [0]*(2*n)
	i = 0
	while i < n:
		rotations[2*i] = list(g*variable_x**i)
		rotations[2*i+1] = [-rotations[2*i][j] for j in range(len(rotations[2*i]))]
		for j in range(len(rotations[2*i])):
			if rotations[2*i][j] == q-1:
				rotations[2*i][j] = -1
			if rotations[2*i+1][j] == q-1:
				rotations[2*i+1][j] = -1
		i +=1
	return rotations

def gen_ntru_challenge(n,q=0):

	K = CyclotomicField(2*n)

	P = Primes()
	if (q==0):
		q = next_prime(55*n)


	F = GF(q)
	Fx = PolynomialRing(F, 'x')
	Fx_qou = Fx.quotient(K.polynomial(), 'x')
	variable_x = Fx_qou.gen()

	sparsity = ceil(n/3.)
	f_poly = (gen_small(sparsity, n))
	g_poly = (gen_small(sparsity, n))
	h = Fx_qou(f_poly)/Fx_qou(g_poly)

	rotations = all_rotations(Fx_qou(f_poly),variable_x,q)

	#print('g*h', Fx_qou(g_poly)*h)

	filename = 'ntru_n_'+str(n)+'.txt'
	Hmat = print_ntru(q, h, variable_x, filename)
	Hmat = matrix(ZZ,[hrow for hrow in Hmat])

	#print(len(Hmat), "is Hmat")
	print(len(vector(ZZ,g_poly)))
	
	qvec = vector(ZZ,g_poly)*Hmat - vector(f_poly)
	assert(len(qvec) == n)
	#print("qvec:", qvec)
	qvec_red = [0]*int(n)
	for i in range(n):
		assert qvec_red[i] % q == 0
		qvec_red[i]  = -qvec[i] / q
	#print("qvec_red:", qvec_red)
	B = matrix(ZZ, 2*n, 2*n)

	for i in range(n):
		B[i,i] = 1
		for j in range(n):
			B[i,j+n] = Hmat[i, j]
		B[i+n, i+n] = q
	#print("B:")
	#print(B)
	f_check = vector(list(g_poly) + list(qvec_red))*B
	#f_check = vector(ZZ, [f_check[i] for i in range(n)])
	#print(f_check, vector(f_poly))
	assert(f_check[:n]==vector(g_poly))
	assert(f_check[n:]==vector(f_poly))
	#print(norm(f_check))
	
	"""
	B = B.LLL()
	
	#print("B")
	#print(B)
	b0 = B[0]
	print('b0:', b0, norm(b0))

	print(Bred[0], norm(Bred[0]))

	for i in range(len(rotations)):
		if vector(b0) == vector(rotations[i]):
			print(i, rotations[i])
			break
	"""
	filename = 'ntru_n_'+str(n)+'_solution.txt'
	f = open(filename, 'w')
	f.write(str(list(f_poly)).replace(',','')+'\n')
	f.write(str(list(g_poly)).replace(',',''))
	f.close()

	return h, q
	
def ntru_plain_hybrid_basis(A, g, q, nsamples):
	"""
		Construct ntru lattice basis
	"""
	n = A.ncols
	ell = n - g
	print(n, ell, nsamples)
	B = IntegerMatrix((nsamples+ell), (ell+nsamples))
	Bg = IntegerMatrix(g, n)


	for i in range(ell):
		B[i,i] = 1
		for j in range(nsamples):
			B[i,j+n] = A[i, j]
	for i in range(nsamples):
		B[i+ell, i+ell] = q

	for i in range(g):
		for j in range(n):
			Bg[i,j] = A[i+ell, j]

	B = LLL.reduction(B)
	return B, Bg

n=256; q=next_prime(int(2^32*15.6))
h,q = gen_ntru_challenge(n,q)
print(h,q)

256
16689515869*x^255 + 44233824686*x^254 + 3261045813*x^253 + 25562826804*x^252 + 11134957798*x^251 + 60754311658*x^250 + 26433963256*x^249 + 29843829995*x^248 + 40178814242*x^247 + 28685695762*x^246 + 66161877269*x^245 + 5939189066*x^244 + 56056646893*x^243 + 61069301517*x^242 + 16031106285*x^241 + 2241381340*x^240 + 57836680140*x^239 + 2654235351*x^238 + 55871924917*x^237 + 41442017973*x^236 + 65427074089*x^235 + 17457525347*x^234 + 52181962580*x^233 + 5313062873*x^232 + 59677077313*x^231 + 40953108393*x^230 + 49172868554*x^229 + 52544305682*x^228 + 37054617141*x^227 + 57949001455*x^226 + 34521729385*x^225 + 46478930697*x^224 + 58986750107*x^223 + 64838048985*x^222 + 36220214766*x^221 + 50563243386*x^220 + 60990357399*x^219 + 63113258725*x^218 + 53787125645*x^217 + 22189309820*x^216 + 62562307395*x^215 + 11368063489*x^214 + 24098781993*x^213 + 63558935792*x^212 + 51739566392*x^211 + 13916846681*x^210 + 27610322831*x^209 + 60402653785*x^208 + 46504122935*x^207 + 30422010917*x^206 + 4

In [61]:
from fpylll import *
from fpylll.algorithms.bkz2 import BKZReduction as BKZ2
from sage.modules.free_module_integer import IntegerLattice

def H_delta(x):
	#СЌРІСЂРёСЃС‚РёРєР° beta-root Hermite factor
	tmp =( x/(2*pi*e)*(pi*x)^(1/x) )^(1/(2*(x-1)))
	return tmp

def H_delta(x,precision=144):
	#СЌРІСЂРёСЃС‚РёРєР° beta-root Hermite factor
	tmp =( x/(2*pi*e)*(pi*x)^(1/x) )^(1/(2*(x-1))).n(precision)
	return tmp

def left_side(q,n,k):
	return 1/(k*(3*k-1))*(k/2*log(k/2/pi/e,2)+k*log(q,2)-n*log(n/2,2))

def find_param(n, q, old=False,logs=False):
    min_beta=Infinity
    min_k = int(n/8)
    L=35; R=2*n
    
    while abs(L-R)>1:
        beta=int(L+(R-L)/2)
        
        if logs:
            print("now:",L,beta,R)
            
        delta_beta =log(H_delta(beta),2)
        
        flag_k_found=False
        if not old:
            l=max([beta,1/2*n]); r=2*n

            flag_k_found=False

            while abs(l-r)>1:
                k=int(l+(r-l)/2)
                tmp=left_side(q,n,k)

                #РµСЃР»Рё РјС‹ РЅР°С€Р»Рё РЅРѕРІС‹Р№ РјРёРЅРёРјСѓРј k, С‚Рѕ РїРѕРёСЃРє РїСЂРѕРґРѕР»Р¶Р°РµС‚СЃСЏ РЅРµ РїСЂР°РІРµРµ
                if tmp >= delta_beta:
                    flag_k_found=True
                    min_beta = beta
                    min_k = k
                    
                    if logs:
                        print(min_k, min_beta)
                        
                    r=k
                #РёРЅР°С‡Рµ РїРѕРёСЃРє РїСЂРѕРёСЃС…РѕРґРёС‚ С‚Р°Рј, РіРґРµ tmp-delta_beta Р±РѕР»СЊС€Рµ
                else:
                    left_=left_side(q,n,k-1)-delta_beta
                    right_=left_side(q,n,k+1)-delta_beta
                    if left_-right_<0:
                        l=k+1
                    else:
                        r=k
        
        else:
            for k in range(max(beta,min_k), 2*n, 10):
                tmp=left_side(q,n,k)
                if tmp >=delta_beta:
                    min_beta = beta
                    min_k = k
                    print( min_k, min_beta)
                    flag_k_found=True
                    break
        #РµСЃР»Рё РґР»СЏ РґР°РЅРЅРѕР№ Р±СЌС‚Р° РµСЃС‚СЊ СЂРµС€РµРЅРёРµ, С‚Рѕ РёС‰РµРј РЅРѕРІРѕРµ РЅРµ РїСЂР°РІРµРµ
        if flag_k_found:
            R=beta
        #РёРЅР°С‡Рµ РёС‰РµРј СЂРµС€РµРЅРёРµ РїСЂР°РІРµРµ
        else:
            L=beta+1
    return min_k, min_beta

def CenteredMod(a, q):
    a = a.mod(q)
    if a <= floor(q/2):
        return a
    else:
        return a-q
    
def read_ntru_from_file(filename):

    data = open(filename, "r").readlines()
    q = ZZ(data[0])
    H = eval(",".join([s_.replace('\n','').replace(" ", ", ") for s_ in data[1 :]]))
    
    n=len(H)
    H_=matrix(n)
    #print(H_)
    
    Zq = ZZ.quo(q)
    
    for t0 in range(0, n):
        for t1 in range(0, n):
            tmp=Zq(copy(H[t0][t1]))
            H_[t0,t1]=tmp
    
    return n, q, H_

path = 'ntru_n_'+str(n)+'.txt'
n, q, H = read_ntru_from_file(path)
print("n=", n, "q=",q)

k, beta = find_param(n,q)


print("k=", k, "beta=",beta)

import time
t=time.perf_counter()

H = IntegerMatrix.from_matrix(H)

print(time.perf_counter()-t, "sec taken")
C_Id = H
C_Id = Matrix(ZZ, [ [ CenteredMod(e, q) for e in C_Id[i] ] for i in range(n) ] )

print("Building basis of NTRU lattice...")
t=time.perf_counter()
B = block_matrix([ [ q*identity_matrix(n), zero_matrix(n) ], [ C_Id, identity_matrix(n) ] ] )
B = IntegerMatrix.from_matrix(B)
print(time.perf_counter()-t, "sec taken")

n= 256 q= 67001489827
k= 129 beta= 36
13.09608709999975 sec taken
Building basis of NTRU lattice...
53.800578100000166 sec taken


In [None]:
### PREPROC PHASE FOR FPYLLL ###
k=127; beta=38

print("Building submatrix of size 2*k =", 2*k, "...")
TIME = time.time()
Bk = B.submatrix(range(n-k,n+k), range(n-k,n+k)) # ~30s
FPLLL.set_precision(180) 
GSO_Bk = GSO.Mat(Bk, float_type='mpfr')
print("Preprocessing...")
Bk_BKZ = BKZ2(GSO_Bk) 
print("Done. time: ", (time.time() - TIME))   

### BKZ PHASE ###
flags = BKZ.AUTO_ABORT|BKZ.GH_BND|BKZ.VERBOSE|BKZ.MAX_LOOPS
#beta = 9
par = BKZ.Param(block_size=beta, strategies=BKZ.DEFAULT_STRATEGY, flags=flags, max_loops=10)
print("BKZ reduction with beta =", beta, "...")
TIME = time.time()
DONE = Bk_BKZ(par) #actual BKZ algorithm; updates Bk in place; ~15 hours
# if it fails because infinite loop in babai, set higher precision in FPLLL.set_precision() and rerun without restarting.
print("Done. time: ", (time.time() - TIME))

### WE WON ###
print( all(Bk[i].norm() < 2^24 for i in range(k)) )

Building submatrix of size 2*k = 254 ...


In [55]:
def write_basis(B, q, filename):
    f = open(filename, 'w')
    f.write(str(q)+'\n')
    for t in B:
        tmp=str(t).replace(',','') +'\n'
        f.write( tmp )
    f.close()

write_basis(Bk,q,'BKZ_n'+str(n)+'.txt')

In [5]:
"""H=H_delta(beta).n()
for t in range(1, Bk.ncols):
    tmp=sqrt(Bk[t].norm()/Bk[t-1].norm())
    print("Hermit: ", tmp , "ratio: ", tmp/H)"""

'H=H_delta(beta).n()\nfor t in range(1, Bk.ncols):\n    tmp=sqrt(Bk[t].norm()/Bk[t-1].norm())\n    print("Hermit: ", tmp , "ratio: ", tmp/H)'

In [6]:
"""left=H^(-k*(3*k-1)) * Bk[0].norm()^k
left"""

'left=H^(-k*(3*k-1)) * Bk[0].norm()^k\nleft'

In [7]:
"""right=product([Bk[i].norm() for i in range(k)])
right"""

'right=product([Bk[i].norm() for i in range(k)])\nright'

In [41]:
#print(GSO_Bk.get_r(0,0))

3174.0


In [56]:
#GSO - РєРІР°РґСЂР°С‚ РЅРѕСЂРјС‹ РѕСЂС‚РѕРіРѕРЅР°Р»РёР·Р°С†РёРё
GSO_Bk_new=GSO.Mat(Bk, float_type='mpfr')

H=H_delta(beta).n()
H_exeremental=1

for t in range(1, Bk.ncols):
    tmp=sqrt(GSO_Bk.get_r(t-1,t-1)/GSO_Bk.get_r(t,t))
    H_exeremental*=tmp
    print("Hermit: ", tmp , "ratio: ", tmp^2/H)
H_exeremental=H_exeremental^(1/(Bk.ncols-1))

Hermit:  1.126937791581575 ratio:  1.25499543096611
Hermit:  1.0184072304901113 ratio:  1.02490876359841
Hermit:  1.0002898418796093 ratio:  0.988767027405618
Hermit:  0.8728184237773522 ratio:  0.752818127871812
Hermit:  1.1365947519386599 ratio:  1.27659621177334
Hermit:  1.0195349280801231 ratio:  1.02717981395682
Hermit:  1.0251685270285387 ratio:  1.03856286061576
Hermit:  0.9866375709089247 ratio:  0.961961203536871
Hermit:  1.0035904615626652 ratio:  0.995302989431407
Hermit:  0.982292250035598 ratio:  0.953506578584927
Hermit:  0.8860261151729321 ratio:  0.775774149541277
Hermit:  1.0978604431689518 ratio:  1.19106793549991
Hermit:  1.0322914209639542 ratio:  1.05304491338598
Hermit:  1.0461940526360685 ratio:  1.08160018326502
Hermit:  0.9991130108803101 ratio:  0.986441846939943
Hermit:  1.0275748338444164 ratio:  1.04344407537634
Hermit:  1.0000055190917014 ratio:  0.988205012213731
Hermit:  1.004017086045371 ratio:  0.996149372277789
Hermit:  1.0231850309176846 ratio:  1.03

In [57]:
#left=( log(H,2)*(-k*(3*k-1))+log(Bk[0].norm(),2)*k).n()
left=( log(H,1/2)*(-k*(3*k-1)/2)+log(Bk[0].norm(),2)*k/2).n()
print(left)

right=sum([log(GSO_Bk.get_r(i+k,i+k),2)/2 for i in range(k)]).n()
left_=( log(H_exeremental,2)*(-k*(3*k-1))+log(Bk[0].norm(),2)*k/2).n()
print(left_, right)

649.350343720210
1415.75259619613 1348.39555229282


In [58]:
H_exeremental, 1/H

(0.9826664508988533, 0.988194104315869)

In [59]:
sum([log(GSO_Bk.get_r(i,i),2)/2 for i in range(k)]).n()-log(Bk[0].norm(),2).n()*k

-62.6449351348290