In [166]:
import random

In [109]:
import numpy as np

# Lattices

In [110]:
def sample_error(n):
    return np.random.randint(-2, 3, n)  # "small" error vector

def generate_key(n, q):
    return np.random.randint(0, q, n)

def encrypt(m, A, s, q):
    n = len(s)
    e = sample_error(n)
    return (np.dot(A, s) + m + e) % q

def decrypt(c, A, s, q):
    n = len(s)
    approx = (c - np.dot(A, s)) % q
    # The decryption step should remove the error and recover the message
    # By rounding to the nearest multiple of (q // 2)
    return (np.round(approx / (q / 2)) % 2 * (q // 2)).astype(int)

In [111]:
n = 30  # Dimension
q = 101  # Modulus

In [112]:
s = generate_key(n, q)
s

array([66, 91, 90, 75, 60, 22, 63, 90, 31, 75, 73,  4, 89, 94, 88, 72,  2,
       33, 17, 60, 84, 77, 53, 92, 75, 70, 14,  3,  5, 67])

In [113]:
# Generate a public key A (this is usually more complex in practice)
A = np.random.randint(0, q, n)
A

array([63, 95, 88, 15, 73, 47, 86,  1, 90, 56, 21, 50, 69, 15, 54, 81, 59,
       13, 54, 30, 57, 29, 78, 67, 22, 88, 89, 12, 98, 91])

In [114]:
# Generate a random binary message and multiply by (q // 2)
m = np.random.randint(0, 2, n) * (q // 2)
m

array([50,  0,  0, 50,  0,  0,  0, 50, 50,  0,  0,  0,  0, 50, 50, 50, 50,
        0, 50,  0, 50,  0,  0, 50,  0, 50, 50, 50, 50, 50])

In [115]:
# this multiplication by q // 2
# "lifts" these bits into the modular arithmetic space defined by q and helps as to deal with this "noise"

# generally, in real worl; algorithm we should use some more advanced technique than just "mapping" these bits
# it looks rudimental but for educational purposes - it works

In [116]:
# Encrypt the message
ciphertext = encrypt(m, A, s, q)

In [117]:
# Decrypt the message
decrypted = decrypt(ciphertext, A, s, q)

In [118]:
# Output the original and decrypted messages
print("Original message:", m)
print("Decrypted message:", decrypted)

Original message: [50  0  0 50  0  0  0 50 50  0  0  0  0 50 50 50 50  0 50  0 50  0  0 50
  0 50 50 50 50 50]
Decrypted message: [50  0  0 50  0  0  0 50 50  0  0  0  0 50 50 50 50  0 50  0 50  0  0 50
  0 50 50 50 50 50]


In [119]:
# Convert back to bit representation for comparison
original_bits = m // (q // 2)
decrypted_bits = decrypted // (q // 2)
print("Original bits:", original_bits)
print("Decrypted bits:", decrypted_bits)

Original bits: [1 0 0 1 0 0 0 1 1 0 0 0 0 1 1 1 1 0 1 0 1 0 0 1 0 1 1 1 1 1]
Decrypted bits: [1 0 0 1 0 0 0 1 1 0 0 0 0 1 1 1 1 0 1 0 1 0 0 1 0 1 1 1 1 1]


# NTRU

In [216]:
from sympy.abc import x
from sympy import ZZ, Poly
import numpy as np
from sympy.polys.polyerrors import NotInvertible

In [239]:
N = 512
p = 2
q = 128

In [240]:
def random_poly(length, d, neg_ones_diff=0):
    return Poly(np.random.permutation(
        np.concatenate((np.zeros(length - 2 * d - neg_ones_diff), np.ones(d), -np.ones(d + neg_ones_diff)))), x).set_domain(ZZ)

def invert_poly(f_poly, R_poly, p):
    inv_poly = None
    if (p & (p - 1) == 0):  # Check for power of 2
        inv_poly = invert(f_poly, R_poly, domain=GF(2))
        e = int(np.log2(p))
        for _ in range(1, e):
            inv_poly = ((2 * inv_poly - f_poly * inv_poly ** 2) % R_poly).trunc(p)
    else:
        raise Exception("Cannot invert polynomial in Z_{}".format(p))
    return inv_poly

In [241]:
def encrypt(msg_poly, rand_poly, h_poly, q, R_poly):
    return (((rand_poly * h_poly).trunc(q) + msg_poly) % R_poly).trunc(q)

def decrypt(msg_poly, f_poly, f_p_poly, q, p, R_poly):
    a_poly = ((f_poly * msg_poly) % R_poly).trunc(q)
    b_poly = a_poly.trunc(p)
    return ((f_p_poly * b_poly) % R_poly).trunc(p)

In [242]:
R_poly = Poly(x ** N - 1, x).set_domain(ZZ)

# Key generation
g_poly = random_poly(N, int(np.sqrt(q)))
tries = 10
h_poly = None
while tries > 0 and (h_poly is None):
    f_poly = random_poly(N, N // 3, neg_ones_diff=-1)
    try:
        f_p_poly = invert_poly(f_poly, R_poly, p)
        f_q_poly = invert_poly(f_poly, R_poly, q)
        p_f_q_poly = (p * f_q_poly).trunc(q)
        h_before_mod = (p_f_q_poly * g_poly).trunc(q)
        h_poly = (h_before_mod % R_poly).trunc(q)
    except NotInvertible:
        tries -= 1

if h_poly is None:
    raise Exception("Couldn't generate invertible f")

In [243]:
msg_poly = random_poly(5, 1)  # for example purposes, low degree
print(f"Original message polynomial: {msg_poly}")

Original message polynomial: Poly(-x**4 + 1, x, domain='ZZ')


In [244]:
rand_poly = random_poly(N, N // 5)  # again, for example purposes
print(f"Randomizing polynomial: {rand_poly}")

Randomizing polynomial: Poly(-x**509 + x**508 - x**507 + x**506 + x**502 + x**501 - x**500 + x**494 - x**493 + x**492 - x**487 - x**480 - x**476 - x**475 + x**474 + x**473 - x**472 - x**471 + x**468 - x**467 + x**464 - x**462 + x**460 + x**459 - x**451 + x**450 - x**447 - x**445 + x**439 - x**437 - x**436 - x**435 + x**426 - x**425 + x**424 + x**423 + x**418 - x**416 - x**415 - x**411 - x**410 + x**407 + x**404 + x**402 - x**401 + x**399 - x**397 + x**395 - x**391 - x**389 - x**388 - x**384 + x**383 - x**380 - x**379 - x**377 - x**376 + x**374 - x**373 + x**370 - x**363 - x**362 + x**361 + x**360 - x**358 + x**354 - x**352 + x**351 + x**350 + x**347 + x**345 + x**342 - x**339 + x**337 + x**336 - x**330 + x**326 - x**318 + x**317 + x**316 + x**315 + x**314 + x**307 + x**304 + x**303 + x**299 - x**297 - x**296 - x**294 + x**284 - x**282 - x**281 - x**280 + x**277 + x**276 + x**274 + x**271 + x**269 - x**268 - x**252 + x**251 + x**250 + x**249 - x**245 - x**243 + x**238 - x**237 + x**235 

In [245]:
encrypted_poly = encrypt(msg_poly, rand_poly, h_poly, q, R_poly)
print(f"Encrypted message: {encrypted_poly}")

Encrypted message: Poly(6*x**511 - 46*x**510 + 64*x**509 + 4*x**508 + 22*x**507 + 16*x**506 - 34*x**505 - 14*x**504 + 16*x**503 + 40*x**502 + 8*x**501 + 46*x**500 + 36*x**499 - 36*x**498 + 16*x**497 - 24*x**496 - 62*x**495 - 56*x**494 + 14*x**493 + 26*x**492 - 10*x**491 - 42*x**490 - 42*x**489 - 18*x**488 - 34*x**487 - 10*x**486 + 14*x**485 + 28*x**484 + 22*x**483 + 64*x**482 + 2*x**481 + 38*x**480 + 4*x**479 - 42*x**478 + 6*x**477 - 52*x**476 - 36*x**475 - 8*x**474 - 50*x**473 - 14*x**472 - 2*x**471 - 28*x**470 + 62*x**469 + 18*x**468 + 48*x**467 + 24*x**466 + 60*x**465 + 62*x**464 - 50*x**463 + 16*x**462 - 4*x**461 - 22*x**460 + 10*x**459 - 28*x**458 - 8*x**457 + 52*x**456 + 6*x**455 + 18*x**454 + 54*x**453 - 34*x**452 + 46*x**451 - 14*x**450 - 36*x**449 - 28*x**448 - 22*x**447 + 62*x**446 + 18*x**445 + 14*x**444 - 48*x**443 + 44*x**442 + 8*x**441 + 54*x**440 - 20*x**439 - 34*x**438 + 4*x**437 - 28*x**436 + 60*x**435 + 64*x**434 + 22*x**433 - 18*x**432 - 40*x**431 - 26*x**430 - 20*x*

In [246]:
decrypted_poly = decrypt(encrypted_poly, f_poly, f_p_poly, q, p, R_poly)
print(f"Decrypted message: {decrypted_poly}")

Decrypted message: Poly(x**4 + 1, x, domain='ZZ')


In [247]:
if decrypted_poly == msg_poly:
    print("Decryption successful, message retrieved!")
else:
    print("Decryption failed, something went wrong!")

Decryption failed, something went wrong!


In [248]:
# ToDo:
# Encrypted/Decrypted always different by 1 sign, like x - 1 