In [11]:
import random
import math
import time
from base64 import b64encode, b64decode


# https://github.com/marceloarenassaavedra/IIC2283-2-21/blob/master/codigos%20de%20clases/alg_teoria_numeros.py
# https://www.section.io/engineering-education/rsa-encryption-and-decryption-in-python/
# https://coderoasis.com/implementing-rsa-from-scratch-in-python/


_primers = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 
                     83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 
                     179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 
                     271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 
                     379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 
                     479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 
                     599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 
                     701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 
                     823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 
                     941, 947, 953, 967, 971, 977, 983, 991, 997]


def _exp(a: int, b: int) -> int:
    """
    Argumentos :
        a: int
        b: int - b >= 0
    Retorna :
        int - a**b
    """
    if b == 0:
        return 1
    else:
        res = 1
        pot = a
        while b > 0:
            if b % 2 == 1:
                res = pot * res
            b = b // 2
            pot = pot * pot
        return res



def _exp_mod(a: int, b: int, n: int) -> int:
    """
    Argumentos :
        a: int
        b: int
        n: int - n > 0
    Retorna :
        int - a**b en modulo n
    """
    if b == 0:
        return 1
    elif b > 0:
        res = 1
        pot = a
        while b > 0:
            if b % 2 == 1:
                res = (pot * res) % n
            b = b // 2
            pot = (pot * pot) % n
        return res
    else:
        return _exp_mod(_inverso(a,n),-b,n)

    

def _mcd(a: int, b: int) -> int:
    """
    Argumentos :
        a: int
        b: int - a > 0 o b > 0
    Retorna :
        maximo comun divisor entre a y b,
    """
    while b > 0:
        temp = b
        b = a % b
        a = temp
    return a



def _alg_ext_euclides(a: int, b: int) -> (int, int, int):
    """
    Argumentos :
        a: int
        b: int - a >= b >= 0 y a > 0
    Retorna :
        (int , int , int) - maximo comun divisor MCD(a, b) entre a y b,
        y numeros enteros s y t tales que MCD(a, b) = s*a + t*b
    """
    r_0 = a
    s_0 = 1
    t_0 = 0
    r_1 = b
    s_1 = 0
    t_1 = 1
    while r_1 > 0:
        r_2 = r_0 % r_1
        s_2 = s_0 - (r_0 // r_1) * s_1
        t_2 = t_0 - (r_0 // r_1) * t_1
        r_0 = r_1
        s_0 = s_1
        t_0 = t_1
        r_1 = r_2
        s_1 = s_2
        t_1 = t_2
    return r_0, s_0, t_0



def _inverso(a: int, n: int) -> int:
    """
    Argumentos :
        a: int - a >= 1
        n: int - n >= 2, a y n son primos relativos
    Retorna :
        int - inverso de a en modulo n
    """
    (r, s, t) = _alg_ext_euclides(a, n)
    return s % n



def _es_potencia(n: int) -> bool:
    """
    Argumentos :
        n: int - n >= 1
    Retorna :
        bool - True si existen numeros naturales a y b tales que n = (a**b),
        donde a >= 2 y b >= 2. En caso contrario retorna False.       
    """
    if n <= 3:
        return False
    else:
        k = 2
        lim = 4
        while lim <= n:
            if _tiene_raiz_entera(n, k):
                return True
            k = k + 1
            lim = lim * 2
        return False


    
def _tiene_raiz_entera(n: int, k: int) -> bool:
    """
    Argumentos :
        n: int - n >= 1
        k: int - k >= 2
    Retorna :
        bool - True si existe numero natural a tal que n = (a**k),
        donde a >= 2. En caso contrario retorna False.       
    """
    if n <= 3:
        return False
    else:
        a = 1
        while _exp(a,k) < n:
            a = 2*a
        return _tiene_raiz_entera_intervalo(n, k, a//2, a)


    
def _tiene_raiz_entera_intervalo(n: int, k: int, i: int, j: int) -> bool:
    """
    Argumentos :
        n: int - n >= 1
        k: int - k >= 2
        i: int - i >= 0
        j: int - j >= 0
    Retorna :
        bool - True si existe numero natural a tal que n = (a**k),
        donde i <= a <= j. En caso contrario retorna False.       
    """
    while i <= j:
        if i==j:
            return n == _exp(i,k)
        else:
            p = (i + j)//2 
            val = _exp(p,k)
            if n == val:
                return True
            elif val < n:
                i = p+1
            else:
                j = p-1
    return False


def _get_prime(max_size):
    while True:
        _min = 2 ** ((max_size // 2 + 1) - 1)
        _max = 2 ** (max_size // 2 + 1)
        __prime = random.randint(_min, _max)
        
        if _test_miller_rabin(__prime, 100):
            return __prime

        
        
def _test_primalidad(n: int, k: int) -> bool:
    """
    Argumentos :
        n: int - n >= 1
        k: int - k >= 1
    Retorna :
        bool - True si n es un numero primo, y False en caso contrario.
        La probabilidad de error del test es menor o igual a 2**(-k),
        y esta basado en el test de primalidad de Solovay–Strassen
    """
    if n == 1:
        return False
    elif n == 2:
        return True
    elif n%2 == 0:
        return False
    elif _es_potencia(n):
        return False
    else:
        neg = 0
        for i in range(1,k+1):
            a = random.randint(2,n-1)
            if _mcd(a,n) > 1:
                return False
            else:
                b = _exp_mod(a,(n-1)//2,n)
                if b == n - 1:
                    neg = neg + 1
                elif b != 1:
                    return False
        if neg > 0:
            return True
        else:
            return False


def random_mcd(_size, _phi):
    while True:
        _min = 2 ** ((_size // 2 + 1) - 1)
        _max = 2 ** (_size // 2 + 1)
        d = random.randint(_min, _max)
        if _mcd(d, _phi) == 1:
            return d
        
        
def _test_miller_rabin(n: int, k: int) -> bool :
    """
    Argumentos :
        n: int - n >= 1
        k: int - k >= 1
    Retorna :
        bool - True si n es un numero primo, y False en caso contrario.
        La probabilidad de error del test es menor o igual a 2**(-k),
        e implementa el test de primalidad de Miller-Rabin.
    """
    if n == 1:
        return False
    elif n==2:
        return True
    elif n%2 == 0:
        return False
    else:
        s = 0
        d = n-1
        while d%2==0:
            s = s + 1
            d = d//2
        num = k//2 + 1
        for i in range(0,num):
            a = random.randint(2,n-1)
            pot = _exp_mod(a,d,n)
            if pot != 1 and pot != n-1:
                pasar = False
                for j in range(0,s):
                    pot = (pot*pot) % n
                    if pot == n-1:
                        pasar = True
                        break
                if pasar == False:
                    return False
        return True
    
    
"""

Receiver

"""

class RSAReceiver:
    def __init__(self, bit_len):      
        """
        Arguments:
            bit_len: A lower bound for the number of bits of N,
            the second argument of the public and secret key.
        """
        self.bit_len = bit_len
        
        self.P = _get_prime(bit_len)
        self.Q = _get_prime(bit_len)
        
        self.n = self.P * self.Q
        self.phi = (self.P - 1) * (self.Q - 1)
        self.d = random_mcd(self.bit_len, self.phi)
        self._lenght_n = int(math.log(self.n,2)//8 +1)
        
    def get_public_key(self):
        """
        Returns: 
            public_key
            
            Public key expressed as a Python ’bytearray’ using the PEM format. 
            This means the public key is divided in: 
            (1) The number of bytes of e (4 bytes)
            (2) the number e (as many bytes as indicated in (1))
            (3) The number of bytes of N (4 bytes)
            (4) the number N (as many bytes as indicated in (3))
        """
        e = _inverso(self.d, self.phi)
        
        length_e = len(bin(e)) // 8
        if len(bin(e)) % 8 != 0:
            length_e += 1
            
        length_n = len(bin(self.n)) // 8
        if len(bin(self.n)) % 8 != 0:
            length_n += 1
            
        _bytearray_e = bytearray((e).to_bytes(length_e, 'big'))
        e_len = len(_bytearray_e)
            
        n_bytes = bytearray((self.n).to_bytes(length_n, 'big'))
        n_len = len(n_bytes)
        public_key = e_len.to_bytes(4, 'big') + _bytearray_e + n_len.to_bytes(4, 'big') + n_bytes
        
        return public_key

    def decrypt(self, ciphertext):
        """
        Arguments:
            ciphertext: The ciphertext to decrypt 
        Returns:
            message: The original message 
        """
        
        def _clean(_encoded):
            for i in range(len(_encoded)):
                if _encoded[i] != 0:
                    _encoded = _encoded[i:]
                    break
            return _encoded
        
        plain_text = ""
        _lenght_ciphertext = len(ciphertext)
        _idx = 0
        
        while _idx < _lenght_ciphertext:
            _upper = _idx + self._lenght_n
            num_c = int.from_bytes(ciphertext[_idx : _upper], "big")
            _power = pow(num_c, self.d, self.n)
            _encoded = _power.to_bytes(self._lenght_n - 1,"big")
            
            # Clean extra info
            if _lenght_ciphertext <= _idx + self._lenght_n:
                _encoded = _clean(_encoded)
            
            # Add to current decrypted message
            plain_text += _encoded.decode("utf-8")
            _idx += self._lenght_n
            
        return plain_text
    
"""

RSA Sender

"""
class RSASender:
    def __init__(self, public_key):
        """
        Arguments:
            public_key: The public key that will be used to encrypt messages 
        """
        # PK
        self.public_key = public_key
        
        # E, N
        length_E = int.from_bytes(self.public_key[0:4], 'big')
        self.e = int.from_bytes(self.public_key[4: 4 + length_E], 'big')
        length_N = int.from_bytes(self.public_key[4 + length_E : 8 + length_E], 'big')
        self.n = int.from_bytes(self.public_key[8 + length_E : 8 + length_E + length_N], 'big')
        
        # Bit Len
        self.bit_len = math.ceil(math.log(self.n) / math.log(2))
    
    def encrypt(self, message): 
        """
        Arguments:
            message: The plaintext message to encrypt Returns:
            ciphertext: The encrypted message 
        """
        length_N = len(bin(self.n))
        
        # Block size
        n = 0
        if length_N % 8 == 0:
            n = length_N // 8
            n -= 1
        else:            
            n = math.floor(length_N / 8)
        
        _bytearray = bytearray(message.encode('utf-8'))
        
        # Calculate blocks
        n_blocks = len(_bytearray) // n
        if len(_bytearray) % n != 0:
            n_blocks += 1
        
        # Encrypt
        _ciphertext = bytearray()
        for i in range(n_blocks):
            msg_int = int.from_bytes(_bytearray[i * n : (i + 1) * n], 'big')
            _pow_cipher = pow(msg_int, self.e, self.n)
            _ciphertext += _pow_cipher.to_bytes(n + 1, 'big')
            
        return _ciphertext