In [2]:
import hashlib

def bytes2int(r):
    return int.from_bytes(r, 'big')

def kdf(salt, m):
    dk = hashlib.pbkdf2_hmac('sha256', m, salt, 100000)
    return dk       

def int2bytes(size, r):
    return int(r).to_bytes(size, 'big')

def primerandom(l):
    return random_prime(2**l - 1, None, 2**(l-1))

#-----------------------------------------------------

def byte_xor(ba1, ba2):
    return bytes([_a ^ _b for _a, _b in zip(ba1, ba2)])

def mgf(seed, l):

    hs = hashlib.sha256
    hTam = hs().digest_size

    if l > 2**32 * hTam:
        raise ValueError('Máscara demasiado grande')
        exit()

    t = b''
    for i in range (0, round(l/hTam) - 1): # ver o que fazer acerca da divisão
        c = int2bytes(4, i)
        t = t + hs(seed + c).digest()

    return t[:l]

In [9]:
class Rsa:
    def __init__(self, segParam):
        # https://www.inf.pucrs.br/~calazans/graduate/TPVLSI_I/RSA-oaep_spec.pdf --> Pág.15
        # https://stackoverflow.com/questions/34140589/how-to-get-a-prime-of-a-given-length-in-sage --> Informação random_prime

        self.modSize = segParam

        # random_prime(2**l - 1, None, 2**(l-1))  --> l = segParam/2
        # random_prime((2**(l//2)) - 1, None, (2**((l-1)//2)) + 1) --> l = segParam

        p = primerandom(self.modSize/2)
        q = primerandom(self.modSize/2)
        n = p*q

        #phi = euler_phi(int(n))  # Euler totient function ou Carmichael's totient function
        phi = (p-1)*(q-1)
        
        # gcd(e, phi) = 1
        e = randrange(phi)
        while gcd(e,phi) != 1:
            e = randrange(phi)
            
        d = inverse_mod(e, phi)       # expoente da chave privada

        # Public Key: (n,e)
        self.publicKey = (n,e)
        # Private Key: (n,d)
        self.privateKey = (n,d)

    def rsaEncrypt(self, publicKey, m):
        return pow(m, publicKey[1], publicKey[0])

    def rsaDecrypt(self, c):
        return pow(c, self.privateKey[1], self.publicKey[0])
        #return int(m).to_bytes(self.modSize//8, 'big')
    
    def encaps(self): # A publicKey pertence a quem se vai enviar
        m = randrange(self.modSize - 1)
        self.salt = os.urandom(16)
        mbytes = int2bytes(self.modSize//8, m)
        dk = kdf(self.salt, mbytes)
        
        c = self.rsaEncrypt(self.publicKey, m)
        
        cf = int2bytes(self.modSize//8, c)
        return (cf, dk)
    
    def decap(self, cf):
        c = bytes2int(cf)
        mbytes = self.rsaDecrypt(c)
        m = int2bytes(self.modSize//8, mbytes)
        key = kdf(self.salt, m)
        
        return key;

# ---------------------------------------------

    def padding(self, msg):
        hs = hashlib.sha256()
        p = b''

        m = msg

        emTam = self.modSize//8
        mTam = len(m)
        hTam = hs.digest_size

        if mTam > (emTam - 2*hTam - 1):
            raise ValueError('Erro: Mensagem demasiado grande')
            exit()

        ps = b''
        ps = b'\0' * (emTam - mTam - 2*hTam - 1)
        
        pHash = hs(p).digest()
        
        # data block
        db = pHash + ps + b'\x01' + m

        # Gerar seed
        seed = os.urandom(hTam)

        dbMask = mgf(seed, (emTam - hTam))

        maskedDB = byte_xor(db, dbMask)

        seedMask = mgf(maskedDB, hTam)

        maskedSeed = byte_xor(seed, seedMask)

        em = maskedSeed + maskedDB

        return em
    

    def unpadding(self, em):
        p = b''
        emLen = self.modSize//8

        hs = hashlib.sha256()
        hLen = hs.digest_size

        if emLen < 2*hLen + 1:
            raise ValueError('Decoding error')
            exit()

        maskedSeed = em[:hLen]

        maskedDB = em[hLen:]

        seedMask = mgf(maskedDB, hLen)

        seed = byte_xor(maskedSeed, seedMask)

        dbMask = mgf(seed, (emLen - hLen))

        db = byte_xor(maskedDB, dbMask)

        phash = hs(p).digest()

        phash2 = db[:hLen]

        temp = db[hLen:]

        i = 0
        boo = True
        while(boo):
            if db[i] != b'\0':
                boo = False
            else:
                i+=1

        temp2 = temp[i:]

        if temp2[0] != b'\x01':
            raise ValueError('Não foi encontrado 01')
            exit()

        m = temp2[1:]

        if phash != phash2:
            raise ValueError('pHashes não são iguais')
            exit()

        return m

In [10]:
rsa = Rsa(1024)

(cf, dk) = rsa.encaps()
print(dk)

k = rsa.decap(cf)

print(k)

b'\x16\x05\x83\r\xaf\xf36\xd8F\xf7\x86T\xd6T\x02i\xa9\xeaU\x11\x87 \x86\x84\x7f:\xbd\xc7\xb8T\xa9\xcc'
b'\x16\x05\x83\r\xaf\xf36\xd8F\xf7\x86T\xd6T\x02i\xa9\xeaU\x11\x87 \x86\x84\x7f:\xbd\xc7\xb8T\xa9\xcc'
