In [19]:
from sage.all import *
from fpylll import IntegerMatrix

#file generation

def gen_small(s, n):
	"""
	s+1 entries of 1s and s entries of -1s
	"""
	deg = n
	coeff_vector = deg*[0]
	coeff_vector[deg-1] = 1
	coeff_vector[0] = 1
	index_set = set({0,deg-1})
	for i in range(s-2):
	# add 1's
		while True:
			index1 = ZZ.random_element(1,deg-1)
			if not index1 in index_set:
				coeff_vector[index1] = 1
				index_set = index_set.union({index1})
				break
	# add -1's
	for i in range(s):
		while True:
			index2 = ZZ.random_element(1,deg-1)
			if not index2 in index_set:
				coeff_vector[index2] = -1
				index_set = index_set.union({index2})
				break
	return coeff_vector



def print_ntru(q, h, variable_x, filename):
	n = len(list(h))
	f = open(filename, 'w')
	f.write(str(q)+'\n')
	#f.write('[')
	HMat = [0]*n
	for i in range(n):
		hvector = list(h* variable_x**i)
		HMat[i] = hvector
		f.write( str(hvector).replace(',','') +'\n')
		#f.write( str(hvector)+'\n')
	#f.write(']')
	f.close()

	return HMat

def all_rotations(g, variable_x, q):
	n = len(list(g))
	rotations = [0]*(2*n)
	i = 0
	while i < n:
		rotations[2*i] = list(g*variable_x**i)
		rotations[2*i+1] = [-rotations[2*i][j] for j in range(len(rotations[2*i]))]
		for j in range(len(rotations[2*i])):
			if rotations[2*i][j] == q-1:
				rotations[2*i][j] = -1
			if rotations[2*i+1][j] == q-1:
				rotations[2*i+1][j] = -1
		i +=1
	return rotations

def gen_ntru_challenge(n,q=0):

	K = CyclotomicField(2*n)

	P = Primes()
	if (q==0):
		q = next_prime(55*n)


	F = GF(q)
	Fx = PolynomialRing(F, 'x')
	Fx_qou = Fx.quotient(K.polynomial(), 'x')
	variable_x = Fx_qou.gen()

	sparsity = ceil(n/3.)
	f_poly = (gen_small(sparsity, n))
	g_poly = (gen_small(sparsity, n))
	h = Fx_qou(f_poly)/Fx_qou(g_poly)

	rotations = all_rotations(Fx_qou(f_poly),variable_x,q)

	#print('g*h', Fx_qou(g_poly)*h)

	filename = 'ntru_n_'+str(n)+'.txt'
	Hmat = print_ntru(q, h, variable_x, filename)
	Hmat = matrix(ZZ,[hrow for hrow in Hmat])

	#print(len(Hmat), "is Hmat")
	print(len(vector(ZZ,g_poly)))
	
	qvec = vector(ZZ,g_poly)*Hmat - vector(f_poly)
	assert(len(qvec) == n)
	#print("qvec:", qvec)
	qvec_red = [0]*int(n)
	for i in range(n):
		assert qvec_red[i] % q == 0
		qvec_red[i]  = -qvec[i] / q
	#print("qvec_red:", qvec_red)
	B = matrix(ZZ, 2*n, 2*n)

	for i in range(n):
		B[i,i] = 1
		for j in range(n):
			B[i,j+n] = Hmat[i, j]
		B[i+n, i+n] = q
	#print("B:")
	#print(B)
	f_check = vector(list(g_poly) + list(qvec_red))*B
	#f_check = vector(ZZ, [f_check[i] for i in range(n)])
	#print(f_check, vector(f_poly))
	assert(f_check[:n]==vector(g_poly))
	assert(f_check[n:]==vector(f_poly))
	#print(norm(f_check))
	
	"""
	B = B.LLL()
	
	#print("B")
	#print(B)
	b0 = B[0]
	print('b0:', b0, norm(b0))

	print(Bred[0], norm(Bred[0]))

	for i in range(len(rotations)):
		if vector(b0) == vector(rotations[i]):
			print(i, rotations[i])
			break
	"""
	filename = 'ntru_n_'+str(n)+'_solution.txt'
	f = open(filename, 'w')
	f.write(str(list(f_poly)).replace(',','')+'\n')
	f.write(str(list(g_poly)).replace(',',''))
	f.close()

	return h, q
	
def ntru_plain_hybrid_basis(A, g, q, nsamples):
	"""
		Construct ntru lattice basis
	"""
	n = A.ncols
	ell = n - g
	print(n, ell, nsamples)
	B = IntegerMatrix((nsamples+ell), (ell+nsamples))
	Bg = IntegerMatrix(g, n)


	for i in range(ell):
		B[i,i] = 1
		for j in range(nsamples):
			B[i,j+n] = A[i, j]
	for i in range(nsamples):
		B[i+ell, i+ell] = q

	for i in range(g):
		for j in range(n):
			Bg[i,j] = A[i+ell, j]

	B = LLL.reduction(B)
	return B, Bg

n=128; q=next_prime(int(2^20*5.6))
h,q = gen_ntru_challenge(n,q)
print(h,q)

128
5548353*x^127 + 2884362*x^126 + 2303708*x^125 + 2897051*x^124 + 4491212*x^123 + 2473130*x^122 + 4961550*x^121 + 4501500*x^120 + 2238198*x^119 + 4579360*x^118 + 5630237*x^117 + 4185467*x^116 + 4884891*x^115 + 5701388*x^114 + 1377364*x^113 + 5216988*x^112 + 3142762*x^111 + 4244963*x^110 + 2885939*x^109 + 3793312*x^108 + 5765537*x^107 + 3257657*x^106 + 3261449*x^105 + 5448123*x^104 + 1671897*x^103 + 1365807*x^102 + 840350*x^101 + 2544333*x^100 + 252806*x^99 + 5611816*x^98 + 762126*x^97 + 3831895*x^96 + 2587222*x^95 + 1166572*x^94 + 4620179*x^93 + 2839810*x^92 + 3124008*x^91 + 2715307*x^90 + 4568662*x^89 + 3647587*x^88 + 452404*x^87 + 4313497*x^86 + 5775221*x^85 + 2125263*x^84 + 2339096*x^83 + 5382382*x^82 + 892100*x^81 + 3755459*x^80 + 4014838*x^79 + 5378389*x^78 + 2097353*x^77 + 1565925*x^76 + 4814216*x^75 + 3871545*x^74 + 4072907*x^73 + 5705776*x^72 + 2219330*x^71 + 2584710*x^70 + 5694114*x^69 + 5169218*x^68 + 5732411*x^67 + 202021*x^66 + 4061449*x^65 + 951578*x^64 + 2485867*x^63 + 

In [20]:
from fpylll import *
from fpylll.algorithms.bkz2 import BKZReduction as BKZ2
from sage.modules.free_module_integer import IntegerLattice

def H_delta(x):
	#эвристика beta-root Hermite factor
	tmp =( x/(2*pi*e)*(pi*x)^(1/x) )^(1/(2*(x-1)))
	return tmp

def H_delta(x,precision=144):
	#эвристика beta-root Hermite factor
	tmp =( x/(2*pi*e)*(pi*x)^(1/x) )^(1/(2*(x-1))).n(precision)
	return tmp

def left_side(q,n,k):
	return 1/(k*(3*k-1))*(k/2*log(k/2/pi/e,2)+k*log(q,2)-n*log(n/2,2))

def find_param(n, q, old=False,logs=False):
    min_beta=Infinity
    min_k = int(n/8)
    L=35; R=2*n
    
    while abs(L-R)>1:
        beta=int(L+(R-L)/2)
        
        if logs:
            print("now:",L,beta,R)
            
        delta_beta =log(H_delta(beta),2)
        
        flag_k_found=False
        if not old:
            l=max([beta,1/2*n]); r=2*n

            flag_k_found=False

            while abs(l-r)>1:
                k=int(l+(r-l)/2)
                tmp=left_side(q,n,k)

                #если мы нашли новый минимум k, то поиск продолжается не правее
                if tmp >= delta_beta:
                    flag_k_found=True
                    min_beta = beta
                    min_k = k
                    
                    if logs:
                        print(min_k, min_beta)
                        
                    r=k
                #иначе поиск происходит там, где tmp-delta_beta больше
                else:
                    left_=left_side(q,n,k-1)-delta_beta
                    right_=left_side(q,n,k+1)-delta_beta
                    if left_-right_<0:
                        l=k+1
                    else:
                        r=k
        
        else:
            for k in range(max(beta,min_k), 2*n, 10):
                tmp=left_side(q,n,k)
                if tmp >=delta_beta:
                    min_beta = beta
                    min_k = k
                    print( min_k, min_beta)
                    flag_k_found=True
                    break
        #если для данной бэта есть решение, то ищем новое не правее
        if flag_k_found:
            R=beta
        #иначе ищем решение правее
        else:
            L=beta+1
    return min_k, min_beta

def CenteredMod(a, q):
    a = a.mod(q)
    if a <= floor(q/2):
        return a
    else:
        return a-q
    
def read_ntru_from_file(filename):

    data = open(filename, "r").readlines()
    q = ZZ(data[0])
    H = eval(",".join([s_.replace('\n','').replace(" ", ", ") for s_ in data[1 :]]))
    
    n=len(H)
    H_=matrix(n)
    #print(H_)
    
    Zq = ZZ.quo(q)
    
    for t0 in range(0, n):
        for t1 in range(0, n):
            tmp=Zq(copy(H[t0][t1]))
            H_[t0,t1]=tmp
    
    return n, q, H_

path = 'ntru_n_'+str(n)+'.txt'
n, q, H = read_ntru_from_file(path)
print("n=", n, "q=",q)

k, beta = find_param(n,q)


print("k=", k, "beta=",beta)

import time
t=time.perf_counter()

H = IntegerMatrix.from_matrix(H)

print(time.perf_counter()-t, "sec taken")
C_Id = H
C_Id = Matrix(ZZ, [ [ CenteredMod(e, q) for e in C_Id[i] ] for i in range(n) ] )

print("Building basis of NTRU lattice...")
t=time.perf_counter()
B = block_matrix([ [ q*identity_matrix(n), zero_matrix(n) ], [ C_Id, identity_matrix(n) ] ] )
B = IntegerMatrix.from_matrix(B)
print(time.perf_counter()-t, "sec taken")

n= 128 q= 5872033
k= 65 beta= 36
3.3385662999999113 sec taken
Building basis of NTRU lattice...
13.353306099999827 sec taken


In [None]:
### PREPROC PHASE FOR FPYLLL ###
k=120; beta=48

print("Building submatrix of size 2*k =", 2*k, "...")
TIME = time.time()
Bk = B.submatrix(range(n-k,n+k), range(n-k,n+k)) # ~30s
FPLLL.set_precision(180) 
GSO_Bk = GSO.Mat(Bk, float_type='mpfr')
print("Preprocessing...")
Bk_BKZ = BKZ2(GSO_Bk) 
print("Done. time: ", (time.time() - TIME))   

### BKZ PHASE ###
flags = BKZ.AUTO_ABORT|BKZ.GH_BND|BKZ.VERBOSE#|BKZ.MAX_LOOPS
#beta = 9
par = BKZ.Param(block_size=beta, strategies=BKZ.DEFAULT_STRATEGY, flags=flags)
print("BKZ reduction with beta =", beta, "...")
TIME = time.time()
DONE = Bk_BKZ(par) #actual BKZ algorithm; updates Bk in place; ~15 hours
# if it fails because infinite loop in babai, set higher precision in FPLLL.set_precision() and rerun without restarting.
print("Done. time: ", (time.time() - TIME))

### WE WON ###
print( all(Bk[i].norm() < 2^24 for i in range(k)) )

Building submatrix of size 2*k = 240 ...


In [22]:
def write_basis(B, q, filename):
    f = open(filename, 'w')
    f.write(str(q)+'\n')
    for t in B:
        tmp=str(t).replace(',','') +'\n'
        f.write( tmp )
    f.close()

write_basis(Bk,q,'BKZ_n'+str(n)+'.txt')

In [36]:
H=H_delta(beta).n()
for t in range(1, Bk.ncols):
    tmp=sqrt(Bk[t].norm()/Bk[t-1].norm())
    print("Hermit: ", tmp , "ratio: ", tmp/H)

Hermit:  1.0108492551522574 ratio:  0.998332694940375
Hermit:  1.0059811278930328 ratio:  0.993524845915169
Hermit:  1.0474821503139071 ratio:  1.03451199344981
Hermit:  0.9798832060483738 ratio:  0.967750074340947
Hermit:  0.9885216932386113 ratio:  0.976281597862264
Hermit:  1.0971499617506488 ratio:  1.08356480700307
Hermit:  0.9429164278515607 ratio:  0.931241027010315
Hermit:  0.951921630587547 ratio:  0.940134725323965
Hermit:  1.0535530971069913 ratio:  1.04050776843000
Hermit:  1.0670304529166326 ratio:  1.05381824462368
Hermit:  0.9585312497308951 ratio:  0.946662502693613
Hermit:  1.032675905030071 ratio:  1.01988908238684
Hermit:  0.9611115368424171 ratio:  0.949210840116464
Hermit:  1.0357665532368052 ratio:  1.02294146150036
Hermit:  1.0014690799115078 ratio:  0.989068667114881
Hermit:  1.048214585983902 ratio:  1.03523535993851
Hermit:  0.9422441403293548 ratio:  0.930577063901674
Hermit:  1.0603512026038187 ratio:  1.04722169827319
Hermit:  0.9631306889690524 ratio:  0.9

In [37]:
left=H^(-k*(3*k-1)) * Bk[0].norm()^k
left

9.68676861254463e-24

In [38]:
right=product([Bk[i].norm() for i in range(k)])
right

9.697855910196352e+224

In [41]:
print(GSO_Bk.get_r(0,0))

3174.0


In [55]:
#GSO - квадрат нормы ортогонализации

H=H_delta(beta).n()
for t in range(1, Bk.ncols):
    tmp=sqrt(GSO_Bk.get_r(t,t)/GSO_Bk.get_r(t-1,t-1))
    print("Hermit: ", tmp , "ratio: ", tmp/H)

Hermit:  1.0132119410301212 ratio:  1.00066612551642
Hermit:  0.9398698051468325 ratio:  0.928232128265250
Hermit:  0.9882019107354859 ratio:  0.975965774977188
Hermit:  0.9389615925992886 ratio:  0.927335161407386
Hermit:  1.0140584759622078 ratio:  1.00150217846477
Hermit:  0.9850793205232008 ratio:  0.972881849371133
Hermit:  1.0072688891701274 ratio:  0.994796661845832
Hermit:  0.9661134394411576 ratio:  0.954150808045191
Hermit:  0.9953910855292505 ratio:  0.983065931810340
Hermit:  0.9574961540481439 ratio:  0.945640223795730
Hermit:  1.0277371437519007 ratio:  1.01501147394882
Hermit:  0.9685898424434988 ratio:  0.956596547674997
Hermit:  0.9956278413123016 ratio:  0.983299756030649
Hermit:  0.9544333140240092 ratio:  0.942615308537713
Hermit:  0.9667666335487715 ratio:  0.954795914158144
Hermit:  0.9660129671090244 ratio:  0.954051579783811
Hermit:  1.0614628848265304 ratio:  1.04831961539947
Hermit:  0.8997504618396927 ratio:  0.888609551586374
Hermit:  0.9865342018377384 rati

In [70]:
left=( log(H,2)*(-k*(3*k-1))+log(Bk[0].norm(),2)*k).n()
print(left)

right=sum([log(GSO_Bk.get_r(i+k,i+k),2)/2 for i in range(k)]).n()
print(right)

-76.4502587968110
2259.03821715398


In [54]:
sqrt(GSO_Bk.get_r(0,0)), Bk[0].norm()

(56.3382640840131, 56.3382640840131)