In [1]:
# CRYSTALS-KYBER PKE Step-by-Step Implementation - Simple (without the CCA transform, NTT, and symmetric primitives)
#
# Thesis: CRYSTALS-KYBER: Post-Quantum Secure Encryption
# By Stefanos Anastasiades
#
# School of Informatics
# Aristotle University of Thessaloniki
# Supervisor: Konstantinos Draziotis
# February 2025

In [2]:
# 1. Parameter Setup
#
# KYBER512 Parameters:
# - Polynomial Degree n = 256
# - Modulus q = 3329
# - Module Rank k = 2
# - Error Parameters η1 = 3, η2 = 2
#
# KYBER768 Parameters:
# - Polynomial Degree n = 256
# - Modulus q = 3329
# - Module Rank k = 3
# - Error Parameters η1 = 2, η2 = 2
#
# KYBER1024 Parameters:
# - Polynomial Degree n = 256
# - Modulus q = 3329
# - Module Rank k = 4
# - Error Parameters η1 = 2, η2 = 2

n = 256
q = 3329
k = 2
eta1 = 3
eta2 = 2

print(f"KYBER Parameters:\n- Polynomial Degree: n = {n}\n- Modulus: q = {q}\n- Module Rank: k = {k}\n- Error Parameters: eta1 = {eta1}, eta2 = {eta2}")

KYBER Parameters:
- Polynomial Degree: n = 256
- Modulus: q = 3329
- Module Rank: k = 2
- Error Parameters: eta1 = 3, eta2 = 2


In [3]:
# 2. Polynomial Ring Setup
#
# We work in the polynomial ring Rq = Zq[x]/[x^n + 1]

Zq = IntegerModRing(q)
R.<x> = PolynomialRing(Zq)
QR.<xbar> = R.quotient(x^n + 1)  # Quotient ring

print("Sample polynomial:", QR.random_element())

Sample polynomial: 3317*xbar^255 + 2736*xbar^254 + 1938*xbar^253 + 2227*xbar^252 + 1000*xbar^251 + 448*xbar^250 + 1653*xbar^249 + 1324*xbar^248 + 305*xbar^247 + 1099*xbar^246 + 402*xbar^245 + 1810*xbar^244 + 2842*xbar^243 + 3283*xbar^242 + 2328*xbar^241 + 677*xbar^240 + 3039*xbar^239 + 373*xbar^238 + 2768*xbar^237 + 709*xbar^236 + 2128*xbar^235 + 3115*xbar^234 + 1417*xbar^233 + 877*xbar^232 + 628*xbar^231 + 2714*xbar^230 + 2602*xbar^229 + 94*xbar^228 + 3223*xbar^227 + 746*xbar^226 + 3144*xbar^225 + 1513*xbar^224 + 1023*xbar^223 + 95*xbar^222 + 307*xbar^221 + 395*xbar^220 + 2616*xbar^219 + 2386*xbar^218 + 261*xbar^217 + 389*xbar^216 + 191*xbar^215 + 1390*xbar^214 + 2765*xbar^213 + 2463*xbar^212 + 2357*xbar^211 + 2120*xbar^210 + 2595*xbar^209 + 564*xbar^208 + 3103*xbar^207 + 1574*xbar^206 + 2727*xbar^205 + 836*xbar^204 + 606*xbar^203 + 780*xbar^202 + 2186*xbar^201 + 1048*xbar^200 + 1313*xbar^199 + 2225*xbar^198 + 1521*xbar^197 + 1125*xbar^196 + 2491*xbar^195 + 137*xbar^194 + 2177*xbar^19

In [4]:
# 3. Center Binomial Sampling (CBD)
#
# Used for error generation in KYBER

# Sample polynomial coefficients from centered binomial distribution (CBD)
def sample_binomial(eta):
    coefficients = []
    for _ in range(n):
        a = sum(randint(0, 1) for _ in range(eta))
        b = sum(randint(0, 1) for _ in range(eta))
        coefficient = Zq(a - b)
        coefficients.append(coefficient)
    return QR(coefficients)

# Test sampling
test_sampling = sample_binomial(eta1)
print("Sampled polynomial coefficients:", test_sampling.list())

Sampled polynomial coefficients: [0, 0, 3328, 0, 1, 1, 1, 2, 0, 1, 3328, 1, 3327, 2, 0, 1, 1, 0, 0, 1, 0, 3327, 3328, 3328, 0, 1, 1, 3328, 3328, 3328, 2, 2, 1, 1, 2, 0, 2, 0, 1, 3328, 1, 3328, 3327, 0, 0, 0, 3327, 3327, 3328, 0, 1, 0, 3326, 1, 1, 1, 1, 3328, 0, 2, 3328, 1, 3327, 3328, 3328, 3327, 2, 3328, 3328, 0, 3327, 0, 3328, 0, 1, 3326, 2, 1, 0, 1, 0, 0, 3328, 2, 0, 3328, 3328, 0, 1, 1, 3328, 0, 1, 0, 0, 0, 0, 1, 1, 0, 3328, 2, 3326, 0, 1, 1, 3, 3328, 1, 3327, 1, 2, 1, 3328, 1, 3328, 0, 3327, 0, 2, 0, 1, 2, 0, 3328, 0, 1, 0, 0, 1, 2, 3327, 0, 3327, 1, 0, 3328, 3328, 0, 3328, 0, 3, 1, 2, 0, 0, 0, 3327, 3327, 3328, 3328, 2, 0, 0, 1, 1, 3328, 3328, 2, 2, 3, 2, 3327, 3328, 3328, 3, 1, 0, 0, 3328, 3328, 0, 1, 3327, 0, 3328, 3328, 0, 1, 0, 2, 0, 1, 0, 3328, 3328, 3327, 3328, 1, 0, 3327, 0, 0, 3326, 2, 2, 0, 2, 1, 2, 1, 0, 3328, 1, 3328, 0, 3328, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 3328, 0, 3, 3328, 3328, 3327, 0, 0, 0, 1, 1, 3327, 1, 1, 2, 1, 3328, 0, 1, 3328, 0, 1, 1, 2, 1, 0, 3327, 0, 1, 332

In [5]:
# 4. Message Encoding/Decoding
#
# - Encode bits to polynomial using q/2 scaling
# - Decode using threshold rounding

# Encode 256-bit message to polynomial
def encode_message(bits):
    return QR([Zq(b * (q // 2)) for b in bits])

# Decode polynomial to 256-bit message
def decode_message(polynomial):
    half_q = q // 2
    threshold = q // 4
    bits = []
    for coefficient in polynomial.list():
        c = Integer(coefficient)
        distance = min((c - half_q) % q, (half_q - c) % q)
        bits.append(1 if distance < threshold else 0)
    return bits

# Test encoding/decoding
test_bits = [0,1,0,1,1] + [0]*(n-5)
test_polynomial = encode_message(test_bits)
decoded_bits = decode_message(test_polynomial)
print("Original bits:", test_bits)
print("Decoded bits:", decoded_bits)

Original bits: [0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Decoded bits: [0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [6]:
# 5. Key Generation
#
# Generates:
# - Public key: (A, t)
# - Secret key: s

def keygen():
    # Generate random matrix A
    A = matrix(QR, k, k, [QR.random_element() for _ in range(k * k)])

    # Generate secret vector s and error vector e
    s = matrix(QR, k, 1, [sample_binomial(eta1) for _ in range(k)])
    e = matrix(QR, k, 1, [sample_binomial(eta1) for _ in range(k)])

    # Compute t = A*s + e
    t = A * s + e
    return (A, t), s

# Generate keys
pk, sk = keygen()
A, t = pk
s = sk

print("Public key matrix A dimensions:", A.dimensions())
print("Secret key s dimensions:", s.dimensions())
print("First element of t:", t[0,0].list(), "...")

Public key matrix A dimensions: (2, 2)
Secret key s dimensions: (2, 1)
First element of t: [1569, 1172, 2160, 1297, 1179, 1710, 2497, 1734, 3271, 1335, 3023, 2358, 369, 949, 653, 325, 1801, 1959, 2363, 84, 1751, 3317, 3201, 2251, 1989, 3017, 1871, 3003, 293, 2032, 3002, 428, 3257, 3216, 929, 1775, 1419, 1891, 3219, 1486, 915, 2232, 241, 2791, 791, 2687, 1583, 1237, 465, 1905, 423, 1828, 850, 1560, 1584, 2818, 2286, 2026, 417, 2933, 158, 1847, 1795, 2664, 2172, 188, 3033, 474, 352, 2842, 949, 2122, 1183, 174, 3205, 3085, 1270, 327, 2236, 1239, 2103, 880, 213, 99, 1169, 46, 1612, 1106, 488, 936, 1902, 1731, 2228, 1809, 273, 2522, 1174, 90, 394, 438, 1087, 1091, 1115, 2191, 2491, 502, 1784, 1859, 334, 1148, 266, 1548, 670, 2467, 45, 2093, 2797, 2731, 3321, 2744, 1955, 2820, 3181, 1759, 1947, 1623, 1550, 3112, 3220, 22, 1862, 1305, 1224, 1543, 297, 3148, 2905, 2011, 2422, 1265, 2073, 2819, 2209, 189, 2061, 2453, 2879, 1647, 1387, 1908, 257, 1311, 2051, 2121, 2183, 107, 678, 598, 887, 1151,

In [7]:
# 6. Encryption Process
#
# Encrypts message m into ciphertext (u, v)

def encrypt(pk, m_bits):
    A, t = pk
    m = encode_message(m_bits)

    r = matrix(QR, k, 1, [sample_binomial(eta2) for _ in range(k)])
    e1 = matrix(QR, k, 1, [sample_binomial(eta2) for _ in range(k)])
    e2 = sample_binomial(eta2)

    u = A.transpose() * r + e1
    v = (t.transpose() * r)[0,0] + e2 + m
    return (u, v)

# Test encryption
message = [randint(0, 1) for _ in range(n)]
ct = encrypt(pk, message)
u, v = ct

print("Ciphertext component u dimensions:", u.dimensions())
print("All coefficients of v:", v.list())

Ciphertext component u dimensions: (2, 1)
All coefficients of v: [349, 1370, 2692, 2708, 1452, 463, 1208, 282, 220, 1562, 2239, 3291, 1556, 2786, 331, 2794, 2543, 1424, 848, 1396, 732, 86, 3198, 1060, 3100, 1198, 1656, 2011, 1461, 564, 1049, 857, 2940, 2541, 154, 1616, 3060, 781, 1734, 690, 571, 253, 2728, 917, 2276, 3086, 2071, 147, 2622, 2341, 864, 3065, 1222, 939, 954, 3036, 1884, 1410, 460, 740, 874, 2689, 2874, 1854, 84, 1037, 2282, 647, 1659, 2961, 1684, 862, 547, 1259, 1009, 1771, 132, 587, 3124, 1851, 1552, 2036, 1129, 1346, 1354, 1810, 1978, 1280, 2910, 2964, 1689, 2869, 1587, 2781, 434, 391, 3044, 86, 792, 1630, 565, 1211, 1687, 1578, 930, 432, 1452, 1751, 65, 2581, 1455, 2896, 2987, 2644, 3205, 1173, 2199, 3039, 2607, 2698, 2311, 90, 2128, 2741, 2077, 1269, 3137, 1071, 2407, 1257, 502, 2494, 1434, 988, 2174, 1015, 3041, 731, 806, 3220, 955, 755, 923, 2474, 1371, 2001, 1568, 196, 1222, 2921, 850, 1047, 3291, 1589, 1229, 3098, 1161, 2912, 1066, 2479, 419, 1970, 1411, 902, 2463

In [8]:
# 7. Decryption Process
#
# Recovers message from ciphertext using secret key

def decrypt(sk, ct):
    s = sk
    u, v = ct
    m_prime = v - (s.transpose() * u)[0,0]
    return decode_message(m_prime)

# Test decryption
decrypted_bits = decrypt(sk, ct)

# Check correctness
errors = sum(1 for o,d in zip(message, decrypted_bits) if o != d)
print("Decryption errors:", errors)
print("Sample original bits:", message)
print("Sample decrypted bits:", decrypted_bits)

Decryption errors: 0
Sample original bits: [1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0]
Sample decrypted bits: [1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0,

In [9]:
# 8. Full Workflow Example
#
# Putting all components together

# Generate new key pair
pk, sk = keygen()

# Create random message
original_message = [randint(0,1) for _ in range(n)]
print("Original message:", original_message)

# Encrypt
ciphertext = encrypt(pk, original_message)

# Decrypt
received_message = decrypt(sk, ciphertext)
print("Received message:", received_message)

# Verification
correct = all(o == r for o,r in zip(original_message, received_message))
print("\nDecryption successful?", correct)

Original message: [0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1]
Received message: [0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0,