# [Homomorphic Encryption](https://en.wikipedia.org/wiki/Homomorphic_encryption)

Playing around with (fully) homomorphic encryption schemes.

In [1]:
import numpy as np
from math import gcd
from random import randint, randrange
from typing import List, Tuple, NamedTuple
from numpy import ndarray

In [2]:
# Taken from: https://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Extended_Euclidean_algorithm
def egcd(a: int, b: int) -> Tuple[int, int, int]:
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = egcd(b % a, a)
        return (g, x - (b // a) * y, y)

def modinv(a: int, m: int) -> int:
    g, x, y = egcd(a, m)
    if g != 1:
        raise Exception('modular inverse does not exist')
    else:
        return x % m

assert modinv(17, 3120) == 2753
assert egcd(1071, 462) == (21, -3, 7)

## [El Gamal](https://en.wikipedia.org/wiki/ElGamal_encryption)

El Gamal can be used to perform encrypted multiplications.

In [3]:
class PublicKey(NamedTuple):
    p: int
    a: int
    b: int

class SecretKey(NamedTuple):
    d: int

class Ciphertext(NamedTuple):
    r: int
    t: int

In [4]:
def keygen(p: int) -> Tuple[PublicKey, SecretKey]:
    # a: 1 < a < p - 1
    a: int = randint(1, p - 1)
    # d: 2 <= d <= p - 2
    d: int = randint(2, p - 2)
    b: int = (a ** d) % p
    pk: PublicKey = PublicKey(p, a, b)
    sk: SecretKey = SecretKey(d)
    return (pk, sk)

In [5]:
def encrypt(message: int, pk: PublicKey) -> Ciphertext:
    k: int = randint(0, 100)    
    r: int = (pk.a ** k) % pk.p
    t: int = ((pk.b ** k) * message) % pk.p
    return Ciphertext(r, t)

In [6]:
def decrypt(c: Ciphertext, pk: PublicKey, sk: SecretKey) -> int:
    # NOTE: This implementation of https://en.wikipedia.org/wiki/Modular_multiplicative_inverse is expensive
    # TODO: One can use the `modinv` function from above but I'll leave this code here as another way to compute it
    return ((c.r ** sk.d) ** (pk.p - 2) * c.t) % pk.p

In [7]:
def mult(a: Ciphertext, b: Ciphertext) -> Ciphertext:
    r: int = a.r * b.r
    t: int = a.t * b.t
    return Ciphertext(r, t)

In [8]:
pk, sk = keygen(47)

print('--- Message Encryption / Decryption ---')
plaintext: int = 42
print(f'Message (Plaintext): {plaintext}')
    
ciphertext: Ciphertext = encrypt(plaintext, pk)
print(f'Message (Ciphertext): {ciphertext}')

decrypted: int = decrypt(ciphertext, pk, sk)
print(f'Message (Decrypted): {decrypted}')

assert plaintext == decrypted

--- Message Encryption / Decryption ---
Message (Plaintext): 42
Message (Ciphertext): Ciphertext(r=41, t=32)
Message (Decrypted): 42


In [9]:
pk, sk = keygen(47)

print('--- Encrypted Multiplication ---')
a: int = 6
b: int = 5
print(f'Numbers (Plaintext): {a}, {b}')
print(f'Result (Plaintext): {a * b}')

enc_a: Ciphertext = encrypt(a, pk)
enc_b: Ciphertext = encrypt(b, pk)
print(f'Numbers (Ciphertext): {enc_a}, {enc_b}')

result: Ciphertext = mult(enc_a, enc_b)
print(f'Result (Ciphertext): {result}')
decrypted: int = decrypt(result, pk, sk)
print(f'Result (Decrypted): {decrypted}')

assert a * b == decrypted

--- Encrypted Multiplication ---
Numbers (Plaintext): 6, 5
Result (Plaintext): 30
Numbers (Ciphertext): Ciphertext(r=12, t=18), Ciphertext(r=8, t=38)
Result (Ciphertext): Ciphertext(r=96, t=684)
Result (Decrypted): 30


## [RSA](https://en.wikipedia.org/wiki/RSA_(cryptosystem)) Cryptosystem

RSA can be used to perform encrypted multiplications.

In [10]:
class PublicKey(NamedTuple):
    e: int
    n: int

class SecretKey(NamedTuple):
    d: int
    n: int

class Ciphertext(NamedTuple):
    m: int

In [11]:
def keygen(p: int, q: int) -> Tuple[PublicKey, SecretKey]:
    n: int = p * q
    phi: int = (p - 1) * (q - 1)
    # e must be greater than 1 and smaller than phi
    # furthermore gcd(phi, e) must be 1
    e: int = 2
    while gcd(phi, e) != 1:
        e += 1
    d: int = modinv(e, phi)
    pk: PublicKey = PublicKey(e, n)
    sk: SecretKey = SecretKey(d, n)
    return (pk, sk)
    
assert keygen(61, 53)[0] == PublicKey(7, 3233)
assert keygen(61, 53)[1] == SecretKey(1783, 3233)

In [12]:
def encrypt(message: int, pk: PublicKey) -> Ciphertext:
    return Ciphertext(message ** pk.e % pk.n)

In [13]:
def decrypt(c: Ciphertext, sk: SecretKey) -> int:
    return c.m ** sk.d % sk.n

In [14]:
def mult(a: Ciphertext, b: Ciphertext) -> Ciphertext:
    return Ciphertext(m=(a.m * b.m))

In [15]:
pk, sk = keygen(61, 53)

print('--- Message Encryption / Decryption ---')
plaintext: int = 42
print(f'Message (Plaintext): {plaintext}')
    
ciphertext: Ciphertext = encrypt(plaintext, pk)
print(f'Message (Ciphertext): {ciphertext}')

decrypted: int = decrypt(ciphertext, sk)
print(f'Message (Decrypted): {decrypted}')

assert plaintext == decrypted

--- Message Encryption / Decryption ---
Message (Plaintext): 42
Message (Ciphertext): Ciphertext(m=240)
Message (Decrypted): 42


In [16]:
pk, sk = keygen(61, 53)

print('--- Encrypted Multiplication ---')
a: int = 6
b: int = 5
print(f'Numbers (Plaintext): {a}, {b}')
print(f'Result (Plaintext): {a * b}')

enc_a: Ciphertext = encrypt(a, pk)
enc_b: Ciphertext = encrypt(b, pk)
print(f'Numbers (Ciphertext): {enc_a}, {enc_b}')

result: Ciphertext = mult(enc_a, enc_b)
print(f'Result (Ciphertext): {result}')
decrypted: int = decrypt(result, sk)
print(f'Result (Decrypted): {decrypted}')

assert a * b == decrypted

--- Encrypted Multiplication ---
Numbers (Plaintext): 6, 5
Result (Plaintext): 30
Numbers (Ciphertext): Ciphertext(m=1898), Ciphertext(m=533)
Result (Ciphertext): Ciphertext(m=1011634)
Result (Decrypted): 30


## [Paillier](https://en.wikipedia.org/wiki/Paillier_cryptosystem) Cryptosystem

Paillier can be used to perform encrypted additions.

In [17]:
class PublicKey(NamedTuple):
    n: int
    g: int

class SecretKey(NamedTuple):
    la: int
    mu: int

class Ciphertext(NamedTuple):
    m: int

In [18]:
def keygen(p: int, q: int) -> Tuple[PublicKey, SecretKey]:
    assert p.bit_length() == q.bit_length()
    n: int = p * q
    g: int = n + 1
    la: int = (p - 1) * (q - 1)
    mu: int = modinv(la, n)
    pk: PublicKey = PublicKey(n, g)
    sk: SecretKey = SecretKey(la, mu)
    return (pk, sk)

assert keygen(61, 53)[0] == PublicKey(3233, 3234)
assert keygen(61, 53)[1] == SecretKey(3120, 2718)

In [19]:
def encrypt(message: int, pk: PublicKey) -> Ciphertext:
    r: int = 0
    while gcd(r, pk.n) != 1:
        r: int = randrange(0, pk.n + 1)
    m: int = ((pk.g ** message % pk.n ** 2) * (r ** pk.n % pk.n ** 2)) % pk.n ** 2
    return Ciphertext(m)

In [20]:
def decrypt(c: Ciphertext, pk: PublicKey, sk: SecretKey) -> int:
    return ((((c.m ** sk.la) % pk.n ** 2) - 1) // pk.n) * sk.mu % pk.n

In [21]:
def add(a: Ciphertext, b: Ciphertext, pk: PublicKey) -> Ciphertext:
    return Ciphertext(a.m * b.m % (pk.n ** 2))

In [22]:
pk, sk = keygen(61, 53)

print('--- Message Encryption / Decryption ---')
plaintext: int = 42
print(f'Message (Plaintext): {plaintext}')
    
ciphertext: Ciphertext = encrypt(plaintext, pk)
print(f'Message (Ciphertext): {ciphertext}')

decrypted: int = decrypt(ciphertext, pk, sk)
print(f'Message (Decrypted): {decrypted}')

assert plaintext == decrypted

--- Message Encryption / Decryption ---
Message (Plaintext): 42
Message (Ciphertext): Ciphertext(m=1548548)
Message (Decrypted): 42


In [23]:
pk, sk = keygen(61, 53)

print('--- Encrypted Addition (with encrypted values) ---')
a: int = 6
b: int = 5
print(f'Numbers (Plaintext): {a}, {b}')
print(f'Result (Plaintext): {a + b}')

enc_a: Ciphertext = encrypt(a, pk)
enc_b: Ciphertext = encrypt(b, pk)
print(f'Numbers (Ciphertext): {enc_a}, {enc_b}')

result: Ciphertext = add(enc_a, enc_b, pk)
print(f'Result (Ciphertext): {result}')
decrypted: int = decrypt(result, pk, sk)
print(f'Result (Decrypted): {decrypted}')

assert a + b == decrypted

print('--- Encrypted Addition (with encrypted and plaintext value) ---')
a: int = 20
b: int = 13
print(f'Numbers (Plaintext): {b}, {b}')
print(f'Result (Plaintext): {a + b}')

enc_a: Ciphertext = encrypt(a, pk)
print(f'Number (Ciphertext): {enc_a}')

# `pk.n + 1` == `g`
result: Ciphertext = Ciphertext(enc_a.m * (pk.n + 1) ** b % (pk.n ** 2))
print(f'Result (Ciphertext): {result}')
decrypted: int = decrypt(result, pk, sk)
print(f'Result (Decrypted): {decrypted}')

assert a + b == decrypted

print('--- Encrypted Multiplication (with encrypted and plaintext value) ---')
a: int = 2
b: int = 12
print(f'Numbers (Plaintext): {b}, {b}')
print(f'Result (Plaintext): {a * b}')

enc_a: Ciphertext = encrypt(a, pk)
print(f'Number (Ciphertext): {enc_a}')

result: Ciphertext = Ciphertext(enc_a.m ** b % (pk.n ** 2))
print(f'Result (Ciphertext): {result}')
decrypted: int = decrypt(result, pk, sk)
print(f'Result (Decrypted): {decrypted}')

assert a * b == decrypted

--- Encrypted Addition (with encrypted values) ---
Numbers (Plaintext): 6, 5
Result (Plaintext): 11
Numbers (Ciphertext): Ciphertext(m=934918), Ciphertext(m=4097088)
Result (Ciphertext): Ciphertext(m=1421243)
Result (Decrypted): 11
--- Encrypted Addition (with encrypted and plaintext value) ---
Numbers (Plaintext): 13, 13
Result (Plaintext): 33
Number (Ciphertext): Ciphertext(m=10310040)
Result (Ciphertext): Ciphertext(m=10436127)
Result (Decrypted): 33
--- Encrypted Multiplication (with encrypted and plaintext value) ---
Numbers (Plaintext): 12, 12
Result (Plaintext): 24
Number (Ciphertext): Ciphertext(m=4026668)
Result (Ciphertext): Ciphertext(m=3272659)
Result (Decrypted): 24


## [Efficient Homomorphic Encryption on Integer Vectors and Its Applications](https://www.rle.mit.edu/sia/wp-content/uploads/2015/04/2014-zhou-wornell-ita.pdf)

**NOTE:** The code written here was produced by following the blog post ["Building Safe A.I."](http://iamtrask.github.io/2017/03/17/safe-ai/) by Andrew Trask.

### Terminology

- **S**: Matrix which represents the secret / private key
- **M**: Public Key (also used to perform Math operations)
- **c**: Vector which contains the encrypted data
- **x**: Plaintext (some papers use the variable **m** instead)
- ***w***: (Weighting) Scalar used to control signal / noise ratio of **x**
- **e**: Random noise (e.g. noise added to the data before encrypting it via the public key) which makes the decryption difficult

Homomorphic Encryption has 4 kind of operations we care about:

1. Public / private keypair generation
1. One-way encryption
1. Decryption
1. Math operations

$$
\textit{S}c = \textit{w}x + e
$$

$$
x = \lceil \frac{Sc}{\textit{w}} \rfloor
$$

In [24]:
def generate_key(w: int, m: int, n: int) -> ndarray:
    S: ndarray = (np.random.rand(m, n) * w / (2 ** 16))
    return S

def encrypt(x: ndarray, S: ndarray, m: int, n: int, w: int) -> ndarray:
    assert len(x) == len(S)
    e: ndarray = (np.random.rand(m))
    c: ndarray = np.linalg.inv(S).dot((w * x) + e)
    return c

def decrypt(c: ndarray, S: ndarray, w) -> ndarray:
    return (S.dot(c) / w).astype('int')

def switch_key(c: ndarray, S: ndarray, m: int, n: int, T) -> (ndarray, ndarray):
    l: int = int(np.ceil(np.log2(np.max(np.abs(c)))))
    c_star: ndarray = get_c_star(c, m, l)
    S_star: ndarray = get_S_star(S, m, n, l)
    n_prime = n + 1
    S_prime = np.concatenate((np.eye(m), T.T), 0).T
    A: ndarray = (np.random.rand(n_prime - m, n * l) * 10).astype('int')
    E: ndarray = (1 * np.random.rand(S_star.shape[0], S_star.shape[1])).astype('int')
    M: ndarray = np.concatenate(((S_star - T.dot(A) + E), A), 0)
    c_prime: ndarray = M.dot(c_star)
    return c_prime, S_prime

def get_c_star(c: ndarray, m: int, l: int) -> ndarray:
    c_star: ndarray = np.zeros(l * m, dtype='int')
    for i in range(m):
        b: ndarray = np.array(list(np.binary_repr(np.abs(c[i]))), dtype='int')
        if (c[i] < 0):
            b *= -1
        c_star[(i * l) + (l - len(b)): (i + 1) * l] += b
    return c_star

def get_S_star(S: ndarray, m: int, n: int, l: int) -> ndarray:
    S_star: List = list()
    for i in range(l):
        S_star.append(S * 2 ** (l - i - 1))
    S_star: ndarray = np.array(S_star).transpose(1, 2, 0).reshape(m, n * l)
    return S_star

def get_T(n: int) -> ndarray:
    n_prime = n + 1
    T: ndarray = (10 * np.random.rand(n, n_prime - n)).astype('int')
    return T

def encrypt_via_switch(x: ndarray, w: int, m: int, n: int, T: ndarray) -> (ndarray, ndarray):
    c, S = switch_key(x * w, np.eye(m), m, n, T)
    return (c, S)

In [25]:
x: ndarray = np.array([0, 1, 2, 5])
    
m: int = len(x)
n: int = m
w: int = 16

S: ndarray = generate_key(w, m, n)
S

array([[3.23752234e-05, 8.71385929e-05, 1.31983522e-04, 2.07309940e-05],
       [1.86839421e-04, 6.44340965e-05, 1.80036139e-04, 1.37908927e-04],
       [3.59239655e-05, 1.13418365e-04, 2.11182783e-04, 3.64218508e-05],
       [1.56050588e-04, 2.07357406e-04, 1.54697291e-04, 4.90344554e-05]])

### Basic addition / multiplication

In [26]:
c: ndarray = encrypt(x, S, m, n, w)
c

array([-20682572.06481609,  15771016.05581861, -10814309.73560343,
        34891263.61949188])

In [27]:
decrypt(c, S, w)

array([0, 1, 2, 5])

In [28]:
decrypt(c + c, S, w)

array([ 0,  2,  4, 10])

In [29]:
decrypt(c * 10, S, w)

array([ 0, 10, 20, 50])

### Key-switching addition / multiplication

In [30]:
T: ndarray = get_T(n)

In [31]:
c, S = encrypt_via_switch(x, w, m, n, T)

In [32]:
decrypt(c, S, w)

array([0, 1, 2, 5])

In [33]:
decrypt(c + c, S, w)

array([ 0,  2,  4, 10])

In [34]:
decrypt(c * 10, S, w)

array([ 0, 10, 20, 50])