In [1]:
from sympy import mod_inverse, lcm, isprime
import random

In [2]:
def generate_large_prime(bits):
    """Generate a large prime number."""
    while True:
        num = random.getrandbits(bits)
        if isprime(num):
            return num

In [3]:
def keygen(bits):
    # Choose two large prime numbers
    p = generate_large_prime(bits)
    q = generate_large_prime(bits)

    # Compute n and lambda
    n = p * q
    lambda_n = lcm(p - 1, q - 1)

    # Choose g
    g = n + 1

    # Compute mu
    n2 = n ** 2
    g_lambda_mod_n2 = pow(g, int(lambda_n), n2)  # Convert lambda_n to int
    L = (g_lambda_mod_n2 - 1) // n
    mu = mod_inverse(L, n)

    return (n, g), (int(lambda_n), int(mu))  # Ensure all values are standard integers

In [4]:
public_key, private_key = keygen(16)
print("Public Key:", public_key)
print("Private Key:", private_key)

Public Key: (132128933, 132128934)
Private Key: (2539680, 109774635)


In [5]:
def encrypt(public_key, plaintext):
    n, g = public_key
    r = random.randint(1, n-1)
    n2 = n ** 2
    
    # Compute ciphertext
    c = (pow(g, plaintext, n2) * pow(r, n, n2)) % n2
    return c

In [6]:
plaintext = 10
ciphertext = encrypt(public_key, plaintext)
print("Plaintext:", plaintext)
print("Ciphertext:", ciphertext)

Plaintext: 10
Ciphertext: 6505047243130652


In [7]:
def decrypt(private_key, public_key, ciphertext):
    lambda_n, mu = private_key
    n, _ = public_key
    n2 = n ** 2
    
    # Compute m
    c_lambda = pow(ciphertext, lambda_n, n2)
    L = (c_lambda - 1) // n
    m = (L * mu) % n
    
    return m

In [8]:
decrypted_message = decrypt(private_key, public_key, ciphertext)
print("Decrypted Message:", decrypted_message)

Decrypted Message: 10


In [9]:
def homomorphic_addition(public_key, c1, c2, c3):
    """Perform homomorphic addition of two ciphertexts."""
    n, _ = public_key
    n2 = n ** 2
    # Homomorphic addition: c1 * c2 * c3 % n^2
    c_sum = (c1 * c2 * c3) % n2
    return c_sum

In [10]:
# Encrypt two plaintext numbers
gradient_1 = 40
gradient_2 = 58
gradient_3 = 20

encr_gradient_1 = encrypt(public_key, gradient_1)
encr_gradient_2 = encrypt(public_key, gradient_2)
encr_gradient_3 = encrypt(public_key, gradient_3)

print("Encrypted Gradient 1:", encr_gradient_1)
print("Encrypted Gradient 2:", encr_gradient_2)
print("Encrypted Gradient 3:", encr_gradient_3)

Encrypted Gradient 1: 4411436230418666
Encrypted Gradient 2: 10760337853812060
Encrypted Gradient 3: 482927698409710


In [11]:
# Homomorphic addition
encr_gradient_sum = homomorphic_addition(public_key, encr_gradient_1, encr_gradient_2, encr_gradient_3)
print("Encrypted Gradient Sum:", encr_gradient_sum)

Encrypted Gradient Sum: 12402911223688676


In [12]:
# Decrypt the result of the homomorphic addition
decr_gradient_sum = decrypt(private_key, public_key, encr_gradient_sum)
print("Decrypted Gradient Sum:", decr_gradient_sum)
print("Decrypted Gradient Average:", decr_gradient_sum/3)

Decrypted Gradient Sum: 118
Decrypted Gradient Average: 39.333333333333336
