# ML-KEM (Module-Lattice-Based Key-Encapsulation) Mechanism Standart

Implementation of the [paper](https://nvlpubs.nist.gov/nistpubs/fips/nist.fips.203.pdf) published by NIST.

In [9]:
import numpy as np

In [46]:
# Global variables
q = 3329
n = 256

In [14]:
def BitsToBytes(b):
    """
    Converts a bit array (of a length that is a multiple of eight) into and array of bytes.
    """
    lenb = len(b)
    B = np.zeros(lenb // 8)

    for i in range(lenb):
        B[i // 8] = B[i // 8] + b[i] * (2 ** (i % 8))

    return B 

In [16]:
def BytesToBits(B):
    """
    Performs the inverse of BitsToBytes, converting a byte array into a bit array.
    """
    l = len(B)
    # Why is there no initialization in the paper ?
    b = np.zeros(8 * l)
    # Why should we do a copy ?
    C = np.copy(B)

    for i in range(l):
        for j in range(8):
            b[8 * i + j] = C[i] % 2
            C[i] = C[i] // 2
    
    return b

In [18]:
b = np.array([0, 0, 0, 0, 1, 1, 0, 0])
print(BytesToBits(BitsToBytes(b)))

[0. 0. 0. 0. 1. 1. 0. 0.]


In [64]:
# Note: should be a templated function? so c++?
def ByteEncode(F):
    """
    Encodes an array of 𝑑-bit integers into a byte array for 1 ≤ 𝑑 ≤ 12.

    OR: converts an array of 𝑛 = 256 integers modulo 𝑚 into a corresponding array of bytes.
    """

    # TODO: d should be an external variable
    d = 11

    b = np.zeros(8 * 32 * 11) # not sure

    for i in range(256):
        a = F[i]
        for j in range(d):
            b[i * d + j] = a % 2
            a = (a - b[i * d + j]) / 2

    return BitsToBytes(b)

In [65]:
def ByteDecode(B):
    """
    Decodes a byte array into an array of 𝑑-bit integers for 1 ≤ 𝑑 ≤ 12.

    OR: converts an array of bytes into an array of integers modulo 𝑚.
    """
    
    # TODO: d should be an external variable
    d = 11

    m = 2 ** d if d < 12 else q
    b = BytesToBits(B)
    F = np.zeros(256)

    for i in range(256):
        summation = 0
        for j in range(d):
            summation += b[i * d + j] * (2 ** j)
        F[i] = summation % m
    
    return F

In [75]:
B = BitsToBytes(b)
res = np.zeros(32 * 11)
res[:B.shape[0]] = B
B = res
F = ByteDecode(B)
Bres = ByteEncode(F)