# Public Key Encryption

In [1]:
from math import sqrt
from typing import Tuple

## [Diffie-Hellman](https://en.wikipedia.org/wiki/Diffie–Hellman_key_exchange) Key Exchange

In [2]:
# Public, shared information
p: int = 23 # A prime number
g: int = 5 # A base number

In [3]:
# Private, non-shared information
a: int = 4 # Alices secret exponent
b: int = 3 # Bobs secret exponent

In [4]:
def alice_enc(p: int, g: int) -> int:
    return (g ** a) % p

In [5]:
def bob_enc(p: int, g: int) -> int:
    return (g ** b) % p

In [6]:
j: int = alice_enc(p, g)
k: int = bob_enc(p, g)

In [7]:
def alice_dec(k: int) -> int:
    return (k ** a) % p

assert alice_dec(k) == (k ** a % p) == (g ** (b * a) % p)

In [8]:
def bob_dec(j: int) -> int:
    return (j ** b) % p

assert bob_dec(j) == (j ** b % p) == (g ** (a * b) % p)

In [9]:
print(f'Alices number: {alice_dec(k)}')
print(f'Bobs number: {bob_dec(j)}')

Alices number: 18
Bobs number: 18


## [RSA](https://en.wikipedia.org/wiki/RSA_(cryptosystem)) Cryptosystem

In [10]:
# Taken from: https://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Extended_Euclidean_algorithm
def egcd(a: int, b: int) -> Tuple[int, int, int]:
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = egcd(b % a, a)
        return (g, x - (b // a) * y, y)

def modinv(a: int, m: int) -> int:
    g, x, y = egcd(a, m)
    if g != 1:
        raise Exception('modular inverse does not exist')
    else:
        return x % m

assert modinv(17, 3120) == 2753
assert egcd(1071, 462) == (21, -3, 7)

In [11]:
# Private, non-shared information
p: int = 61
q: int = 53
n: int = p * q
phi_n: int = (p - 1) * (q - 1)

# NOTE: We start with a "high" guess for e here so that we can control
# how "large" e should be
e: int = 12
while egcd(e, phi_n)[0] != 1:
    e += 1

d: int = modinv(e, phi_n)

secret_key: Tuple[int, int] = (d, n)

# Public, shared information
public_key: Tuple[int, int] = (e, n)
    
assert secret_key == (2753, 3233)
assert public_key == (17, 3233)

In [12]:
e: int = public_key[0]
n: int = public_key[1]

plaintext: int = 42
print(f'Message (Plaintext): {plaintext}')

ciphertext: int = (plaintext ** e) % n
print(f'Message (Encrypted): {ciphertext}')

d: int = secret_key[0]
n: int = secret_key[1]
decrypted: int = (ciphertext ** d) % n
    
print(f'Message (Decrypted): {decrypted}')

Message (Plaintext): 42
Message (Encrypted): 2557
Message (Decrypted): 42
