In [1]:
from sage.all import *

import random
import string

In [2]:
ALPH = string.ascii_letters + string.digits

In [3]:
def rand_str(n):
    return ''.join(random.choice(ALPH) for _ in range(n))

In [4]:
def str_to_int(s):
    return int(''.join(hex(ord(c))[2:].zfill(2) for c in s), 16)

In [5]:
def solve_dicrete_log(a, b, N):
    K = GF(N)
    k_a = K(a)
    k_b = K(b)
    return k_b.log(k_a)

In [6]:
def try_to_decode(mess, enc):
    bit_len = max(len(bin(mess)), len(bin(enc))) - 2
    pq_len = (bit_len // 2 + random.randint(1, 3))
    p, q = random_prime(2 ^ pq_len), random_prime(2 ^ pq_len)
    N = p * q
    phi = (p - 1) * (q - 1)
    
    x1 = solve_dicrete_log(mess, enc, p)
    x2 = solve_dicrete_log(mess, enc, q)
    
    e = crt([x1, x2], [p - 1, q - 1])
    d = inverse_mod(e, phi)
    return e, d, N

In [7]:
def check_solution(e, d, N, mess, enc):
    encoded_test = power_mod(mess, e, N)
    decoded_test = power_mod(encoded_test, d, N)
    return encoded_test == enc and decoded_test == mess

In [8]:
def solve(mess, enc):
    cnt = 1
    while True:
        try:
            re, rd, rN = try_to_decode(mess, enc)
            if not check_solution(re, rd, rN, mess, enc):
                raise ValueError
        except:
            cnt += 1
            continue
        else:
            break
    return re, rd, rN, cnt

#### Test message generation solution code

In [9]:
def test_all():
    counts = []
    for _ in range(100):
        s = rand_str(12)
        num = str_to_int(s)
        bit_len = len(bin(num)) - 2
        p, q = random_prime(2 ^ (bit_len // 2 + 1)), random_prime(2 ^ (bit_len // 2 + 1))
        N = p * q
        phi = (p - 1) * (q - 1)

        while True:
            e = random_prime(2^bit_len, lbound=17)
            try:
                d = inverse_mod(e, phi)
            except ZeroDivisionError:
                continue
            else:
                break

        enc = power_mod(num, e, N)

        re, rd, rN, cnt = solve(num, enc)
        counts.append(cnt)
#         print('Got solution ({}, {}, {}) in {} tries\n'.format(re, rd, rN, cnt))

    print('Average tries: {}'.format(float(sum(counts) / len(counts))))

In [10]:
%time test_all()

Average tries: 42.22
CPU times: user 56.7 s, sys: 493 ms, total: 57.2 s
Wall time: 57.7 s


### The solution itself

In [11]:
import socket
import time
import re

In [12]:
def test_solution():
    sock = socket.socket()
    sock.connect(('localhost', 1337))
    time.sleep(0.1)

    while True:
        s = sock.recv(1024).decode()
        if 'shadowctf' in s:
            print('Got a flag!')
            break
        mess, enc = map(int, re.findall(r": (\d+)", s))
        e, d, N, _ = solve(mess, enc)
        sock.send("{} {} {}\n".format(e, d, N))
        time.sleep(0.1)

In [13]:
%time test_solution()

Got a flag!
CPU times: user 46.7 s, sys: 404 ms, total: 47.1 s
Wall time: 58 s
