In [71]:
from sage.all import *
import struct
import re
import base64
from collections import deque
from Crypto.PublicKey import RSA
from Crypto.Cipher import AES
from Crypto import Random
from Crypto.Cipher import PKCS1_v1_5

In [26]:
TQDM_ON = True
if TQDM_ON:
    from tqdm import tqdm

In [30]:
def fast_pairwise_gcd_helper_recursive(nums:list):
    length = len(nums)
    if length<2:
        raise ValueError

    nums_sage = list(map(lambda x:Integer(x),nums))
    nums_square = list(map(lambda x:x**2,nums_sage))

    total_product = Integer(1)
    for num_s in nums_sage:
        total_product*=num_s
    
    sq_product_table = dict()
    def calculate_square(l,r):
        if l==r:
            sq_product_table[(l,l)]=nums_square[l]
            return
        m = (l+r)>>1
        calculate_square(l,m)
        l_res = sq_product_table[(l,m)]
        calculate_square(m+1,r)
        r_res = sq_product_table[(m+1,r)]
        sq_product_table[(l,r)] = l_res*r_res
    calculate_square(0,length-1)

    mod_ni_square = [None]*length
    def calculate_mode_ni_square(previous_mod_res,l,r):
        this_mod_res = previous_mod_res%sq_product_table[(l,r)]
        if l==r:
            mod_ni_square[l] = this_mod_res
            return
        m = (l+r)>>1
        calculate_mode_ni_square(this_mod_res,l,m)
        calculate_mode_ni_square(this_mod_res,m+1,r)
    m = (0+length-1)>>1
    calculate_mode_ni_square(total_product,0,m)
    calculate_mode_ni_square(total_product,m+1,length-1)

    return list(gcd(nums_sage[i],mod_ni_square[i]//nums_sage[i]) for i in range(length))





In [40]:
def fast_pairwise_gcd_helper(nums:list):
    length = len(nums)
    if length<2:
        raise ValueError
    length_padded = 0
    for i in range(32):
        if (1<<i)>=length:
            length_padded = 1<<i
            break
    nums+=[1]*(length_padded-length)
    origin_len,length = length,length_padded

    nums_sage = list(map(lambda x:Integer(x),nums))
    nums_square = list(map(lambda x:x**2,nums_sage))

    total_product = Integer(1)
    for num_s in (tqdm(nums_sage) if TQDM_ON else nums_sage):
        total_product*=num_s

    index_pair_list = list()
    bfs_queue = deque([(0,length-1)])
    while len(bfs_queue)>0:
        i_pair = bfs_queue.popleft()
        index_pair_list.append(i_pair)
        l,r=i_pair
        if l<r:
            m=(l+r)>>1
            bfs_queue.append((l,m))
            bfs_queue.append((m+1,r))
    # print(index_pair_list)
    
    sq_product_table = [0]*len(index_pair_list)
    irange = range(len(index_pair_list)-1,-1,-1)
    if TQDM_ON:
        irange = tqdm(irange)
    for i in irange:
        l,r = index_pair_list[i]
        if l==r:
            sq_product_table[i] = nums_square[l]
        else:
            sq_product_table[i] = sq_product_table[2*i+1]*sq_product_table[2*i+2]

    mod_ni_square_table = [0]*len(index_pair_list)
    mod_ni_square_table[0] = total_product
    irange = range(1,len(index_pair_list))
    if TQDM_ON:
        irange = tqdm(irange)
    for i in irange:
        mod_ni_square_table[i] = mod_ni_square_table[(i-1)>>1]%sq_product_table[i]
    
    mod_ni_square = [0]*length
    for i in range(len(index_pair_list)-1,-1,-1):
        l,r = index_pair_list[i]
        if l<r:
            break
        mod_ni_square[l] = mod_ni_square_table[i]
    
    return list(gcd(nums_sage[i],mod_ni_square[i]//nums_sage[i]) for i in range(length))[:origin_len]


In [33]:
fast_pairwise_gcd_helper([6,9,11,17,13])

100%|██████████| 15/15 [00:00<00:00, 175738.99it/s]
100%|██████████| 14/14 [00:00<00:00, 170698.42it/s]


[3, 3, 1, 1, 1]

In [34]:
l = [17,13,101*2,101*3,61,79]
print(fast_pairwise_gcd_helper_recursive(l))
print(fast_pairwise_gcd_helper(l))

[1, 1, 101, 101, 1, 1]


100%|██████████| 15/15 [00:00<00:00, 173797.13it/s]
100%|██████████| 14/14 [00:00<00:00, 109145.46it/s]

[1, 1, 101, 101, 1, 1]





In [37]:
with open("./moduli.sorted") as fin:
    lines = fin.readlines()
    ns = list(map(lambda line:int(line.strip(),16),lines))

In [43]:
nums = ns

length = len(nums)
if length<2:
    raise ValueError
length_padded = 0
for i in range(32):
    if (1<<i)>=length:
        length_padded = 1<<i
        break
nums+=[1]*(length_padded-length)
origin_len,length = length,length_padded

nums_sage = list(map(lambda x:Integer(x),nums))
nums_square = list(map(lambda x:x**2,nums_sage))

# total_product = Integer(1)
# for num_s in (tqdm(nums_sage) if TQDM_ON else nums_sage):
#     total_product*=num_s

100%|██████████| 131072/131072 [18:37<00:00, 117.28it/s]


In [45]:
# len(str(total_product))

30738564

In [56]:
index_pair_list = list()
bfs_queue = deque([(0,length-1)])
while len(bfs_queue)>0:
    i_pair = bfs_queue.popleft()
    index_pair_list.append(i_pair)
    l,r=i_pair
    if l<r:
        m=(l+r)>>1
        bfs_queue.append((l,m))
        bfs_queue.append((m+1,r))
# print(index_pair_list)
    
sq_product_table = [0]*len(index_pair_list)
irange = range(len(index_pair_list)-1,-1,-1)
if TQDM_ON:
    irange = tqdm(irange)
for i in irange:
    l,r = index_pair_list[i]
    if l==r:
        sq_product_table[i] = nums_square[l]
    else:
        sq_product_table[i] = sq_product_table[2*i+1]*sq_product_table[2*i+2]


100%|██████████| 262143/262143 [00:06<00:00, 40778.43it/s]  


In [57]:
product_table = [0]*len(index_pair_list)
irange = range(len(index_pair_list)-1,-1,-1)
if TQDM_ON:
    irange = tqdm(irange)
for i in irange:
    l,r = index_pair_list[i]
    if l==r:
        product_table[i] = nums_sage[l]
    else:
        product_table[i] = product_table[2*i+1]*product_table[2*i+2]

100%|██████████| 262143/262143 [00:02<00:00, 87525.06it/s]  


In [59]:
total_product=product_table[0]

In [48]:
len(str(sq_product_table[0]))

61477127

In [47]:
mod_ni_square_table = [0]*len(index_pair_list)
mod_ni_square_table[0] = total_product
irange = range(1,len(index_pair_list))
if TQDM_ON:
    irange = tqdm(irange)
for i in irange:
    mod_ni_square_table[i] = mod_ni_square_table[(i-1)>>1]%sq_product_table[i]
    
mod_ni_square = [0]*length
for i in range(len(index_pair_list)-1,-1,-1):
    l,r = index_pair_list[i]
    if l<r:
        break
    mod_ni_square[l] = mod_ni_square_table[i]
    
gcd_helper_res= list(gcd(nums_sage[i],mod_ni_square[i]//nums_sage[i]) for i in range(length))[:origin_len]

100%|██████████| 262142/262142 [00:27<00:00, 9627.67it/s]  


In [50]:
have_gcd_lines_i = list()

for i,r in enumerate(gcd_helper_res):
    if r>1:
        have_gcd_lines_i.append(i)

In [51]:
have_gcd_lines_i

[71679, 81922]

In [52]:
a,b = have_gcd_lines_i

In [63]:
n_a,n_b = int(lines[a],16),int(lines[b],16)

In [66]:
p = int(gcd(n_a,n_b))

In [68]:
q_a = n_a//p
q_b = n_b//p

In [69]:
def read_int(byte_arr:bytes):
    if not isinstance(byte_arr,bytes):
        raise ValueError
    b_arr_list = list(byte_arr)
    if len(b_arr_list)<4:
        raise ValueError
    big_int_byte_len = 0
    for i in range(4):
        big_int_byte_len += b_arr_list[i]<<(8*i)
    if len(b_arr_list)<4+big_int_byte_len:
        raise ValueError
    res = 0
    for byte_val in b_arr_list[4:4+big_int_byte_len]:
        res = (res<<8)+byte_val
    return res,bytes(b_arr_list[4+big_int_byte_len:])

def fast_pow_with_mod(a:int,b:int,moder:int)->int:
    res = 1
    while b>0:
        if b&1:
            res = (res*a)%moder
        a = (a*a)%moder
        b>>=1
    return res

def extended_gcd(a, b):
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = extended_gcd(b % a, a)
        return (g, x - (b // a) * y, y)
    
def extended_euclidian(a,b):
    if not isinstance(a,int):
        raise ValueError
    if not isinstance(b,int):
        raise ValueError
    if a<=0 or b<=0:
        raise ValueError
    num1,num2 = a,b
    ka1,kb1,ka2,kb2=1,0,0,1
    while num2>0:
        floor_div_res = num1//num2
        num3 = num1%num2
        ka3,kb3=ka1-floor_div_res*ka2,kb1-floor_div_res*kb2
        num1,ka1,kb1 = num2,ka2,kb2
        num2,ka2,kb2 = num3,ka3,kb3
    return num1,ka1,kb1

def mod_inverse(a, m):
    g, x, y = extended_euclidian(a, m)
    if g != 1:
        raise ValueError(f"Modular inverse does not exist, a and m have gcd:{g}")
    else:
        return x % m
    
def mod_inverse_of_prime(a,m_p):
    a%=m_p
    if a%m_p==0:
        raise ValueError
    return fast_pow_with_mod(a,m_p-2,m_p)

In [72]:
def build_rsa_key(p,q,e):
    n = p*q
    phi_n = (p-1)*(q-1)
    d = mod_inverse(e,phi_n)
    return RSA.construct((n, e, d))

In [73]:
def bits_to_mpi(s):
    return struct.pack('I',len(s))+s

bits_to_mpi(b"abc")

b'\x03\x00\x00\x00abc'

In [78]:
def read_bytes(b:bytes):
    assert len(b)>=4
    b_list = list(b)
    length_representation = b_list[:4]
    length = 0
    for i in range(4):
        length+=length_representation[i]<<(8*i)
    assert len(b)>=4+length
    return bytes(b_list[4:4+length]),bytes(b_list[4+length:])


In [79]:
encrypt_header = '-----BEGIN PRETTY BAD ENCRYPTED MESSAGE-----\n'
encrypt_footer = '-----END PRETTY BAD ENCRYPTED MESSAGE-----\n'
with open("./hw6.pdf.enc.asc") as fin:
    enc_string_all = fin.read()

assert enc_string_all.startswith(encrypt_header)
assert enc_string_all.endswith(encrypt_footer)

enc_msg_b64_encoded = enc_string_all[len(encrypt_header):-len(encrypt_footer)]
enc_msg_bytes = base64.b64decode(enc_msg_b64_encoded)

aeskey_rsa_encrypted,rest = read_bytes(enc_msg_bytes)
iv = rest[:AES.block_size]
origin_msg_aes_encrypted = rest[AES.block_size:]

In [81]:
len(origin_msg_aes_encrypted)%AES.block_size

0

In [84]:
def dec(p,q,e):
    rsa_key = build_rsa_key(p,q,e)
    rsa_cipher = PKCS1_v1_5.new(rsa_key)
    aes_key = rsa_cipher.decrypt(aeskey_rsa_encrypted,sentinel=None)
    aes_cipher = AES.new(aes_key,AES.MODE_CBC,iv)
    original_msg_bytes = aes_cipher.decrypt(origin_msg_aes_encrypted)
    original_msg_bytes = original_msg_bytes[:-int(original_msg_bytes[-1])]
    with open("./hw6.pdf","wb+") as fout:
        fout.write(original_msg_bytes)

In [85]:
e=65537
dec(p,q_a,e)