In [1]:
# As before, I used Dr. Paccagnella's implementation as a guide in several areas
# https://github.com/ricpacca/cryptopals/blob/master/S6C47.py

from RSA import BIT_STRENGTH, RSAClient, RSAServer
from random import choice, randint
from math import ceil, floor

def PCKS1Pad(plaintext: bytes, k = 2*BIT_STRENGTH//8):
    # k is maximum plaintext length in bytes
    assert type(plaintext) is bytes
    #assert len(plaintext) <= k-11
    
    padding = [ 1+choice(range(0xFF)) for pad in range(k-3-len(plaintext)) ]
    assert not any( char == 0x00 for char in padding )
    # Omit the following requirement in order to make this work with smaller bit sizes
    # assert len(padding) >= 8
    return b'\x00\x02' + bytes(padding) + b'\x00' + plaintext


def PCKS1Check(encryption_block: bytes):
    assert type(encryption_block) is bytes

    if encryption_block[:2] == b'\x00\x02':
        return True
    else:
        return False
    
def PCKS1UnPad(plaintext: bytes):
    assert PCKS1Check(plaintext)
    plaintext = plaintext[2:]
    index = plaintext.index(b'\x00')
    return plaintext[1+index:]

class PCKS1Oracle(RSAServer):
    def PCKS1Check(self, ciphertext: int):
        # The cryptopals instructions say this method only needs to check the first two bytes
        # But Dr. Paccgnella's implementation says it needs to check the length in order to work
        # and he spent "hours" debugging this issue, so I'm going to do this preemptively.
        encryption_block = self.DecryptBytes(ciphertext)
        return PCKS1Check( encryption_block ) and len(encryption_block) == ceil(self.n.bit_length() / 8)

In [2]:
# B-1 is the greatest integer that can be encrypted, 
# excluding the two bytes required for the 0x0002 prefix
# So B would correspond to 0x0001..., so 2B = 0x0002... and 3B = 0x06... = 0b00000000,00000110...

k = 2*BIT_STRENGTH//8
B = 2 ** (8*(k-2))

In [3]:
# If a number passes the oracle, it must be in [2B, 3B), because 3B would be invalid
# This is not biconditional, however

In [4]:
oracle = PCKS1Oracle()
pubkey = oracle.GetPubkey()
client = RSAClient(**pubkey)
print(pubkey)

{'e': 3, 'n': 12872930962670972071}


In [5]:
# Check that everything is working

message = b'Hello'
padded_message = PCKS1Pad(message)
ciphertext = client.Encrypt(padded_message)
assert oracle.PCKS1Check(ciphertext)
assert PCKS1UnPad(oracle.DecryptBytes(ciphertext)) == message

In [6]:
# Now do the attack, starting with the 
# Step 1 "blinding" step

solution = None
i = 0
e, n = pubkey['e'], pubkey['n']

c = int(ciphertext)
s0 = 1

while True:
    if oracle.PCKS1Check(c * pow(s0, e, n) % n):
        c0 = (c * pow(s0, e, n)) % n
        M0 = [(2*B, 3*B-1)]
        i = 1
        break
    else:
        s0 = randint(0, n-1)

print(f'{i = }, {s0 = }, {c0 = }')

i = 1, s0 = 1, c0 = 12737469434191777123


In [7]:
# Step 2a

assert i == 1 and oracle.PCKS1Check(c0)

s1 = ceil(n/(3*B))

while not oracle.PCKS1Check(c0 * pow(s1, e, n) % n):
    s1 += 1
c1 = c0 * pow(s1, e, n) % n
i = 2
M1 = M0
print(f'{i = }, {s1 = }, {c1 = }')

i = 2, s1 = 45710, c1 = 3095778069094715468


In [8]:
M = M1
s = s1

while solution is None:
    print(f'{i = }')
    # Step 2b
    
    assert i > 1

    if len(M) >= 2:
        print('\tM has more than two elements')
        s += 1
        while not oracle.PCKS1Check(c0 * pow(s, e, n) % n):
            s += 1
    elif len(M) == 1:
        print('\tM has one element')
        [(a, b)] = M
        assert b >= a

        if a == b:
            # the interval is fully closed and this must be the solution
            solution = a
            
        r = ceil(2 * (b * s - 2 * B) / n)
        s = ceil((2 * B + r * n) / b)
        print(f'\t{r=}, {s=}')

        while not oracle.PCKS1Check(c0 * pow(s, e, n) % n):
            assert s < n
            assert r < n
            s += 1
            if not s < (3*B + r*n) / a:
                r += 1
                s = ceil((2 * B + r * n) / b)

    else:
        # This should not happen
        assert False
        
    # Step 3
    print('Starting step 3')
    Mnew = []
    for a, b in M:
        r_values = range(ceil((a*s - 3*B + 1)/n), 1+floor((b*s - 2*B)/n))
        print(f'{len(r_values)=}')
        for r in r_values:
            anew =  ceil( (2*B + r*n)/s )
            bnew = floor( (3*B - 1 + r*n)/s )
            Mnew += [ (max(a, anew), min(b, bnew)) ]
    M = Mnew
    i += 1


i = 2
	M has one element
	r=6, s=91469
Starting step 3
len(r_values)=3
i = 3
	M has more than two elements
Starting step 3
len(r_values)=1
len(r_values)=0
len(r_values)=1
i = 4
	M has more than two elements
Starting step 3
len(r_values)=1
len(r_values)=0
i = 5
	M has one element
	r=21, s=479941
Starting step 3
len(r_values)=1
i = 6
	M has one element
	r=45, s=1028445
Starting step 3
len(r_values)=1
i = 7
	M has one element
	r=93, s=2125451
Starting step 3
len(r_values)=1
i = 8
	M has one element
	r=189, s=4319463
Starting step 3
len(r_values)=1
i = 9
	M has one element
	r=379, s=8661779
Starting step 3
len(r_values)=1
i = 10
	M has one element
	r=763, s=17437829
Starting step 3
len(r_values)=1
i = 11
	M has one element
	r=1527, s=34898511
Starting step 3
len(r_values)=1
i = 12
	M has one element
	r=3057, s=69865583
Starting step 3
len(r_values)=1
i = 13
	M has one element
	r=6115, s=139754019
Starting step 3
len(r_values)=1
i = 14
	M has one element
	r=12231, s=279530891
Starting step 

In [9]:
print(cracked_message := solution.to_bytes(byteorder='big', length = 2*BIT_STRENGTH//8) )
assert padded_message == cracked_message

b'\x00\x02\x00Hello'


In [None]:
# So, I only got this working with relatively small bit strengths.
# With strengths larger than 32, I found the program would take an intolerably long time to find a solution,
# or sometimes wouuld converge early to an incorrect solution.  
# Regardless, I decided that solving this in the case of 32 bits was sufficient, that there was probably some
# minor error causing this problem that wasn't worth hours of debugging, and that is was time to move on.