# Textbook RSA

This is a pretty fun challenge that essentially boils down to: If you can determine $d_1$ you are done. For math purposes, we will use $d_1$, $N_1$, and $N_2$ to refer to `priv1`, `pub1`, and `pub2` respectively.

Let's start off by getting our parameters from the remote server.

In [1]:
from pwn import *
sh = remote('chals.ctf.sg', 10301)
def getline():
    sh.recvuntil(b'= ')
    return int(sh.recvline(), 16)

enc_flag = getline()
enc_d1 = getline()
N1 = getline()
N2 = getline()
assert N2>2*N1, "Unlucky, try again"

print(f'{enc_flag=}')
print(f'{enc_d1=}')
print(f'{N1=}')
print(f'{N2=}')

[x] Opening connection to chals.ctf.sg on port 10301
[x] Opening connection to chals.ctf.sg on port 10301: Trying 178.128.215.64
[+] Opening connection to chals.ctf.sg on port 10301: Done
enc_flag=1388714090704906567737751820371049635813317690959338577217579618718306592417951718932766694793089754045991949063726979658535268848385960389362390405296449316302361528389024290397582037
enc_d1=3925965248886933389323944939633610468454108277934339138210192235270241445899985396165397341051835966952505868094883467717628929069596773329644288186985679987676300860825217778491052387
N1=2435358538630566974005724606037897829743591167742859547132109786791544639082574131698254012704489428622440555764717019055845504876382969702524795667185669308078066892872302798671295183
N2=7129685878343952152493055619580676123740220078164018320975410092116141694830564940977902387382549139187422734846999791729348860103043228587726621551517850670710000039425133528602793683


First thing we need is a mental model of how big the values are. It should look something like this:
![model](textbook_model.png)

We know that $0 < d_1 < N_1 < N_2$, and more specifically we will require $N_1 < \frac{N_2}{2}$ for our exploit to work -- but don't worry, this happens more than 96% of the time, or to be precise $2-\frac{3}{2}\log{2}\approx 0.9603$ (proof left to the reader), and we'll just rerun the script otherwise.

Now, the idea here is we want to leak the value $t = \frac{N_2}{d_1} \in \mathbb{R}$. In the picture above, we'd have $t$ roughly equal to 7.4.

The server gives us an oracle whereby for any integer $m$ it can tell us whether $d_1 m \bmod{N_2} \geq N_1$ -- let's call this `is_error(m)`. This is because RSA is severely broken, but basically we can multiply our message by the encrypted $m^e$ and the decryptor will turn it back to $m$.

In any case, we can see that if we just sweep the small values of $m$ in order, the oracle will give us a series of `False`s followed by a series of `True`s followed by a series of `False`s again etc. In particular, this gives us enough to learn the value of $\left\lfloor t \right\rfloor$. In the picture above, we'd have `is_error(7) == True` and `is_error(8) == False`, so what we learn from this is that $7 \leq t \leq 8$.

Now we wish to split this interval in half, so that $t \in [7, 7.5]$ or $t \in [7.5,8]$. We cannot query `is_error(7.5)` directly because `7.5` is not an integer, but it turns out that is exactly equivalent to `is_error(15)`. The proof is again left to the reader as an exercise, but it's worth pointing out that the proof requires the crucial property that $N_1+d_1<N_2$, which we can only guarantee if $N_1 < \frac{N_2}{2}$.

So we can keep halving the interval this way, effectively learning one bit of $t$ per query. Since the value we want is $d_1 = \frac{N_2}{t}$, we can stop as soon as we have an interval $[t_L, t_H]$ small enough such that
$$\left\lceil\frac{N_2}{t_H}\right\rceil = \left\lfloor\frac{N_2}{t_L}\right\rfloor,$$
and this value must necessarily be the $d_1$ that we seek.

For python purposes, we represent our `t` as shifted left so it's always an integer, but the mental model should be that it is a real number.

In [2]:
from tqdm import trange

querycount = 0
def is_error(m):
    global querycount
    querycount += 1
    sh.sendline(str((pow(m, 65537, N2) * enc_d1) % N2).encode())
    return b'Error' in sh.recvline()

c1 = next(x for x in trange(2,999) if is_error(x)) - 1
print(f'{c1=}')
c2 = next(x for x in trange(c1*N2//N1+1,999) if not is_error(x)) - 1
print(f'{c2=}')

t = c2
for i in trange(999):
    upper = 2**i * N2 // t
    lower = 2**i * N2 // (t+1) + 1
    if upper==lower: break
    t = 2*t + is_error(2*t+1)

print(f'It took us a total of {querycount} queries, but we got there!')
print(f'{lower=}')

  0%|          | 1/997 [00:01<17:32,  1.06s/it]


c1=2


  0%|          | 1/993 [00:01<18:16,  1.11s/it]


c2=6


 61%|██████    | 605/999 [05:30<03:35,  1.83it/s]

It took us a total of 609 queries, but we got there!
lower=1077344320795541414931625917851759065564889987872514518370601750746303037616643868748886588239417868915668163328367309419779925665438236578280067547964746219064327646845986944575367681





Hooray! We successfully leaked $d_1$, so all that's left to do is get the flag.

In [3]:
from Crypto.Util.number import long_to_bytes
from hashlib import shake_128

def xor_bytes(a, b): return type(a)([x ^ y for x, y in zip(a, b)])
def unpad(msg: bytes) -> bytes:
    msg = bytearray(msg)
    msg[38:] = xor_bytes(msg[38:], shake_128(msg[:38]).digest(38))
    msg[:38] = xor_bytes(msg[:38], shake_128(msg[38:]).digest(38))
    msg = msg[:-msg[-1]-8]
    return bytes(msg)

unpad(long_to_bytes(pow(enc_flag, lower, N1)))

b'CTFSG{https://arxiv.org/abs/1802.03367?salt=290nlk01nx}'

## Post-solve analysis: Reducing the query count

600 queries is pretty good, but we actually only need about half that. Recall that $d_1 e = 1 + k \phi(N_1)$ for some integer $1 \leq k \leq e$. This means that
$$d_1 \approx \frac{kN}{e},$$
and in fact this approximation matches the exact value of $d_1$ to half its top bits [citation needed, probably Boneh-Durfee]. Since $k$ is small (roughly 16 bits of information), if we have 16 bits of $t$ from the original procedure then we have a small enough interval for $d_1$ such that we can pick out the specific value of $k$. In other words, we can learn 300 bits from just 16 bits, saving some 284 queries.

Not much can be done for the lower half, but towards the end  you can just brute force over all possible values of $d_1$ (as the range becomes small enough). This actually admits a meet-in-the-middle attack, so that e.g. you can learn the lowest 60 bits by just doing a $O(2^{30})$ brute force, which is a typically feasible brute force amount (depending on your appetite / GPU / time utility), so that saves another 60 or so queries.

In conclusion, this challenge can be solved in under 300 queries, though I haven't bothered to implement any of it so it's just theoretical.