In [1]:
# This is alternate solution to this challenge
# My original solution did work... but for some reason only with RSA BIT_STRENGTH <= 32
# I have no idea why this is.
# I'm re-solving this challenge, but more closely following Dr. Paccagnella's code, which is confirmed to work.
# The first solution is largely my own code, working directly from Bleichenbacher's paper, with only incidental
# references to Dr. Paccanella's code, whereas this solution will be more of a translation of Dr. Paccanella's
# code in my own style.

from RSA import RSAClient, RSAServer, BIT_STRENGTH
from Crypto import Random

server = RSAServer()
client = RSAClient(**server.GetPubkey())

In [2]:
# First make sure all the RSA code is working

message = b'A'*(BIT_STRENGTH//8)
ciphertext = client.Encrypt(message)
result = server.DecryptBytes(ciphertext)

print(result)
print(len(result))
assert result[-len(message):] == message

b'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA'
32


In [3]:
Ceil  = lambda numerator, denominator: (numerator + denominator - 1) // denominator
Floor = lambda numerator, denominator: Ceil(numerator, denominator) - 1

assert  Ceil(10, 3) == 4
assert Floor(10, 3) == 3

In [4]:
class Oracle(RSAServer):
    def IsPaddingCorrect(self, encrypted_int_data):
        plaintext = self.DecryptBytes(encrypted_int_data)
        return len(plaintext) == Ceil(self.n.bit_length(), 8) and plaintext[:2] == b'\x00\x02'
    
def Pad(binary_data, length = BIT_STRENGTH//8):
    padding_string = Random.new().read(length - 3 - len(binary_data))
    return b'\x00\x02' + padding_string + b'\x00' + binary_data

Pad(b'Hello, world!')

b'\x00\x02\xf7\xb2\xa0=5\xa6/\x9e\xf2jNx\xcc=R\xcc\x00Hello, world!'

In [5]:
# Set up the attack
oracle = Oracle()
pubkey = oracle.GetPubkey()
e, n = pubkey['e'], pubkey['n']
client = RSAClient(**pubkey)
message = b'Hello, world!'
ciphertext = client.Encrypt(Pad(message))

B = 2**(BIT_STRENGTH-2*8)

In [6]:
# Do the attack

c_0 = ciphertext
M = [(2 * B, 3 * B - 1)]
i = 1

if not oracle.IsPaddingCorrect(c_0):
    # This shouldn't happen
    assert False
    while True:
        s = randint(0, n - 1)
        c_0 = (ciphertext * pow(s, e, n)) % n
        if oracle.IsPaddingCorrect(c_0):
            break

In [7]:
def CalculateNewM(intervals, lower_bound, upper_bound):

    for i, (a, b) in enumerate(intervals):

        # If there is an overlap, then replace the boundaries of the overlapping
        # interval with the wider (or equal) boundaries of the new merged interval
        if not (b < lower_bound or a > upper_bound):
            new_a = min(lower_bound, a)
            new_b = max(upper_bound, b)
            intervals[i] = new_a, new_b
            return

    # If there was no interval overlapping with the one we want to add, add
    # the new interval as a standalone interval to the list
    intervals.append((lower_bound, upper_bound))

while True:

    # Step 2.a: Starting the search
    if i == 1:
        s = Ceil(n, 3 * B)
        while True:

            c = (c_0 * pow(s, e, n)) % n
            if oracle.IsPaddingCorrect(c):
                break

            s += 1

    # Step 2.b: Searching with more than one interval left
    elif len(M) >= 2:
        while True:
            s += 1
            c = (c_0 * pow(s, e, n)) % n

            if oracle.IsPaddingCorrect(c):
                break

    # Step 2.c: Searching with one interval left
    elif len(M) == 1:
        a, b = M[0]

        # Check if the interval contains the solution
        if a == b:
            solution = a
            break

        r = Ceil(2 * (b * s - 2 * B), n)
        s = Ceil(2 * B + r * n, b)

        while True:
            c = (c_0 * pow(s, e, n)) % n
            if oracle.IsPaddingCorrect(c):
                break

            s += 1
            if s > (3 * B + r * n) // a:
                r += 1
                s = Ceil((2 * B + r * n), b)

    # Step 3: Narrowing the set of solutions
    M_new = []

    for a, b in M:
        min_r = Ceil(a * s - 3 * B + 1, n)
        max_r = (b * s - 2 * B) // n

        for r in range(min_r, max_r + 1):
            l = max(a, Ceil(2 * B + r * n, s))
            u = min(b, (3 * B - 1 + r * n) // s)

            if l > u:
                raise Exception('Unexpected error: l > u in step 3')

            # Do append and merge
            CalculateNewM(M_new, l, u)

    if len(M_new) == 0:
        raise Exception('Unexpected error: there are 0 intervals.')

    M = M_new
    i += 1

In [8]:
solution

3589675748917653109152975854075669722762603212638742011209542056436065313

In [9]:
cracked_message = solution.to_bytes(byteorder='big', length = BIT_STRENGTH//8)
print(cracked_message)
print(message)

b'\x00\x02\x08\x1cyBX8\x17\nxT;g_z\xccb\x00Hello, world!'
b'Hello, world!'
