In [17]:
from hashlib import shake_128, shake_256, sha3_256, sha3_512
from secrets import token_bytes
from sage.modules.free_module_element import FreeModuleElement_generic_dense
from sage.matrix.matrix_generic_dense import Matrix_generic_dense
import os

class Kyber512():
    def __init__(self, **kwargs):
        self.n = kwargs.get('n', 256)
        self.q = kwargs.get('q', 3329)
        self.k = kwargs.get('k', 2)
        self.eta_1 = kwargs.get('eta_1', 3)
        self.eta_2 = kwargs.get('eta_2', 2)
        self.d_u = kwargs.get('d_u', 10)
        self.d_v = kwargs.get('d_v', 4)
        self.secret_key_size = kwargs.get('secret_key_size', 32)
        self.zeta = kwargs.get('zeta', 17)

        PRq.<x> = PolynomialRing(GF(self.q))
        self.modulus = x^self.n + 1
        self.Rq = PRq.quotient(self.modulus)
        self.x = x
        self.zeta_powers = [int(self.Rq(self.zeta^i)) for i in range(self.n)]

    def __bytes_to_bits(self, B):
        return [B[i // 8] >> (7 - (i % 8)) & 1 for i in range(8 * len(B))]

    def __bits_to_bytes(self, B):
        return bytes([sum([B[i + j] << (7 - j) for j in range(8)]) for i in range(0, len(B), 8)])

    def __coeffs_to_poly(self, coeffs):
        return sum(self.Rq(coeffs[i] * x^i) for i in range(len(coeffs)))

    def __CBD(self, B, eta):
        beta = self.__bytes_to_bits(B)
        coeffs = [0] * 256
        for i in range(256):
            a = sum(beta[2 * i * eta + j] for j in range(eta))
            b = sum(beta[(2 * i + 1) * eta + j] for j in range(eta))
            coeffs[i] = a - b
        return self.__coeffs_to_poly(coeffs)

    def __PRF(self, seed, nonce, output_length):
        nonce_bytes = int(nonce).to_bytes(1, 'little')
        input_data = seed + nonce_bytes
        shake = shake_256(input_data)
        return shake.digest(output_length)
    
    def __decode_single(self, B, L):
        beta = self.__bytes_to_bits(B)
        coeffs = []

        for i in range(256):
            f_i = sum(beta[i * L + j] * (2 ** j) for j in range(L))
            coeffs.append(f_i)

        return self.__coeffs_to_poly(coeffs)

    def __encode_single(self, f, L):
        beta = []

        for i in range(256):
            beta_i = [(int(f[i]) >> j) & 1 for j in range(L)]
            beta.extend(beta_i)

        return self.__bits_to_bytes(beta)

    def __decode(self, byte_string, L):
        ret = []
        for i in range(0, len(byte_string), 256 * L // 8):
            ret.append(self.__decode_single(byte_string[i:i + 256 * L // 8], L))

        if len(ret) == 1:
            return ret[0]
        else:
            return vector(self.Rq, ret)

    def __encode(self, vec, L):
        if not isinstance(vec, (list, FreeModuleElement_generic_dense)):
            return self.__encode_single(vec, L)
        else:
            ret = b''
            for v in vec:
                ret += self.__encode_single(v, L)
            return ret

    def __G(self, d):
        shake = sha3_512(d).digest()
        rho = shake[:32]
        sigma = shake[32:64]
        return rho, sigma

    def __H(self, m):
        return sha3_256(m).digest()

    def __KDF(self, m, output_length = None):
        if output_length is None:
            output_length = self.secret_key_size
        shake = shake_256(m)
        return shake.digest(output_length)

    def __XOF(self, rho, i, j, output_length):
        i_bytes = int(i).to_bytes(1, 'little')
        j_bytes = int(j).to_bytes(1, 'little')
        input_data = rho + i_bytes + j_bytes
        shake = shake_128(input_data)
        return shake.digest(output_length)
    
    def __PRF(self, seed, nonce, output_length):
        nonce_bytes = int(nonce).to_bytes(1, 'little')
        input_data = seed + nonce_bytes
        shake = shake_256(input_data)
        return shake.digest(output_length)

    def bit_reverse(self, num, bits=7):
        result = 0
        for _ in range(bits):
            result = (result << 1) | (num & 1)
            num >>= 1
        return result

    def __NTT_single(self, f):
        PRq.<x> = PolynomialRing(GF(self.q))
        ret = self.Rq(0)
        for i in range(128):
            modulus = x^2 - self.zeta_powers[2*i+1]
            ntt_modulo = PRq.quotient(modulus)
            f_hat = ntt_modulo(f)
            ret += self.Rq(f_hat[1] * x^(2*i+1) + f_hat[0] * x^(2*i)) 
        return ret

    def __NTT_inv_single(self, f):
        PRq.<x> = PolynomialRing(GF(self.q))
        ret = self.Rq(0)
        for i in range(128):
            r = f[2*i+1]*x + f[2*i] 
            modulus = x^2 - self.zeta_powers[2*i+1]
            M = self.modulus // modulus
            inv_mod = M.inverse_mod(modulus)
            ret += self.Rq(r * M * inv_mod)
        return ret

    def __NTT(self, f):
        if not isinstance(f, (list, FreeModuleElement_generic_dense)):
            return self.__NTT_single(f)
        else:
            ret = []
            for v in f:
                ret.append(self.__NTT_single(v))
            return vector(self.Rq, ret)

    def __NTT_inv(self, f):
        if not isinstance(f, (list, FreeModuleElement_generic_dense)):
            return self.__NTT_inv_single(f)
        else:
            ret = []
            for v in f:
                ret.append(self.__NTT_inv_single(v))
            return vector(self.Rq, ret)

    def __NTT_product_element(self, f, g):
        PRq.<x> = PolynomialRing(GF(self.q))
        ret = self.Rq(0)
        for i in range(128):
            rf = f[2*i+1]*x + f[2*i] 
            rg = g[2*i+1]*x + g[2*i]
            modulus = x^2 - self.zeta_powers[2*i+1]
            res = (rf * rg) % modulus
            ret += self.Rq(res[1] * x^(2*i+1) + res[0] * x^(2*i))
        return ret

    def NTT_product(self, F, G):
        if type(F) == Matrix_generic_dense: # matrix
            ret = [0] * F.nrows()
            for i in range(F.nrows()):
                for j in range(F.ncols()):
                    ret[i] += self.__NTT_product_element(F[i][j], G[j])
            ret = vector(self.Rq, ret)
        elif type(F) == FreeModuleElement_generic_dense:
            ret = self.Rq(0)
            for i in range(len(F)):
                ret += self.__NTT_product_element(F[i], G[i])
        return ret

    def __parse(self, byte_stream):
        i = 0
        j = 0
        a = [0] * self.n

        while j < self.n and i+2 < len(byte_stream):
            d1 = byte_stream[i] + 256 * (byte_stream[i + 1] % 16)
            d2 = (byte_stream[i + 1] // 16) + 16 * byte_stream[i + 2]

            if d1 < self.q:
                a[j] = d1
                j += 1

            if d2 < self.q and j < self.n:
                a[j] = d2
                j += 1

            i += 3

        return self.__NTT(sum(self.Rq(a[i] * self.x^i) for i in range(self.n)))

    def __compress_single(self, f, d):
        """
        Compress_q(x, d) = floor((2^d / q) * x) mod 2^d
        """

        scale_factor = (2 ** d) / self.q
        new_f = []
        
        for i in range(256):
            new_val = round(scale_factor * int(f[i])) % (2 ** d)
            new_f.append(new_val)
       
        return self.__coeffs_to_poly(new_f)

    def __decompress_single(self, f, d):
        """
        Decompress_q(x, d) = floor((q / 2^d) * x)
        """

        scale_factor = self.q / (2 ** d)
        new_f = []

        for i in range(256):
            new_val = round(scale_factor * int(f[i]))
            new_f.append(new_val)

        return self.__coeffs_to_poly(new_f)

    def __compress(self, f, d):
        if not isinstance(f, (list, FreeModuleElement_generic_dense)):
            return self.__compress_single(f, d)
        else:
            ret = []
            for e in f:
                ret.append(self.__compress_single(e, d))
            return vector(self.Rq, ret)

    def __decompress(self, f, d):
        if not isinstance(f, (list, FreeModuleElement_generic_dense)):
            return self.__decompress_single(f, d)
        else:
            ret = []
            for e in f:
                ret.append(self.__decompress_single(e, d))
            return vector(self.Rq, ret)

    def generate_key(self):
        d = token_bytes(32)
        rho, sigma = self.__G(d)
        N = 0
        A_hat = Matrix([[self.__parse(self.__XOF(rho, j, i, 64 * self.eta_1)) for j in range(self.k)] for i in range(self.k)])

        s = []
        for i in range(self.k):
            s.append(self.__CBD(self.__PRF(sigma, N, 64 * self.eta_1), self.eta_1))
            N += 1
        
        e = []
        for i in range(self.k):
            e.append(self.__CBD(self.__PRF(sigma, N, 64 * self.eta_1), self.eta_1))
            N += 1

        s_hat = self.__NTT(s)
        e_hat = self.__NTT(e)
        t_hat = self.NTT_product(A_hat, s_hat) + e_hat

        pk = self.__encode(t_hat, 12) + rho
        sk = self.__encode(s_hat, 12)
        return pk, sk

    def encrypt(self, pk, m, r):
        t = pk[:-32]
        rho = pk[-32:]

        t_hat = self.__decode(t, 12)
        A_hat = Matrix([[self.__parse(self.__XOF(rho, i, j, 64 * self.eta_1)) for j in range(self.k)] for i in range(self.k)])

        N = 0
        r_vec = []
        for i in range(self.k):
            r_vec.append(self.__CBD(self.__PRF(r, N, 64 * self.eta_1), self.eta_1))
            N += 1
        
        e1_vec = []
        for i in range(self.k):
            e1_vec.append(self.__CBD(self.__PRF(r, N, 64 * self.eta_2), self.eta_2))
            N += 1

        e1_hat = vector(self.Rq, e1_vec)
        e2 = self.__CBD(self.__PRF(r, N, 64 * self.eta_2), self.eta_2)
            
        r_hat = self.__NTT(r_vec)
        u_hat = self.__NTT_inv(self.NTT_product(A_hat, r_hat)) + e1_hat
        v = self.__NTT_inv(self.NTT_product(t_hat, r_hat)) + e2 + self.__decompress(self.__decode(m, 1), 1)

        comp_u = [self.__compress(u_hat[i],  self.d_u) for i in range(len(u_hat))]
        comp_v = self.__compress(v, self.d_v)
        c1 = self.__encode(comp_u,  self.d_u)
        c2 = self.__encode(comp_v,  self.d_v)
        ciphertext = c1 + c2

        return ciphertext

    def decrypt(self, sk, c):
        c1 = c[:-256 * self.d_v // 8]
        c2 = c[-256 * self.d_v // 8:]
        u_hat = self.__decompress(self.__decode(c1, self.d_u), self.d_u)
        v = self.__decompress(self.__decode(c2, self.d_v), self.d_v)
        s_hat = self.__decode(sk, 12)
        m = self.__encode(self.__compress(v - self.__NTT_inv(self.NTT_product(s_hat, self.__NTT(u_hat))), 1), 1)
        return m

    def ccakem_generate_key(self):
        z = token_bytes(32)
        pk, sk_prime = self.generate_key()
        sk = sk_prime + pk + self.__H(pk) + z
        return pk, sk

    def ccakem_encrypt(self, pk):
        m = token_bytes(32)
        m = self.__H(m)
        K, r = self.__G(m + self.__H(pk))
        c = self.encrypt(pk, m, r)
        K = self.__KDF(K + self.__H(c))
        return c, K

    def ccakem_decrypt(self, c, sk):
        pk_start_idx = 12 * self.k * self.n // 8
        h_start_idx = 24 * self.k * self.n // 8 + 32
        z_start_idx = h_start_idx + 32
        pk = sk[pk_start_idx:h_start_idx]
        h = sk[h_start_idx:z_start_idx]
        z = sk[z_start_idx:]
        sk = sk[:pk_start_idx]
        m_prime = self.decrypt(sk, c)
        K_prime, r_prime = self.__G(m_prime + h)
        c_prime = self.encrypt(pk, m_prime, r_prime)
        if c == c_prime:
            return self.__KDF(K_prime + self.__H(c))
        else:
            return self.__KDF(z + self.__H(c))    

In [18]:
# from kyber import Kyber512
kyber = Kyber512()

In [19]:
pk, sk = kyber.ccakem_generate_key()
print(pk)
print(sk)

b'z]I6\x97!\x18\xe3\x13\xd6\xe8\xf1b=~\x1ae=>\xe5\xa0\xd8?\xcd\xf4\xed\xf9\xd6\xab)\xc9VHK\x84\xb9\xa9\x0f\x88\xcf\x03\xbef\xe2\xb6,h\x81\xfbLDf`\x8a\xf2@\x81FHU\x17\x99#)\xd0l<\x15*l"\x04\xfa\xa5\xbe\x0e, +\xab\x804\x07\xca\x14\xe6B\x90\x9dp]\x88\xa4\xa8Y^\xe7\xcb\xd1\xd6.,SCL\xc1/~\x9b\xccM\xd8?\x06\x13\xeap\x9a\xd3f9@3\xca\xc5\x9ee\xca0pbS\xab\x0c\xa8I%h\xab\xe7\xa2W\x8a\xb4\xf1\x10p\x98&z\xff\x92u\xc4\x13@\xb7\x0b\rnHBM\x8ch\xa9\xa29m\xa1\xe3\xcc\x96\xd4\x10\x92\x85N\x03\xb6\xc4\x9c\xc6\x94H\xa9\xdf4\xde\x95\x8bf\x99\xab^\xb1E&$Z\x06v@\x08\x92AU\x99em\xcdi\\\xa6?\xfd\xd1\xc1\xb9\xf5\xad\xee\xa2\x97\xf0\xaa \xca\xb9\xaa\xba\xa2\'$\x8d\x8a&\r1\x85\x1fXJ\xf6a&\xf8\xd0\xf0\xf6\x16\xc6I\x98$\xdc\xe8h\x9c\x91\xd4H\x9aA\xe4\xce\x16{\xe0H\x7f\xd3u\x9d\x9dt\xaf\x94\xd0u \x90)$\x01\x0c\xecr\x0eV#\x04M\x1e\xd8S\xc4\x86\x16\xf3/\xe1\xa3\xe9\xec\xde<\x89\xb5\xc2Q\xcd\xca?\x9caf]3I\x8c3j\x16HTn\xf9\\\x9c\x02\x9ee\xa6\x97m\x8bd\xc4\xe7\x86bp1p/Fp\xf2eU\xf6b\xc1\xab\x9bhu\xde\xb4\x01\x91R-\xa2\xc2

In [20]:
c, K1 = kyber.ccakem_encrypt(pk)
print(c)
print(K1)

381*xbar^255 + 1211*xbar^254 + 2610*xbar^253 + 502*xbar^252 + 1361*xbar^251 + 1905*xbar^250 + 1942*xbar^249 + 589*xbar^248 + 2614*xbar^247 + 1576*xbar^246 + 3305*xbar^245 + 2388*xbar^244 + 3177*xbar^243 + 434*xbar^242 + 930*xbar^241 + 666*xbar^240 + 2949*xbar^239 + 1285*xbar^238 + 2595*xbar^237 + 3163*xbar^236 + 2570*xbar^235 + 2821*xbar^234 + 3112*xbar^233 + 2130*xbar^232 + 2806*xbar^231 + 1717*xbar^230 + 2030*xbar^229 + 50*xbar^228 + 2506*xbar^227 + 2008*xbar^226 + 788*xbar^225 + 2475*xbar^224 + 1078*xbar^223 + 2081*xbar^222 + 2169*xbar^221 + 2864*xbar^220 + 2898*xbar^219 + 2512*xbar^218 + 266*xbar^217 + 1221*xbar^216 + 2119*xbar^215 + 568*xbar^214 + 2436*xbar^213 + 1047*xbar^212 + 1122*xbar^211 + 2039*xbar^210 + 1964*xbar^209 + 3240*xbar^208 + 1339*xbar^207 + 2968*xbar^206 + 1754*xbar^205 + 34*xbar^204 + 438*xbar^203 + 700*xbar^202 + 2631*xbar^201 + 2772*xbar^200 + 1316*xbar^199 + 2315*xbar^198 + 2521*xbar^197 + 348*xbar^196 + 2014*xbar^195 + 1201*xbar^194 + 650*xbar^193 + 1353*xbar

In [21]:
K2 = kyber.ccakem_decrypt(c, sk)
print(K2)
print(K1==K2)

381*xbar^255 + 1211*xbar^254 + 2610*xbar^253 + 502*xbar^252 + 1361*xbar^251 + 1905*xbar^250 + 1942*xbar^249 + 589*xbar^248 + 2614*xbar^247 + 1576*xbar^246 + 3305*xbar^245 + 2388*xbar^244 + 3177*xbar^243 + 434*xbar^242 + 930*xbar^241 + 666*xbar^240 + 2949*xbar^239 + 1285*xbar^238 + 2595*xbar^237 + 3163*xbar^236 + 2570*xbar^235 + 2821*xbar^234 + 3112*xbar^233 + 2130*xbar^232 + 2806*xbar^231 + 1717*xbar^230 + 2030*xbar^229 + 50*xbar^228 + 2506*xbar^227 + 2008*xbar^226 + 788*xbar^225 + 2475*xbar^224 + 1078*xbar^223 + 2081*xbar^222 + 2169*xbar^221 + 2864*xbar^220 + 2898*xbar^219 + 2512*xbar^218 + 266*xbar^217 + 1221*xbar^216 + 2119*xbar^215 + 568*xbar^214 + 2436*xbar^213 + 1047*xbar^212 + 1122*xbar^211 + 2039*xbar^210 + 1964*xbar^209 + 3240*xbar^208 + 1339*xbar^207 + 2968*xbar^206 + 1754*xbar^205 + 34*xbar^204 + 438*xbar^203 + 700*xbar^202 + 2631*xbar^201 + 2772*xbar^200 + 1316*xbar^199 + 2315*xbar^198 + 2521*xbar^197 + 348*xbar^196 + 2014*xbar^195 + 1201*xbar^194 + 650*xbar^193 + 1353*xbar

In [12]:
pk, sk = kyber.generate_key()
print(pk)
print(sk)
print(len(pk))
print(len(sk))

b')\xee\x1e\xa2\x1a\x1d\xff\xdc\x9d\xe9Nm\xb6:\xb8T5C\xf2\x91\x04\xe95(\x089Y\xf0#p\x1c\x11\xba\xeb\xa6\x0c\x98e!\xc2\x90\xe3E\r\xae\nH\x8a\xf3J\x13<9$\xeem\xf1\xb5\x04\x94x\xd6\x86\xbe\xd7\x11\x84\x9daj\xa8~\x00\xdfL\xd8\xe0\xbd\xd1?>\xabb\xc5>Oz"\xc6\xf9\xd3E\x1a\xc0\x12\xd1\x82\x8aU\xff\x8d\x95\xdc\x06\xa1\xe9\xdf\xe9\x86\xd2\xaae9\xbe\x10\x96\xf1]\xc8\xaeK\x19\xfc\x0bc\x92\xb5S0t.\x8cK\x8c\xf4e\x19\xda\xc8\x8b%Q\r\x93\x0c\x12@X 1\x1bW\x99&(\xe3\xf7\x17,6\xe35\xb5K#\xe3S.T\x11\x1e.\x06\xe2\x92\xa5D\x94j\xea\xe9\xe1t\xee\x8f\x1e4\x86B<\x88\xf0\xe7\xc0\xee\xfa\xe4%\xacH\xf4A\x0cSy@\xcc\xaa\xa6\xb5\x81\x98\xb2\xa6\t\x18\xbf8`\x8a\x88\xc1,S\xe5\'\xa7x\xd8\xaeT\t\xa5y>\x1fb\xbe.\xed\xf8\xaa\x01\x7fg6\x83\xd6\xce\x96\xc0f)\xc3\xf6\xe99\xa0\xe1+\xa5\xcf1\x02\xbb\xc4\x0c\xc9\x9fqp\xcf\xbcc\x00\xa1\x1d\x07\xd6\x0c\x01t\x198\xe4\x97\x16.\xdb4\x80^`\x99Ni\xa8\xfd\xac\xfa\x1aS9|\x85deb\xa4\xd1O!D\xab\xfc#\x04\x9eD\x92\xbdT\x15\x0e\x0e+\xcd\x00\x16Y\xe5\xc5\xc1*\xac\xe54\xc9\xfc\xca^\xa6\x04\xea

In [6]:
m = b"I love Cryptography!"
m = m.ljust(32, b'\0')
print(m)
r = os.urandom(32)
c = kyber.encrypt(pk, m, r)
print(c)
print(len(c))

b'I love Cryptography!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
b'\xb3\x1f>kc\x8f0\x8a\xe2\xdc1\xa6\x8fma\x18\xc3"\xd0o\xf29[\x18\xf4\xcc\x90\xa5\xca\xc0\x00T\x882\xfd9NH\xeaL\x8aCF0\x1f\x97*)\x02\xf06\xccv\xee\xe4\xae\xe2\x85Q\xdf\xcb]\xb8\xdb\x8e\xc8\xe5 pV\xbf`\x8cfi.#\xad\x1d\x90\xe5a^J\xbeK\xe7\xf6\xaa\xd7\xae%<d%\x83\x19\x9c\x87\xe5\xab}\x86\x16S\x16}\x0cT\xf7\xf5\x044\xcdi\xac[\xfe\xcc.\xdc\x13\x04\xa7\xc1\n4\xed-\x0c\xe2\t\xd2\x1a\xf6\x94\xb5\x8f\xd9\xbem\xdd\x15{\x03"2\x9d\xc5\xbe^\xf0g\x16\xcc\xef\x96}\x8a\xad\xa9"\x91\x8f\x1e?i\xca\xb3\x18\xb9H\xd3\x0eOZ\xce\r\xd2n\xe6\xa6\xc1\xda\x98\xb1\xbd\\h\xb6\x15\xead^\x00\xa3\x94\xff\xcc\x1b4\xc5r\xf4\x87\xcdd\xf5\xe3:\x84\xe3[\xaaB\xb5cLaH\x8bBjv*\xc1\x1e\x87\xdf\t\t\xf9\x18?\xb2#.\x90\xfc\xc6VA\x18\x90)hr\x10L\x9d\xb8\x93\x05c\xb55{\x8ak\x81\x8d0\x07\xbe\xd7c\xe1M(\xe8\xd2\xae_\xf1\x8e\x9a^\xc1U\xb2\\rqF\xca\xcf\xffj\xa0\xbff\x0c\xc0@^\xf0\xfa\xac\x03z\x9f\x89\xb4\x96\xfa\xc6S^/U3V\xbb\xcaf\xce\xfd\xceJ;\xa1\x98?\xa7\x18\xe

In [7]:
m_prime = kyber.decrypt(sk, c)
print(m_prime)

b'I love Cryptography!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
