In [1]:
from fpylll import *
from fpylll.algorithms.bkz2 import BKZReduction as BKZ2
from sage.modules.free_module_integer import IntegerLattice

def gen_small(s, n):
	"""
	s+1 entries of 1s and s entries of -1s
	"""
	deg = n
	coeff_vector = deg*[0]
	coeff_vector[deg-1] = 1
	coeff_vector[0] = 1
	index_set = set({0,deg-1})
	for i in range(s-2):
	# add 1's
		while True:
			index1 = ZZ.random_element(1,deg-1)
			if not index1 in index_set:
				coeff_vector[index1] = 1
				index_set = index_set.union({index1})
				break
	# add -1's
	for i in range(s):
		while True:
			index2 = ZZ.random_element(1,deg-1)
			if not index2 in index_set:
				coeff_vector[index2] = -1
				index_set = index_set.union({index2})
				break
	return coeff_vector



def print_ntru(q, h, variable_x, filename):
	n = len(list(h))
	f = open(filename, 'w')
	f.write(str(q)+'\n')
	#f.write('[')
	HMat = [0]*n
	for i in range(n):
		hvector = list(h* variable_x**i)
		HMat[i] = hvector
		f.write( str(hvector).replace(',','') +'\n')
		#f.write( str(hvector)+'\n')
	#f.write(']')
	f.close()

	return HMat

def all_rotations(g, variable_x, q):
	n = len(list(g))
	rotations = [0]*(2*n)
	i = 0
	while i < n:
		rotations[2*i] = list(g*variable_x**i)
		rotations[2*i+1] = [-rotations[2*i][j] for j in range(len(rotations[2*i]))]
		for j in range(len(rotations[2*i])):
			if rotations[2*i][j] == q-1:
				rotations[2*i][j] = -1
			if rotations[2*i+1][j] == q-1:
				rotations[2*i+1][j] = -1
		i +=1
	return rotations

def gen_ntru_challenge(n,q=0):

	K = CyclotomicField(2*n)

	P = Primes()
	if (q==0):
		q = next_prime(55*n)


	F = GF(q)
	Fx = PolynomialRing(F, 'x')
	Fx_qou = Fx.quotient(K.polynomial(), 'x')
	variable_x = Fx_qou.gen()

	sparsity = ceil(n/3.)
	f_poly = (gen_small(sparsity, n))
	g_poly = (gen_small(sparsity, n))
	h = Fx_qou(f_poly)/Fx_qou(g_poly)

	rotations = all_rotations(Fx_qou(f_poly),variable_x,q)

	#print('g*h', Fx_qou(g_poly)*h)

	filename = str(i)+'_ntru_n_'+str(n)+'_q_'+str(q)+'.txt'
	Hmat = print_ntru(q, h, variable_x, filename)
	Hmat = matrix(ZZ,[hrow for hrow in Hmat])

	#print(len(Hmat), "is Hmat")
	print(len(vector(ZZ,g_poly)))
	
	qvec = vector(ZZ,g_poly)*Hmat - vector(f_poly)
	assert(len(qvec) == n)
	#print("qvec:", qvec)
	qvec_red = [0]*int(n)
	for i in range(n):
		assert qvec_red[i] % q == 0
		qvec_red[i]  = -qvec[i] / q
	#print("qvec_red:", qvec_red)
	B = matrix(ZZ, 2*n, 2*n)

	for i in range(n):
		B[i,i] = 1
		for j in range(n):
			B[i,j+n] = Hmat[i, j]
		B[i+n, i+n] = q
	#print("B:")
	#print(B)
	f_check = vector(list(g_poly) + list(qvec_red))*B
	#f_check = vector(ZZ, [f_check[i] for i in range(n)])
	#print(f_check, vector(f_poly))
	assert(f_check[:n]==vector(g_poly))
	assert(f_check[n:]==vector(f_poly))
	#print(norm(f_check))
	
	"""
	B = B.LLL()
	
	#print("B")
	#print(B)
	b0 = B[0]
	print('b0:', b0, norm(b0))

	print(Bred[0], norm(Bred[0]))

	for i in range(len(rotations)):
		if vector(b0) == vector(rotations[i]):
			print(i, rotations[i])
			break
	"""
	filename = 'ntru_n_'+str(n)+'_solution.txt'
	f = open(filename, 'w')
	f.write(str(list(f_poly)).replace(',','')+'\n')
	f.write(str(list(g_poly)).replace(',',''))
	f.close()

	return h, q
	
def ntru_plain_hybrid_basis(A, g, q, nsamples):
	"""
		Construct ntru lattice basis
	"""
	n = A.ncols
	ell = n - g
	print(n, ell, nsamples)
	B = IntegerMatrix((nsamples+ell), (ell+nsamples))
	Bg = IntegerMatrix(g, n)


	for i in range(ell):
		B[i,i] = 1
		for j in range(nsamples):
			B[i,j+n] = A[i, j]
	for i in range(nsamples):
		B[i+ell, i+ell] = q

	for i in range(g):
		for j in range(n):
			Bg[i,j] = A[i+ell, j]

	B = LLL.reduction(B)
	return B, Bg

def CenteredMod(a, q):
    a = a.mod(q)
    if a <= floor(q/2):
        return a
    else:
        return a-q
    
def read_ntru_from_file(filename):

    data = open(filename, "r").readlines()
    q = ZZ(data[0])
    H = eval(",".join([s_.replace('\n','').replace(" ", ", ") for s_ in data[1 :]]))
    
    n=len(H)
    H_=matrix(n)
    #print(H_)
    
    Zq = ZZ.quo(q)
    
    for t0 in range(0, n):
        for t1 in range(0, n):
            tmp=Zq(copy(H[t0][t1]))
            H_[t0,t1]=tmp
    
    return n, q, H_

def write_basis(B, q, filename):
    f = open(filename, 'w')
    f.write(str(q)+'\n')
    for t in B:
        tmp=str(t).replace(',','') +'\n'
        f.write( tmp )
    f.close()

def NTRU_Tests(n, log_q_min, log_q_max, ticks):
    if ticks>1:
        delta=(log_q_max-log_q_min)/(ticks-1)
    else:
        delta=0
    
    for i in range(ticks):
        q=next_prime(int(2^(log_q_max-delta*i)))
        h,q = gen_ntru_challenge(n,q)
        print("h, q =",h,q)
        
        path=str(i)+'_ntru_n_'+str(n)+'_q_'+str(q)+'.txt'
        n, q, H = read_ntru_from_file(path)
        print("n=", n, "q=",q)

        k, beta = find_param(n,q)
        print("k=", k, "beta=",beta)

        import time
        t=time.perf_counter()

        H = IntegerMatrix.from_matrix(H)

        print(time.perf_counter()-t, "sec taken")
        C_Id = H
        C_Id = Matrix(ZZ, [ [ CenteredMod(e, q) for e in C_Id[i] ] for i in range(n) ] )

        print("Building basis of NTRU lattice...")
        t=time.perf_counter()
        B = block_matrix([ [ q*identity_matrix(n), zero_matrix(n) ], [ C_Id, identity_matrix(n) ] ] )
        B = IntegerMatrix.from_matrix(B)
        print(time.perf_counter()-t, "sec taken")
        
        k=max(k,n-2)
        
        print("Building submatrix of size 2*k =", 2*k, "...")
        TIME = time.time()
        Bk = B.submatrix(range(n-k,n+k), range(n-k,n+k)) # ~30s
        FPLLL.set_precision(180) 
        GSO_Bk = GSO.Mat(Bk, float_type='mpfr')
        print("Preprocessing...")
        Bk_BKZ = BKZ2(GSO_Bk) 
        print("Done. time: ", (time.time() - TIME))   

        ### BKZ PHASE ###
        flags = BKZ.AUTO_ABORT|BKZ.MAX_LOOPS|BKZ.GH_BND|BKZ.VERBOSE
        #beta = 9
        par = BKZ.Param(block_size=beta, strategies=BKZ.DEFAULT_STRATEGY, flags=flags, max_loops=10)
        print("BKZ reduction with beta =", beta, "...")
        TIME = time.time()
        DONE = Bk_BKZ(par) #actual BKZ algorithm; updates Bk in place; ~15 hours
        # if it fails because infinite loop in babai, set higher precision in FPLLL.set_precision() and rerun without restarting.
        print("Done. time: ", (time.time() - TIME))

        ### WE WON ###
        print( all(Bk[i].norm() < 2^24 for i in range(k)) )
        
        write_basis(Bk,q,str(i)+'_BKZ_n'+str(n)+'_q_' + str(q) + '.txt')
        print(i, "of ", ticks, "steps done...", (i/ticks*100).n(), "% done")

In [None]:
NTRU_Tests(128, 16777421, 4295010253, ticks=)