In [None]:
import random
import signal
import string
from Crypto.Util.number import *
from hashlib import sha256

from os import urandom

FLAG = b'ctf{f3k3_fl4g_g0_brr_____}'

P_BITS = 512
E_BITS = int(P_BITS * 2 * 0.292) + 30
CNT_MAX = 7

class LCG:
    def __init__(self):
        self.init()

    def next(self):
        out = self.s[0]
        self.s = self.s[1: ] + [(sum([i * j for (i, j) in zip(self.a, self.s)]) + self.b) % self.p]
        return out

    def init(self):
        while True:
            p = getPrime(2 * P_BITS)
            if p.bit_length() == 2 * P_BITS:
                self.p = p
                break
        self.b = getRandomRange(1, self.p)
        self.a = [getRandomRange(1, self.p) for _ in range(6)]
        self.s = [getRandomRange(1, self.p) for _ in range(6)]

class RSA:
    def __init__(self, l, p=0, q=0):
        self.l = l
        if not p:
            while True:
                p = getPrime(P_BITS)
                if p.bit_length() == P_BITS:
                    self.p = p
                    break
            while True:
                q = getPrime(P_BITS)
                if q.bit_length() == P_BITS:
                    self.q = q
                    break
        else:
            self.p = abs(p)
            self.q = abs(q)
        self.e = getPrime(E_BITS)
        self.check()

    def enc(self, m):
        return pow(m, self.e, self.n)

    def noisy_enc(self, m, r=1):
        if r:
            self.refresh()
        return pow(m, self.e ^^ self.l.next(), self.n)

    def dec(self, c):
        return pow(c, self.d, self.n)

    def check(self):
        assert self.p.bit_length() == P_BITS
        assert self.q.bit_length() == P_BITS
        self.n = self.p * self.q
        self.phi = (self.p - 1) * (self.q - 1)
        assert self.e.bit_length() >= E_BITS
        assert self.e < self.phi
        assert GCD(self.e, self.phi) == 1
        self.d = inverse(self.e, self.phi)
        assert self.d.bit_length() >= E_BITS
        for _ in range(20):
            x = self.l.next() % self.n
            assert self.dec(self.enc(x)) == x

    def refresh(self):
        self.e = (self.e ^^ self.l.next()) % (2 ** E_BITS)

class Task:
    def __init__(self):
        pass

    def proof_of_work(self):
        random.seed(urandom(16))
        proof = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(20)])
        digest = sha256(proof.encode()).hexdigest()
        print(f'sha256(XXXX + {proof[4:]}) == {digest}')
        print('Give me XXXX:')
        x = input().strip()
        if len(x) != 4 or sha256((x + proof[4:]).encode()).hexdigest() != digest:
            return False
        return True

    def timeout_handler(self, signum, frame):
        raise TimeoutError

    def recv_hex(self, l):
        return int(input().strip(), 16)

    def handle(self):
        #signal.signal(signal.SIGALRM, self.timeout_handler)
        #signal.alarm(60)
        # if not self.proof_of_work():
        #     pass
        #     print('You must pass the PoW!')
        #     return

        #signal.alarm(20)
        print('Give me your RSA key plz.')
        pq = [self.recv_hex(P_BITS // 4) for _ in range(2)]
        lcg = LCG()
        alice = RSA(lcg)
        bob = RSA(lcg, *pq)
        secrets = getRandomNBitInteger(P_BITS // 8)
        secrets_ct = alice.enc(secrets)
        print(f'{alice.e}\n{alice.n}')
        print(f'{lcg.p}\n{lcg.a}\n{lcg.b}\n{lcg.s}')

        CNT = 0
        while CNT < CNT_MAX:
            print('pt: ', end='')
            pt = self.recv_hex(P_BITS // 2)
            if pt == 0:
                break
            ct = alice.noisy_enc(pt)
            ct = bob.noisy_enc(ct)
            print('ct:', hex(ct))
            CNT += 1
        print(secrets_ct)
        secrets_ct = bob.noisy_enc(secrets_ct)
        print('secrets_ct:', hex(secrets_ct))
        lcg.init()
        bob = RSA(lcg, *pq)
        print(f'{lcg.p}\n{lcg.a}\n{lcg.b}\n{lcg.s}')

        seen = set()
        while CNT < CNT_MAX:
            print('ct: ', end='')
            ct = self.recv_hex(P_BITS // 2)
            if ct == 0:
                break
            pt = alice.dec(ct)
            if pt in seen:
                print('You can only decrypt each ciphertext once.')
                return
            else:
                seen.add(pt)
            pt = bob.noisy_enc(pt)
            print('pt:', hex(pt))
            CNT += 1

        guess = self.recv_hex(P_BITS // 4)
        if guess == secrets:
            print('Wow, how do you know that?')
            print('Here is the flag:', FLAG)
        else:
            print('Wrong!')
        return

if __name__ == "__main__":
    task = Task()
    task.handle()


In [1]:
from binascii import hexlify
from gmpy2 import *
import math
import os
import sys


SEED = mpz(hexlify(os.urandom(32)).decode(), 16)
STATE = random_state(SEED)


def get_prime(state, bits):
    return next_prime(mpz_urandomb(state, bits) | (1 << (bits - 1)))


def get_smooth_prime(state, bits, smoothness=16):
    p = mpz(2)
    p_factors = [p]
    while p.bit_length() < bits - 2 * smoothness:
        factor = get_prime(state, smoothness)
        p_factors.append(factor)
        p *= factor
        

    bitcnt = (bits - p.bit_length()) // 2

    while True:
        prime1 = get_prime(state, bitcnt)
        prime2 = get_prime(state, bitcnt)
        tmpp = p * prime1 * prime2
        if tmpp.bit_length() < bits:
            bitcnt += 1
            continue
        if tmpp.bit_length() > bits:
            bitcnt -= 1
            continue
        if is_prime(tmpp + 1):
            p_factors.append(prime1)
            p_factors.append(prime2)
            p = tmpp + 1
            break

    p_factors.sort()

    return (p, p_factors)


#P, pfs = get_smooth_prime(STATE, 512)


In [2]:
from tqdm import tqdm

P_BITS = 512

def dlog(g, y, p, pfs):
    F = GF(p)
    mods, res = [], []
    for prime in pfs:
        r = (p-1)//prime
        Pg = F(g)^r
        Py = F(y)^r
        res.append(discrete_log(Py, Pg))
        mods.append(prime)
    return crt(res, mods)

def dlog_composite(g, y, primes, facts):
    mods, res = [], []
    for p, pf in tqdm(zip(primes, facts)):
        res.append(dlog(g, y, p, pf))
        mods.append(p - 1)
    return crt(res, mods)


In [3]:
while True:
    p, pf1 = get_smooth_prime(STATE, 512)
    p = int(p)
    if p.bit_length() == P_BITS: break
        
while True:
    q, pf2 = get_smooth_prime(STATE, 512)
    q = int(q)
    if q.bit_length() == P_BITS: break
        
p, q

(10902995298555379214173922615534898283384518483931807056619746678584622512049509408032957357947895499327690145411631247921912901667778501909030057938518863,
 12274909239996328672653125301682106499291367864729384741992307956851485249039354146766582663498689810555011312721394468295425937240245825348732639190926223)

In [4]:
from Crypto.Util.number import *

n = p * q
secrets = getRandomNBitInteger(P_BITS // 8)

print(secrets)

y = pow(3, secrets, n)


13299307709408882842


In [5]:
sec = dlog_composite(3, y, [p, q], [pf1, pf2])
print(sec)

2it [00:04,  2.11s/it]

13299307709408882842





In [16]:
phi = (p - 1) * (q - 1)
sec % phi

16803998651525666372162064344294759209429351018507714974539262720787015393812605036758150911956777421784313067608844950198472564436163888022606141012136983634037993757915092707350638013684471538263234576781601709063840037270774921898842387279600232084495849043684891158454993726397556946149024199430662399242

In [17]:
pow(3, sec, n) == y

False

In [12]:
P_BITS = 512
E_BITS = int(P_BITS * 2 * 0.292) + 30
CNT_MAX = 7

class LCG:
    def __init__(self):
        self.init()

    def next(self):
        out = self.s[0]
        self.s = self.s[1: ] + [(sum([i * j for (i, j) in zip(self.a, self.s)]) + self.b) % self.p]
        return out

    def init(self):
        while True:
            p = getPrime(2 * P_BITS)
            if p.bit_length() == 2 * P_BITS:
                self.p = p
                break
        self.b = getRandomRange(1, self.p)
        self.a = [getRandomRange(1, self.p) for _ in range(6)]
        self.s = [getRandomRange(1, self.p) for _ in range(6)]
        
lmao = LCG()

for i in range(20):
    print(lmao.next() & 1, end = ' ')

0 1 0 0 1 0 1 1 1 0 1 0 0 0 0 1 1 1 1 0 

In [14]:
import socket

def sendline(sock, data):
    sock.sendall(data + b"\n")

def recvline(sock):
    buf = b""
    while True:
        data = sock.recv(1)
        if not data:
            break
        buf += data
        if data == b"\n":
            break
    return buf

def recvuntil(sock, delim):
    buf = b""
    while True:
        data = sock.recv(1)
        if not data:
            break
        buf += data
        if delim in buf:
            break
    return buf


In [53]:
E_BITS = int(P_BITS * 2 * 0.292) + 30

class LCG:
    def __init__(self, p, a, b, s):
        self.p = p
        self.a = a
        self.b = b
        self.s = s

    def next(self):
        out = self.s[0]
        self.s = self.s[1: ] + [(sum([i * j for (i, j) in zip(self.a, self.s)]) + self.b) % self.p]
        return out


def run():
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.connect(('0.0.0.0', 5000))

    recvline(sock)
    
    while True:
        p, pf1 = get_smooth_prime(STATE, 512)
        p = int(p)
        if p.bit_length() == P_BITS: break
        
    while True:
        q, pf2 = get_smooth_prime(STATE, 512)
        q = int(q)
        if q.bit_length() == P_BITS: break
            
            
    sendline(sock, str(hex(p)[2:]).encode())
    sendline(sock, str(hex(q)[2:]).encode())
    
    n = p * q
    
    alice_e = int(recvline(sock).decode().strip())
    alice_n = int(recvline(sock).decode().strip())
    
    if alice_n > n:
        return
    
    print(':Thonk')
    
    phi = (p - 1) * (q - 1)
    
#     secret = getRandomNBitInteger(P_BITS // 8)
#     e = getPrime(E_BITS)
#     e = secret ^^ e
#     print(e)
#     base = getRandomNBitInteger(P_BITS // 8)
#     g = pow(base, alice_e ^^ getRandomNBitInteger(P_BITS // 8), alice_n)
#     y = pow(g, e, n)
#     sec = dlog_composite(g, y, [p, q], [pf1, pf2])
#     print(sec)
#     print('IMaginray calcs done')
    
    
    lcg_p = int(recvline(sock).decode().strip())
    lcg_a = eval(recvline(sock).decode().strip())
    lcg_b = int(recvline(sock).decode().strip())
    lcg_s = eval(recvline(sock).decode().strip())
    
    lcg = LCG(lcg_p, lcg_a, lcg_b, lcg_s)
    rands = [lcg.next() for _ in range(4)]
    mask = -1
    cnt = 0
    
    for idx, rand in enumerate(rands):
        if rand %2 == 0 and idx > 0:
            break
        recvuntil(soc, b'pt: ')
        sendline(sock, b'2')
        y = int(recvline(sock).decode().strip().split(': ')[1], 16)
        mask = rand
        cnt += 1
        
    assert(mask > 1)
    g = pow(2, alice_e ^^ mask, alice_n)
    e1_ = dlog_composite(g, y, [p, q], [pf1, pf2])
    e1 = int(e1_ ^^ rands[0])
    print(e1)
    
    recvuntil(sock, b'pt: ')
    sendline(sock, b'0')
    
    recvline(sock)
    secrets_ct = int(recvline(sock).decode().strip().split(': ')[1], 16)
    
    e1_ = int(e1 ^^ rands[1])
    print(e1_)
    d1_ = int(pow(e1_, -1, phi))
    secrets_ct = int(pow(secrets_ct, d1_, n))
    print(secrets_ct < alice_n)
    
    print(':On to the next Round')
    
    lcg_p = int(recvline(sock).decode().strip())
    lcg_a = eval(recvline(sock).decode().strip())
    lcg_b = int(recvline(sock).decode().strip())
    lcg_s = eval(recvline(sock).decode().strip())
    
    lcg = LCG(lcg_p, lcg_a, lcg_b, lcg_s)
    rands = [lcg.next() for _ in range(2)]
    assert(rands[0] % 2 != rands[1] % 2)
    
    enc_three = pow(3, alice_e, alice_n)
    
    if rands[0] & 1:
        recvuntil(sock, b'ct: ')
        sendline(sock)
        recvline(sock)
        
    

while True:
    print('Trying...')
    run()

Trying...
Trying...
Trying...
Trying...
Trying...
:Thonk
:Thonk more :0 :0 :0


2it [00:03,  1.85s/it]


71687207728299686893086699259754507009999983848301852034869567633627621575973118141486727970972958594319253873808518804524428754409896434098627489975974995740980903834076940639104075022611609796256833547671978046446336074155639009222202698402948011504591473895633103622956857892058273403392629827848829318319
118097862335938588527995196628675679663538535854331156458324567112049786007918496567289518884297265882767136772353598937937595136396104758159822422821197257994812601480399619990952487813467781522910144063339757031182012079279666594975279607272088167339758367137492885478282671066959615909711830367815610014131
True
:On to the next Round


AssertionError: 

In [43]:
x = eval('[2, 3, 4]')
print(x)

[2, 3, 4]
