In [1]:
from Polynomial import Polynomial, poly_round, PolynomialTensor
import numpy as np
from dataclasses import dataclass

In [2]:
SecretKey = Polynomial
PrivateKey = tuple[Polynomial, Polynomial]
RelinearizationKey = tuple[Polynomial, Polynomial]

In [235]:
### RLK V1 polynom calc
@dataclass
class BfvConfig:
    poly_len: int
    modulus: int
    base: int


def decompose_len(modulus: int, base: int) -> int:
    return int(np.floor(np.emath.logn(base, modulus))) + 1


def to_base(poly: Polynomial, base: int) -> np.ndarray:
    l = decompose_len(poly.modulus, base)

    def _to_base(arr: np.ndarray, base: int, index=1) -> list:
        if index >= l:
            return arr
        return [arr[-1] % base] + _to_base([arr[-1] // base], base, index + 1)

    return np.asarray(_to_base(poly.poly_mat, base))


class BfvMessage:

    def __init__(self, config: BfvConfig, rlk: RelinearizationKey,
                 u: Polynomial, v: Polynomial):
        self.config = config
        self.u = u
        self.v = v
        self.rlk = rlk

    def __add__(self, other: "BfvMessage") -> "BfvMessage":
        assert self.config == other.config
        return BfvMessage(self.config,
                          self.rlk,
                          u=(self.u + other.u) % self.config.modulus,
                          v=(self.v + other.v) % self.config.modulus)

    def __mul__(self, other: "BfvMessage") -> "BfvMessage":
        assert self.config == other.config
        l = decompose_len(self.config.modulus, self.config.base)

        def round_with_poly(poly: Polynomial) -> Polynomial:
            return Polynomial((np.round(poly.poly_mat)) % self.config.modulus,
                              self.config.modulus)

        v = round_with_poly((2 / self.config.modulus) * (self.v @ other.v))

        u = round_with_poly(
            (2 / self.config.modulus) * (self.v @ other.u + self.u @ other.v))

        uv = round_with_poly((2 / self.config.modulus) * (self.u @ other.u))

        uvu = Polynomial(
            (
                self.rlk[0].mul_matrix(axis=1).reshape(l, self.config.poly_len, self.config.poly_len)
                @ to_base(uv, self.config.base).reshape(l, self.config.poly_len, 1)
             ).sum(axis=0).T,
             self.config.modulus
        ) # yapf: disable
        uvv = Polynomial(
            (
                self.rlk[1].mul_matrix(axis=1).reshape(l, self.config.poly_len, self.config.poly_len)
                @ to_base(uv, self.config.base).reshape(l, self.config.poly_len, 1)
            ).sum(axis=0).T,
            self.config.modulus
        ) # yapf: disable

        return BfvMessage(self.config,
                          self.rlk,
                          u=(u + uvu) % self.config.modulus,
                          v=(v + uvv) % self.config.modulus)


class BFV:

    @staticmethod
    def keygen(
            conf: BfvConfig
    ) -> tuple[SecretKey, PrivateKey, RelinearizationKey]:
        # Key Generation
        s = Polynomial.random_polynomial(conf.poly_len, conf.modulus, 0, 1)
        e = Polynomial.random_polynomial(conf.poly_len, conf.modulus, 0, 1)
        A = Polynomial.random_polynomial(conf.poly_len, conf.modulus)
        b = (-1 * (A @ s + e)) % conf.modulus

        # RLK Generation
        l = decompose_len(conf.modulus, conf.base)
        ra = PolynomialTensor.random_polynomial_matrix(conf.poly_len,
                                                       conf.modulus, (l, ))
        re = PolynomialTensor.random_polynomial_matrix(conf.poly_len,
                                                       conf.modulus, (l, ), 0,
                                                       2)

        ra_s = PolynomialTensor(
            (ra.mul_matrix(axis=1) @ s.poly_mat.T).reshape(l, conf.poly_len),
            conf.modulus)
        T_is = np.asarray([[conf.base**i for i in range(l)]]).T
        ti_s2 = PolynomialTensor((s @ s).poly_mat * T_is, conf.modulus)
        rb = (-1 * (ra_s + re) + ti_s2) % conf.modulus

        return (s, (A, b), (ra, rb))

    @staticmethod
    def encrypt(conf: BfvConfig, pk: PrivateKey, rlk: RelinearizationKey,
                message: list) -> BfvMessage:
        assert isinstance(message, list)
        A, b = pk
        e1 = Polynomial.random_polynomial(conf.poly_len, conf.modulus, 0, 1)
        e2 = Polynomial.random_polynomial(conf.poly_len, conf.modulus, 0, 1)
        r = Polynomial.random_polynomial(conf.poly_len, conf.modulus, 0, 2)

        v = (b @ r + e1 + Polynomial(
            np.asarray([message]) *
            (conf.modulus // 2), conf.modulus)) % conf.modulus
        u = (A @ r + e2) % conf.modulus

        return BfvMessage(conf, rlk, u=u, v=v)

    def decrypt(sk: SecretKey, m_enc: BfvMessage) -> list:
        return (np.round(
            ((2 / m_enc.config.modulus) *
             ((m_enc.v + m_enc.u @ sk) % m_enc.config.modulus)).poly_mat) %
                2).tolist()[0]


In [68]:
#### RLK V2

# @dataclass
# class BfvConfig:
#     poly_len: int
#     modulus: int
#     p: int


# class BfvMessage:

#     def __init__(self, config: BfvConfig, rlk: RelinearizationKey,
#                  u: Polynomial, v: Polynomial):
#         self.config = config
#         self.u = u
#         self.v = v
#         self.rlk = rlk

#     def __add__(self, other: "BfvMessage") -> "BfvMessage":
#         assert self.config == other.config
#         return BfvMessage(self.config,
#                           self.rlk,
#                           u=(self.u + other.u) % self.config.modulus,
#                           v=(self.v + other.v) % self.config.modulus)

#     def __mul__(self, other: "BfvMessage") -> "BfvMessage":
#         assert self.config == other.config

#         def round_with_poly(poly: Polynomial) -> Polynomial:
#             return Polynomial((np.round(poly.poly_mat)) % self.config.modulus,
#                               self.config.modulus)

#         v = round_with_poly((2 / self.config.modulus) * (self.v @ other.v))

#         u = round_with_poly(
#             (2 / self.config.modulus) * (self.v @ other.u + self.u @ other.v))

#         uv = round_with_poly(
#             (2 / self.config.modulus) * (self.u @ other.u)).change_modulus(
#                 self.config.p * self.config.modulus)
#         uvu = round_with_poly((uv @ self.rlk[0]) * (1 / self.config.p))
#         uvv = round_with_poly((uv @ self.rlk[1]) * (1 / self.config.p))

#         return BfvMessage(self.config,
#                           self.rlk,
#                           u=(u + uvu) % self.config.modulus,
#                           v=(v + uvv) % self.config.modulus)


# class BFV:

#     @staticmethod
#     def keygen(
#             conf: BfvConfig
#     ) -> tuple[SecretKey, PrivateKey, RelinearizationKey]:
#         s = Polynomial.random_polynomial(conf.poly_len, conf.modulus, 0, 1)
#         e = Polynomial.random_polynomial(conf.poly_len, conf.modulus, 0, 1)
#         A = Polynomial.random_polynomial(conf.poly_len, conf.modulus)

#         rs = Polynomial(s.poly_mat, conf.p * conf.modulus)
#         re = Polynomial.random_polynomial(conf.poly_len, conf.p * conf.modulus,
#                                           0, 1)
#         rA = Polynomial.random_polynomial(conf.poly_len, conf.p * conf.modulus)

#         return (s, (A, (-1 * (A @ s + e)) % conf.modulus),
#                 (rA, (-1 * (rA @ rs + re) + conf.p * (rs @ rs)) %
#                  (conf.p * conf.modulus)))

#     @staticmethod
#     def encrypt(conf: BfvConfig, pk: PrivateKey, rlk: RelinearizationKey,
#                 message: list) -> BfvMessage:
#         assert isinstance(message, list)
#         A, b = pk
#         e1 = Polynomial.random_polynomial(conf.poly_len, conf.modulus, 0, 1)
#         e2 = Polynomial.random_polynomial(conf.poly_len, conf.modulus, 0, 1)
#         r = Polynomial.random_polynomial(conf.poly_len, conf.modulus, 0, 2)

#         v = (b @ r + e1 + Polynomial(
#             np.asarray([message]) *
#             (conf.modulus // 2), conf.modulus)) % conf.modulus
#         u = (A @ r + e2) % conf.modulus

#         return BfvMessage(conf, rlk, u=u, v=v)

#     def decrypt(sk: SecretKey, m_enc: BfvMessage) -> list:
#         return (np.round(
#             ((2 / m_enc.config.modulus) *
#              ((m_enc.v + m_enc.u @ sk) % m_enc.config.modulus)).poly_mat) %
#                 2).tolist()[0]


In [193]:
conf = BfvConfig(10, 1000, 2)
sk, pk, rlk = BFV.keygen(conf)

In [236]:
# conf = BfvConfig(8, 1000, 1000**3)
conf = BfvConfig(3, 10000, 2)
sk, pk, rlk = BFV.keygen(conf)
# m1 = Polynomial.random_polynomial(conf.poly_len, conf.modulus, 0, 1)
# m_e1 = BFV.encrypt(conf, pk, rlk, m1.poly_mat[0].tolist())
# for i in range(10000):
#     m2 = Polynomial.random_polynomial(conf.poly_len, conf.modulus, 0, 1)
#     m_e2 = BFV.encrypt(conf, pk, rlk, m2.poly_mat[0].tolist())
#     m_e1 = m_e1*m_e2
#     m1 = (m1 @ m2)%2
#     assert BFV.decrypt(sk, m_e1) == m1.poly_mat[0].tolist(), f"{i}: {m1} -- {BFV.decrypt(sk, m_e1)}"


In [237]:
m1 = Polynomial.random_polynomial(conf.poly_len, conf.modulus, 0, 1)
m_e1 = BFV.encrypt(conf, pk, rlk, m1.poly_mat[0].tolist())
m1

1.0 + 0.0·x + 1.0·x²

In [238]:
%timeit m_e1*m_e1

1.18 ms ± 24.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [233]:
BFV.decrypt(sk, m_e1*m_e1)

[1.0, 0.0, 0.0]

In [230]:
m1 @ m1 % 2

1.0 + 1.0·x + 0.0·x²

In [240]:
%load_ext line_profiler
%load_ext cProfile


The cProfile module is not an IPython extension.


In [244]:
%prun m_e1*m_e1

 

         1178 function calls (1152 primitive calls) in 0.011 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       32    0.001    0.000    0.001    0.000 {method 'astype' of 'numpy.ndarray' objects}
       32    0.001    0.000    0.004    0.000 _special_matrices.py:221(circulant)
       32    0.001    0.000    0.001    0.000 {method 'copy' of 'numpy.ndarray' objects}
       64    0.001    0.000    0.001    0.000 {built-in method numpy.arange}
        4    0.001    0.000    0.001    0.000 {method 'flatten' of 'numpy.ndarray' objects}
        6    0.001    0.000    0.007    0.001 Polynomial.py:84(<listcomp>)
       32    0.001    0.000    0.001    0.000 stride_tricks.py:38(as_strided)
        6    0.000    0.000    0.000    0.000 shape_base.py:292(hstack)
       72    0.000    0.000    0.000    0.000 {built-in method numpy.asarray}
       32    0.000    0.000    0.000    0.000 {method 'outer' of 'numpy.ufunc' objects}
     2

In [252]:
def mul(m1: "BfvMessage", m2: "BfvMessage") -> "BfvMessage":
    l = decompose_len(m1.config.modulus, m1.config.base)

    def round_with_poly(poly: Polynomial) -> Polynomial:
        return Polynomial((np.round(poly.poly_mat)) % m1.config.modulus,
                            m1.config.modulus)

    v = (2 / m1.config.modulus) * (m1.v @ m2.v)
    v = round_with_poly(v)

    u = (2 / m1.config.modulus) * (m1.v @ m2.u + m1.u @ m2.v)
    u = round_with_poly(u)

    uv = (2 / m1.config.modulus) * (m1.u @ m2.u)
    uv = round_with_poly(uv)
    uv_base = to_base(uv, m1.config.base).reshape(l, m1.config.poly_len, 1)

    rA = m1.rlk[0].mul_matrix(axis=1)
    rA = rA.reshape(l, m1.config.poly_len, m1.config.poly_len)
    rb = m1.rlk[1].mul_matrix(axis=1)
    rb = rb.reshape(l, m1.config.poly_len, m1.config.poly_len)
    uvu = Polynomial(
        (
            rA @ uv_base
            ).sum(axis=0).T,
            m1.config.modulus
    ) # yapf: disable
    uvv = Polynomial(
        (
            rb @ uv_base
        ).sum(axis=0).T,
        m1.config.modulus
    ) # yapf: disable

    return BfvMessage(m1.config,
                        m1.rlk,
                        u=(u + uvu) % m1.config.modulus,
                        v=(v + uvv) % m1.config.modulus)

m1 = Polynomial.random_polynomial(conf.poly_len, conf.modulus, 0, 1)
m_e1 = BFV.encrypt(conf, pk, rlk, m1.poly_mat[0].tolist())

%lprun -f mul mul(m_e1, m_e1)

Timer unit: 1e-09 s

Total time: 0.0031485 s
File: /tmp/ipykernel_18904/3485604266.py
Function: mul at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def mul(m1: "BfvMessage", m2: "BfvMessage") -> "BfvMessage":
     2         1     153684.0 153684.0      4.9      l = decompose_len(m1.config.modulus, m1.config.base)
     3                                           
     4         1       1690.0   1690.0      0.1      def round_with_poly(poly: Polynomial) -> Polynomial:
     5                                                   return Polynomial((np.round(poly.poly_mat)) % m1.config.modulus,
     6                                                                       m1.config.modulus)
     7                                           
     8         1     200575.0 200575.0      6.4      v = (2 / m1.config.modulus) * (m1.v @ m2.v)
     9         1      19731.0  19731.0      0.6      v = round_with_poly(v)
    10       