## Helper Functions

In [1]:
import hashlib

In [2]:
def balanced_remainder(a, b):
    """
    Helper function to get a remainder in the range (-b/2, b/2]
    """
    a = a % b
    if a > b // 2:
        return a - b
    else:
        return a

print("done")

done


In [3]:
def extra_secure_poly_generator(d, N, small_modulo):
    """
    Generate key with larger range not just -1, 0 ,1
    """
    coefficients = [0 for _ in range(N)]
    counter = 0
    while counter <= d:
        rand_index = floor(N*random())
        if coefficients[rand_index] == 0:
            coefficients[rand_index] = ZZ.random_element(small_modulo-1)+1
            counter+=1

    counter=0

    while counter < d:
        rand_index = floor(N*random())
        if coefficients[rand_index] == 0:
            coefficients[rand_index] = -(ZZ.random_element(small_modulo-1)+1)
            counter+=1

    return vector(ZZ, coefficients)

print("done")

done


In [4]:
def rand_poly_generator(d, N):
    coefficients = [0 for _ in range(N)]
    counter = 0
    while counter <= d:
        rand_index = floor(N*random())
        if coefficients[rand_index] == 0:
            coefficients[rand_index] = 1
            counter+=1

    counter=0

    while counter < d:
        rand_index = floor(N*random())
        if coefficients[rand_index] == 0:
            coefficients[rand_index] = -1
            counter+=1

    return vector(ZZ, coefficients)

print("done")

done


## Key Generation

In [5]:
def generate_keypair(small_modulo, big_modulo, N, d):
    """
    Generates an NTRU public and private key.
    """
    while True:
        f = extra_secure_poly_generator(d, N, small_modulo)
        f += vector(ZZ, [1] + [0 for _ in range(N-1)])

        # Check for invertibility
        circular_f = matrix.circulant(f)
        cir_f_det = circular_f.determinant()

        if (cir_f_det.gcd(big_modulo) == 1) and (cir_f_det.gcd(small_modulo) == 1):
                break

    g = rand_poly_generator(d, N)
    
    cir_f_big_mod = matrix.circulant(vector(Zmod(big_modulo),list(f)))
    public_key = g*cir_f_big_mod^-1
    public_key = vector(ZZ,list(public_key))

    cir_f_small_mod = matrix.circulant(vector(Zmod(small_modulo), list(f)))
    f_inv_vector = vector(Zmod(small_modulo), [1] + [0]*(N-1)) * cir_f_small_mod.inverse()
    private_key = (f,f_inv_vector)
    return private_key, public_key
    
print("done")

done


## Public Key Encryption (PKE)

In [6]:
def encrypt_PKE(message, public_key, small_modulus, big_modulus, N, d):
    """
    Encrypts message using public key
    """
    blinding_poly = rand_poly_generator(d, N)

    aux = small_modulus * blinding_poly * matrix.circulant(public_key)
    aux = vector([element.mod(big_modulus) for element in list(aux)])

    ciphertext = message + aux

    ciphertext = vector([element.mod(big_modulus) for element in list(ciphertext)])
    
    return ciphertext

print("done")

done


In [7]:
def decrypt_PKE(ciphertext, private_key, small_modulus, big_modulus, N):
    """
    Decrypts a ciphertext.
    """
    f, f_inv_vector = private_key
    
    aux = ciphertext * matrix.circulant(f)
    aux = vector([element.mod(big_modulus) for element in list(aux)])
    
    aux = vector([balanced_remainder(element, big_modulus) for element in list(aux)])
    
    aux = vector([element.mod(small_modulus) for element in list(aux)])
    aux = vector([balanced_remainder(element, small_modulus) for element in list(aux)])
    
    f_vector = matrix.circulant(f_inv_vector)
    message_recovered = aux * f_vector
    
    message_recovered = vector(ZZ, list(message_recovered))
    message_recovered = vector([balanced_remainder(element, small_modulus) for element in list(message_recovered)])
    
    return message_recovered

print("done")

done


## Key Encapsulation Mechanism (KEM)

In [8]:
def encapsulation_KEM(public_key, small_modulus, big_modulus, N, d):
    """
    Generate message inside the function and then encrypt
    """
    message = rand_poly_generator(d, N)
    ciphertext = encrypt_PKE(message, public_key, small_modulus, big_modulus, N, d)
    
    chain_poly = ''
    for i in range(N):
        chain_poly += Integer(message[i]).binary()
        
    shared_secret = Integer('0x' + hashlib.sha256(chain_poly.encode('utf-8')).hexdigest()).binary()
    
    # 4. Return the ciphertext and the hashed secret
    return ciphertext, shared_secret

print("done")

done


In [9]:
def decapsulation_KEM(ciphertext, private_key, small_modulus, big_modulus, N):
    
    message_recovered = decrypt_PKE(ciphertext, private_key, small_modulus, big_modulus, N)
    
    chain_m = ''
    for i in range(N):
        chain_m += Integer(message_recovered[i]).binary()
        
    shared_secret_hash = hashlib.sha256(chain_m.encode('utf-8')).hexdigest()
    shared_secret_binary = Integer('0x' + shared_secret_hash).binary()
    
    return shared_secret_binary

print("done")

done


## Check

In [10]:
def cross_check(decrypted_message, plain_text, N):
    """
    Checks if decryption was successful.
    """
    # We must compare the coefficient lists up to N
    plain_coeffs = plain_text.list()
    plain_coeffs.extend([0] * (N - len(plain_coeffs)))
    
    dec_coeffs = decrypted_message.list()
    dec_coeffs.extend([0] * (N - len(dec_coeffs)))
    
    if plain_coeffs == dec_coeffs:
        print("✅ Successful!")
    else:
        print("❌ Error!!!")
        print(f"  Plaintext: {plain_text}")
        print(f"  Decrypted: {decrypted_message}")

print("done")

done


## Testing

In [11]:
# Standard NTRU parameters (must be co-prime)
N = 11
p = 3
q = 32 # Must be a power of 2
d = N // 3 # d=3, so 3 ones and 3 minus-ones
print(f"--- NTRU Parameters ---")
print(f"N = {N} (dimension)")
print(f"p = {p} (small modulus)")
print(f"q = {q} (large modulus)")
print(f"d = {d} (key 'smallness')\n")

--- NTRU Parameters ---
N = 11 (dimension)
p = 3 (small modulus)
q = 32 (large modulus)
d = 3 (key 'smallness')



In [12]:
# 1. Key Generation
print("Generating keys...")
private_key, public_key = generate_keypair(p, q, d, N)
print(f"Public Key (h): {public_key}")
print(f"Secret Key (f): {private_key[0]}\n")
print(f"Secret Key (g): {private_key[1]}\n")

Generating keys...


KeyboardInterrupt: 

### Testing PKE

In [13]:
# 2. Encryption
print("Encrypting...")
message = rand_poly_generator(d, N)
ciphertext = encrypt_PKE(message, public_key, p, q, N, d)
print("Decrypting...")
decrypted_message = decrypt_PKE(ciphertext, private_key, p, q, N)
print(f"msg = {message} (dimension)")
print(f"dec_msg = {decrypted_message} (dimension)")

Encrypting...
Decrypting...
msg = (1, 1, 0, -1, 1, 0, 0, 1, 0, -1, -1) (dimension)
dec_msg = (1, 1, 0, -1, 1, 0, 0, 1, 0, -1, -1) (dimension)


In [14]:
# 4. Check
cross_check(decrypted_message, message, N)

✅ Successful!


### Testing KEM

In [15]:
print("Encapsulating...")
ciphertext, shared_secret_sent = encapsulation_KEM(public_key, p, q, N, d)
print("Decapsulating...")
shared_secret_received = decapsulation_KEM(ciphertext, private_key, p, q, N)

if (shared_secret_sent == shared_secret_received):
    print("✅ Successful!")
else:
    print("❌ Error!!!")

Encapsulating...
Decapsulating...
✅ Successful!


## LLL Attack

In [20]:
def construct_ntru_lattice_basis(public_key, N, big_modulus):
    """
    Constructs the 2N x 2N NTRU public lattice basis M_h.
    
    M_h = [ I_n  |  cir_pub_key ]
          [ 0_n  |  bigmod*I_n ]
    """

    cir_pub_key = matrix.circulant(public_key)
    I_n = matrix.identity(N)
    O_n = matrix.zero(N)
    bigmod_I_n = big_modulus * I_n

    M_h = block_matrix(ZZ, [
        [I_n, cir_pub_key],
        [O_n, bigmod_I_n]
    ])

    return M_h

print("done")

done


In [22]:
# Standard NTRU parameters (must be co-prime)
N = 11
p = 3
q = 64 # Must be a power of 2
d = N // 3 # d=3, so 3 ones and 3 minus-ones
print(f"--- NTRU Parameters ---")
print(f"N = {N} (dimension)")
print(f"p = {p} (small modulus)")
print(f"q = {q} (large modulus)")
print(f"d = {d} (key 'smallness')\n")

--- NTRU Parameters ---
N = 11 (dimension)
p = 3 (small modulus)
q = 64 (large modulus)
d = 3 (key 'smallness')



In [23]:
print("Generating keys...")
private_key, public_key = generate_keypair(p, q, d, N)

Generating keys...


In [24]:
print("Encrypting...")
message = rand_poly_generator(d, N)
ciphertext = encrypt_PKE(message, public_key, p, q, N, d)

Encrypting...


In [30]:
### Information available => ciphertext, public_key
attacker_message = run_attack_and_decrypt(ciphertext, public_key, p, q, N, d)

if attacker_message is not None:
    print("\n--- Attacker's FINAL RESULT ---")
    attack_success = cross_check(attacker_message, message, N)
    if attack_success:
        print("\n*** LLL ATTACK SUCCEEDED ***")
else:
    print("\n*** LLL ATTACK FAILED ***")

Constructing 2N x 2N lattic...
Performing LLL on 2Nx2N matrix
LLL Complete, extracting shortest vector...
Found f_attack = (4, 4, 1, 0, -4, 2, -1, -1, -1, -1, -4)
computing f_p_inv...
Successfully built attack key...
Attempting decryption...

--- Attacker's FINAL RESULT ---
❌ Error!!!
  Plaintext: (-1, 0, 0, -1, 1, 0, 1, 0, -1, 1, 1)
  Decrypted: (1, 0, 1, 1, 1, -1, -1, 0, -1, -1, 0)


In [21]:
def run_attack_and_decrypt(public_key, ciphertext, small_modulus, big_modulus, N, d):
    """
    Perform entire LLL attack
    """
    print("Constructing 2N x 2N lattic...")
    M_h = construct_ntru_lattice_basis(public_key, N, big_modulus)

    print("Performing LLL on 2Nx2N matrix")
    try:
        M_LLL = m_h.LLL(implementation="fplll")
    except:
        M_LLL = M_h.LLL()

    print("LLL Complete, extracting shortest vector...")
    b1 = M_LLL[0]

    f_attack = vector(ZZ, b1[0:N])

    if f_attack[0] == -1:
        f_attack = -f_attack

    print(f"Found f_attack = {f_attack}")

    print("computing f_p_inv...")
    try:
        cir_f_smallmod_attack = matrix.circulant(vector(Zmod(small_modulus), list(f_attack)))
        f_inv_vector_attack = vector(Zmod(small_modulus), [1] + [0]*(N-1)) * cir_f_smallmod_attack.inverse()

        attacker_private_key = (f_attack, f_inv_vector_attack)
        print("Successfully built attack key...")

        print("Attempting decryption...")
        decrypted_message = decrypt_PKE(ciphertext, attacker_private_key, small_modulus, big_modulus, N)
        return decrypted_message

    except Exception as e:
        print(f"Failed: f_attack not invertible: {e}")
        return None