In [23]:
from sage.all import *
from fpylll import IntegerMatrix

#file generation

def gen_small(s, n):
	"""
	s+1 entries of 1s and s entries of -1s
	"""
	deg = n
	coeff_vector = deg*[0]
	coeff_vector[deg-1] = 1
	coeff_vector[0] = 1
	index_set = set({0,deg-1})
	for i in range(s-2):
	# add 1's
		while True:
			index1 = ZZ.random_element(1,deg-1)
			if not index1 in index_set:
				coeff_vector[index1] = 1
				index_set = index_set.union({index1})
				break
	# add -1's
	for i in range(s):
		while True:
			index2 = ZZ.random_element(1,deg-1)
			if not index2 in index_set:
				coeff_vector[index2] = -1
				index_set = index_set.union({index2})
				break
	return coeff_vector



def print_ntru(q, h, variable_x, filename):
	n = len(list(h))
	f = open(filename, 'w')
	f.write(str(q)+'\n')
	#f.write('[')
	HMat = [0]*n
	for i in range(n):
		hvector = list(h* variable_x**i)
		HMat[i] = hvector
		f.write( str(hvector).replace(',','') +'\n')
		#f.write( str(hvector)+'\n')
	#f.write(']')
	f.close()

	return HMat

def all_rotations(g, variable_x, q):
	n = len(list(g))
	rotations = [0]*(2*n)
	i = 0
	while i < n:
		rotations[2*i] = list(g*variable_x**i)
		rotations[2*i+1] = [-rotations[2*i][j] for j in range(len(rotations[2*i]))]
		for j in range(len(rotations[2*i])):
			if rotations[2*i][j] == q-1:
				rotations[2*i][j] = -1
			if rotations[2*i+1][j] == q-1:
				rotations[2*i+1][j] = -1
		i +=1
	return rotations

def gen_ntru_challenge(n,q=0):

	K = CyclotomicField(2*n)

	P = Primes()
	if (q==0):
		q = next_prime(55*n)


	F = GF(q)
	Fx = PolynomialRing(F, 'x')
	Fx_qou = Fx.quotient(K.polynomial(), 'x')
	variable_x = Fx_qou.gen()

	sparsity = ceil(n/3.)
	f_poly = (gen_small(sparsity, n))
	g_poly = (gen_small(sparsity, n))
	h = Fx_qou(f_poly)/Fx_qou(g_poly)

	rotations = all_rotations(Fx_qou(f_poly),variable_x,q)

	#print('g*h', Fx_qou(g_poly)*h)

	filename = 'ntru_n_'+str(n)+'.txt'
	Hmat = print_ntru(q, h, variable_x, filename)
	Hmat = matrix(ZZ,[hrow for hrow in Hmat])

	#print(len(Hmat), "is Hmat")
	print(len(vector(ZZ,g_poly)))
	
	qvec = vector(ZZ,g_poly)*Hmat - vector(f_poly)
	assert(len(qvec) == n)
	#print("qvec:", qvec)
	qvec_red = [0]*int(n)
	for i in range(n):
		assert qvec_red[i] % q == 0
		qvec_red[i]  = -qvec[i] / q
	#print("qvec_red:", qvec_red)
	B = matrix(ZZ, 2*n, 2*n)

	for i in range(n):
		B[i,i] = 1
		for j in range(n):
			B[i,j+n] = Hmat[i, j]
		B[i+n, i+n] = q
	#print("B:")
	#print(B)
	f_check = vector(list(g_poly) + list(qvec_red))*B
	#f_check = vector(ZZ, [f_check[i] for i in range(n)])
	#print(f_check, vector(f_poly))
	assert(f_check[:n]==vector(g_poly))
	assert(f_check[n:]==vector(f_poly))
	#print(norm(f_check))
	
	"""
	B = B.LLL()
	
	#print("B")
	#print(B)
	b0 = B[0]
	print('b0:', b0, norm(b0))

	print(Bred[0], norm(Bred[0]))

	for i in range(len(rotations)):
		if vector(b0) == vector(rotations[i]):
			print(i, rotations[i])
			break
	"""
	filename = 'ntru_n_'+str(n)+'_solution.txt'
	f = open(filename, 'w')
	f.write(str(list(f_poly)).replace(',','')+'\n')
	f.write(str(list(g_poly)).replace(',',''))
	f.close()

	return h, q
	
def ntru_plain_hybrid_basis(A, g, q, nsamples):
	"""
		Construct ntru lattice basis
	"""
	n = A.ncols
	ell = n - g
	print(n, ell, nsamples)
	B = IntegerMatrix((nsamples+ell), (ell+nsamples))
	Bg = IntegerMatrix(g, n)


	for i in range(ell):
		B[i,i] = 1
		for j in range(nsamples):
			B[i,j+n] = A[i, j]
	for i in range(nsamples):
		B[i+ell, i+ell] = q

	for i in range(g):
		for j in range(n):
			Bg[i,j] = A[i+ell, j]

	B = LLL.reduction(B)
	return B, Bg

n=128; q=2256116141
h,q = gen_ntru_challenge(n,q)
print(h,q)

128
1129777394*x^127 + 2169119461*x^126 + 1581545353*x^125 + 1105632046*x^124 + 5892970*x^123 + 2249749989*x^122 + 1286258488*x^121 + 2111351180*x^120 + 569151222*x^119 + 1600347811*x^118 + 1030463914*x^117 + 17900866*x^116 + 589869302*x^115 + 453519340*x^114 + 336382531*x^113 + 1884073484*x^112 + 1542039697*x^111 + 587735917*x^110 + 2086973598*x^109 + 2052939524*x^108 + 874132557*x^107 + 231569860*x^106 + 736008142*x^105 + 1821908646*x^104 + 1778693382*x^103 + 2224381278*x^102 + 782029731*x^101 + 175936435*x^100 + 2190195999*x^99 + 2097014365*x^98 + 1868927158*x^97 + 1750720587*x^96 + 811285197*x^95 + 932903915*x^94 + 1200562219*x^93 + 1718450763*x^92 + 947525053*x^91 + 758908820*x^90 + 581432777*x^89 + 2024092761*x^88 + 1336898485*x^87 + 369243439*x^86 + 507086225*x^85 + 551953368*x^84 + 820972755*x^83 + 1515094530*x^82 + 663171675*x^81 + 551703917*x^80 + 468519859*x^79 + 64293410*x^78 + 2161583510*x^77 + 854159762*x^76 + 2122495832*x^75 + 969007064*x^74 + 1111800624*x^73 + 187193918

In [None]:
from fpylll import *
from fpylll.algorithms.bkz2 import BKZReduction as BKZ2
from sage.modules.free_module_integer import IntegerLattice

def H_delta(x):
	#эвристика beta-root Hermite factor
	tmp =( x/(2*pi*e)*(pi*x)^(1/x) )^(1/(2*(x-1)))
	return tmp

def H_delta(x,precision=144):
	#эвристика beta-root Hermite factor
	tmp =( x/(2*pi*e)*(pi*x)^(1/x) )^(1/(2*(x-1))).n(precision)
	return tmp

def left_side(q,n,k):
	return 1/(k*(3*k-1))*(k/2*log(k/2/pi/e,2)+k*log(q,2)-n*log(n/2,2))

def find_param(n, q, old=False,logs=False):
    min_beta=Infinity
    min_k = int(n/8)
    L=35; R=2*n
    
    while abs(L-R)>1:
        beta=int(L+(R-L)/2)
        
        if logs:
            print("now:",L,beta,R)
            
        delta_beta =log(H_delta(beta),2)
        
        flag_k_found=False
        if not old:
            l=beta; r=2*n

            flag_k_found=False

            while abs(l-r)>1:
                k=int(l+(r-l)/2)
                tmp=left_side(q,n,k)

                #если мы нашли новый минимум k, то поиск продолжается не правее
                if tmp >= delta_beta:
                    flag_k_found=True
                    min_beta = beta
                    min_k = k
                    
                    if logs:
                        print(min_k, min_beta)
                        
                    r=k
                #иначе поиск происходит там, где tmp-delta_beta больше
                else:
                    left_=left_side(q,n,k-1)-delta_beta
                    right_=left_side(q,n,k+1)-delta_beta
                    if left_-right_<0:
                        l=k+1
                    else:
                        r=k
        
        else:
            for k in range(max(beta,min_k), 2*n, 10):
                tmp=left_side(q,n,k)
                if tmp >=delta_beta:
                    min_beta = beta
                    min_k = k
                    print( min_k, min_beta)
                    flag_k_found=True
                    break
        #если для данной бэта есть решение, то ищем новое не правее
        if flag_k_found:
            R=beta
        #иначе ищем решение правее
        else:
            L=beta+1
    return min_k, min_beta

def CenteredMod(a, q):
    a = a.mod(q)
    if a <= floor(q/2):
        return a
    else:
        return a-q
    
def read_ntru_from_file(filename):

    data = open(filename, "r").readlines()
    q = ZZ(data[0])
    H = eval(",".join([s_.replace('\n','').replace(" ", ", ") for s_ in data[1 :]]))
    
    n=len(H)
    H_=matrix(n)
    #print(H_)
    
    Zq = ZZ.quo(q)
    
    for t0 in range(0, n):
        for t1 in range(0, n):
            tmp=Zq(copy(H[t0][t1]))
            H_[t0,t1]=tmp
    
    return n, q, H_

path = 'ntru_n_'+str(n)+'.txt'
n, q, H = read_ntru_from_file(path)
print("n=", n, "q=",q)

k, beta = find_param(n,q)
print("k=", k, "beta=",beta)

import time
t=time.perf_counter()

H = IntegerMatrix.from_matrix(H)

print(time.perf_counter()-t, "sec taken")
C_Id = H
C_Id = Matrix(ZZ, [ [ CenteredMod(e, q) for e in C_Id[i] ] for i in range(n) ] )

print("Building basis of NTRU lattice...")
t=time.perf_counter()
B = block_matrix([ [ q*identity_matrix(n), zero_matrix(n) ], [ C_Id, identity_matrix(n) ] ] )
B = IntegerMatrix.from_matrix(B)
# print(B)
print(time.perf_counter()-t, "sec taken")

n= 128 q= 2256116141
k= 37 beta= 36
0.003351296000005277 sec taken
Building basis of NTRU lattice...
[  2256116141           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0           0          

See https://github.com/sagemath/sage/issues/15114 for details.
  k, beta = find_param(n,q)


In [31]:
import time
### PREPROC PHASE FOR FPYLLL ###
k=37; beta=36

print("Building submatrix of size 2*k =", 2*k, "...")
time1 = time.time()
Bk = B.submatrix(range(n-k,n+k), range(n-k,n+k)) # ~30s
FPLLL.set_precision(180) 
GSO_Bk = GSO.Mat(Bk, float_type='mpfr')
print("Preprocessing...")
Bk_BKZ = BKZ2(GSO_Bk) 
print("Done. time: ", (time.time() - time1))   
# print(Bk)
### BKZ PHASE ###
flags = BKZ.AUTO_ABORT|BKZ.MAX_LOOPS|BKZ.GH_BND|BKZ.VERBOSE
#beta = 9
par = BKZ.Param(block_size=beta, strategies=BKZ.DEFAULT_STRATEGY, flags=flags, max_loops=8)
print("BKZ reduction with beta =", beta, "...")
time1 = time.time()
DONE = Bk_BKZ(par) #actual BKZ algorithm; updates Bk in place; ~15 hours
# if it fails because infinite loop in babai, set higher precision in FPLLL.set_precision() and rerun without restarting.
print("Done. time: ", (time.time() - time1))
print(Bk)
### WE WON ###
print( all(Bk[i].norm() < 2^24 for i in range(k)) )

Building submatrix of size 2*k = 74 ...
Preprocessing...
Done. time:  0.003537893295288086
BKZ reduction with beta = 36 ...
Done. time:  6.220838785171509
[       89       0    -248    -179    -105     -102     -320      52      22      415     -218      77      171      310      64     -139     -296      400      255      250      159      23        7     -121     171      -64     -155      67      75       28      86      16      -5      152     408      293       1    -203        9      47     207    -506    -677    -397      12      69       0     -118     159    -118    -199    -420    -275     153      96    -119     -62     169    -177    -237    -374    -198    -198     167     378     -81      16    -104    -371     277     736    -210    -117     -95 ]
[       34     245     -87      61      66     -224      -79    -312     123     -343      104     -94      -37     -431     132      -74      452     -277      241     -196       32     175      107       -7    -218      311  

In [26]:
def write_basis(B, q, filename):
    f = open(filename, 'w')
    f.write(str(q)+'\n')
    for t in B:
        tmp=str(t).replace(',','') +'\n'
        f.write( tmp )
    f.close()

write_basis(Bk,q,'BKZ_n'+str(n)+'.txt')