# Pregunta 2


## Funciones de utilidad
### power
Retorna "a" elevado a "b" en modulo n. Se asume que a y b son positivos

### mcd
Retorna el maximo común divisor entre a y b

In [496]:
from random import randint, randrange

In [497]:
def power(a, b, n):
    res = 1
    a = a % n
    while b > 0:
        if b % 2 == 1:
            res = (res * a) % n
        b = b // 2
        a = (a * a) % n
    return res


def mcd(a, b):
    while True:
        t = a % b
        if t == 0:
            return b
        a = b
        b = t

In [498]:
class RSAReceiver:

    def __init__(self, bit_len: int) -> None:
        self._bit_len = bit_len
        self._P = self._generate_prime()
        Q = self._generate_prime()
        while True:
            if Q != self._P:
                self._Q = Q
                break
            Q = self._generate_prime()
        self._N = self._P * self._Q
        self._public_key = self._make_public_key()
        self._block_length = ((len(bin(self._N)) - 2) // 8) + 1

    
    def _rabin_miller(self, p: int, d: int) -> bool:
        # Elegir testigo
        a = 2 + randint(1, p - 4)
        x = power(a, d, p)

        if x == 1 or x == (p - 1):
            return True
        
        while d != p - 1:
            x = (x * x) % p
            d *= 2
            if x == 1:
                return False
            if x == p - 1:
                return True

        return False

    
    def _is_prime(self, p: int, k: int) -> bool:
        # Casos Borde
        if p <= 1 or p == 4: # negativo o 1 o 4
            return False
        if p <= 3: # 2 o 3
            return True
        
        # Encontrar d para rabin-miller
        # p = d * 2^r + 1, d impar
        d = p - 1
        while d % 2 == 0:
            d //= 2
        
        for i in range(k):
            if self._rabin_miller(p, d) == False:
                return False
        return True

    
    def _generate_prime(self) -> int:
        counter = 0
        while counter <= 2 ** (self._bit_len):
            candidate = randrange(2 ** (self._bit_len - 1), (2 ** self._bit_len) - 1)
            counter += 1
            if self._is_prime(candidate, 10):
                return candidate


    def _make_public_key(self) -> bytearray:
        phi = (self._P - 1) * (self._Q - 1)
        self._e = 2
        while self._e < phi:
            if mcd(self._e, phi) == 1:
                break
            self._e += 1
        
        self._d = (1 + phi) // self._e

        len_e = 1 + (len(bin(self._e)) - 2) // 8
        len_e_bytes = bytearray(len_e.to_bytes(4, "big"))
        e_bytes = bytearray(self._e.to_bytes(len_e, "big"))

        len_n = 1 + (len(bin(self._N)) - 2) // 8
        len_n_bytes = bytearray(len_n.to_bytes(4, "big"))
        n_bytes = bytearray(self._N.to_bytes(len_n, "big"))

        public_key = len_e_bytes + e_bytes + len_n_bytes + n_bytes
        return public_key
    

    def get_public_key(self) -> bytearray:
        return self._public_key

    
    def decrypt(self, ciphertext: bytearray) -> str:
        s = ""
        block_amount = len(ciphertext) // self._block_length
        last_block_length = (len(ciphertext) % (self._block_length + 1))
        has_last_block =  last_block_length != 0

        for i in range(block_amount):
            s += self._decrypt_block(ciphertext[i * self._block_length: (i + 1) * self._block_length], self._block_length)
        if has_last_block:
             s += self._decrypt_block(ciphertext[block_amount * (self._block_length + 1):], last_block_length)
        return s
    

    def _decrypt_block(self, block: bytearray, block_length) -> str:
        c = int.from_bytes(block, "big")
        m = pow(c, self._d, self._N)
        return bytearray(m.to_bytes(self._block_length , "big")).decode()


In [499]:
class RSASender:
    
    def __init__(self, public_key: bytearray) -> None:
        self.public_key = public_key
        self._make_e_and_N()
        self._block_length = (len(bin(self._n)) - 2) // 8


    def _make_e_and_N(self) -> None:
        e_len = int.from_bytes(self.public_key[:4], "big")
        self._e = int.from_bytes(self.public_key[4:4+e_len], "big")
        n_len = int.from_bytes(self.public_key[4 + e_len: 8 + e_len], "big")
        self._n = int.from_bytes(self.public_key[8 + e_len:8 + e_len + n_len], "big")

    
    def encrypt(self, message: str) -> bytearray:
        mess_bytes = bytearray(message, 'utf-8')
        mess_len = len(mess_bytes)
        block_amount = mess_len // self._block_length
        last_block_length = mess_len % self._block_length
        has_last_block = last_block_length != 0
        print(mess_len, block_amount, self._block_length, last_block_length, has_last_block)

        # First block
        c = self._encrypt_block(mess_bytes[:self._block_length], self._block_length)

        # Middle blocks
        for i in range(1, block_amount):
            c += self._encrypt_block(mess_bytes[i * self._block_length: i * (self._block_length + 1)], self._block_length)

        # Last block
        if has_last_block:
            block = mess_bytes[self._block_length * block_amount:]
            number = int.from_bytes(block, "big")
            c = power(number, self._e, self._n)
            print()
            print(int.from_bytes(mess_bytes[self._block_length * block_amount:], "big"))
            c += self._encrypt_block(mess_bytes[self._block_length * block_amount:], last_block_length)
        
        return c
    

    def _encrypt_block(self, block: bytearray, block_length: int) -> bytearray:
        number = int.from_bytes(block, "big")
        c = power(number, self._e, self._n)
        return bytearray(c.to_bytes(block_length + 1, "big"))

In [500]:
if __name__ == "__main__":
    message1 = "this is a message"
    rsaR = RSAReceiver(16)
    rsaS = RSASender(rsaR.get_public_key())
    c = rsaS.encrypt(message1)
    print(c, len(c), len(bytearray(message1, 'utf-8')))
    m = rsaR.decrypt(c)
    print(m)

17 4 4 1 True
bytearray(b'e')
101


OverflowError: int too big to convert