# Mini Projekt - Baby Kyber


## Pierścień $\mathbb{Z}_{17}[X]/(X^4+1)$


In [121]:
import math

import numpy as np
import secrets as sc

In [122]:
q = 17
n = 4
k = 2
N = 17

In [123]:
from functools import total_ordering


@total_ordering
class Zn:
    def __init__(self, n, mod=q):
        self.mod = mod
        self.n = int(n) % mod

    def _make_other(self, other):
        if isinstance(other, Zn) and other.mod == self.mod:
            return other
        elif isinstance(other, int):
            return Zn(other, self.mod)
        elif isinstance(other, float):
            return Zn(int(other), self.mod)
        else:
            raise ValueError(
                "Operation not supported between Zn with different moduli or non-integer types")

    def __add__(self, other):
        other = self._make_other(other)
        return Zn(self.n + other.n, self.mod)

    def __radd__(self, other):
        return self.__add__(other)

    def __mul__(self, other):
        other = self._make_other(other)
        return Zn(self.n * other.n, self.mod)

    def __rmul__(self, other):
        return self.__mul__(other)

    def __pow__(self, other):
        other = self._make_other(other)
        return Zn(self.n ** other.n, self.mod)

    def __rpow__(self, other):
        return self.__pow__(other)

    def __sub__(self, other):
        other = self._make_other(other)
        return Zn(self.n - other.n, self.mod)

    def __rsub__(self, other):
        other = self._make_other(other)
        return Zn(other.n - self.n, self.mod)

    def __str__(self):
        return str(self.n)

    def __repr__(self):
        return self.__str__()

    def __eq__(self, other):
        other = self._make_other(other)
        return self.n == other.n

    def __lt__(self, other):
        other = self._make_other(other)
        return self.n < other.n


def divide_polynomials(p1, p2):
    return np.polydiv(p1, p2)


class ZnW:
    def __init__(self, wsp: list[int], n=N, W=[1, 0, 0, 0, 1]):
        _, wsp = divide_polynomials(
            [int(x.n) if isinstance(x, Zn) else x for x in wsp], W)
        self.wsp = [Zn(x, n) for x in wsp]
        self.n = n
        self.W = W

    def __add__(self, other: 'ZnW'):
        if not isinstance(other, ZnW):
            raise TypeError("Operand must be an instance of ZnW")
        len_of_longer = max(len(self.wsp), len(other.wsp))
        wsp1 = [0] * (len_of_longer - len(self.wsp)) + self.wsp
        wsp2 = [0] * (len_of_longer - len(other.wsp)) + other.wsp
        wsp = [x + y for x, y in zip(wsp1, wsp2)]
        return ZnW(wsp, self.n, self.W)

    def __mul__(self, other):
        if isinstance(other, int):
            wsp = [x * other for x in self.wsp]
            return ZnW(wsp, self.n, self.W)

        if not isinstance(other, ZnW):
            raise TypeError("Operand must be an instance of ZnW")

        wsp = np.convolve([x.n for x in self.wsp], [x.n for x in other.wsp])
        return ZnW(wsp, self.n, self.W)

    def __rmul__(self, other):
        if isinstance(other, int):
            wsp = [x * other for x in self.wsp]
            return ZnW(wsp, self.n, self.W)
        if not isinstance(other, ZnW):
            raise TypeError("Operand must be an instance of ZnW")

    def __repr__(self):
        return str(self.wsp)

## Baby Kyber

Zaimplementuj poniższe elementy kryptosystemu Baby Kyber tak, aby osiągnąć jak największą skuteczność w testach (przy niezerowych błędach). Wymagana minimalna skuteczność to 60%.


In [125]:
def weighted_random_selection(r=[-1, 0, 1], p=[0.1, 0.8, 0.1]):
    r = [Zn(el) for el in r]
    return ZnW(np.random.choice(r, size=n, p=p))

### Generowanie klucza

Zaimplementuj funkcję `key_gen()` realizującą generowanie klucza w kryptosystemie Baby Kyber. Funkcja ma zwracać `A,t,s`. Przetestuj, czy dla podanych $A,s,e$ otrzymasz poprawny wielomian $t$.

$A=\left[\begin{matrix}
    6x^3+16x^2+16x+11&9x^3+4x^2+6x+3\\
    5x^3+3x^2+10x+1&6x^3+x^2+9x+15
\end{matrix}\right]$

$\mathbf{s}=(-x^3-x^2+x,-x^3-x)$

$\mathbf{e}=(x^2,x^2-x)$

$\mathbf{t}=A\mathbf{s}+\mathbf{e}:\ \ \mathbf{t}=(16x^3+15x^2+7,10x^3+12x^2+11x+6)$


In [126]:
def key_gen():
    # A = np.array([
    #     np.array([ZnW([6, 16, 16, 11]), ZnW([9, 4, 6, 3])]),
    #     np.array([ZnW([5, 3, 10, 1]), ZnW([6, 1, 9, 15])])])

    # s = np.array([ZnW([-1, -1, 1, 0]), ZnW([-1, 0, -1, 0])])
    # e = [ZnW([1, 0, 0]), ZnW([1, -1, 0])]

    A = np.array([[weighted_random_selection(range(q), p=None)
                 for _ in range(k)] for _ in range(k)])

    s = np.array([weighted_random_selection([-1, 0, 1]) for _ in range(k)])
    e = np.array([weighted_random_selection([-1, 0, 1]) for _ in range(k)])

    t = A @ s + e

    return A, t, s


A, t, s = key_gen()
A, t, s

(array([[[16, 14, 9, 14], [15, 12, 16, 12]],
        [[1, 14, 4, 14], [4, 6, 11, 13]]], dtype=object),
 array([[9, 8, 0, 2], [9, 15, 6, 1]], dtype=object),
 array([[16, 0, 0], [16]], dtype=object))

### Szyfrowanie

Zaimplementuj funkcję `encrypt(A,t,m)` realizującą szyfrowanie w kryptosystemie Baby Kyber a gdzie wejściowe `m` jest w postaci listy. Funkcja ma zwracać szyfrogram `c`. Przetestuj poprawność działania na poniższych danych.

$m=1\cdot x^3+0\cdot x^2+1\cdot x+1=x^3+x+1$

$\mathbf{r}=(-x^3+x^2,x^3+x^2-1)$

$\mathbf{e_1}=(x^2+x,x^2)$

$e_2=-x^3-x^2$

$\mathbf{u}=A^T\mathbf{r}+\mathbf{e_1}:\ \ \mathbf{u}=(11x^3+11x^2+10x+3,4x^3+4x^2+13x+11)$

$v=\mathbf{t}^T\mathbf{r}+e_2+\lfloor\frac{q}{2}\rceil m:\ \ v=8x^3+6x^2+9x+16$

$\mathbf{c}=(\mathbf{u},v):\ \ \mathbf{c}=((11x^3+11x^2+10x+3,4x^3+4x^2+13x+11),8x^3+6x^2+9x+16)$


In [130]:
def encrypt(A, t, m):
    # r = np.array([ZnW([-1, 1, 0, 0]), ZnW([1, 1, 0, -1])])
    # e1 = np.array([ZnW([1, 1, 0]), ZnW([1, 0, 0])])
    # e2 = ZnW([-1, -1, 0, 0])

    r = np.array([weighted_random_selection() for _ in range(k)])
    e1 = np.array([weighted_random_selection() for _ in range(k)])
    e2 = weighted_random_selection()

    u = A.T @ r + e1
    v = t.T @ r + e2 + ZnW([math.ceil((q / 2)) * el for el in ZnW(m).wsp])

    return u, v


m = [1, 0, 1, 1]
c = encrypt(A, t, m)
c

(array([[11, 16, 1, 16], [6, 1, 6, 14]], dtype=object), [2, 1, 10, 14])

### Deszyfrowanie

Zaimplementuj funkcję `decrypt(c,s)` realizującą deszyfrowanie w kryptosystemie Baby Kyber. Funkcja ma zwracać ostateczną odszyfrowaną wiadomość `m_n`. Przetestuj działanie na poniższych danych.

$m_n=v-\mathbf{s}^T\mathbf{u}:\ \ m_n=8x^3+14x^2+8x+6$

$m_n=1\cdot x^3+0\cdot x^2+1\cdot x+1$


In [131]:
def decrypt(c, s):
    u, v = c
    # m_n = ZnW([8, 14, 8, 6])
    m_n = v + -1 * s.T @ u

    rounded_m_n = []
    for x in m_n.wsp:
        dist_to_0 = abs(x.n - 0)
        dist_to_9 = abs(x.n - 9)
        dist_to_17 = abs(x.n - 17)

        if dist_to_9 < min(dist_to_0, dist_to_17):
            rounded_m_n.append(1)
        else:
            rounded_m_n.append(0)

    return rounded_m_n


decrypt(c, s)

[1, 0, 1, 1]

### Testy


In [135]:
import secrets as sc

success = 0
for i in range(1000):
    output = []
    A, t, s = key_gen()

    m = [sc.choice((0, 1)) for k in range(4)]

    c = encrypt(A, t, m)
    m_n = decrypt(c, s)

    print(f'm: {m}')
    print(f'm_n: {m_n}')

    if m_n == m:
        success += 1

print(f'Success rate: {success * 100 / 1000} %')

m: [1, 0, 0, 0]
m_n: [1, 0, 0, 0]
m: [1, 1, 0, 0]
m_n: [1, 1, 0, 0]
m: [1, 1, 0, 1]
m_n: [1, 1, 0, 1]
m: [0, 1, 1, 0]
m_n: [1, 1, 0]
m: [0, 0, 0, 0]
m_n: [0]
m: [1, 1, 1, 0]
m_n: [1, 1, 1, 0]
m: [0, 1, 0, 0]
m_n: [1, 0, 0]
m: [0, 0, 1, 1]
m_n: [0, 0, 1, 1]
m: [0, 1, 0, 1]
m_n: [1, 0, 1]
m: [1, 1, 0, 1]
m_n: [1, 1, 0, 1]
m: [0, 1, 1, 1]
m_n: [0, 1, 1, 1]
m: [0, 1, 1, 1]
m_n: [0, 1, 1, 1]
m: [0, 0, 1, 0]
m_n: [0, 0, 1, 0]
m: [1, 0, 0, 0]
m_n: [1, 0, 0, 0]
m: [1, 0, 0, 1]
m_n: [1, 0, 0, 1]
m: [1, 1, 1, 1]
m_n: [1, 1, 1, 1]
m: [0, 0, 1, 0]
m_n: [0, 0, 1, 0]
m: [1, 1, 1, 1]
m_n: [1, 1, 1, 1]
m: [0, 0, 0, 0]
m_n: [0, 0, 0, 0]
m: [0, 0, 0, 0]
m_n: [0, 0, 0]
m: [0, 1, 0, 1]
m_n: [1, 0, 1]
m: [0, 1, 0, 0]
m_n: [1, 0, 0]
m: [1, 1, 1, 0]
m_n: [1, 1, 1, 0]
m: [1, 1, 0, 1]
m_n: [1, 1, 0, 1]
m: [1, 1, 1, 0]
m_n: [1, 1, 1, 0]
m: [0, 0, 1, 0]
m_n: [1, 0]
m: [0, 1, 0, 1]
m_n: [0, 1, 0, 1]
m: [0, 1, 1, 1]
m_n: [0, 1, 1, 1]
m: [1, 1, 1, 0]
m_n: [1, 1, 1, 0]
m: [1, 1, 1, 1]
m_n: [1, 1, 1, 1]
m: [1, 0, 1, 