# Estruturas Criptográficas - TP2-2
### PG53721 - Carlos Machado
### PG54249 - Tiago Oliveira
### Enunciado - Implementação KEM FIPS203

Neste problema, pretende-se implmentar um protótipo do standard parametrizado de um **Key Encapsulation Mechanism (KEM)** de acordo com as variantes sugeridas na norma **FIPS203** (512, 768 e 1024 bits de segurança).

In [21]:
import hashlib, os
from functools import reduce

## Inicialização

In [22]:
def select_method(method_number):
    methods = {
        1: (10, 4),
        2: (10, 4),
        3: (11, 5)
    }
    return methods.get(method_number, None)

N = 256
Q = 3329
k_value = 3 # errado
ETA1 = 2    # errado
ETA2 = 2

print("Escolha um dos seguintes métodos:")
print("1. ML-KEM-512")
print("2. ML-KEM-768")
print("3. ML-KEM-1024")

method_number = int(input("Escreva o número correspondente ao método: "))

DU, DV = select_method(method_number)

Choose from the following numbers:
1. ML-KEM-512
2. ML-KEM-768
3. ML-KEM-1024


In [23]:
print(DU)
print(DV)

10
4


## Algoritmos auxiliares

### Conversões e compressões

In [24]:
def bits_to_bytes(b):
    B = bytearray([0] * (len(b) // 8))
    
    for i in range(len(b)):
        B[i // 8] += b[i] * 2 ** (i % 8)
    
    return bytes(B)

def bytes_to_bits(B):
    B_list = list(B)
    b = [0] * (len(B_list) * 8)
    
    for i in range(len(B_list)):
        for j in range(8):
            b[8 * i + j] = B_list[i] % 2
            B_list[i] //= 2

    return b

def byte_encode(d, F):
    b = [0] * (256 * d)
    for i in range(256):
        a = F[i]
        for j in range(d):
            b[i * d + j] = a % 2
            a = (a - b[i * d + j]) // 2
    
    return bits_to_bytes(b)

def byte_decode(d, B):
    m = 2 ** d if d < 12 else Q
    b = bytes_to_bits(B)
    F = [0] * 256
    for i in range(256):
        F[i] = sum(b[i * d + j] * (2 ** j) % m for j in range(d))
    
    return F
    
def compress(d, x):
	return [(((n * 2**d) + Q // 2 ) // Q) % (2**d) for n in x]

def decompress(d, x):
	return [(((n * Q) + 2**(d-1) ) // 2**d) % Q for n in x]

### Sampling

In [25]:
def sample_ntt(B):
    a_nr = [0] * 256
    i = 0
    j = 0
    
    while j < 256 and i < len(B):
        d1 = B[i] + 256 * (B[i + 1] % 16)
        d2 = (B[i + 1] // 16) + 16 * B[i + 2]
        
        if d1 < Q:
            a_nr[j] = d1
            j += 1
        
        if d2 < Q and j < 256:
            a_nr[j] = d2
            j += 1
        
        i += 3
    
    return a_nr

def sample_poly_cbd(B, eta):
    b = bytes_to_bits(B)
    f = [0] * 256
    
    for i in range(256):
        x = sum(b[2*i*eta + j] * (2 ** j) for j in range(eta))
        y = sum(b[2*i*eta + eta + j] * (2 ** j) for j in range(eta))
        f[i] = (x - y) % Q
    
    return f

### NTT

In [26]:
def bit_rev_7(r):
    return int('{:07b}'.format(r)[::-1], 2)

ZETA = [pow(17, bit_rev_7(k_value), Q) for k_value in range(128)]
GAMMA = [pow(17, 2 * bit_rev_7(k_value) + 1, Q) for k_value in range(128)]

def ntt(f):
    fc = f
    k_value = 1
    len = 128

    while len >= 2:
        start = 0
        while start < 256:
            zeta = ZETA[k_value]
            k_value += 1
            for j in range(start, start + len):
                t = (zeta * fc[j + len]) % Q
                fc[j + len] = (fc[j] - t) % Q
                fc[j] = (fc[j] + t) % Q

            start += 2 * len
            
        len //= 2

    return fc

def ntt_inv(fc):
    f = fc
    k_value = 127
    len = 2
    while len <= 128:
        start = 0
        while start < 256:
            zeta = ZETA[k_value]
            k_value -= 1
            for j in range(start, start + len):
                t = f[j]
                f[j] = (t + f[j + len]) % Q
                f[j + len] = (zeta * (f[j + len] - t)) % Q

            start += 2 * len

        len *= 2

    return [(felem * 3303) % Q for felem in f]


def base_case_multiply(a0, a1, b0, b1, gamma):
    c0 = a0 * b0 + a1 * b1 * gamma
    c1 = a0 * b1 + a1 * b0

    return c0, c1

def multiply_ntt_s(fc, gc):
    hc = [0] * 256
    for i in range(128):
        hc[2 * i], hc[2 * i + 1] = base_case_multiply(fc[2 * i], fc[2 * i + 1], gc[2 * i], gc[2 * i + 1], GAMMA[i])
    
    return hc

#### Funções auxiliares

In [27]:
def XOF(rho, i, j):
    return hashlib.shake_128(rho + bytes([i]) + bytes([j])).digest(1536)

def PRF(eta, s, b):
    return hashlib.shake_256(s + b).digest(64 * eta)

def vector_add(ac, bc):
	return [(x + y) % Q for x, y in zip(ac, bc)]

def vector_sub(ac, bc):
	return [(x - y) % Q for x, y in zip(ac, bc)]

def G(c):
    G_result = hashlib.sha3_512(c).digest()
    return G_result[:32], G_result[32:]

def H(c):
    return hashlib.sha3_256(c).digest()

def J(s, l):
    return hashlib.shake_256(s).digest(l)

### K-PKE

In [28]:
def k_pke_keygen():
    d = os.urandom(32)
    rho, sigma = G(d)
    N = 0
    Ac = [[None for _ in range(k_value)] for _ in range(k_value)]
    s = [None for _ in range(k_value)]
    e = [None for _ in range(k_value)]

    for i in range(k_value):
        for j in range(k_value):
            Ac[i][j] = sample_ntt(XOF(rho, i, j))

    for i in range(k_value):
        s[i] = sample_poly_cbd(PRF(ETA1, sigma, bytes([N])), ETA1)
        N += 1

    for i in range(k_value):
        e[i] = sample_poly_cbd(PRF(ETA1, sigma, bytes([N])), ETA1)
        N += 1

    sc = [ntt(s[i]) for i in range(k_value)]
    ec = [ntt(e[i]) for i in range(k_value)]
    tc = [reduce(vector_add, [multiply_ntt_s(Ac[i][j], sc[j]) for j in range(k_value)] + [ec[i]]) for i in range(k_value)]

    ek_PKE = b"".join(byte_encode(12, tc_elem) for tc_elem in tc) + rho
    dk_PKE = b"".join(byte_encode(12, sc_elem) for sc_elem in sc)

    return ek_PKE, dk_PKE

def k_pke_encrypt(ek_PKE, m, rand):
    N = 0
    tc = [byte_decode(12, ek_PKE[i * 128 * k_value : (i + 1) * 128 * k_value]) for i in range(k_value)]
    rho = ek_PKE[384 * k_value : 384 * k_value + 32]
    Ac = [[None for _ in range(k_value)] for _ in range(k_value)]
    r = [None for _ in range(k_value)]
    e1 = [None for _ in range(k_value)]

    for i in range(k_value):
        for j in range(k_value):
            Ac[i][j] = sample_ntt(XOF(rho, i, j))


    for i in range(k_value):
        r[i] = sample_poly_cbd(PRF(ETA1, rand, bytes([N])), ETA1)
        N += 1

    
    for i in range(k_value):
        e1[i] = sample_poly_cbd(PRF(ETA2, rand, bytes([N])), ETA2)
        N += 1

    e2 = sample_poly_cbd(PRF(ETA2, rand, bytes([N])), ETA2)
    rc = [ntt(r[i]) for i in range(k_value)]

    u = [vector_add(ntt_inv(reduce(vector_add, [multiply_ntt_s(Ac[j][i], rc[j]) for j in range(k_value)])), e1[i]) for i in range(k_value)]

    mu = decompress(1, byte_decode(1, m))

    v = vector_add(ntt_inv(reduce(vector_add, [multiply_ntt_s(tc[i], rc[i]) for i in range(k_value)])), vector_add(e2, mu))

    c1 = b"".join(byte_encode(DU, compress(DU, u[i])) for i in range(k_value))
    c2 = byte_encode(DV, compress(DV, v))


    
    return c1 + c2

def k_pke_decrypt(dk_PKE, c):
    c1 = c[:32 * DU * k_value]
    c2 = c[32 * DU * k_value : 32 * (DU * k_value + DV)]
    u = [decompress(DU, byte_decode(DU, c1[i * 32 * DU : (i + 1) * 32 * DU])) for i in range(k_value)]

    v = decompress(DV, byte_decode(DV, c2))

    sc = [byte_decode(12, dk_PKE[i * 384 : (i + 1) * 384]) for i in range(k_value)]

    w = vector_sub(v, ntt_inv(reduce(vector_add, [multiply_ntt_s(sc[i], ntt(u[i])) for i in range(k_value)])))

    return byte_encode(1, compress(1, w))

### L-KEM Key-Encapsulation Mechanism

In [29]:
def ml_kem_keygen():
    z = os.urandom(32)
    ek_PKE, dk_PKE = k_pke_keygen()
    ek = ek_PKE
    dk = dk_PKE + ek + H(ek) + z

    return ek, dk


def ml_kem_encaps(ek):
    m = os.urandom(32)
    K, r = G(m + H(ek))
    c = k_pke_encrypt(ek, m, r)

    return K, c


def ml_kem_decaps(c, dk):
    dk_PKE = dk[0: 384 * 3]
    ek_PKE = dk[384 * 3 : 768 * 3 + 32]
    h = dk[768 * 3 + 32 : 768 * 3 + 64]
    z = dk[768 * 3 + 64 : 768 * 3 + 96]
    ml = k_pke_decrypt(dk_PKE, c)
    Kl, rl = G(ml + h)
    Kb = J((z + c), 32)
    cl = k_pke_encrypt(ek_PKE, ml, rl)
    if c != cl:
        Kl = Kb

    return Kl

## Teste

In [30]:
rand = os.urandom(32)
message = b'Este e um exemplo de mensagem !!'
ek_PKE, dk_PKE = k_pke_keygen()
#  
print('message:', message)
cipher = k_pke_encrypt(ek_PKE, message, rand)
print(cipher)
# 
message2 = k_pke_decrypt(dk_PKE, cipher)
print('message:', message2)

ek, dk = ml_kem_keygen()
#print('ek:', ek)
#print('dk:', dk)

K, c = ml_kem_encaps(ek)
#print('K:', K)
#print('c:', c)

Kl = ml_kem_decaps(c, dk)

message: b'Este e um exemplo de mensagem !!'
b'\x96\xe3\xb6\x8d1R\x8b=\xaa\xf8O\xed.\x9f\xd7\x95\xb4Z\x0e;\x8e\xe5?\xee\xaeh+\xe62\x8e\xf5h4l\xb1\xbb\xda\xa0\x87\xe8\xcb\x1e\xe3\xc3z\xbf\xdb\xbe\xd3\xc8\x0f\x9aP\x9d\xf8\xec=\x8c\x9bS\xa8E\xcb\xdaE\xc5\xbb\xa5[4\xde\x9e\xaa\x05\x8f\x90\xee\xe1\x94\x97\x9b\xea\xc6c\x19\xc9\xa1Hue\xc2`0\x95&\xe3\x97\x87\xe5a\xf1[\x8c6\xc7\xae\xabuF\xb3)BS\xdb\xcdX\xd2\x8e\xb9\xce5\xdaR\xfc\xebJ7\x82\xe3\xdf\x809\xa4T\x05\xf7\xd9\xc1\x95\xcaB/O!7\xd1#\x0c\xa1\x19\x12+\xda#\x1c\xec\xfc\xd8\xbb\xffdi\xc5\x99`W\xb5\x03\xb3\x01O\xb4\xe2\x1c\xa4\xc1\x97\xa2\xe3n\x04P\'\x01\x05\x88\xd3\xc67\xd2\xfa\xe8\xc5\xe6h\xea\r\x00\x980\x8a\x7f\x1c&\xe9\xae\xf2&\x0c}\xce8c\x02}\xe6aM0\xc3\rP\xee5\xd2\xa6\x80\x13\x8c#\xfa/\xd3\x9fL\xbb\x8e\x0b\x1cO\x1d\xf2%]m\xf4\xc6*!\x01\xdec\xf4Z\xea\x89\xb7\xff=\x7f\xa8\x8b\xdb\xee\x99\x89\xcd-I&\x9dS3\xac\xab\x89\xa9D\xf6\xdd7\xf8\x83F\xf6\xc7\x9f\xfd%w\xe5\xde\x01[\xa8\xca9\xecH\x03&N\xf0\x86\n?|\x93\x90]\xab\x82g\xb2\x82\xff\xe9\xddC