In [19]:
from pwn import *
from string import printable, ascii_lowercase, ascii_uppercase, digits
HOST = 'chal-careless-padding.chal.hitconctf.com'
HOST = 'localhost'
PORT = '11111'

DICTIONARY = ascii_lowercase + digits + '_!?' + ascii_uppercase

r = remote(HOST, PORT)
print(r.recvuntil(b's your encrypted key: '))
key = bytes.fromhex(r.recvline().decode().strip())
r.recv()
IV = key[:16]
CT = key[16:]

def get_blocks(data: bytes, bs: int = 16) -> list:
    return [data[i:i+bs] for i in range(0, len(data), bs)]

def xor(a: bytes, b: bytes, repeat = True) -> bytes:
    '''
        to get the key to roll over, the module operator
        is used over the length of the key, 
        so string_pos MOD key_length
        ex: 0%3 = 0, 1%3 = 1, 2%3 = 2, 3%3 = 0, etc..
    '''
    xored = []
    if repeat:
        for char_pos, c in enumerate(a):
            xored.append(c ^ b[char_pos%len(b)])
    else:
        if len(a)<len(b):
            b, a = a, b
        for char_pos in range(len(b)):
            xored.append(a[char_pos] ^ b[char_pos])
        for char_pos in range(len(b), len(a)):
            xored.append(a[char_pos])
    return bytes(xored)


def bit_flip(previous_block: bytes, guess: int, flip_pos: int, padding_byte: int) -> bytes:
    flipped_byte = bytes([previous_block[flip_pos] ^ guess ^ padding_byte])
    return previous_block[:flip_pos] + flipped_byte + previous_block[flip_pos+1:]

def send_msg(r, ciphertext: bytes) -> str:
    payload = ciphertext.hex()
    r.sendline(payload)
    response = r.recvline().decode().strip()
    print(response)
    response = r.recv().decode().strip()
    print(response)
    return response


def get_padding_size(r, ciphertext: bytes):
    ct_blocks = get_blocks(ciphertext)[-2:]
    # ct_blocks[0] = b'R\xd5\x0bx\xcb\xec\x1b\xf4\xb9k\x84\x07\xa4_\xa3\xb9'
    for byte in range(16):
        # print(f"position: {byte}")
        ct_blocks[0] = bit_flip(ct_blocks[0], guess=b'}'[0], flip_pos=byte, padding_byte=1)
        iv_to_send = ct_blocks[0]
        ct_to_send = b''.join(ct_blocks[1:])
        response = send_msg(r, b''.join([iv_to_send, ct_to_send]))
        # print(response)
        if 'put that weirdo in me' in response:
            return 16-byte-1

def get_padding_byte(r, ciphertext: bytes, padding_size):
    for byte in printable:
        byte = ord(byte)
        # print(f'guessing byte: {bytes([byte])}')
        ct_blocks = get_blocks(ciphertext)[-2:]
        ct_blocks[0] = bit_flip(ct_blocks[0], guess=b'}'[0], flip_pos=15-padding_size, padding_byte=15-padding_size-2)
        ct_blocks[0] = bit_flip(ct_blocks[0], guess=b'}'[0], flip_pos=15-padding_size-2, padding_byte=byte)
        ct_to_send = b''.join(ct_blocks[1:])
        iv = ct_blocks[0]
        response = send_msg(r, b''.join([iv, ct_to_send]))
        if 'Bad key' in response:
            return byte

def recover_block(r, ciphertext: bytes, padding_size, padding_byte):
    known = b''
    for position in range(3):
        for byte in printable:
            byte = ord(byte)
            # print(f'guessing byte: {bytes([byte])}')
            ct_blocks = get_blocks(ciphertext)[-2:]
            ct_blocks[0] = bit_flip(ct_blocks[0], guess=b'}'[0], flip_pos=15-padding_size, padding_byte=position)
            ct_blocks[0] = bit_flip(ct_blocks[0], guess=b'}'[0], flip_pos=15-padding_size-2, padding_byte=padding_byte)
            ct_blocks[0] = bit_flip(ct_blocks[0], guess=byte, flip_pos=position, padding_byte=padding_byte)

            ct_to_send = b''.join(ct_blocks[1:])
            iv = ct_blocks[0]
            response = send_msg(r, b''.join([iv, ct_to_send]))
            if 'Bad key' in response:
                known += bytes([byte])
                # print(f'known: {known}')
                # print('_'*50)
                break
            else:
                # print(response)
                pass
    return known
        

# padding_size = get_padding_size(r, CT)
# # padding_size = 10

# print(f'padding size: {padding_size}')

# padding_byte = get_padding_byte(r, CT, padding_size)
# padding_byte = b'd'[0]^1
# print(f'padding byte: {chr(padding_byte)}')

# # get first block
# known_lastblock = b'}"}'
# lastblock = recover_block(r, CT, padding_size, padding_byte)
# print("lastblock =", lastblock + known_lastblock)

# get candidates for block before last 15th byte

def confirm_15th_byte_from_padding_byte(previous_block, current_block, padding_byte, candidates):
    for candidate in candidates:
        for character in printable:
            previous_block_new = bit_flip(previous_block, guess=ord(candidate), flip_pos=14, padding_byte=ord(character))
            previous_block_new = bit_flip(previous_block_new, guess=padding_byte, flip_pos=13, padding_byte=14)
            ct_to_send = current_block
            iv = previous_block_new
            response = send_msg(r, b''.join([iv, ct_to_send]))
            if 'Bad key' in response:
                print(f"found 15th byte: {candidate}, byte: {character}")
                return candidate


def get_block_offending_byte(iv, ciphertexts):
    for byte in range(1,256):
        iv_new = iv
        for i in range(16):
            iv_new = bit_flip(iv_new, guess=byte, flip_pos=i, padding_byte=0)
        ct_to_send = ciphertexts
        response = send_msg(r, b''.join([iv_new, ct_to_send]))
        if 'Bad key' in response:
            print(f"found offending byte: {bytes([byte])}")
            print('_'*50)
            return byte
        
def get_block_15th_byte_candidates(iv, ciphertexts, offending_byte, known_plaintext):
    candidates = []
    for i in range(16):
        iv_new = bit_flip(iv, guess=offending_byte, flip_pos=i, padding_byte=0)
        ct_to_send = ciphertexts
        response = send_msg(r, b''.join([iv_new, ct_to_send]))
        if 'Bad key' in response:
            print(f"found position: {i}")
            print('_'*50)
            last_byte = known_plaintext[i] ^ offending_byte
            print(f"last byte: {bytes([last_byte])}")
            for character in printable:
                    if ord(character) % 16 == i:
                        candidates.append(character)
            return candidates, last_byte

        
def get_block_15th_byte_candidates_old(iv, ciphertexts):
    candidates = []
    for i in range(16):
        for byte in range(256):
            # print(f"trying byte: {bytes([byte])}")
            iv_new = bit_flip(iv, guess=byte, flip_pos=i, padding_byte=0)
            ct_to_send = ciphertexts
            response = send_msg(r, b''.join([iv_new, ct_to_send]))
            if 'Bad key' in response:
                print(f"found position: {i}")
                print('_'*50)
                for character in printable:
                    if ord(character) % 16 == i:
                        candidates.append(character)
                return candidates
            
# known_plaintext = b'{"key": "hitcon{'
known_block = b'p4dd1ng_w0n7_r4w'
known_block = b'this_is_a_really'
ct_blocks = get_blocks(CT)
test_ct = b''.join(ct_blocks)
# offending_byte = get_block_offending_byte(test_ct[16:32], test_ct[32:48])
# candidates, last_byte = get_block_15th_byte_candidates(test_ct[16:32], test_ct[32:48], offending_byte, known_block)
# print(f'candidates: {candidates}')
# print(f'last byte: {bytes([last_byte])}')
# print(f"offending_byte: {bytes([offending_byte])}")

# candidates: ['d', 't', '4']
# last byte: b'w'
# offending_byte: b'\x0e'
# w
# ['d', 't', '4']
# r , s
# _ 
# 6, 7
# n, o
# 1, 0
# v, w
# f, g
# n, o
# 0, 1
# d, e
# d, e
# 4, 5

def get_chars_for_pos(pos: int):
    candidates = []
    for character in printable:
        if ord(character)%16 == pos:
            candidates.append(character)
    return candidates

def get_character(iv, ciphertexts, known_plaintext, pos):
    iv_new_orig = iv
    candidates = []
    iv_new_orig = bit_flip(iv_new_orig, guess=ord(known_plaintext[-1]), flip_pos=15, padding_byte=b'{'[0])
    chars_for_pos = get_chars_for_pos(pos)
    for character in printable:
        print(f"trying character: {character}")
        iv_new = bit_flip(iv_new_orig, guess=ord(known_plaintext[-2]), flip_pos=14, padding_byte=ord(chars_for_pos[0]))
        iv_new = bit_flip(iv_new, guess=ord(character), flip_pos=pos, padding_byte=b'{'[0])
        ct_to_send = ciphertexts
        iv = iv_new
        response = send_msg(r, b''.join([iv, ct_to_send]))
        if 'Bad key' in response:
            candidates.append(character)
            if len(candidates) == 2:
                return candidates
    return candidates
# candidates: ['n', 'N', '.', '>', '^', '~']
# 0, 1
# known_block = 'p4dd1ng_w0n7_st0'
known_block = 'ly'
print(f"known_block: {known_block}")
block = 0
if len(known_block) != 16:
    candidates = get_character(test_ct[16*(block):16*(block+1)], test_ct[16*(block+1):16*(block+2)], known_block, 15-len(known_block))
    print(f'candidates: {candidates}')
else:
    print('done')

[x] Opening connection to localhost on port 11111
[x] Opening connection to localhost on port 11111: Trying ::1
[+] Opening connection to localhost on port 11111: Done


b'b\'{"key": "hitcon{this_is_a_really_long_flag_12that_is_very_hard_to_guess_for_testing}"}\'\nb\'{"key": "hitcon{this_is_a_really_long_flag_12that_is_very_hard_to_guess_for_testing}"}eeeeeeeeee\'\n\n*********************************************************\nYou are put into the careless prison and trying to escape.\nThanksfully, someone forged a key for you, but seems like it\'s encrypted... \nFortunately they also leave you a copied (and apparently alive) prison door.\nThe replica pairs with this encrypted key. Wait, how are this suppose to help?\nAnyway, here\'s your encrypted key: '
known_block: ly
trying character: 0
[b"this_is_a_rea'm{"]
message_len: 15, X: m, Y: {
_Y location: -3
_Y: b"'"
_Y^1: b'&'
Y: b'{'
Don't put that weirdo in me!
Try unlock:
trying character: 1
[b'this_is_a_rea&m{']
message_len: 15, X: m, Y: {
_Y location: -3
_Y: b'&'
_Y^1: b"'"
Y: b'{'
Don't put that weirdo in me!
Try unlock:
trying character: 2
[b'this_is_a_rea%m{']
message_len: 15, X: m, Y: {
_Y locatio

  r.sendline(payload)


In [None]:
get_chars_for_pos(13)

['m', 'M', '-', '=', ']', '}', '\r']

In [None]:
#!/usr/local/bin/python
import random
import os
from secret import flag
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
import json

N = 16

# 0 -> 0, 1~N -> 1, (N+1)~(2N) -> 2 ...
def count_blocks(length):
    block_count = (length-1) // N + 1
    return block_count

def find_repeat_tail(message):
    Y = message[-1]
    message_len = len(message)
    for i in range(len(message)-1, -1, -1):
        if message[i] != Y:
            X = message[i]
            message_len = i + 1
            break
    return message_len, X, Y

def my_padding(message):
    print(message)
    message_len = len(message)
    block_count = count_blocks(message_len)
    result_len =  block_count * N
    if message_len % N == 0:
        result_len += N
    X = message[-1]
    Y = message[(block_count-2)*N+(X%N)]
    print(f'X: {bytes([X])}, Y: {bytes([Y])}')
    if X==Y:
        Y = Y^1
    print(f'X: {bytes([X])}, Y: {bytes([Y])}')
    padded = message.ljust(result_len, bytes([Y]))
    return padded

def my_unpad(message):
    message_len, X, Y = find_repeat_tail(message)
    block_count = count_blocks(message_len)
    _Y = message[(block_count-2)*N+(X%N)]
    print(f"pos: {(block_count-2)*N+(X%N)}, Y_ = {bytes([_Y])}, Y = {bytes([Y])}")
    if (Y != _Y and Y != _Y^1):
        raise ValueError("Incorrect Padding")
    return message[:message_len]
def get_blocks(data: bytes, bs: int = 16) -> list:
    return [data[i:i+bs] for i in range(0, len(data), bs)]

# message = b'a'*16*4 +b'_bcdefghijklm3zzzzzzzzzz' + b'}"}'
# X = message[-1]
# padded = my_padding(message)
# message_len = len(message)
# block_count = count_blocks(message_len)
# print(f'message_len: {message_len}, block_count: {block_count}')
# print((block_count-2)*16+(X%16))
# chr(message[(block_count-2)*16+(X%16)])


In [None]:
import json
def bit_flip(previous_block: bytes, guess: int, flip_pos: int, padding_byte: int) -> bytes:
    flipped_byte = bytes([previous_block[flip_pos] ^ guess ^ padding_byte])
    print(f"previous_block[flip_pos]: {bytes([previous_block[flip_pos]])}")
    return previous_block[:flip_pos] + flipped_byte + previous_block[flip_pos+1:]

k = b'YELLOW SUBMARINE'
iv = b'\x01\x02\x03\x04' * 4
# iv = os.urandom(16)
cipher = AES.new(k, AES.MODE_CBC, iv) 
message = 'hitcon{' + 'a426'*4*3 + '_-@0123486759Ab2DeF}'
message = json.dumps({'key':message}).encode()
padded = my_padding(message)
ciphertext = cipher.encrypt(padded)
print(get_blocks(padded))

# ct_blocks[0] = b'R\xd5\x0bx\xcb\xec\x1b\xf4\xb9k\x84\x07\xa4_\xa3\xb9'
def detect_padding_byte():
    for byte in range(1,256):
        print(f'guessing byte: {bytes([byte])}')
        ct_blocks = get_blocks(ciphertext)[-2:]
        ct_blocks[0] = bit_flip(ct_blocks[0], guess=b'}'[0], flip_pos=15-6, padding_byte=byte)
        # ct_blocks[0] = bytes([96]) + ct_blocks[0][1:]
        # ct_blocks[0] = bit_flip(ct_blocks[0], guess=0, flip_pos=0, padding_byte=1)
        ct_to_send = b''.join(ct_blocks[1:])
        iv = ct_blocks[0]
        cipher = AES.new(k, AES.MODE_CBC, iv)
        decrypted = cipher.decrypt(ct_to_send)
        print(f'decrypted: {get_blocks(decrypted)}')
        try:
            unpadded = my_unpad(decrypted)
            print(f'unpadded: {get_blocks(unpadded)}')
            print('_'*50)
        except:
            print('incorrect padding')
            print('_'*50)
            

# detect_padding_byte()
PADDING_SIZE = 10
PADDING_BYTE = b'A'[0]
# try to get padding byte

# known = b''
# print(f'guessing byte: {bytes([byte])}')
ct_blocks = get_blocks(ciphertext)[-3:-1]
# ct_blocks[0] = bit_flip(ct_blocks[0], guess=b'}'[0], flip_pos=15-PADDING_SIZE, padding_byte=position)
# ct_blocks[0] = bit_flip(ct_blocks[0], guess=b'}'[0], flip_pos=15-PADDING_SIZE-2, padding_byte=PADDING_BYTE)
# ct_blocks[0] = bit_flip(ct_blocks[0], guess=b'e'[0], flip_pos=0, padding_byte=PADDING_BYTE)
# ct_blocks[0] = bit_flip(ct_blocks[0], guess=b'{'[0], flip_pos=5, padding_byte=b'\x00'[0])
# ct_blocks[0] = bit_flip(ct_blocks[0], guess=b'e'[0], flip_pos=-15, padding_byte=b'~'[0])
# ct_blocks[0] = bit_flip(ct_blocks[0], guess=13, flip_pos=14, padding_byte=b'b'[0])
# ct_blocks[0] = bit_flip(ct_blocks[0], guess=1, flip_pos=14, padding_byte=b'b'[0])
# ct_blocks[0] = bit_flip(ct_blocks[0], guess=b'_'[0], flip_pos=2, padding_byte=0)
# iv = bit_flip(iv, guess=b'_'[0], flip_pos=15, padding_byte=b'`'[0])
# iv = bit_flip(iv, guess=b'"'[0], flip_pos=1, padding_byte=b'_'[0])

# ct_blocks[-2] = b'0234h678901aac\x2c!'
print(f'ct_blocks: {ct_blocks}')
ct_to_send = b''.join(ct_blocks[1:])
iv = ct_blocks[0]
cipher = AES.new(k, AES.MODE_CBC, iv)
decrypted = cipher.decrypt(ct_to_send)
print(f'decrypted: {get_blocks(decrypted)}')
unpadded = my_unpad(decrypted)
print(f'unpadded: {get_blocks(unpadded)}')
print('_'*50)

b'{"key": "hitcon{a426a426a426a426a426a426a426a426a426a426a426a426_-@0123486759Ab2DeF}"}'
X: b'}', Y: b'A'
X: b'}', Y: b'A'
[b'{"key": "hitcon{', b'a426a426a426a426', b'a426a426a426a426', b'a426a426a426a426', b'_-@0123486759Ab2', b'DeF}"}AAAAAAAAAA']
ct_blocks: [b'\xef\x17\xbf;\xce\xcd!go\xf0r\x90\x08\xb5\x81+', b'\xa49\xf1\xb2\xd9IJ\xe4u\xbb}\xc3B;\x0e\xc2']
decrypted: [b'_-@0123486759Ab2']
pos: -14, Y_ = b'@', Y = b'2'


ValueError: Incorrect Padding

In [None]:
import json
#!/usr/local/bin/python
import random
import os
from secret import flag
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
import json

N = 16

# 0 -> 0, 1~N -> 1, (N+1)~(2N) -> 2 ...
def count_blocks(length):
    block_count = (length-1) // N + 1
    return block_count

def find_repeat_tail(message):
    Y = message[-1]
    message_len = len(message)
    for i in range(len(message)-1, -1, -1):
        if message[i] != Y:
            X = message[i]
            message_len = i + 1
            break
    return message_len, X, Y

def my_padding(message):
    print(message)
    message_len = len(message)
    block_count = count_blocks(message_len)
    result_len =  block_count * N
    if message_len % N == 0:
        result_len += N
    X = message[-1]
    Y = message[(block_count-2)*N+(X%N)]
    print(f'X: {bytes([X])}, Y: {bytes([Y])}')
    if X==Y:
        Y = Y^1
    print(f'X: {bytes([X])}, Y: {bytes([Y])}')
    padded = message.ljust(result_len, bytes([Y]))
    return padded

def my_unpad(message):
    message_len, X, Y = find_repeat_tail(message)
    block_count = count_blocks(message_len)
    _Y = message[(block_count-2)*N+(X%N)]
    print(f"pos: {(block_count-2)*N+(X%N)}, Y_ = {bytes([_Y])}, Y = {bytes([Y])}")
    if (Y != _Y and Y != _Y^1):
        raise ValueError("Incorrect Padding")
    return message[:message_len]

def get_blocks(data: bytes, bs: int = 16) -> list:
    return [data[i:i+bs] for i in range(0, len(data), bs)]

def bit_flip(previous_block: bytes, guess: int, flip_pos: int, padding_byte: int) -> bytes:
    flipped_byte = bytes([previous_block[flip_pos] ^ guess ^ padding_byte])
    # print(f"previous_block[flip_pos]: {bytes([previous_block[flip_pos]])}")
    return previous_block[:flip_pos] + flipped_byte + previous_block[flip_pos+1:]

k = b'YELLOW SUBMARINE'
iv = b'\x01\x02\x03\x04' * 4
# iv = os.urandom(16)
cipher = AES.new(k, AES.MODE_CBC, iv) 
message = 'hitcon{' + 'a426'*4*3 + '_-@0123486759Ab2DeF}'
message = json.dumps({'key':message}).encode()
padded = my_padding(message)
ciphertext = cipher.encrypt(padded)
print(get_blocks(padded))

# ct_blocks[0] = b'R\xd5\x0bx\xcb\xec\x1b\xf4\xb9k\x84\x07\xa4_\xa3\xb9'
def detect_padding_byte():
    for byte in range(1,256):
        print(f'guessing byte: {bytes([byte])}')
        ct_blocks = get_blocks(ciphertext)[-2:]
        ct_blocks[0] = bit_flip(ct_blocks[0], guess=b'}'[0], flip_pos=15-6, padding_byte=byte)
        # ct_blocks[0] = bytes([96]) + ct_blocks[0][1:]
        # ct_blocks[0] = bit_flip(ct_blocks[0], guess=0, flip_pos=0, padding_byte=1)
        ct_to_send = b''.join(ct_blocks[1:])
        iv = ct_blocks[0]
        cipher = AES.new(k, AES.MODE_CBC, iv)
        decrypted = cipher.decrypt(ct_to_send)
        print(f'decrypted: {get_blocks(decrypted)}')
        try:
            unpadded = my_unpad(decrypted)
            print(f'unpadded: {get_blocks(unpadded)}')
            print('_'*50)
        except:
            print('incorrect padding')
            print('_'*50)
            

# detect_padding_byte()
PADDING_SIZE = 10
PADDING_BYTE = b'A'[0]
# try to get padding byte

# known = b''
# print(f'guessing byte: {bytes([byte])}')
possible_chars = b'2brBR'
for i in possible_chars:
    for character in DICTIONARY:
        try:
            print(f"trying candidate: {bytes([i])}, character: {character}")
            ct_blocks = get_blocks(ciphertext)[-3:-1]
            ct_blocks[0] = bit_flip(ct_blocks[0], guess=i, flip_pos=14, padding_byte=ord(character))
            ct_blocks[0] = bit_flip(ct_blocks[0], guess=b'A'[0], flip_pos=13, padding_byte=14)

            # ct_blocks[-2] = b'0234h678901aac\x2c!'
            # print(f'ct_blocks: {ct_blocks}')
            ct_to_send = b''.join(ct_blocks[1:])
            iv = ct_blocks[0]
            cipher = AES.new(k, AES.MODE_CBC, iv)
            decrypted = cipher.decrypt(ct_to_send)
            unpadded = my_unpad(decrypted)
            print(f'decrypted: {get_blocks(decrypted)}')
            print(f'unpadded: {get_blocks(unpadded)}')

            ct_blocks = get_blocks(ciphertext)[-3:-1]
            ct_blocks[0] = bit_flip(ct_blocks[0], guess=i, flip_pos=14, padding_byte=ord(character))
            ct_blocks[0] = bit_flip(ct_blocks[0], guess=b'A'[0], flip_pos=13, padding_byte=15)
            ct_to_send = b''.join(ct_blocks[1:])
            iv = ct_blocks[0]
            cipher = AES.new(k, AES.MODE_CBC, iv)
            decrypted = cipher.decrypt(ct_to_send)
            unpadded = my_unpad(decrypted)

            print(f'decrypted: {get_blocks(decrypted)}')
            print(f'unpadded: {get_blocks(unpadded)}')
            print('_'*50)
        except:
            pass

b'{"key": "hitcon{a426a426a426a426a426a426a426a426a426a426a426a426_-@0123486759Ab2DeF}"}'
X: b'}', Y: b'A'
X: b'}', Y: b'A'
[b'{"key": "hitcon{', b'a426a426a426a426', b'a426a426a426a426', b'a426a426a426a426', b'_-@0123486759Ab2', b'DeF}"}AAAAAAAAAA']
trying candidate: b'2', character: a
pos: -15, Y_ = b'-', Y = b'2'
trying candidate: b'2', character: b
pos: -2, Y_ = b'2', Y = b'2'
decrypted: [b'_-@0123486759\x0e22']
unpadded: [b'_-@0123486759\x0e']
pos: -1, Y_ = b'2', Y = b'2'
decrypted: [b'_-@0123486759\x0f22']
unpadded: [b'_-@0123486759\x0f']
__________________________________________________
trying candidate: b'2', character: c
pos: -13, Y_ = b'0', Y = b'2'
trying candidate: b'2', character: d
pos: -12, Y_ = b'1', Y = b'2'
trying candidate: b'2', character: e
pos: -11, Y_ = b'2', Y = b'2'
decrypted: [b'_-@0123486759\x0e52']
unpadded: [b'_-@0123486759\x0e5']
pos: -11, Y_ = b'2', Y = b'2'
decrypted: [b'_-@0123486759\x0f52']
unpadded: [b'_-@0123486759\x0f5']
___________________________

In [None]:
import json
#!/usr/local/bin/python
import random
import os
from secret import flag
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
import json

N = 16

# 0 -> 0, 1~N -> 1, (N+1)~(2N) -> 2 ...
def count_blocks(length):
    block_count = (length-1) // N + 1
    return block_count

def find_repeat_tail(message):
    Y = message[-1]
    message_len = len(message)
    for i in range(len(message)-1, -1, -1):
        if message[i] != Y:
            X = message[i]
            message_len = i + 1
            break
    return message_len, X, Y

def my_padding(message):
    print(message)
    message_len = len(message)
    block_count = count_blocks(message_len)
    result_len =  block_count * N
    if message_len % N == 0:
        result_len += N
    X = message[-1]
    Y = message[(block_count-2)*N+(X%N)]
    print(f'X: {bytes([X])}, Y: {bytes([Y])}')
    if X==Y:
        Y = Y^1
    print(f'X: {bytes([X])}, Y: {bytes([Y])}')
    padded = message.ljust(result_len, bytes([Y]))
    return padded

def my_unpad(message):
    message_len, X, Y = find_repeat_tail(message)
    block_count = count_blocks(message_len)
    _Y = message[(block_count-2)*N+(X%N)]
    print(f"pos: {(block_count-2)*N+(X%N)}, Y_ = {bytes([_Y])}, Y = {bytes([Y])}")
    if (Y != _Y and Y != _Y^1):
        raise ValueError("Incorrect Padding")
    return message[:message_len]

def get_blocks(data: bytes, bs: int = 16) -> list:
    return [data[i:i+bs] for i in range(0, len(data), bs)]

def bit_flip(previous_block: bytes, guess: int, flip_pos: int, padding_byte: int) -> bytes:
    flipped_byte = bytes([previous_block[flip_pos] ^ guess ^ padding_byte])
    # print(f"previous_block[flip_pos]: {bytes([previous_block[flip_pos]])}")
    return previous_block[:flip_pos] + flipped_byte + previous_block[flip_pos+1:]

k = b'YELLOW SUBMARINE'
iv = b'\x01\x02\x03\x04' * 4
# iv = os.urandom(16)
cipher = AES.new(k, AES.MODE_CBC, iv) 
message = 'hitcon{this_is_a_really_long_flag_12that_is_very_hard_to_guess_for_testing}'
message = json.dumps({'key':message}).encode()
padded = my_padding(message)
ciphertext = cipher.encrypt(padded)
print(get_blocks(padded))

# try to get padding byte
def get_block_15th_byte_candidates(iv, ciphertexts, known_plaintext):
    candidates = []
    for i in range(16):
        for byte in range(256):
            # print(f"trying byte: {bytes([byte])}")
            iv_new = bit_flip(iv, guess=byte, flip_pos=i, padding_byte=0)
            iv = iv_new
            ct_to_send = ciphertexts
            try:
                cipher = AES.new(k, AES.MODE_CBC, iv)
                decrypted = cipher.decrypt(ct_to_send)
                print(f'decrypted: {get_blocks(decrypted)}')
                unpadded = my_unpad(decrypted)
                print(f'unpadded: {get_blocks(unpadded)}')
                print('_'*50)
                for character in DICTIONARY:
                    if ord(character) % 16 == i:
                        candidates.append(character)
                last_byte = known_plaintext[i] ^ byte
                return candidates, i, last_byte
            except:
                pass
known_plaintext = b'{"key": "hitcon{'
ct_blocks = get_blocks(ciphertext)
test_ct = b''.join(ct_blocks[:2])
candidates, pos, last_byte = get_block_15th_byte_candidates(iv, test_ct, known_plaintext)
print(f'candidates: {candidates}, pos: {pos}, last_byte: {chr(last_byte)}')

def get_chars_for_pos(pos: int):
    candidates = []
    for character in DICTIONARY:
        if ord(character)%16 == pos:
            candidates.append(character)
    return candidates

print('#'*50)
def get_character(iv, ciphertexts, known_plaintext, pos):
    length_known = len(known_plaintext)
    iv_new_orig = iv
    candidates = []
    iv_new_orig = bit_flip(iv_new_orig, guess=ord(known_plaintext[-1]), flip_pos=15, padding_byte=b'{'[0])
    # for i in range(15, 15-length_known, -1):
    #     print(known_plaintext[length_known-(15-i)-1])
    #     iv_new_orig = bit_flip(iv_new_orig, guess=ord(known_plaintext[length_known-(15-i)-1]), flip_pos=i, padding_byte=b'{'[0])
    chars_for_pos = get_chars_for_pos(pos)
    for character in DICTIONARY:
        print(f"trying character: {character}")
        iv_new = bit_flip(iv_new_orig, guess=ord(known_plaintext[-2]), flip_pos=14, padding_byte=ord(chars_for_pos[0]))
        iv_new = bit_flip(iv_new, guess=ord(character), flip_pos=pos, padding_byte=b'{'[0])
        ct_to_send = ciphertexts
        try:
            cipher = AES.new(k, AES.MODE_CBC, iv_new)
            decrypted = cipher.decrypt(ct_to_send)
            print(f'decrypted: {get_blocks(decrypted)}')
            unpadded = my_unpad(decrypted)
            print(f'unpadded: {get_blocks(unpadded)}')
            print('_'*50)
            candidates.append(character)
        except:
            pass
    return candidates
            
candidates = get_character(test_ct[:16], test_ct[16:], 'lly', 15-len('lly'))
print(f'candidates: {candidates}')
# ['a', 'c', 'p', 'q', 's', '0', '1', '3']
# candidates: ['a', 'b', 'c', 'f', 'g', 'p', 'q', 'r', 's', 'v', '0', '1', '2', '3', '6']

b'{"key": "hitcon{this_is_a_really_long_flag_12that_is_very_hard_to_guess_for_testing}"}'
X: b'}', Y: b'e'
X: b'}', Y: b'e'
[b'{"key": "hitcon{', b'this_is_a_really', b'_long_flag_12tha', b't_is_very_hard_t', b'o_guess_for_test', b'ing}"}eeeeeeeeee']
decrypted: [b'{"key": "hitcon{', b'this_is_a_really']
pos: 12, Y_ = b'c', Y = b'y'
decrypted: [b'z"key": "hitcon{', b'this_is_a_really']
pos: 12, Y_ = b'c', Y = b'y'
decrypted: [b'x"key": "hitcon{', b'this_is_a_really']
pos: 12, Y_ = b'c', Y = b'y'
decrypted: [b'{"key": "hitcon{', b'this_is_a_really']
pos: 12, Y_ = b'c', Y = b'y'
decrypted: [b'\x7f"key": "hitcon{', b'this_is_a_really']
pos: 12, Y_ = b'c', Y = b'y'
decrypted: [b'z"key": "hitcon{', b'this_is_a_really']
pos: 12, Y_ = b'c', Y = b'y'
decrypted: [b'|"key": "hitcon{', b'this_is_a_really']
pos: 12, Y_ = b'c', Y = b'y'
decrypted: [b'{"key": "hitcon{', b'this_is_a_really']
pos: 12, Y_ = b'c', Y = b'y'
decrypted: [b's"key": "hitcon{', b'this_is_a_really']
pos: 12, Y_ = b'c', Y = b'y'

In [None]:
def get_chars_for_pos(pos: int):
    candidates = []
    for character in DICTIONARY:
        if ord(character)%16 == pos:
            candidates.append(character)
    return candidates

get_chars_for_pos(13)

['m']

In [None]:
# we know X % 16 = 2
from string import printable

for character in printable:
    if ord(character)%16 == 2:
        print(character, end='')

2brBR"

In [None]:
chr(b'{'[0] ^ 1)

'z'