In [1]:
from random import choice, choices
import base64
from Crypto.Cipher import AES

BLOCKSIZE = 16

def XOR(A: bytes, B: bytes):
    assert type(A) is bytes and type(B) is bytes and len(A) == len(B)
    return bytes([ a^b for a, b in zip(A, B) ])

def IntToLittleEndian(x: int, nbytes: int):
    assert type(x) is int and x >= 0
    out = bytes()
    while x >  0:
        out += bytes([ x & 0xFF ])
        x >>= 8
    
    assert len(out) <= nbytes
    out += bytes( [0] * ( nbytes-len(out) ) )
    
    return out

def LittleEndianToInt(x: bytes):
    assert type(x) is bytes
    out = 0
    byte_counter = 0
    while len(x) > 0:
        out += x[0] << byte_counter*8
        byte_counter += 1
        x = x[1:]
        
    return out

class CTR(object):
    def __init__(self, nonce: bytes, key: bytes):
        assert type(nonce) is bytes and len(nonce) < BLOCKSIZE
        self.nonce = nonce
        self.aes = AES.new(key, AES.MODE_ECB)
        
    def EncryptBlock(self, block: bytes, counter: int):
        assert type(counter) is int and counter >= 0
        assert type(block) is bytes and len(block) == BLOCKSIZE
        little_endian: bytes = IntToLittleEndian(counter, nbytes = BLOCKSIZE-len(self.nonce))
        key_block = self.aes.encrypt( self.nonce + little_endian )
        assert type(key_block) is bytes and len(key_block) == BLOCKSIZE
        return XOR(key_block, block)
    
    def DecryptBlock(self, block, counter: int):
        return self.EncryptBlock(block, counter)
    
    def EncryptBytes(self, plaintext: bytes, counter: int):
        assert type(plaintext) is bytes and len(plaintext) <= BLOCKSIZE
        assert type(counter) is int and counter >= 0
        little_endian: bytes = IntToLittleEndian(counter, nbytes = BLOCKSIZE-len(self.nonce))
        key_block = self.aes.encrypt( self.nonce + little_endian )
        assert type(key_block) is bytes and len(key_block) == BLOCKSIZE
        return XOR(key_block[:len(plaintext)], plaintext)
    
    def DecryptBytes(self, ciphertext: bytes, counter: int):
        return self.EncryptBytes(ciphertext, counter)
    
    def EncryptStream(self, plaintext: bytes, counter: int = 0):
        assert type(plaintext) is bytes
        stream = bytes(plaintext)
        assert len(stream) % BLOCKSIZE == 0
        ciphertexts = list()
        
        while len(stream) > 0:
            block =  stream[:BLOCKSIZE]
            stream = stream[BLOCKSIZE:]
            assert len(block) == BLOCKSIZE
            ciphertexts += [ self.EncryptBytes(block, counter) ]
            counter += 1
        
        ciphertext = b''.join(ciphertexts)
        assert len(ciphertext) == len(plaintext)
        return ciphertext
    
    def DecryptStream(self, plaintext: bytes, counter: int = 0):
        return self.EncryptStream(plaintext, counter)
    

In [22]:
ORACLE_KEY = b'\xe7\xef\x1e\x7f\xd7\x87\xa4\xeb\x10<\xd9\x9f\x8b\xec\x03\x8f'

def Oracle(s: bytes):
    key = ORACLE_KEY
    assert type(s) is bytes
    assert set(b';=') & set(s) == set() # ensure neither the ; or = character is input
    prefix = b"comment1=cooking%20MCs;userdata="
    suffix = b";comment2=%20like%20a%20pound%20of%20bacon"
    plaintext = prefix + s + suffix
    pad = BLOCKSIZE - len(plaintext) % BLOCKSIZE
    plaintext += bytes([pad] * pad)
    assert len(plaintext) % BLOCKSIZE == 0
    return CTR(nonce = b'A'*8, key = key).EncryptStream(plaintext)

def IsAdmin(ciphertext: bytes):
    key = ORACLE_KEY
    plaintext = CTR(nonce = b'A'*8, key = key).DecryptStream(ciphertext)
    return b';admin=true;' in plaintext
    

In [23]:
# Ok, for this exercise I'm going to assume that we know the length of the prefix and suffix.
# In a previous exercise we solved for this.  I don't feel like doing that again,
# so I'm going to pretend I did it again and now have that information

len_prefix = len(b"comment1=cooking%20MCs;userdata=")
len_suffix = len(b";comment2=%20like%20a%20pound%20of%20bacon")

In [24]:
my_prefix = bytes( [0] * (BLOCKSIZE - len_prefix%BLOCKSIZE) ) # pad this out so any data I put in starts a new block
my_block_index = ( len_prefix+len(my_prefix) ) // BLOCKSIZE # the index my data block will be

# Put in a block of all zeros, so whatever ciphertext comes out in that position will be the bitmask
ciphertext = Oracle(my_prefix + bytes([0]*BLOCKSIZE))
start, end = my_block_index*BLOCKSIZE, (1+my_block_index)*BLOCKSIZE
bitmask = ciphertext[start:end]
print(bitmask)

b'\x8b\xc1\x9f^&\xb7_\t\x0b&\xe9\x0cF\xab\x0bv'


In [25]:
# So now I know the bitmask, I can modify the ciphertext
block = b'A;admin=true;AAA'
assert len(block) == BLOCKSIZE
ciphertext = ciphertext[:start] + XOR(bitmask, block) + ciphertext[end:]

In [26]:
IsAdmin(ciphertext)

True