In [None]:
from base64 import b64decode
from Crypto.Cipher import AES
from Crypto import Random
from S1C08 import count_aes_ecb_repetitions
from S2C10 import aes_ecb_encrypt
from S2C09 import pkcs7_unpad

class ECBOracle:
    """This oracle uses always the same key (generated during the initialization)."""

    def __init__(self, secret_padding):
        self._key = Random.new().read(AES.key_size[0])
        self._secret_padding = secret_padding

    def encrypt(self, data):
        """Encrypts with AES-128-ECB mode, after appending a fixed (given) string to every plaintext"""
        return aes_ecb_encrypt(data + self._secret_padding, self._key)


def find_block_length(encryption_oracle):
    """Returns the length of a block for the block cipher used by the encryption_oracle.
    To find the length of a block, we encrypt increasingly longer plaintexts until the size of the
    output ciphertext increases too. When this happens, we can then easily compute the length of a
    block as the difference between this new length of the ciphertext and its initial one.
    """
    my_text = b''
    ciphertext = encryption_oracle.encrypt(my_text)
    initial_len = len(ciphertext)
    new_len = initial_len

    while new_len == initial_len:
        my_text += b'A'
        ciphertext = encryption_oracle.encrypt(my_text)
        new_len = len(ciphertext)

    return new_len - initial_len


def get_next_byte(block_length, curr_decrypted_message, encryption_oracle):
    """Finds the next byte of the mysterious message that the oracle is appending to our plaintext."""
    length_to_use = (block_length - (1 + len(curr_decrypted_message))) % block_length
    prefix = b'A' * length_to_use
    cracking_length = length_to_use + len(curr_decrypted_message) + 1
    real_ciphertext = encryption_oracle.encrypt(prefix)
    for i in range(256):
        fake_ciphertext = encryption_oracle.encrypt(prefix + curr_decrypted_message + bytes([i]))
        if fake_ciphertext[:cracking_length] == real_ciphertext[:cracking_length]:
            return bytes([i])
    return b''


def byte_at_a_time_ecb_decryption_simple(encryption_oracle):
    """Performs the byte-at-a-time ECB decryption attack to discover the secret padding used by the oracle."""
    block_length = find_block_length(encryption_oracle)
    ciphertext = encryption_oracle.encrypt(bytes([0] * 64))
    assert count_aes_ecb_repetitions(ciphertext) > 0
    mysterious_text_length = len(encryption_oracle.encrypt(b''))
    secret_padding = b''
    for i in range(mysterious_text_length):
        secret_padding += get_next_byte(block_length, secret_padding, encryption_oracle)
    return secret_padding


def main():
    secret_padding = b64decode("Um9sbGluJyBpbiBteSA1LjAKV2l0aCBteSByYWctdG9wIGRvd24gc28gbXkgaGF"
                               "pciBjYW4gYmxvdwpUaGUgZ2lybGllcyBvbiBzdGFuZGJ5IHdhdmluZyBqdXN0IH"
                               "RvIHNheSBoaQpEaWQgeW91IHN0b3A/IE5vLCBJIGp1c3QgZHJvdmUgYnkK")
    oracle = ECBOracle(secret_padding)
    discovered_secret_padding = byte_at_a_time_ecb_decryption_simple(oracle)

    # Check if the attack works correctly
    assert pkcs7_unpad(discovered_secret_padding) == secret_padding


if __name__ == '__main__':
    main()