In [1]:
import random
from sage.rings.real_mpfi import RealIntervalField
import time

In [2]:
def ceill(a,b):
    return  a // b + (a % b > 0)

In [3]:
def generate_rsa_keypair(bits=256):
    # Generate a random prime number p
    p = next_prime(randint(2^(bits//2), 2^(bits//2 + 1)))

    # Generate a random prime number q, making sure it's different from p
    q = next_prime(randint(2^(bits//2), 2^(bits//2 + 1)))
    while q == p:
        q = next_prime(randint(2^(bits//2), 2^(bits//2 + 1)))
    
    n = p * q

    # Calculate totient (Euler's totient function) phi(n)
    phi_n = (p - 1) * (q - 1)

    # Choose public exponent e such that 1 < e < phi_n and gcd(e, phi_n) = 1
    e = 65537
    
    print("prime:", gcd(e,phi_n))

    # Calculate private exponent d such that d * e ≡ 1 (mod phi_n)
    d = inverse_mod(e, phi_n)
    

    # Public key: (e, n)
    # Private key: (d, n)
    return ((e, n), (d, n))

In [4]:
def convert_EB_To_Integer(bytes_message):
    integer_value = int.from_bytes(p, byteorder='big', signed=False)
    return integer_value

In [5]:
def convert_Integer_To_Bytes(integer, modulus_length, target_byte_length=32):
    hex_string = hex(integer)[2:]
    
    if len(hex_string) % 2 != 0:
        hex_string = '0' + hex_string
    
    byte_length = (modulus_length + 7) // 8
    padded_byte_length = max(byte_length, target_byte_length)
    
    # Pad with '00' on the left if needed
    hex_string = hex_string.rjust(2 * padded_byte_length, '0')
    
    return bytes.fromhex(hex_string)

In [6]:
def rsa_encrypt(integer_m, public_key):
    e, n = public_key
    # Encrypt using the public key
    c = power_mod(integer_m, e, n)
    return c

In [7]:
def rsa_decrypt(ciphertext, private_key):
    d, n = private_key
    # Decrypt using the private key
    m = power_mod(ciphertext, d, n) 
    return m

In [8]:
def pkcs1_pad(message, n_bytes):
    # Ensure the message is not too long for the given modulus length
    max_message_length = n_bytes - 11
    print(len(message))
    if len(message) > max_message_length:
        raise ValueError("Message too long for PKCS #1  padding")
    
    # Calculate padding length
    pad_length = n_bytes - len(message) - 3
    
    # Generate random non-zero padding bytes
    padding = bytes([randint(1, 255) for _ in range(pad_length)])
    
    # Combine the result: 0x00 || 0x02 || non-zero padding || 0x00 || message
    padded_message = b'\x00\x02' + padding + b'\x00' + message.encode('utf-8')

    
    return padded_message


In [9]:
def pkcs1_unpad(padded_message):
    # Ensure the first two bytes are 0x00 and 0x02
    if padded_message[:2] != b'\x00\x02':
        raise ValueError("Invalid PKCS #1 padding")
    
    # Find the position of the first 0x00 byte after the 0x02 byte
    separator_position = padded_message.find(b'\x00', 2)
    
    # Ensure the separator byte is present
    if separator_position == -1:
        raise ValueError("Invalid PKCS #1 padding")
    
    # Extract and return the message portion
    return padded_message[separator_position + 1:].decode('utf-8')

In [18]:
def is_pkcs1_padding_conforming(cnew, block_size):
    cnew_decrypted = rsa_decrypt(cnew,private_key) %n 
    message = convert_Integer_To_Bytes(cnew_decrypted,block_size,k) 
    if message[0] != 0x00 or message[1] != 0x02 :
        return False
    return True

In [19]:
#step 1: Blinding (en option si c est déjà PKCS conforming)
def step1(c):
    global nb_oracle_request
    while True: 
        s0 = random.randint(1,n)
        cprime = (c * power_mod(s0,e,n)) % n
        nb_oracle_request += 1
        if(is_pkcs1_padding_conforming(cprime,k)):
            c0 = cprime
            M0 = set([(2*B, 3*B-1)])
            break 
    return c0,M0,s0

In [20]:
def step2(previous_M, previous_s, c0): #avec un seul intervalle
    global nb_oracle_request
    for elmt in previous_M:
        a = elmt[0]
        b = elmt[1]
    lower_bound_r = 2 * (b * previous_s - 2 * B)/n
    r_i = ceil(lower_bound_r)
    while True: 
        lower_bound_si = ceil((2 * B + r_i * n) / b)
        higher_bound_si = floor((3 * B + r_i * n) / a) + 1
        for si in range(lower_bound_si, higher_bound_si):
            cnew = (c0 * power_mod(si,e,n)) % n
            if is_pkcs1_padding_conforming(cnew, k):
                cprime_decrypted = rsa_decrypt(cnew,private_key)
                message = convert_Integer_To_Bytes(cprime_decrypted,k,k)
                print("voici le message correspondant à si:",message)
                return si
            nb_oracle_request += 1
        r_i += 1
    

In [21]:
#construction des nouveaux intervales
def new_interval(previous_M, s):
    new_M = set([])
    for a_set in previous_M:
        a = a_set[0]
        b = a_set[1]
        low_bound = ceil( (a * s - 3 * B + 1) / n)
        high_bound = floor( (b * s - 2 * B ) / n) + 1
        for r in range(low_bound, high_bound):
            print("r",r)
            aa = ceil((2 * B + r * n) / s)
            bb = floor((3 * B - 1 + r * n) / s)
            newa = max(a,aa)
            newb = min(b,bb)
            if newa <= newb:
                new_M |= set([ (newa, newb) ])
    if len(new_M) == 0:
        new_M = previous_M
    return new_M

In [36]:
message = "Hello"

key_bits = 256
k = key_bits / 8
print(k)
B = 2^(8*(k-2))
print("B:",B)

public_key, private_key = generate_rsa_keypair(key_bits)
n = public_key[1]
print(n.nbits())
print("n:",n)
e = public_key[0]
print("e:", e)
d = private_key[0]


p = pkcs1_pad(message,k)
print("p:", p)
I = convert_EB_To_Integer(p) ## message convertie en entier
print("I:",I)
c = rsa_encrypt(I,public_key)
print("c:", c)
nb_oracle_request = 0
mI = rsa_decrypt(c, private_key)
m_c = convert_Integer_To_Bytes(mI,k,k)
print(m_c)
pkcs1_unpad(m_c)


32
B: 1766847064778384329583297500742918515827483896875618958121606201292619776
prime: 1
257
n: 198755298618036868330521871872507046623936415099075788736562733867548982479549
e: 65537
5
p: b'\x00\x02j\xb2\xa1\x0cD/\xe0_$]!\xb0\xae\x1c_\x8e-\x8fQ\xa7\xd8\xd8\xa9\x07\x00Hello'
I: 4270095073148365856556796903888400159753810171988057224155692729975467119
c: 94085248464267071413857763656670892501043366112897764453480296080140981379076
b'\x00\x02j\xb2\xa1\x0cD/\xe0_$]!\xb0\xae\x1c_\x8e-\x8fQ\xa7\xd8\xd8\xa9\x07\x00Hello'


'Hello'

In [37]:
def attack(c):
    print("\n\n======================================================\n",
    "Bleichenbacher Attack on RSA PKCS v1.5", 
    "\n======================================================\n",
    "\nPublic Key\ne =", e, "\nn =", n, "  (", key_bits, "bits )\n\nCiphertext")
    print(c)
    global nb_oracle_request
    start_time = time.time()
    #c0, M, s0 = step1(c)
    
    c0 = c
    s0 = 1
    M = set([(2*B, 3*B-1)])
    si = ceil( n / (3 * B))
    print(si)
    i = 1
    while True:
        print("\n\nIteration ", i,
        "\n---------------------------------------------")
            
        if i >= 1 and len(M) >= 2:
            while True:
                si+=1 #pour ne pas reprendre le même si
                cnew = (c0 * power_mod(si,e,n)) % n
                if is_pkcs1_padding_conforming(cnew, k):
                    break
                nb_oracle_request += 1
        elif len(M) == 1:
            for elmt in M:
                if elmt[0] == elmt[1]:
                    elapsed_time = time.time() - start_time
                    print("found the message in", i, "steps with ", nb_oracle_request, "request to the oracle")
                    print("Time", elapsed_time)
                    message_integer = (elmt[0]* power_mod(s0,-1,n))%n
                    print("Message integer = ", (elmt[0]* power_mod(s0,-1,n))%n)
                    message = convert_Integer_To_Bytes(message_integer,k,k)
                    print("The message is:", pkcs1_unpad(message))
                    return message_integer, nb_oracle_request, elapsed_time
            print("Entering step 2 with :"
                 "\nsi = ",si,
                 "\nM: ",M,
                 "\n---------------------------------------------")
            si = step2(M,si,c0)
        #cnew = (c0 * power_mod(si,e,n)) % n
        #cprime_decrypted = rsa_decrypt(cnew,private_key)
        #print((si*I)%n == cprime_decrypted ) --> TRUE
        M = new_interval(M,si)
        print("Exiting step 2:"
              "\nsi = ",si,
              "\nNew Interval: ",M,
            "\n---------------------------------------------")
        i += 1

In [None]:
%%time
attack(c)