<a href="https://colab.research.google.com/github/ManFat84/Git_Security/blob/main/RBE_Scheme.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Define the utils encryption functions**

In [2]:
import random
from typing import List, Tuple


def is_prime(n: int) -> bool:

    if n <= 1:
        return False
    if n <= 3:
        return True
    if n % 2 == 0 or n % 3 == 0:
        return False
    i = 5
    while i * i <= n:
        if n % i == 0 or n % (i + 2) == 0:
            return False
        i += 6

    return True


def prime_factors(n: int) -> set:
    """
    Returns the set of prime factors of n.
    """
    factors = set()
    i = 2
    while i * i <= n:
        if n % i:
            i += 1
        else:
            n //= i
            factors.add(i)
    if n > 1:
        factors.add(n)

    return factors


def is_generator(g: int, p: int) -> bool:
    """
    Checks if g is a generator of the group Z_p* (multiplicative group modulo p).

    Args:
        g (int): Candidate generator.
        p (int): Prime number (modulus of the group).

    Returns:
        bool: True if g is a generator, False otherwise.
    """
    if not is_prime(p):
        raise ValueError("p must be prime to check for generator.")

    # Group order is p - 1 for Z_p*
    group_order = p - 1

    # Factor the group order (we know p is prime, so we check prime divisors)
    factors = prime_factors(group_order)

    # g is a generator if g^(group_order / q) mod p != 1 for all q in factors
    for factor in factors:
        if pow(g, group_order // factor, p) == 1:
            return False

    return True


def extended_gcd(a: int, b: int) -> Tuple[int, int, int]:
    """
    Computes the greatest common divisor (GCD) of two integers a and b,
    along with the coefficients x and y such that: a*x + b*y = gcd(a, b).
    This is known as the Extended Euclidean Algorithm.

    Args:
        a (int): First integer.
        b (int): Second integer.

    Returns:
        Tuple[int, int, int]: A tuple containing (gcd, x, y) where
        gcd is the greatest common divisor of a and b,
        x and y are the coefficients satisfying BÃ©zout's identity.

    Example:
        >>> extended_gcd(30, 20)
        (10, 1, -1)
        # 30*1 + 20*(-1) = 10
    """
    if a == 0:
        return b, 0, 1

    gcd, x1, y1 = extended_gcd(b % a, a)
    x = y1 - (b // a) * x1
    y = x1

    return gcd, x, y


def modular_inverse(a: int, m: int) -> int:
    """
    Computes the modular inverse of a under modulus m,
    i.e., x such that (a * x) % m == 1.

    It uses the Extended Euclidean Algorithm.

    Args:
        a (int): The number to find the modular inverse of.
        m (int): The modulus.

    Returns:
        int: The modular inverse of a mod m if it exists, else None if the inverse does not exist.
    """
    # Calculate gcd and the coefficients using the extended Euclidean algorithm
    gcd, x, y = extended_gcd(a, m)

    # Modular inverse exists only if gcd is 1
    if gcd != 1:
        raise ValueError(f"No modular inverse for {a} mod {m}")
    else:
        # x might be negative, so we take x % m to get the positive value
        return x % m


def select_random_generator(p: int) -> int:
    if not is_prime(p):
        raise ValueError("p must be prime to select generator.")

    while True:
        g = random.randint(2, p - 2)
        if is_generator(g, p):
            return g


def generate_modular_set(l: int, p: int, target_sum: int) -> List[int]:
    """
    Generates a list of length l of random values modulo P
    such that their sum modulo p equals target_sum.

    Args:
        target_sum (int): The target sum modulo p.

    Returns:
        List[int]: List of integers of length self.l satisfying the sum condition.
    """
    # Create a list of random values mod p
    values = [random.randint(0, p - 1) for _ in range(l - 1)]

    # Calculate the sum of the current values modulo p
    current_sum_mod_p = sum(values) % p

    # Calculate the last value needed to satisfy the condition
    last_value = (target_sum - current_sum_mod_p) % p

    # Append the last value to the set
    values.append(last_value)

    return values


def modular_multiply_sets(
    set1: List[int], set2: List[int], reg: int, p: int, scale: int = None
) -> List[int]:
    """
    Performs modular p multiplication of two sets and adjust values based on k^n.
    n being the highest exponent in the set (set1 and set2 are of size n).

    Args:
        set1 (list): The first set of values of the form [k, k^2, ..., k^l].
        set2 (list): The second set of values of the form [k, k^2, ..., k^l].
        reg (init): Regulator value.
        p (int): The modulus for the modular multiplication.
        scale (int): The scale to use in case of encrypting plaintexts.

    Returns:
        List[int]: A new set after modular multiplication and adjustment.
    """
    assert len(set2) == len(set2)

    l = len(set1)
    result = [0] * l
    for i in range(l):
        for j in range(l):
            # Perform element-wise multiplication modulo p
            element = (set1[i] * set2[j]) % p

            pos = 0
            # Check if the set length is reached
            if i + j >= l - 1:
                # Adjust the value by multiplying it by reg
                element = (element * reg) % p
                pos = i + j - l + 1
            else:
                pos = i + j + 1

            # Insert the value into the result set in the right position
            result[pos] = (result[pos] + element) % p

    # Case when dealing with floats
    if scale:
        result = [(modular_inverse(scale, p) * x) % p for x in result]

    return result


def encode_real(x: float, scale: int) -> tuple[int, float]:
    """Scales a real number to an integer for encryption."""
    return int(round(x * scale)), scale


def decode_real(x_int: int, scale: int, P: int = None) -> float:
    """Scales back integer to real after decryption. Handles centered lifting if P is provided."""
    if P is not None and x_int > P // 2:
        x_int -= P

    return x_int / scale

# **Define the interfaces of the basic RBE**

In [3]:
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, List, Dict

Ciphertext = List
Plaintext = int


class HomomorphicScheme(ABC):
    """Abstract interface every homomorphic scheme should implement."""

    def __init__(self, *args, **kwargs):
        super().__init__()

    @abstractmethod
    def keygen(self):
        """Generate all necessary keys and parameters (private key, public key, etc.)."""
        pass

    @abstractmethod
    def encrypt(self, m: Plaintext) -> Ciphertext:
        """Encrypt a message m and return ciphertext."""
        ...
        pass

    @abstractmethod
    def decrypt(self, c: Ciphertext) -> Plaintext:
        """Decrypt ciphertext and return plaintext."""
        ...
        pass

    @abstractmethod
    def add(self, c1: Ciphertext, c2: Ciphertext) -> Ciphertext:
        """Homomorphic addition of two ciphertexts."""
        ...
        pass

    @abstractmethod
    def multiply(self, c1: Ciphertext, c2: Ciphertext) -> Ciphertext:
        """Homomorphic multiplication of two ciphertexts."""
        ...
        pass

# **The encryption scheme**

In [4]:
import random
from typing import List


class ToyFHEncryptor(HomomorphicScheme):
    def __init__(self, P: int, l: int, scale: int, *args, **kwargs) -> None:
        """
        Initialize with small modulus P and parameter l.

        Args:
            P (int): Modulus (prime number recommended for toy use).
            l (int): Length of modular set used in the scheme.
            scale (int): Scale to convert floats to integers.
        """
        super().__init__(*args, **kwargs)

        if not is_prime(P):
            raise ValueError(f"P must be prime, got {P}.")

        self.P = P  # Modulus (prime)
        self.l = l  # Length of sets
        self.scale = scale
        self.k = None  # Base value for the set [k, k^2, ..., k^l]
        self.private_key = None
        self.public_key = None
        self.encrypted_one = None

    def __repr__(self):
        return f"<ToyFHEncryptor P={self.P} l={self.l}>"

    def keygen(self) -> None:
        """
        Generates:
          - private key: random modular set
          - generator g: valid generator mod P
          - public key: g^x mod P for each private key element x
        """
        self.k = select_random_generator(self.P)
        self.reg = modular_inverse(pow(self.k, self.l, self.P), self.P)
        self.private_key = self.generate_private_key()
        self.public_key = self.generate_public_key()
        self.encrypted_one = self.get_encrypted_one()

    def generate_private_key(self) -> List[int]:
        """
        Generates a private key as a list of powers of k modulo P:
        [k^1 mod P, k^2 mod P, ..., k^l mod P]

        Returns:
            List[int]: A list containing [k, k^2, k^3, ..., k^l] with each value taken modulo P.
        """
        return [pow(self.k, i, self.P) for i in range(1, self.l + 1)]

    def generate_public_key(self) -> List[int]:
        """
        Generates a public key using the private key and a modular set summing to 1 mod P.

        Formula:
            >>> public_key[i] = (private_key[i] * ones[i]) mod P
            where 'ones' is a random modular set of length l summing to 1 mod P.
        """
        # Generate modular set summing to 1 mod P, i.e. e(1)
        ones = generate_modular_set(self.l, self.P, 1)
        return [(self.private_key[i] * ones[i]) % self.P for i in range(self.l)]

    def get_encrypted_one(self) -> List[int]:
        """
        Calculates a set of encrypted one suing a trapdoor
        """
        encrypted_one = modular_multiply_sets(
            self.public_key, self.public_key, self.reg, self.P
        )

        # Selecting a trapdoor x for randomness increase, i.e. e(1)^x
        #x = random.randint(0, (self.P - 1)//self.scale)
        x = random.randint(0, 1000)
        for _ in range(x):
            encrypted_one = modular_multiply_sets(
                self.public_key, encrypted_one, self.reg, self.P
            )

        return encrypted_one

    def encrypt(self, m: int | float) -> List[int]:
        """
        Encrypts a plaintext integer into a ciphertext.

        Returns:
            List[int]: list of length l
        """
        if self.public_key is None:
            raise ValueError("Public key not generated yet.")

        m, _ = encode_real(m, self.scale)
        m_set = generate_modular_set(self.l, self.P, m)
        c = [
            sum(self.encrypted_one[i] * m_set[j] for j in range(self.l)) % self.P
            for i in range(self.l)
        ]

        return c

    def decrypt(self, c: List[int]) -> list:
        """
        Decrypt a ciphertext into a plaintext.

        Returns:
            int: plaintext result (sum of decrypted set)
        """
        if self.private_key is None:
            raise ValueError("Private key not generated yet.")

        tmp = [0] * self.l
        m = 0
        for i in range(self.l):
            inv = modular_inverse(pow(self.k, i + 1, self.P), self.P)
            tmp[i] = (inv * c[i]) % self.P
            m = (m + tmp[i]) % self.P

        m = decode_real(m, self.scale, self.P)

        return m

    def add(self, c1: List[int], c2: List[int]) -> List[int]:
        """
        Performs addition of two ciphertexts.
        """
        return [(c1[i] + c2[i]) % self.P for i in range(self.l)]

    def multiply(self, c1: List[int], c2: List[int]) -> List[int]:
        """
        Performs multiplication of two ciphertexts.
        """
        return modular_multiply_sets(c1, c2, self.reg, self.P, self.scale)

# **Demo of HE**

In [5]:
print("=== Simple Homomorphic Encryption Console Demo ===")

P = 28871271685163 #87803 #
l = 2
scale = 1000000

toyfhe = ToyFHEncryptor(P, l, scale)
toyfhe.keygen()
print(toyfhe)
print(f"Trapdoor k: {toyfhe.k}")
print(f"Private Key: {toyfhe.private_key}")
print(f"Public Key: {toyfhe.public_key}")

m1 = -2.12
m2 = 12.31

# Encryption
c1 = toyfhe.encrypt(m1)
c2 = toyfhe.encrypt(m2)
print(f"Encryption1: enc({m1}) = {c1}")
print(f"Encryption2: enc({m2}) = {c2}")

# Decryption
dec1 = toyfhe.decrypt(c1)
dec2 = toyfhe.decrypt(c2)
print(f"Decryption1: dec({c1}) = {dec1}")
print(f"Decryption2: dec({c2}) = {dec2}")

# Homomorphic addition
homomorphic_sum = toyfhe.add(c1, c2)
decrypted_sum = toyfhe.decrypt(homomorphic_sum)
print(
    f"Decrypted Sum: dec({c1} + {c2}) = {decrypted_sum} (Should be {m1} + {m2} = {m1 + m2})"
)

# Homomorphic multiplication
homomorphic_product = toyfhe.multiply(c1, c2)
decrypted_product = toyfhe.decrypt(homomorphic_product)
print(
    f"Decrypted Product: dec({c1} x {c2}) = {decrypted_product} (Should be {m1} * {m2} = {m1 * m2})"
)

=== Simple Homomorphic Encryption Console Demo ===
<ToyFHEncryptor P=28871271685163 l=2>
Trapdoor k: 20215719508542
Private Key: [20215719508542, 20024301171882]
Public Key: [16513348382117, 7832187602908]
Encryption1: enc(-2.12) = [25881550651447, 25925234679109]
Encryption2: enc(12.31) = [18449607386725, 10297207406577]
Decryption1: dec([25881550651447, 25925234679109]) = -2.12
Decryption2: dec([18449607386725, 10297207406577]) = 12.31
Decrypted Sum: dec([25881550651447, 25925234679109] + [18449607386725, 10297207406577]) = 10.19 (Should be -2.12 + 12.31 = 10.190000000000001)
Decrypted Product: dec([25881550651447, 25925234679109] x [18449607386725, 10297207406577]) = -26.0972 (Should be -2.12 * 12.31 = -26.0972)
