In [227]:
# import the necessary libraries here
from Crypto.Util import number
from Crypto.PublicKey import RSA as CryptoRSA
from Crypto.Util.number import long_to_bytes, bytes_to_long

In [228]:
class RSA:
    """Implements the RSA public key encryption / decryption."""

    def __init__(self, key_length):
        # define self.p, self.q, self.e, self.n, self.d here based on key_length
        # self.key = RSA.generate(key_length)
        self.key = CryptoRSA.generate(key_length)
        self.p = self.key.p
        self.q = self.key.q
        self.n = self.key.n
        self.e = self.key.e
        self.d = self.key.d

    def encrypt(self, message):
        # return encryption of binary_data here
        message_int = bytes_to_long(message)
        power = pow(message_int, self.e, self.n)
        return power

    def decrypt(self, encrypted_int_data):
        # return decryption of encrypted_binary_data here
        power = pow(encrypted_int_data, self.d, self.n)
        return long_to_bytes(power)

In [229]:
class RSAParityOracle(RSA):
    """Extends the RSA class by adding a method to verify the parity of data."""

    def is_parity_odd(self, encrypted_int_data):
        # Decrypt the input data and return whether the resulting number is odd
        decrypted_int = pow(encrypted_int_data, self.d, self.n)
        return decrypted_int % 2 == 1


Oracle in a way gives least significant bit each time
C = P^e mod n

multiplying with 2^e on both sides gives
C1 = C*2^e = (2P)^e mod n
i.e., C1 is cipher of 2P
decrypting msg to get 2P

if rightmost bit is on i.e., oracle returns odd then P > N/2
otherwise bit is off i.e., oracle returns even then P < N/2

again checking for 4P now..
if even then P < N/4 or N/2 <= P < 3N/4
or if odd then N/4 <= P < N/2 or 3N/4 <= P < N

This way each turn doing like binary search we get P finally

In [230]:

def parity_oracle_attack(ciphertext, rsa_parity_oracle):
    # implement the attack and return the obtained plaintext
    e = rsa_parity_oracle.e
    n = rsa_parity_oracle.n
    left = 0
    right = n
    multiplier = pow(2, e, n)
    
    while right - left >= 1:
        ciphertext = (ciphertext * multiplier) % n
        odd = rsa_parity_oracle.is_parity_odd(ciphertext)

        mid = (left + right) // 2
        if odd:
            left = mid
        else:
            right = mid
    return long_to_bytes(left)

        

In [231]:
def main():
    input_bytes = input("Enter the message: ")
    # input_bytes = "Hello, World!"

    # Generate a 1024-bit RSA pair
    rsa_parity_oracle = RSAParityOracle(1024)

    # Encrypt the message
    ciphertext = rsa_parity_oracle.encrypt(input_bytes.encode())
    print("Encrypted message is: ", ciphertext)
    
    decrypted = rsa_parity_oracle.decrypt(ciphertext)
    print("Directly decrypted text is: ",decrypted)
    
    
    # Check if the attack works
    plaintext = parity_oracle_attack(ciphertext, rsa_parity_oracle)
    print("Obtained plaintext: ", plaintext)
    # assert plaintext == input_bytes.encode()


if __name__ == '__main__':
    main()

Encrypted message is:  95190703287181858901777329166188169881045146673027932460461298463026682091738299068461205661851863319252833745006656518486857487833678264144914229553569826801174857691081253458457863832935114178738838905988790059400046079733607457556534602548499449305699903979661091922948834099827120682884882729035812531084
Directly decrypted text is:  b'Hello, World!'
Obtained plaintext:  b'Hello, World\x12'
