# 题目生成代码
### Note: 请在 SageMath 环境下运行

In [1]:
NO_OF_RSA_KEYS = 50


def generate_prime_candidate(length):
    """ Generate an odd integer randomly
        Args:
            length -- int -- the length of the number to generate, in bits
        return a integer
    """
    # generate random bits
    p = getrandbits(length)
    # apply a mask to set MSB and LSB to 1 (making sure that first bit is set and also is odd)
    p |= (1 << length - 1) | 1
    return p

def generate_prime(bits):
    p = generate_prime_candidate(bits)
    return next_prime(p)

def generate_pq(bits, known_prime=None):
    if known_prime:
        return known_prime, generate_prime(bits)
    return generate_prime(bits), generate_prime(bits)

def generate_rsa_params(bits, known_prime=None):
    e = 65537
    p,q = generate_pq(bits, known_prime)
    n =p*q
    phi = (p-1)*(q-1)
    assert gcd(phi, e) == 1
    d = inverse_mod(e, phi)
    return n, e, p, q, d

def rsa_encrypt(pt, e, n):
    return pow(pt, e, n)

def rsa_decrypt(ct, d, n):
    return pow(ct, d, n)

In [2]:
rsa_params = [generate_rsa_params(1024) for _ in range(NO_OF_RSA_KEYS-1)]

In [3]:
# Here is where the actual vuln happens - 2 RSA keys share the same prime factor
p = generate_prime(1024)
rsa1 = generate_rsa_params(1024, p)
rsa2 = generate_rsa_params(1024, p)
rsa_params.append(rsa1)
rsa_params.append(rsa2)
shuffle(rsa_params)
assert gcd(rsa1[0], rsa2[0]) == p

In [4]:
with open("flag.txt", "r") as f:
    flag_str = f.read()

In [6]:
# Use one of those keys to encrypt flag
n, e, p, q, d = rsa1
import binascii
#flag_bytes = int(flag_str.encode("hex"), 16)
tmp=binascii.b2a_hex(bytes(flag_str,encoding='utf8'))
flag_bytes = int(tmp,16)
ct = rsa_encrypt(flag_bytes, e, n)

In [7]:
chall = """INTERCEPTED CIPHERTEXT
==================

ct = {0}


INTERCEPTED KEYS:
(Format: keys = [(n, e), (n, e), .... , (n, e)])
==================

keys = [{1}]


""".format(
    ct,
    ",".join(["[{0}, {1}]".format(key[0], key[1]) for key in rsa_params])
) 
print(chall)
# chall = """n = {0}

# e = {1}

# ct = {2}
# """.format(n, e, ct)

INTERCEPTED CIPHERTEXT

ct = 20118276453892579066339909282308048126097781178508942228028710390784826579726610411164722205621918123106446232145614779743477644810889368889582420924275310461925446264009790560420197830565358237564885261035298063523750318192795166462539249084090792002991859874734573441305387439896199568295963663776338229127860698832416609801326562119147629560567153964082905367248727717187996048111518887310625421252052695952078950428256474331819084796406544927117899492535530208975071753021861181520708690503120163974585952656364809126172174421737541000956846337361339632433142706673737147453793498724462625874577301643912632478815


INTERCEPTED KEYS:
(Format: keys = [(n, e), (n, e), .... , (n, e)])

keys = [[1263481717179427124786892103530611891998562898889140515209410805348594709397430564918722844446877508504004403761371389804291394570464269103205741815327321164765776032513740537847210529618351220307025233104989312231623771065108675145729346282360122876253418084078069634325576

In [8]:
with open("chall.txt", "w") as f:
    f.write(chall)

# 示例解答
最简单的办法是两两测试寻找公因数，但密钥过多得情况下很麻烦，采用 https://factorable.net/weakkeys12.conference.pdf 论文中提到的简化的方法如下：

In [9]:
import copy

def product_level(tree_level):
    level = copy.deepcopy(tree_level)
    if len(tree_level) % 2 != 0:
        level.append(1)
    return [level[i]*level[i+1] for i in range(0,len(level),2)]

def build_product_tree(numbers):
    tree_level = numbers
    product_tree = [tree_level]
    while len(tree_level) != 1:
        tree_level = product_level(tree_level)
        product_tree.append(tree_level)
    return product_tree

def build_remainder_tree(product_tree):
    remainder_tree = []
    level = product_tree[0]
    for i in range(1, len(product_tree)):
        next_level = product_tree[i]
        remainder_level = []
        for j in range(len(next_level)):
            sq = pow(next_level[j],2)
            product_parent = level[j//2]
            remainder_level.append(product_parent % sq)
        remainder_tree.append(remainder_level)
        level = remainder_level
    return remainder_tree

def find_shared_factors(prod_tree, rem_tree):
    results = []
    for idx, remainder in enumerate(rem_tree[-1]):
        moduli = prod_tree[0][idx]
        quotient = remainder//moduli
        results.append(gcd(quotient, moduli))
    return results

product_tree = build_product_tree(list(map(lambda x: x[0], rsa_params)))

remainder_tree = build_remainder_tree(product_tree[::-1])

results = find_shared_factors(product_tree,remainder_tree)

In [10]:
# When the element in the result list is not 1 then that means that a shared factor has been found.
# The index of that factor is the corresponding index of the key in the rsa_params list.
results

[1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 174207074869769245216828780377687199097860593989360776354825328855344958796239782647328126693132149691015425442580527819234175345165191730237185209528449591033357649321832992605031987197922077499429956601055199631671808946686920483830948483133395993133453391093912230365911667888905523690208917071166391354383,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 174207074869769245216828780377687199097860593989360776354825328855344958796239782647328126693132149691015425442580527819234175345165191730237185209528449591033357649321832992605031987197922077499429956601055199631671808946686920483830948483133395993133453391093912230365911667888905523690208917071166391354383,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1]

In [11]:
common_factor = max(results)
indices = [idx for idx, f in enumerate(results) if f == common_factor]

In [28]:
# Attempt to decrypt flag with cracked keys

for idx in indices:
    print("\n[+] Attempting decryption with key [{0}]".format(idx))
    n, e = rsa_params[idx][0], rsa_params[idx][1]
    p = common_factor
    q = n/p
    phi = (p-1)*(q-1)
    assert gcd(phi, e) == 1
    d = inverse_mod(e, phi)
    pt = rsa_decrypt(ct, d, n)
    print(hex(Integer(pt)))
    #output = binascii.hexlify(bytes(hex(Integer(pt))))
    al = []
    a=hex(Integer(pt))[2:]
    for i in range(0, len(a), 2):
        b = a[i:i+2]
        al.append(chr(int(b, 16)))
    print (''.join(al))


[+] Attempting decryption with key [23]
0x1785a4d09125d280b186327781cf397a97d5b39776db6bf75d69759fa8ad3c0c42e9584a8975f7ba314b5848a4844ea016839860836ab2b21ca94a0e42df3515db105149c29e3cb98a8e39f06acd6c5f0d0c5f2f6391b2659719796f9cc6c750fa14a05dee0d1eb1720d09dd7d4952d0cd8baa7e635f14282c0a3a4428a717ce87ea419414750958daad6b18ac6b6a3af7ded031ffd9ce0a630dbb8643eca7d329339b4d5a6b3bf3002914d9cdb823d30a2c5ac964183d76a677911add0c2e7b42a5163e9bf519fa306565f5c5c3f108d93d7e4fb057ec3ffa1971c0fe82d4e6b06df4655f813cdc38a52f5a5f538dba0cd4f379d712f9f6e369197dbc8a54bb
¤Ð%Ò±2wÏ9zÕ³vÛk÷]iu¨­<BéXJu÷º1KXH¤N `j²²©JBß5ÛQIÂ<¹9ðjÍl__/c²eyoÆÇPú ]î±r	Ý}IRÐÍª~c_(,
:D(§ÎêAu	XÚ­k¬kj:÷ÞÐ1ÿÙÎ
c»Cì§Ó)3MZk;ó )ÙÍ¸#Ó
,ZÉd=v¦wÝ.{B¥>õú0eeõÅÃñÙ=~O°Wì?úqÀþÔæ°môe_<Ü8¥/Z_SºÔóy×ùöãi}¼T»

[+] Attempting decryption with key [35]
0x6c7a756374667b7875656a69335f7a68336e5f7368754031217d
lzuctf{xueji3_zh3n_shu@1!}
