In [1]:
import numpy as np


In [2]:
def read_hex(filename):
    hex_list = [int(x,base=16) for x in open(filename,mode="r")]
    data = np.asarray(hex_list, dtype=np.uint32)
    return data


In [292]:
class AES:
    sbox = read_hex("sbox.dat")
    inv_sbox = read_hex("inv_sbox.dat")
    
    shift = np.asarray([0,5,10,15,4,9,14,3,8,13,2,7,12,1,6,11], dtype=np.uint32)
    inv_shift = np.asarray([0,13,10,7,4,1,14,11,8,5,2,15,12,9,6,3], dtype=np.uint32)
    
    mix_column_matrix = np.asarray([[2,3,1,1],
                                    [1,2,3,1],
                                    [1,1,2,3],
                                    [3,1,1,2]], dtype=np.uint32)
    
    inv_mix_column_matrix = np.asarray([[0xe,0xb,0xd,0x9],
                                        [0x9,0xe,0xb,0xd],
                                        [0xd,0x9,0xe,0xb],
                                        [0xb,0xd,0x9,0xe]], dtype=np.uint32)
    
    rounds = {16:10, 24:12, 32:14}
    
    def g(word,i):
        rc = [1,2,4,8,16,32,64,128,27,54]
        shift = np.asarray([1,2,3,0], dtype=np.uint32)
        shifted = word[shift]
        substituted = AES.sbox_substitution(shifted)
        substituted[0] = np.bitwise_xor(substituted[0],rc[i-1])
        return substituted
        
    def key_transform4g(key,i,g_in=None):
        if g_in is None:
            g_in = key[12:16]
        new_key = np.empty(key.shape, dtype=np.uint32)
        new_key[0:4] = np.bitwise_xor(key[0:4],AES.g(g_in,i))
        new_key[4:8] = np.bitwise_xor(key[4:8],new_key[0:4])
        new_key[8:12] = np.bitwise_xor(key[8:12],new_key[4:8])
        new_key[12:16] = np.bitwise_xor(key[12:16],new_key[8:12])
        return new_key
    
    def key_transform6g(key,i):
        new_key = np.empty(key.shape, dtype=np.uint32)
        new_key[0:4] = np.bitwise_xor(key[0:4],AES.g(key[20:24],i))
        new_key[4:8] = np.bitwise_xor(key[4:8],new_key[0:4])
        new_key[8:12] = np.bitwise_xor(key[8:12],new_key[4:8])
        new_key[12:16] = np.bitwise_xor(key[12:16],new_key[8:12])
        new_key[16:20] = np.bitwise_xor(key[16:20],new_key[12:16])
        new_key[20:24] = np.bitwise_xor(key[20:24],new_key[16:20])
        return new_key
        
    def key_transform8hg(key,i):
        new_key = np.empty(key.shape, dtype=np.uint32)
        new_key[0:16] = AES.key_transform4g(key[0:16],i,g_in=key[28:32])
        new_key[16:20] = np.bitwise_xor(key[16:20],AES.sbox_substitution(new_key[12:16]))
        new_key[20:24] = np.bitwise_xor(key[20:24],new_key[16:20])
        new_key[24:28] = np.bitwise_xor(key[24:28],new_key[20:24])
        new_key[28:32] = np.bitwise_xor(key[28:32],new_key[24:28])
        return new_key

    def generate_key_schedule128(key):
        schedule = [key]
        for i in range(10):
            schedule.append(AES.key_transform4g(schedule[i],i+1))
        return schedule
    
    def generate_key_schedule192(key):
        words = np.empty((4*52), dtype=np.uint32)
        words[0:6*4] = key
        for i in range(1,8):
            words[4*6*i:4*6*(i+1)] = AES.key_transform6g(words[4*6*(i-1):4*6*i],i)
        words[-16:] = AES.key_transform4g(words[-10*4:-6*4], i+1, g_in=words[-5*4:-4*4])
        return np.reshape(words[None],(13,16))
    
    def generate_key_schedule256(key):
        words = np.empty((4*60), dtype=np.uint32)
        words[0:8*4] = key
        for i in range(1,7):
            words[4*8*i:4*8*(i+1)] = AES.key_transform8hg(words[4*8*(i-1):4*8*i],i)
        words[-16:] = AES.key_transform4g(words[-12*4:-8*4], i+1, g_in=words[-5*4:-4*4])
        return np.reshape(words[None],(15,16))

    def mul_array(A,lambda_):
        irreducible = 27+256
        p = 0
        if lambda_ == 1:
            p = A
        elif lambda_ == 2:
            p = np.left_shift(A,1)
        elif lambda_ == 3:
            p = np.bitwise_xor(np.left_shift(A,1), A)
        elif lambda_ == 0x9:
            p = np.bitwise_xor(np.left_shift(A,3), A)
        elif lambda_ == 0xb:
            p = np.bitwise_xor(np.left_shift(A,3), np.bitwise_xor(np.left_shift(A,1), A))
        elif lambda_ == 0xd:
            p = np.bitwise_xor(np.left_shift(A,3), np.bitwise_xor(np.left_shift(A,2), A))
        elif lambda_ == 0xe:
            p = np.left_shift(np.bitwise_xor(np.left_shift(A,2), np.bitwise_xor(np.left_shift(A,1), A)),1)
        else:
            raise IllegalArgumentException("illegal lambda_ chosen for multiplication")

        p[p>=1024] = np.bitwise_xor(p[p>=1024], 4*irreducible)
        p[p>=768] = np.bitwise_xor(p[p>=768], np.bitwise_xor(2*irreducible,irreducible))
        p[p>=512] = np.bitwise_xor(p[p>=512], 2*irreducible)
        p[p>=256] = np.bitwise_xor(p[p>=256], irreducible)

        return p

    def key_addition(data,key):
        return np.bitwise_xor(data,key)

    def sbox_substitution(data):
        return AES.sbox[data]
    
    def inv_sbox_substitution(data):
        return AES.inv_sbox[data]

    def shift_rows(data):
        return data[:, AES.shift]
    
    def inv_shift_rows(data):
        return data[:, AES.inv_shift]

    def mix_column(data,inv=False):
        data_sq = np.reshape(data, (data.shape[0], 4, 4))
        if inv:
            mcm = AES.inv_mix_column_matrix
        else:
            mcm = AES.mix_column_matrix
        
        M = np.empty(data_sq.shape, dtype=np.uint32)
        for j in range(4):
            M[:,j,:] = np.bitwise_xor(
                np.bitwise_xor(AES.mul_array(data_sq[:,:,0],mcm[j,0]),
                               AES.mul_array(data_sq[:,:,1],mcm[j,1])),
                np.bitwise_xor(AES.mul_array(data_sq[:,:,2],mcm[j,2]),
                               AES.mul_array(data_sq[:,:,3],mcm[j,3])))
        
        return np.reshape(np.swapaxes(M,1,2), data.shape)
    
    def inv_mix_column(data):
        return AES.mix_column(data,inv=True)
    
    def _check_inputs(data, key):
        # data needs to be a flattened array with length a multiple of 16
        if data.shape[0] % 16 != 0:
            raise IllegalArgumentException("Data length needs to be a multiple of 16(bytes, 128 bit)")
        if len(key) == 16:
            keys = AES.generate_key_schedule128(key)
        elif len(key) == 24:
            keys = AES.generate_key_schedule192(key)
        elif len(key) == 32:
            keys = AES.generate_key_schedule256(key)
        else:
            raise IllegalArgumentException("The key has to be of size 16, 24 or 32 bytes (128, 192, 256 bits, respectively).")
        return keys
    
    def encrypt(data, key):
        keys = AES._check_inputs(data, key)
        out_data = np.reshape(data,(data.shape[0]//16, 16))
        rounds = AES.rounds[len(key)]
        for i in range(rounds):
            data_key_added = AES.key_addition(out_data,keys[i])
            #print("data_key_added",array2hex(data_key_added[0]))
            data_substituted = AES.sbox_substitution(data_key_added)
            #print("data_substituted",array2hex(data_substituted[0]))
            data_shifted = AES.shift_rows(data_substituted)
            #print("data_shifted",array2hex(data_shifted[0]))
            if i < rounds-1:
                data_mixed = AES.mix_column(data_shifted)
                #print("data_mixed",array2hex(data_mixed[0]))
            else:
                data_mixed = data_shifted
            out_data = data_mixed

        out_data = AES.key_addition(data_mixed,keys[-1])
        return out_data.flatten()
    
    def decrypt(data, key):
        keys = AES._check_inputs(data, key)
        keys = np.flip(keys,axis=0)
        out_data = np.reshape(data,(data.shape[0]//16, 16))
        rounds = AES.rounds[len(key)]
        for i in range(rounds):
            data_key_added = AES.key_addition(out_data,keys[i])
            if i > 0:
                data_mixed = AES.inv_mix_column(data_key_added)
            else:
                data_mixed = data_key_added
            data_shifted = AES.inv_shift_rows(data_mixed)
            data_substituted = AES.inv_sbox_substitution(data_shifted)
            out_data = data_substituted

        out_data = AES.key_addition(out_data,keys[-1])
        return out_data.flatten()

In [257]:
def string2array(string):
    out = []
    for c in string:
        out.append(ord(c))
    return np.asarray(out,dtype=np.uint32)

def array2string(array):
    out = []
    for i in array:
        out.append(char(i))
    return "".join(out)

def hex2array(hex_):
    out = []
    for i in range(0, len(hex_), 2):
        out.append(int(hex_[i:i+2],base=16))
    return np.asarray(out,dtype=np.uint32)

def array2hex(array):
    out = []
    for i in array:
        out.append("%02X" % i)
    return " ".join(out)

In [293]:
# 128 bit key test
plain = hex2array("3243F6A8885A308D313198A2E0370734")
correct_cipher = hex2array("3925841D02DC09FBDC118597196A0B32")
key = hex2array("2B7E151628AED2A6ABF7158809CF4F3C")

cipher = AES.encrypt(plain,key)
print(array2hex(cipher))
assert array2hex(cipher) == array2hex(correct_cipher)

decrypted = AES.decrypt(cipher,key)
print(array2hex(decrypted))
assert array2hex(plain) == array2hex(decrypted)

39 25 84 1D 02 DC 09 FB DC 11 85 97 19 6A 0B 32
32 43 F6 A8 88 5A 30 8D 31 31 98 A2 E0 37 07 34


In [294]:
# 192 bit key test
plain = hex2array("6BC1BEE22E409F96E93D7E117393172A")
correct_cipher = hex2array("BD334F1D6E45F25FF712A214571FA5CC")
key = hex2array("8E73B0F7DA0E6452C810F32B809079E562F8EAD2522C6B7B")

cipher = AES.encrypt(plain,key)
print(array2hex(cipher))
assert array2hex(cipher) == array2hex(correct_cipher)

decrypted = AES.decrypt(cipher,key)
print(array2hex(decrypted))
assert array2hex(plain) == array2hex(decrypted)

BD 33 4F 1D 6E 45 F2 5F F7 12 A2 14 57 1F A5 CC
6B C1 BE E2 2E 40 9F 96 E9 3D 7E 11 73 93 17 2A


In [295]:
# 256 bit key test
plain = hex2array("6BC1BEE22E409F96E93D7E117393172A")
correct_cipher = hex2array("F3EED1BDB5D2A03C064B5A7E3DB181F8")
key = hex2array("603DEB1015CA71BE2B73AEF0857D77811F352C073B6108D72D9810A30914DFF4")

cipher = AES.encrypt(plain,key)
print(array2hex(cipher))
assert array2hex(cipher) == array2hex(correct_cipher)

decrypted = AES.decrypt(cipher,key)
print(array2hex(decrypted))
assert array2hex(plain) == array2hex(decrypted)

F3 EE D1 BD B5 D2 A0 3C 06 4B 5A 7E 3D B1 81 F8
6B C1 BE E2 2E 40 9F 96 E9 3D 7E 11 73 93 17 2A


In [298]:
# 256 bit key test multiple blocks
plain = hex2array(8*"6BC1BEE22E409F96E93D7E117393172A")
correct_cipher = hex2array(8*"F3EED1BDB5D2A03C064B5A7E3DB181F8")
key = hex2array("603DEB1015CA71BE2B73AEF0857D77811F352C073B6108D72D9810A30914DFF4")

cipher = AES.encrypt(plain,key)
print(array2hex(cipher))
assert array2hex(cipher) == array2hex(correct_cipher)

decrypted = AES.decrypt(cipher,key)
print(array2hex(decrypted))
assert array2hex(plain) == array2hex(decrypted)

F3 EE D1 BD B5 D2 A0 3C 06 4B 5A 7E 3D B1 81 F8 F3 EE D1 BD B5 D2 A0 3C 06 4B 5A 7E 3D B1 81 F8 F3 EE D1 BD B5 D2 A0 3C 06 4B 5A 7E 3D B1 81 F8 F3 EE D1 BD B5 D2 A0 3C 06 4B 5A 7E 3D B1 81 F8 F3 EE D1 BD B5 D2 A0 3C 06 4B 5A 7E 3D B1 81 F8 F3 EE D1 BD B5 D2 A0 3C 06 4B 5A 7E 3D B1 81 F8 F3 EE D1 BD B5 D2 A0 3C 06 4B 5A 7E 3D B1 81 F8 F3 EE D1 BD B5 D2 A0 3C 06 4B 5A 7E 3D B1 81 F8
6B C1 BE E2 2E 40 9F 96 E9 3D 7E 11 73 93 17 2A 6B C1 BE E2 2E 40 9F 96 E9 3D 7E 11 73 93 17 2A 6B C1 BE E2 2E 40 9F 96 E9 3D 7E 11 73 93 17 2A 6B C1 BE E2 2E 40 9F 96 E9 3D 7E 11 73 93 17 2A 6B C1 BE E2 2E 40 9F 96 E9 3D 7E 11 73 93 17 2A 6B C1 BE E2 2E 40 9F 96 E9 3D 7E 11 73 93 17 2A 6B C1 BE E2 2E 40 9F 96 E9 3D 7E 11 73 93 17 2A 6B C1 BE E2 2E 40 9F 96 E9 3D 7E 11 73 93 17 2A
