In [59]:
# CRYSTALS-KYBER PKE Step-by-Step Implementation - Advanced
#
# Thesis: CRYSTALS-KYBER: Post-Quantum Secure Encryption
# By Stefanos Anastasiades
#
# School of Informatics
# Aristotle University of Thessaloniki
# Supervisor: Konstantinos Draziotis
# February 2025

In [60]:
import math
from os import urandom
from hashlib import sha3_256, sha3_512, shake_128, shake_256

from sage.rings.polynomial.polynomial_quotient_ring import PolynomialQuotientRing as polynomial
from sage.matrix.constructor import matrix
from sage.misc.prandom import randrange
from sage.rings.finite_rings.finite_field_constructor import FiniteField 
from sage.rings.polynomial.polynomial_ring_constructor import PolynomialRing
from sage.rings.polynomial.polynomial_quotient_ring import PolynomialQuotientRing as polynomial
from sage.rings.integer import Integer

# 1. Parameter Setup
#
# KYBER512 Parameters:
# - Polynomial Degree n = 256
# - Modulus q = 3329
# - Module Rank k = 2
#
# KYBER768 Parameters:
# - Polynomial Degree n = 256
# - Modulus q = 3329
# - Module Rank k = 3
#
# KYBER1024 Parameters:
# - Polynomial Degree n = 256
# - Modulus q = 3329
# - Module Rank k = 4

n = 256
q = 3329
q_bytes = 16
k = 2

In [61]:
# 2. Polynomial Ring Setup
#
# We work in the polynomial ring Rq = Zq[x]/[x^n + 1]

ZQ = PolynomialRing(FiniteField(q, 'x'), 'x', sparse=True)
x = ZQ.gen()
RQ = ZQ.quotient(x^n + 1)

print("Sample polynomial:", QR.random_element())

Sample polynomial: 135*xbar^255 + 2592*xbar^254 + 434*xbar^253 + 2253*xbar^252 + 1018*xbar^251 + 543*xbar^250 + 619*xbar^249 + 1676*xbar^248 + 533*xbar^247 + 1801*xbar^246 + 581*xbar^245 + 947*xbar^244 + 88*xbar^243 + 2928*xbar^242 + 901*xbar^241 + 649*xbar^240 + 3151*xbar^239 + 1022*xbar^238 + 652*xbar^237 + 2018*xbar^236 + 351*xbar^235 + 1304*xbar^234 + 2296*xbar^233 + 998*xbar^232 + 1782*xbar^231 + 1410*xbar^230 + 1988*xbar^229 + 2197*xbar^228 + 2831*xbar^227 + 143*xbar^226 + 2710*xbar^225 + 2802*xbar^224 + 1294*xbar^223 + 460*xbar^222 + 2556*xbar^221 + 2686*xbar^220 + 2506*xbar^219 + 2092*xbar^218 + 2035*xbar^217 + 264*xbar^216 + 336*xbar^215 + 2070*xbar^214 + 2411*xbar^213 + 781*xbar^212 + 1378*xbar^211 + 808*xbar^210 + 1110*xbar^209 + 2947*xbar^208 + 3326*xbar^207 + 2019*xbar^206 + 1473*xbar^205 + 2205*xbar^204 + 642*xbar^203 + 1619*xbar^202 + 587*xbar^201 + 542*xbar^200 + 2627*xbar^199 + 819*xbar^198 + 2688*xbar^197 + 2423*xbar^196 + 884*xbar^195 + 1868*xbar^194 + 885*xbar^193 +

In [62]:
# 3. Supporting Functions
#
# Functions that support the implementation

def generate_random_list(size, apply_cbd=False):
    random_values = [randrange(q) for _ in range(size)]
    return FCBD(random_values) if apply_cbd else random_values

def FCBD(values: list):
    transformed = []
    for value in values:
        transformed.append((value % 5) - 2)  # Map value to the range -2 to 2
    return transformed

def generate_uniform_polynomial(size):
    return RQ(generate_random_list(size))

def generate_cbd_polynomial(size) -> polynomial:
    return RQ(generate_random_list(size, apply_cbd=True))

def create_cbd_list(size) -> list:
    return generate_random_list(size, apply_cbd=True)

def calculate_bytes_needed(bit_count: int) -> int:
    return ((bit_count + 7) & (-8)) // 8

def generate_random_integer(bit_count: int) -> int:
    byte_count = calculate_bytes_needed(bit_count)
    random_bytes = urandom(byte_count)
    result = Integer(int.from_bytes(random_bytes, 'big'))
    result &= (2 ** n) - 1
    return result

def convert_polynomial_to_bytes(poly: polynomial) -> bytes:
    byte_sequence = b''
    coefficients = poly.coefficients()[0].list()
    for coeff in coefficients:
        coeff_int = int(coeff)
        byte_repr = coeff_int.to_bytes(q_bytes, 'big')
        byte_sequence += byte_repr
    return byte_sequence

def convert_bytes_to_bit_list(byte_data: bytes) -> list:
    bit_list = []
    for byte in byte_data:
        for i in range(8):
            bit = (byte >> i) & 1
            bit_list.append(bit)
    return bit_list

def compress_polynomial(poly: polynomial):
    half_q = math.ceil(q / 2)
    return poly * half_q

def decompress_polynomial(poly: polynomial) -> int:
    threshold_lower = q / 4
    threshold_upper = 3 * (q / 4)
    return [1 if threshold_upper > Integer(i) > threshold_lower else 0 for i in poly]

In [63]:
def CPAPKE_KeyGen() -> (matrix, matrix, matrix):
    # Debugging output
    print('===== Generating KPKE Key =====')

    # Initialize matrices
    A_matrix = [[[None] for _ in range(k)] for _ in range(k)]
    s_vector = [[None] for _ in range(k)]
    e_vector = [[None] for _ in range(k)]

    # Populate A as a k*k matrix of polynomials
    for i in range(k):
        for j in range(k):
            A_matrix[i][j] = generate_uniform_polynomial(n)
    A_matrix = matrix(A_matrix)

    print('A:', A_matrix)

    # Populate s as a k*1 matrix of polynomials
    for i in range(k):
        s_vector[i] = [generate_cbd_polynomial(n)]
    s_vector = matrix(s_vector)

    print('s:', s_vector)

    # Populate e as a k*1 matrix of polynomials
    for i in range(k):
        e_vector[i] = [generate_cbd_polynomial(n)]
    e_vector = matrix(e_vector)

    print('e:', e_vector)

    # Compute t = A * s + e
    t_vector = A_matrix * s_vector + e_vector

    print('t:', t_vector)

    return A_matrix, t_vector, s_vector

In [64]:
def CPAPKE_Enc(A: matrix, t: matrix, message: int, randomness: polynomial) -> (polynomial, polynomial):
    print('===== KPKE Encryption =====')

    # Initialize matrices and error vectors
    rr_matrix = [[None] for _ in range(k)]
    error1 = [[None] for _ in range(k)]
    error2 = [None] * n

    # Ensure message fits within n bits
    if len(message.bits()) > n:
        raise ValueError('Message has more bits than allowed!')
    message_bits = message.bits()

    print('Message bits:', message_bits)

    # Nonce for modifying randomness deterministically
    nonce = 0

    # Pad message to ensure it is exactly n bits long
    padding = [0] * (n - len(message_bits))
    padded_message = RQ(message_bits + padding)

    print('Padded message polynomial:', padded_message)

    # Compress the message
    compressed_message = compress_polynomial(padded_message)

    print('Compressed message:', compressed_message)

    # Generate rr, e1, e2
    # rr is a k*1 matrix of polynomials with n terms
    for i in range(k):
        temp_poly = [randomness[j] + nonce for j in range(n)]
        temp_poly = FCBD(temp_poly)
        temp_poly = RQ(temp_poly)
        rr_matrix[i] = [temp_poly]
        nonce += 1
    rr_matrix = matrix(rr_matrix)

    print('rr:', rr_matrix)

    # e1 is a k*1 matrix of polynomials with n terms
    for i in range(k):
        temp_poly = [randomness[j] + nonce for j in range(n)]
        temp_poly = FCBD(temp_poly)
        temp_poly = RQ(temp_poly)
        error1[i] = [temp_poly]
        nonce += 1
    error1 = matrix(error1)

    print('e1:', error1)

    # e2 is an n-length polynomial
    error2 = [randomness[i] + nonce for i in range(n)]
    error2 = FCBD(error2)
    error2 = RQ(error2)

    print('e2:', error2)

    # Compute ciphertext components
    u_component = A.transpose() * rr_matrix + error1
    v_component = t.transpose() * rr_matrix + error2 + compressed_message

    print('u:', u_component)
    print('v:', v_component)

    return u_component, v_component

In [65]:
def CPAPKE_Dec(cipher_u: matrix, cipher_v: matrix, secret_s: matrix) -> int:
    print('===== KPKE Decryption =====')

    # Compute a noisy intermediate result
    noisy_message = cipher_v - secret_s.transpose() * cipher_u
    noisy_message = noisy_message.coefficients()[0]

    print('Noisy recovered message:', noisy_message)

    # Convert to a list and reverse order
    noisy_list = noisy_message.list()
    noisy_list.reverse()

    # Decompress and remove noise
    recovered_bits = decompress_polynomial(noisy_list)

    print('Decompressed message:', list(reversed(recovered_bits)))

    # Convert the bit sequence to an integer
    recovered_message = int(''.join(map(str, recovered_bits)), 2)

    return recovered_message

In [66]:
# Alice generates a public-private key pair
public_A, public_t, private_s = CPAPKE_KeyGen()

# Alice shares her public key with Bob
# Bob selects a random message and encrypts it
message = generate_random_integer(n)
random_values = create_cbd_list(n)
cipher_u, cipher_v = CPAPKE_Enc(public_A, public_t, message, random_values)

# Bob sends the ciphertext to Alice
# Alice decrypts to recover the original message
recovered_message = CPAPKE_Dec(cipher_u, cipher_v, private_s)

print("Original message (Bob):", message)
print("Recovered message (Alice):", recovered_message)

if message != recovered_message:
    raise ValueError(
        "Mismatch detected! Decompression step may have failed. Consider increasing the prime q."
    )

===== Generating KPKE Key =====
A: [456*xbar^255 + 2431*xbar^254 + 1237*xbar^253 + 2384*xbar^252 + 1964*xbar^251 + 1846*xbar^250 + 2501*xbar^249 + 1522*xbar^248 + 2512*xbar^247 + 1085*xbar^246 + 48*xbar^245 + 1639*xbar^244 + 1032*xbar^243 + 1814*xbar^242 + 1794*xbar^241 + 3191*xbar^240 + 1554*xbar^239 + 1367*xbar^238 + 113*xbar^237 + 1007*xbar^236 + 2726*xbar^235 + 88*xbar^234 + 1564*xbar^233 + 3117*xbar^232 + 321*xbar^231 + 1221*xbar^230 + 2130*xbar^229 + 2990*xbar^228 + 3031*xbar^227 + 1310*xbar^226 + 3106*xbar^225 + 1277*xbar^224 + 1583*xbar^223 + 1117*xbar^222 + 3244*xbar^221 + 388*xbar^220 + 1996*xbar^219 + 792*xbar^218 + 911*xbar^217 + 2439*xbar^216 + 720*xbar^215 + 36*xbar^214 + 2417*xbar^213 + 2115*xbar^212 + 2741*xbar^211 + 1265*xbar^210 + 2945*xbar^209 + 2443*xbar^208 + 1947*xbar^207 + 991*xbar^206 + 1354*xbar^205 + 375*xbar^204 + 1000*xbar^203 + 3020*xbar^202 + 763*xbar^201 + 1625*xbar^200 + 146*xbar^199 + 1193*xbar^198 + 2788*xbar^197 + 613*xbar^196 + 2973*xbar^195 + 1669*x

In [67]:
def CCAKEM_KeyGen() -> ((polynomial, polynomial), (polynomial, polynomial, int)):
    print('===== MLKEM Key Generation =====')

    # Generate a random integer for additional security
    secret_random = generate_random_integer(n)

    # Generate Kyber public and private key components
    matrix_A, vector_t, secret_s = CPAPKE_KeyGen()

    # Public key consists of (A, t)
    public_key = (matrix_A, vector_t)

    # Hash t using SHA3-256 for verification
    hash_t = sha3_256(convert_polynomial_to_bytes(vector_t)).digest()

    print('SHA3-256 Hash of t:', hash_t.hex(), '\n')

    # Private key includes (s, public key, hash_t, secret_random)
    private_key = (secret_s, public_key, hash_t, secret_random)

    return public_key, private_key

In [68]:
def CCAKEM_Encaps(public_key: (polynomial, polynomial)) -> (bytes, bytes):
    print('===== MLKEM Encapsulation =====')

    # Generate a random message
    random_message = generate_random_integer(n)
    print('Random message:', random_message)

    # Extract public key components
    matrix_A, vector_t = public_key

    # Compute SHA3-256 hash of t
    hash_t = sha3_256(convert_polynomial_to_bytes(vector_t)).digest()
    print('SHA3-256 Hash of t:', hash_t.hex())

    # Generate K and r using SHA3-512
    combined_input = int(random_message).to_bytes(calculate_bytes_needed(n), 'big') + hash_t
    hash_kr = sha3_512(combined_input).digest()
    key_k, randomness_r = (hash_kr[:calculate_bytes_needed(n)], hash_kr[calculate_bytes_needed(n):])

    # Convert r to a polynomial with centered binomial distribution
    randomness_r = FCBD(convert_bytes_to_bit_list(randomness_r))
    print('Randomness r:', randomness_r, '\n')

    # Encrypt message using KPKE
    ciphertext = CPAPKE_Enc(matrix_A, vector_t, random_message, randomness_r)

    return key_k, ciphertext

In [69]:
def CCAKEM_Decaps(ciphertext: (polynomial, polynomial), secret_key: (polynomial, (polynomial, polynomial), bytes, int)) -> bytes:
    print('===== MLKEM Decapsulation =====')

    # Unpack secret key and ciphertext
    sk_s, enc_key, hash_t, _ = secret_key
    matrix_A, vector_t = enc_key
    cipher_u, cipher_v = ciphertext

    # Decrypt the received ciphertext
    recovered_message = CPAPKE_Dec(cipher_u, cipher_v, sk_s)
    print("Recovered message:", recovered_message)

    # Recompute key and randomness
    recomputed_kr = sha3_512(int(recovered_message).to_bytes(calculate_bytes_needed(n), 'big') + hash_t).digest()
    derived_key, derived_r = (recomputed_kr[:calculate_bytes_needed(n)], recomputed_kr[calculate_bytes_needed(n):])

    # Convert r' to polynomial using CBD transformation
    derived_r = FCBD(convert_bytes_to_bit_list(derived_r))
    print("Derived randomness r':", derived_r)

    # Encrypt again to compare
    reconstructed_u, reconstructed_v = CPAPKE_Enc(matrix_A, vector_t, Integer(recovered_message), derived_r)

    print("Reconstructed u':", reconstructed_u)
    print("Reconstructed v':", reconstructed_v, '\n')

    return derived_key

In [70]:
# Alice generates her public and secret key pair
public_key_A, secret_key_A = CCAKEM_KeyGen()

# Bob receives Alice's public key,
# then runs encapsulation to derive his shared secret (ssB)
# and the ciphertext (c)
shared_secret_B, ciphertext = CCAKEM_Encaps(public_key_A)

# Bob sends the ciphertext to Alice,
# who decapsulates it using her secret key
# to obtain her copy of the shared secret (ssA)
shared_secret_A = CCAKEM_Decaps(ciphertext, secret_key_A)

print("Bob's Shared Secret:", shared_secret_B.hex())
print("Alice's Shared Secret:", shared_secret_A.hex())
print("Shared secrets match?", shared_secret_A == shared_secret_B)

if shared_secret_A != shared_secret_B:
    raise ValueError("The shared keys do not match!")


===== MLKEM Key Generation =====
===== Generating KPKE Key =====
A: [                               267*xbar^255 + 1620*xbar^254 + 3264*xbar^253 + 2120*xbar^252 + 832*xbar^251 + 678*xbar^250 + 803*xbar^249 + 2547*xbar^248 + 1322*xbar^247 + 2366*xbar^246 + 2044*xbar^245 + 2574*xbar^244 + 2181*xbar^243 + 1031*xbar^242 + 672*xbar^241 + 2255*xbar^240 + 1933*xbar^239 + 1668*xbar^238 + 2335*xbar^237 + 1886*xbar^236 + 351*xbar^235 + 577*xbar^234 + 1792*xbar^233 + 236*xbar^232 + 1353*xbar^231 + 3129*xbar^230 + 2656*xbar^229 + 1721*xbar^228 + 980*xbar^227 + 1234*xbar^226 + 338*xbar^225 + 677*xbar^224 + 248*xbar^223 + 1893*xbar^222 + 1420*xbar^221 + 2335*xbar^220 + 2977*xbar^219 + 207*xbar^218 + 2933*xbar^217 + 2504*xbar^216 + 1522*xbar^215 + 552*xbar^214 + 2689*xbar^213 + 1628*xbar^212 + 1554*xbar^211 + 1003*xbar^210 + 1940*xbar^209 + 1240*xbar^208 + 1673*xbar^207 + 1053*xbar^206 + 1756*xbar^205 + 105*xbar^204 + 3064*xbar^203 + 1763*xbar^202 + 2574*xbar^201 + 1381*xbar^200 + 2345*xbar^199 + 276