<h1><center>TP3 - Ex.2</center></h1>
<p><center>Abril 25, 2024</center></p>


### Estruturas Criptográficas

PG53886, Ivo Miguel Alves Ribeiro

A95323, Henrique Ribeiro Fernandes


2. Em Agosto de 2023 a NIST publicou um draf da norma FIPS203  para um Key Encapsulation Mechanism (KEM) derivado dos algoritmos KYBER. O preâmbulo do “draft” 
> A key-encapsulation mechanism (or KEM) is a set of algorithms that, under certain conditions, can be used by two parties to establish a shared secret key over a public channel. A shared secret key that is securely established using a KEM can then be used with symmetric-key cryptographic algorithms to perform basic tasks in secure communications, such as encryption and authentication. This standard specifes a key-encapsulation mechanism called ML-KEM. The security of ML-KEM is related to the computational diffculty of the so-called Module Learning with Errorsproblem. At present, ML-KEM is believed to be secure even against adversaries who possess a quantum computer

Neste trabalho pretende-se implementar em Sagemath um protótipo deste standard parametrizado de acordo com as variantes sugeridas na norma (512, 768 e 1024 bits de segurança)

Hash Functions

In [19]:
import hashlib
import secrets

def Random_32_bytes():
    return secrets.token_bytes(32)

def H(s):
    return hashlib.sha3_256(s).digest()

def J(s):
    return hashlib.shake_256(s).digest(32)

def G(s):
    return hashlib.sha3_512(s).digest()

def XOF(x, i, j):
    return (hashlib.shake_128(x + bytes([i, j])).digest(256*3))

def PRF(data, b, eta):
	return hashlib.shake_256(data + bytes([b])).digest(64 * eta)


def polyADD(a, b, Q):
	return [(x + y) % Q for x, y in zip(a, b)]

def polySUB(a, b, Q):
	return [(x - y) % Q for x, y in zip(a, b)]

Algoritmo 2 BitsToBytes e Algoritmo 3 BytesToBits

In [20]:
def BitsToBytes(bits):
	assert(len(bits) % 8 == 0)
	return bytes(
		sum(bits[i + j] << j for j in range(8))
		for i in range(0, len(bits), 8)
	)

def BytesToBits(data):
	bits = []
	for word in data:
		for i in range(8):
			bits.append((word >> i) & 1)
	return bits

Algoritmo 5 ByteEncode e 6 ByteDecode

In [21]:
def ByteEncode(F, d):
	assert(len(F) == 256)
	bits = []
	for a in F:
		for i in range(d):
			bits.append((a >> i) & 1)
	return BitsToBytes(bits)

def ByteDecode(B, d):
	bits = BytesToBits(B)
	return [sum(bits[i * d + j] << j for j in range(d)) for i in range(256)]

def Compress(x, d, Q):
	return [(((n * 2**d) + Q // 2 ) // Q) % (2**d) for n in x]

def Decompress(x, d, Q):
	return [(((n * Q) + 2**(d-1) ) // 2**d) % Q for n in x]

Algoritmo 6 SampleNTT

In [22]:
def SampleNTT(B, Q):
	res = []
	i = 0
	while len(res) < 256:
		a, b, c = B[i:i+3]
		i+=3
		d1 = ((b & 0xf) << 8) | a
		d2 = c << 4 | b >> 4
		if d1 < Q:
			res.append(d1)
		if d2 < Q and len(res) < 256:
			res.append(d2)
	return res

Algoritmo 7 SamplePolyCBC

In [23]:
def SamplePolyCBC(B, Q, eta):
	assert(len(B) == 64 * eta)
	bits = BytesToBits(B)
	f = []
	for i in range(256):
		x = sum(bits[2*i*eta+j] for j in range(eta))
		y = sum(bits[2*i*eta+eta+j] for j in range(eta))
		f.append((x - y) % Q)
	return f

Algoritmo 8 NTT

In [24]:
def bitrev7(n):
	return int(f"{n:07b}"[::-1], 2)

def NTT(f_in, Q):
	f_out = f_in.copy()
	k = 1
	for log2len in range(7, 0, -1):
		length = 2**log2len
		for start in range(0, 256, 2 * length):
			zeta = pow(17, bitrev7(k), Q)
			k += 1
			for j in range(start, start + length):
				t = (zeta * f_out[j + length]) % Q
				f_out[j + length] = (f_out[j] - t) % Q
				f_out[j] = (f_out[j] + t) % Q
	return f_out

Algoritmo 9 NTT-1

In [25]:
def NTT_inv(f_in, Q):
	f_out = f_in.copy()
	k = 127
	for log2len in range(1, 8):
		length = 2**log2len
		for start in range(0, 256, 2 * length):
			zeta = pow(17, bitrev7(k), Q)
			k -= 1
			for j in range(start, start + length):
				t = f_out[j]
				f_out[j] = (t + f_out[j + length]) % Q
				f_out[j + length] = (zeta * (f_out[j + length] - t)) % Q

	for i in range(256):
		f_out[i] = (f_out[i] * 3303) % Q 

	return f_out

Algoritmo 11 NTT mult

In [26]:
def NTT_mult(a, b, Q):
	c = []
	for i in range(128):
		a0, a1 = a[2 * i: 2 * i + 2]
		b0, b1 = b[2 * i: 2 * i + 2]
		c.append((a0 * b0 + a1 * b1 * pow(17, 2*bitrev7(i)+1, Q)) % Q)
		c.append((a0 * b1 + a1 * b0) % Q)
	return c

# K-PKE Component Scheme

In [27]:
class K_PKE:
    def __init__(self, n, q, k, n1, n2, du, dv):
        self.n = n
        self.q = q
        self.k = k
        self.n1 = n1
        self.n2 = n2
        self.du = du
        self.dv = dv

    def keygen(self):
        d = Random_32_bytes()
        ghash = G(d)
        rho, sigma = ghash[:32], ghash[32:]
        a_hat = []
        for i in range(self.k):
            row = []
            for j in range(self.k):
                row.append(SampleNTT(XOF(rho, i, j), self.q))
            a_hat.append(row)
        
        s_hat = [NTT(SamplePolyCBC(PRF(sigma, i, self.n1), self.q, self.n1), self.q) for i in range(self.k)]
        e_hat = [NTT(SamplePolyCBC(PRF(sigma, i+self.k, self.n1), self.q, self.n1), self.q) for i in range(self.k)]
        
        t_hat = []
        print(s_hat)
        print(e_hat)
        print(a_hat)
        for i in range(self.k):
            sum = e_hat[i]
            for j in range(self.k):
                sum = polyADD(NTT_mult(a_hat[j][i], s_hat[j], self.q), sum, self.q)
            t_hat.append(sum)

        ek_pke = b"".join(ByteEncode(s, 12) for s in t_hat) + rho
        dk_pke = b"".join(ByteEncode(s, 12) for s in s_hat)
        return ek_pke, dk_pke


    def PKE_Encrypt(self, ek_pke, msg, r):
        t_hat = [ByteDecode(ek_pke[i*384:(i+1)*384], 12) for i in range(self.k)]
       
        rho = ek_pke[-32:]
        
        a_hat = []
        for i in range(self.k):
            row = []
            for j in range(self.k):
                row.append(SampleNTT(XOF(rho, i, j), self.q))
            a_hat.append(row)
        
        r_hat = [NTT(SamplePolyCBC(PRF(r, i, self.n1), self.q, self.n1), self.q) for i in range(self.k)]
        e1 = [SamplePolyCBC(PRF(r, i+self.k, self.n2), self.q, self.n2) for i in range(self.k)]
        e2 = SamplePolyCBC(PRF(r, 2*self.k, self.n2), self.q, self.n2)

        u = []
        for i in range(self.k):
            sum = e1[i]
            for j in range(self.k):
                # transposta de A_hat por trocar o i com o j
                sum = polyADD(NTT_inv(NTT_mult(a_hat[i][j], r_hat[j], self.q), self.q),sum, self.q)
            u.append(sum)
        
        µ = Decompress(ByteDecode(msg, 1), 1, self.q)
        
        v = [0]*256
        for i in range(self.k):
            v = polyADD(NTT_inv(NTT_mult(t_hat[i], r_hat[i], self.q), self.q), v, self.q)
        v = polyADD(v, polyADD(e2, µ, self.q), self.q)

        c1 = b"".join(ByteEncode(Compress(u[i], self.du, self.q), self.du) for i in range(self.k))
        c2 = ByteEncode(Compress(v, self.dv, self.q), self.dv)
        return c1 + c2


    def PKE_Decrypt(self, dk_pke, cypher):
        c1 = cypher[:32*self.du*self.k]
        c2 = cypher[32*self.du*self.k:]
        u = [Decompress(ByteDecode(c1[i*32*self.du:(i+1)*32*self.du], self.du), self.du, self.q) for i in range(self.k)]
        v = Decompress(ByteDecode(c2, self.dv), self.dv, self.q)
        s_hat = [ByteDecode(dk_pke[i*384:(i+1)*384], 12) for i in range(self.k)]

        w = [0]*256
        for i in range(self.k):
            w = polyADD(NTT_inv(NTT_mult(s_hat[i], NTT(u[i], self.q), self.q), self.q), w, self.q)
        w = polySUB(v, w, self.q)

        m = ByteEncode(Compress(w, 1, self.q),1)
        return m

# Test K_PKE

In [28]:
test1 = K_PKE(256, 3329, 2, 3, 2, 10, 4)
ek_PKE, dk_PKE = test1.keygen()
print("Encryption Key (ek_PKE):", ek_PKE)
print("Tamanho da Encryption Key: ", len(ek_PKE))
print("Decryption Key (dk_PKE):", dk_PKE)
print("Tamanho da Decryption Key: ", len(dk_PKE))

msg = b"Esta e uma mensagem de 32 bytes!"
print("Tamanho da mensagem: ",len(msg))
r = Random_32_bytes()
cyphertext = test1.PKE_Encrypt(ek_PKE, msg, r)
print("Cyphertext: ", cyphertext)

decrypt_msg = test1.PKE_Decrypt(dk_PKE, cyphertext)
print("Decrypt message: ", decrypt_msg)
assert(decrypt_msg == msg)

[[1626, 304, 146, 2041, 2609, 354, 1062, 2287, 616, 1318, 1666, 2780, 639, 1580, 1347, 2985, 602, 1064, 2981, 834, 815, 477, 2262, 1600, 1762, 2957, 550, 800, 2603, 2480, 359, 2954, 2334, 2786, 1484, 2300, 1304, 1393, 172, 2440, 758, 959, 2536, 199, 1941, 2492, 1243, 851, 3223, 93, 796, 1570, 3180, 1234, 919, 3176, 2601, 89, 875, 2512, 258, 1326, 1369, 980, 582, 310, 2956, 2221, 1397, 323, 210, 1795, 677, 3223, 683, 3215, 3084, 2709, 2428, 974, 670, 2395, 2411, 2182, 1244, 2053, 2556, 1245, 892, 1783, 2662, 1759, 267, 1838, 2231, 3167, 507, 326, 1316, 1409, 3006, 171, 1865, 1683, 1482, 1270, 3313, 2021, 2729, 2432, 1712, 1276, 2912, 1671, 1216, 726, 371, 1940, 1279, 1245, 63, 1541, 467, 1874, 3043, 3279, 3065, 1257, 781, 1376, 3174, 1011, 3063, 947, 1551, 1355, 1047, 998, 682, 1623, 2724, 2921, 1653, 670, 2445, 3073, 463, 2266, 2058, 1417, 2868, 1760, 2551, 2066, 2031, 2160, 2626, 1480, 1151, 1904, 3107, 1815, 1047, 1362, 1521, 813, 1731, 2439, 1854, 2290, 1405, 1388, 3122, 3170, 1484,

# ML-KEM Key-Encapsulation Mechanism

In [29]:
class KEM:
    def __init__(self, kem_type):
        self.n = 256
        self.q = 3329
        self.sharedkeysize = 32
        self.type = kem_type
        if kem_type == 512:
            self.k = 2
            self.n1 = 3
            self.n2 = 2
            self.du = 10
            self.dv = 4
            self.enckeysize = 800
            self.deckeysize = 1632
            self.ciphersize = 768
        if kem_type == 768:
            self.k = 3
            self.n1 = 2
            self.n2 = 2
            self.du = 10
            self.dv = 4
            self.enckeysize = 1184
            self.deckeysize = 2400
            self.ciphersize = 1088
        if kem_type == 1024:
            self.k = 4
            self.n1 = 2
            self.n2 = 2
            self.du = 11
            self.dv = 5
            self.enckeysize = 1568
            self.deckeysize = 3168
            self.ciphersize = 1568
        self.PKE = K_PKE(self.n, self.q, self.k, self.n1, self.n2, self.du, self.dv)

    def KeyGen(self):
        z = Random_32_bytes()
        ek_pke, dk_pke = self.PKE.keygen()
        ek = ek_pke
        dk = dk_pke + ek + H(ek) + z
        # Verificar parametros
        assert len(ek) == self.enckeysize
        assert len(dk) == self.deckeysize
        return ek, dk


    def Encaps(self, ek):
        msg = b"Shared secrect key com 32 bytes!"
        print("Chave partilhada: ", msg)
        print("Tamanho da mensagem: ",len(msg))
        ghash = G(msg + H(ek))
        k = ghash[:32]
        r = ghash[32:]
        c = self.PKE.PKE_Encrypt(ek, msg, r)
        # Verificar parametros
        assert len(k) == self.sharedkeysize
        assert len(c) == self.ciphersize
        return k, c

    def Decaps(self, c, dk):
        dk_pke = dk[:384*self.k]
        ek_pke = dk[384*self.k : 768*self.k + 32]
        h = dk[768*self.k+ 32 : 768*self.k + 64]
        z = dk[768*self.k + 64 : 768*self.k + 96]
        mdash = self.PKE.PKE_Decrypt(dk_pke, c)
        ghash = G(mdash + h)
        kdash = ghash[:32]
        rdash = ghash[32:]
        kbar = J(z + c)
        cdash = self.PKE.PKE_Encrypt(ek_pke, mdash, rdash)
        if cdash != c:
            return kbar
        assert len(kdash) == self.sharedkeysize
        return kdash

# test ML-KEM-512 

In [30]:
kem1 = KEM(512)
ek_KEM, dk_KEM = kem1.KeyGen()
print("Encryption Key (ek_KEM):", ek_KEM)
print("Tamanho da Encryption Key: ", len(ek_KEM))
print("Decryption Key (dk_KEM):", dk_KEM)
print("Tamanho da Decryption Key: ", len(dk_KEM))

shared_key1, cyphertext = kem1.Encaps(ek_KEM)
print("Shared key: {}\nTamanho da Shared Key: {}".format(shared_key1, len(shared_key1)))
print("CypherText: {}\nTamanho da cifra: {}".format(cyphertext, len(cyphertext)))

shared_key2 = kem1.Decaps(cyphertext, dk_KEM)
print("Shared key: {}\nTamanho da Shared Key: {}".format(shared_key1, len(shared_key1)))

if shared_key1 == shared_key2:
    print("ML-KEM key-Encapsulation Mechanism Valid!!")
else:
    print("ML-KEM key-Encapsulation Mechanism Invalid!!")

[[2634, 871, 1630, 82, 1096, 563, 1318, 137, 2234, 2401, 1717, 1250, 432, 821, 2445, 1301, 1786, 716, 1173, 3256, 1857, 162, 420, 2181, 1965, 2759, 334, 519, 48, 473, 229, 625, 750, 2744, 2895, 20, 1567, 2740, 1242, 395, 672, 1901, 3178, 1312, 3209, 2009, 3045, 621, 1550, 2125, 2445, 434, 902, 3183, 732, 1898, 3112, 3101, 224, 1209, 682, 1053, 301, 494, 1777, 1585, 2762, 625, 796, 373, 2762, 1250, 1165, 2002, 2578, 1870, 2259, 1440, 1737, 2756, 1294, 821, 2488, 2549, 1207, 1549, 1153, 1340, 3063, 883, 1911, 719, 1021, 2154, 720, 1449, 2632, 2823, 2627, 1827, 1380, 2087, 205, 1436, 1964, 2272, 2101, 2253, 625, 1066, 457, 2156, 660, 1115, 1096, 793, 1218, 2426, 2856, 2342, 885, 913, 401, 2884, 2732, 1163, 3181, 3191, 1547, 543, 3114, 1409, 35, 781, 2563, 2265, 136, 1999, 2691, 1692, 2688, 3068, 2633, 2042, 1808, 1627, 1756, 337, 3249, 1395, 2404, 770, 2731, 2372, 354, 2985, 2000, 3102, 400, 582, 848, 476, 1975, 2306, 1652, 1178, 578, 977, 2251, 2894, 1845, 478, 1784, 407, 1308, 923, 133,

# test ML-KEM-768 

In [31]:
kem1 = KEM(768)
ek_KEM, dk_KEM = kem1.KeyGen()
print("Encryption Key (ek_KEM):", ek_KEM)
print("Tamanho da Encryption Key: ", len(ek_KEM))
print("Decryption Key (dk_KEM):", dk_KEM)
print("Tamanho da Decryption Key: ", len(dk_KEM))

shared_key1, cyphertext = kem1.Encaps(ek_KEM)
print("Shared key: {}\nTamanho da Shared Key: {}".format(shared_key1, len(shared_key1)))
print("CypherText: {}\nTamanho da cifra: {}".format(cyphertext, len(cyphertext)))

shared_key2 = kem1.Decaps(cyphertext, dk_KEM)
print("Shared key: {}\nTamanho da Shared Key: {}".format(shared_key1, len(shared_key1)))

if shared_key1 == shared_key2:
    print("ML-KEM key-Encapsulation Mechanism Valid!!")
else:
    print("ML-KEM key-Encapsulation Mechanism Invalid!!")

[[2004, 836, 1037, 2604, 1444, 720, 598, 104, 1939, 379, 1936, 2094, 636, 3025, 1018, 2011, 3199, 1309, 1084, 1405, 3267, 728, 2851, 2551, 1074, 2179, 361, 2378, 2814, 555, 217, 3154, 859, 899, 389, 1063, 3211, 912, 1918, 2742, 804, 2855, 2074, 1139, 2463, 2256, 3147, 1723, 1294, 1111, 724, 2078, 2626, 739, 2825, 199, 1739, 1596, 1395, 985, 1128, 3218, 3253, 3087, 1841, 832, 3030, 2948, 3146, 1760, 2190, 1811, 1467, 2814, 368, 1472, 2348, 357, 976, 1351, 1787, 320, 1904, 971, 1212, 1731, 2961, 1015, 438, 356, 1737, 1195, 3301, 2337, 1551, 3314, 3152, 198, 1887, 3224, 2114, 783, 1476, 951, 1868, 2328, 409, 1079, 1925, 1589, 1974, 588, 1153, 683, 1566, 3296, 1287, 285, 1010, 1781, 282, 3269, 2665, 1768, 1482, 2042, 821, 1318, 920, 1617, 3281, 56, 973, 2222, 2002, 3018, 2429, 3270, 2166, 1450, 1429, 3218, 1453, 2847, 2172, 3119, 2267, 352, 2943, 879, 1122, 1119, 2187, 3200, 848, 2699, 359, 887, 2899, 2336, 2388, 2169, 947, 1218, 2494, 2691, 916, 759, 2440, 3262, 943, 1320, 1815, 491, 884,

# test ML-KEM-1024 

In [32]:
kem1 = KEM(1024)
ek_KEM, dk_KEM = kem1.KeyGen()
print("Encryption Key (ek_KEM):", ek_KEM)
print("Tamanho da Encryption Key: ", len(ek_KEM))
print("Decryption Key (dk_KEM):", dk_KEM)
print("Tamanho da Decryption Key: ", len(dk_KEM))

shared_key1, cyphertext = kem1.Encaps(ek_KEM)
print("Shared key: {}\nTamanho da Shared Key: {}".format(shared_key1, len(shared_key1)))
print("CypherText: {}\nTamanho da cifra: {}".format(cyphertext, len(cyphertext)))

shared_key2 = kem1.Decaps(cyphertext, dk_KEM)
print("Shared key: {}\nTamanho da Shared Key: {}".format(shared_key1, len(shared_key1)))

if shared_key1 == shared_key2:
    print("ML-KEM key-Encapsulation Mechanism Valid!!")
else:
    print("ML-KEM key-Encapsulation Mechanism Invalid!!")

[[3178, 985, 2020, 2188, 524, 2590, 246, 2551, 531, 619, 1013, 2227, 2395, 1014, 2896, 1356, 3143, 771, 957, 538, 394, 1491, 170, 3228, 74, 935, 86, 759, 1224, 2448, 3251, 2566, 2033, 811, 1668, 2453, 1388, 2944, 2447, 3252, 893, 1418, 1588, 3078, 2273, 3229, 3019, 655, 1848, 22, 2985, 917, 3072, 2083, 2239, 83, 473, 1248, 640, 2915, 2926, 2519, 2221, 2750, 2094, 1896, 427, 2401, 38, 1666, 1570, 1806, 2586, 2549, 1383, 315, 1370, 2591, 379, 121, 1520, 2661, 2653, 2426, 775, 2794, 748, 1072, 2477, 2072, 1591, 1788, 1870, 1539, 2513, 2491, 409, 149, 936, 999, 2188, 2505, 936, 900, 2201, 2354, 483, 3268, 1998, 2805, 2598, 1056, 2518, 283, 647, 1412, 2884, 577, 2751, 320, 57, 788, 920, 2917, 275, 419, 132, 2274, 546, 964, 906, 2298, 1808, 277, 857, 852, 503, 1260, 1022, 1052, 2651, 924, 2404, 1487, 1728, 1, 1475, 1182, 94, 1342, 893, 264, 1398, 1068, 1647, 1026, 425, 2464, 2452, 1791, 3121, 1402, 2951, 2762, 2356, 1519, 2496, 1723, 911, 2159, 939, 2138, 362, 857, 69, 137, 1809, 2291, 488, 