# Pregunta 2: RSA

## Imports

In [None]:
from random import randint

## Test de primalidad de Rabin-Miller

In [None]:
def rabin_miller(n: int, k: int) -> bool:
    """
    Determines if a number is prime using Rabin-Miller primality test

    Args:
        n (int): number >= 1 to test for primality
        k (int): number >= 1, error threshold parameter

    Returns:
        bool: True if n is prime, False otherwise (with error probability <= 2**(-k))
    """
    # 1 nor even numbers are prime
    if n == 1 or n % 2 == 0:
        return False
    if n == 2 or n == 3:
        return True
    # Get r, d such that n == 2**r*d + 1
    r = 0
    d = n - 1
    while d % 2 == 0:
        d //= 2
        r += 1

    # WitnessLoop
    for _ in range(k):
        a = randint(2, n - 2)
        x = exp_mod(a, d, n)
        if x == 1 or x == n - 1:
            continue
        flag = True
        for _ in range(r - 1):
            x = exp_mod(x, 2, n)
            if x == n - 1:
                flag = False
                break
        if flag:
            # Number is composite
            return False
    # Number may be prime (with error < 2**-k)
    return True


def generar_primo(l: int) -> int:
    """
    Generates a random prime number with at least l digits.

    Args:
        l (int): number >= 1 minimum number of digits of prime number

    Returns:
        int: Prime number with at least l digits. Error probability must be <= 2**(-100)
    """
    k = 10
    for _ in range(l - 1):
        k *= 10
    low, high = k // 10, k
    current_test = randint(low, high)
    while not rabin_miller(current_test, 100):
        current_test = randint(low, high)
    return current_test


## Algoritmo extendido de euclides

In [None]:
def alg_ext_euclides(a: int, b: int) -> tuple[int, int, int]:
    """
    Extended euclidean algorithm

    Args:
        a (int): number > 0
        b (int): a >= b >=

    Returns:
        tuple[int, int, int]: (GCD(a,b), s, t) greatest common divisor GCD(a, b) of a and b, and
            integers s and t: MCD(a, b) = s*a + t*b
    """
    # 1 and 0 are always going to be the factors of the last equations
    s_prev, t_prev, = 1, 0,
    s, t = 0, 1
    while b != 0:
        q = a // b
        a, b = b, a % b
        if b != 0:
            s_prev, s = s, s_prev - q * s
            t_prev, t = t, t_prev - q * t
    return a, s, t


## Exponenciación rápida

In [None]:
def exp_mod(a: int, b: int, n: int) -> int:
    """
    Modular exponentation algorithm

    Args:
        a (int): number >= 0
        b (int): number >= 0
        n (int): number > 0

    Returns:
        int: a**b mod n
    """
    a = a % n
    if n == 1:
        return 0
    if b == 0:
        return 1
    if a == 0:
        return 0
    c = 1
    while b > 0:
        if ((b % 2) != 0):
            c = (c * a) % n
        b //= 2
        a = (a * a) % n
    return c


## Inverso  modular

In [None]:
def inverso(a: int, n: int) -> int:
    """
    Modular inverse

    Args:
        a (int): number >= 1
        n (int): number >= 2, relative prime of a

    Returns:
        int: modular inverse of a in mod n    
    """
    x, y, _ = alg_ext_euclides(a, n)
    if x == 1:
        return y % n
    return None


## Generar clave 

In [None]:
def generar_clave(l: int):
    """
    Build a public and private key of specific length.

    Args:
        l (int): length of keys to generate

    Saves private key (d, N) and public key (e, N) in private_key.txt and public_key.txt
    respectively, with the following format:
        d
        N
    """
    # Generate primes
    p, q = generar_primo(l//2), generar_primo(l//2)

    # Get n and phi(n)
    n = p * q
    phi_n = (p - 1) * (q - 1)

    # Get e, d
    e = randint(2, phi_n - 1)
    mcd, _, _ = alg_ext_euclides(e, phi_n)
    while mcd != 1:
        e = randint(2, phi_n - 1)
        mcd, _, _ = alg_ext_euclides(e, phi_n)
    d = inverso(e, phi_n)

    # Store keys
    with open('public_key.txt', 'w') as key_file:
        key_file.write(f'{e}\n{n}\n')
    with open('private_key.txt', 'w') as key_file:
        key_file.write(f'{d}\n{n}\n')
    return (e, n), (d, n)


## Encryption and decryption algorithms

In [None]:
def enc(m: int) -> int:
    """
    Encrypts a message m with public key stored in public_key.txt

    Args:
        m (int): 0 <= m <= N-1, message to encrypt

    Returns:
        int: m encrypted with public key (e, N)
    """
    # Retrieve public key
    with open('public_key.txt', 'r') as key_file:
        lines = key_file.readlines()
        e, n = tuple(int(s.strip()) for s in lines)
    return exp_mod(m, e, n)


def dec(m: int) -> int:
    """
    Decrypts a message m with private key stored in private_key.txt

    Args:
        m (int): 0 <= m <= N-1, message to decrypt

    Returns:
        int: m decrypted with private key (d, N)
    """
    with open('private_key.txt', 'r') as key_file:
        lines = key_file.readlines()
        d, n = tuple(int(s.strip()) for s in lines)
    return exp_mod(m, d, n)


## Message to int

In [None]:
def string_to_int(string_):
    """
    Transform an ASCII string into its numeric representation

    Args:
        string_ (str): an ASCII string to transform
    
    Returns:
        int: numeric representation of string_

    """
    n = 0
    for i, char in enumerate(string_):
        n += ord(char) << (i*7)
    return n

def int_to_string(n):
    """
    string_to_int inverse process
    """
    string_ = ""
    while n > 0:
        string_ += chr(n & 0x7f)
        n >>= 7
    return string_

## Tests

### Generate key

In [None]:
if __name__ == "__main__":
    generar_clave(200)

### Encrypt and decrypt message

In [None]:
if __name__ == "__main__":
    message = "Hello World!"
    cod_m = string_to_int(message)
    print("Message_length:", len(str(cod_m)))
    encrypted = enc(cod_m)
    decrypted = dec(encrypted)

    dec_message = int_to_string(decrypted)
    print(dec_message)
    assert(dec_message == message)