# Advanced Encryption Standard (AES) Python Implementation
U.S. Department of Commerce. (1999). FIPS PUB 179: Advanced Encryption Standard (AES). Available at: https://csrc.nist.gov/publications/detail/fips/197/final [Accessed 8 February 2020].

## Overview
A. Base algorithm
  1. Round Key Generation
  2. Cipher - Encryption:    i. SubBytes,     ii. ShiftRows,     iii. MixColuns,    iv. xor round key
  3. Cipher - Decryption

### Import and Script Operation Variables

In [1]:
import sys

In [2]:
debug = 1 # Debug level 0 - 3, where 0 is no debug messages # 3 = detailed

### Static Variables

In [3]:
aes_sbox     = [['63','7c','77','7b','f2','6b','6f','c5','30','01','67','2b','fe','d7','ab','76'],
                ['ca','82','c9','7d','fa','59','47','f0','ad','d4','a2','af','9c','a4','72','c0'],
                ['b7','fd','93','26','36','3f','f7','cc','34','a5','e5','f1','71','d8','31','15'],
                ['04','c7','23','c3','18','96','05','9a','07','12','80','e2','eb','27','b2','75'],
                ['09','83','2c','1a','1b','6e','5a','a0','52','3b','d6','b3','29','e3','2f','84'],
                ['53','d1','00','ed','20','fc','b1','5b','6a','cb','be','39','4a','4c','58','cf'],
                ['d0','ef','aa','fb','43','4d','33','85','45','f9','02','7f','50','3c','9f','a8'],
                ['51','a3','40','8f','92','9d','38','f5','bc','b6','da','21','10','ff','f3','d2'],
                ['cd','0c','13','ec','5f','97','44','17','c4','a7','7e','3d','64','5d','19','73'],
                ['60','81','4f','dc','22','2a','90','88','46','ee','b8','14','de','5e','0b','db'],
                ['e0','32','3a','0a','49','06','24','5c','c2','d3','ac','62','91','95','e4','79'],
                ['e7','c8','37','6d','8d','d5','4e','a9','6c','56','f4','ea','65','7a','ae','08'],
                ['ba','78','25','2e','1c','a6','b4','c6','e8','dd','74','1f','4b','bd','8b','8a'],
                ['70','3e','b5','66','48','03','f6','0e','61','35','57','b9','86','c1','1d','9e'],
                ['e1','f8','98','11','69','d9','8e','94','9b','1e','87','e9','ce','55','28','df'],
                ['8c','a1','89','0d','bf','e6','42','68','41','99','2d','0f','b0','54','bb','16']]

inv_aes_sbox = [['52','09','6a','d5','30','36','a5','38','bf','40','a3','9e','81','f3','d7','fb'],
                ['7c','e3','39','82','9b','2f','ff','87','34','8e','43','44','c4','de','e9','cb'],
                ['54','7b','94','32','a6','c2','23','3d','ee','4c','95','0b','42','fa','c3','4e'],
                ['08','2e','a1','66','28','d9','24','b2','76','5b','a2','49','6d','8b','d1','25'],
                ['72','f8','f6','64','86','68','98','16','d4','a4','5c','cc','5d','65','b6','92'],
                ['6c','70','48','50','fd','ed','b9','da','5e','15','46','57','a7','8d','9d','84'],
                ['90','d8','ab','00','8c','bc','d3','0a','f7','e4','58','05','b8','b3','45','06'],
                ['d0','2c','1e','8f','ca','3f','0f','02','c1','af','bd','03','01','13','8a','6b'],
                ['3a','91','11','41','4f','67','dc','ea','97','f2','cf','ce','f0','b4','e6','73'],
                ['96','ac','74','22','e7','ad','35','85','e2','f9','37','e8','1c','75','df','6e'],
                ['47','f1','1a','71','1d','29','c5','89','6f','b7','62','0e','aa','18','be','1b'],
                ['fc','56','3e','4b','c6','d2','79','20','9a','db','c0','fe','78','cd','5a','f4'],
                ['1f','dd','a8','33','88','07','c7','31','b1','12','10','59','27','80','ec','5f'],
                ['60','51','7f','a9','19','b5','4a','0d','2d','e5','7a','9f','93','c9','9c','ef'],
                ['a0','e0','3b','4d','ae','2a','f5','b0','c8','eb','bb','3c','83','53','99','61'],
                ['17','2b','04','7e','ba','77','d6','26','e1','69','14','63','55','21','0c','7d']]

mix_matrix       = [['02', '03', '01', '01'],
                    ['01', '02', '03', '01'],
                    ['01', '01', '02', '03'],
                    ['03', '01', '01', '02']]

inv_mix_matrix   = [['0e', '0b', '0d', '09'],
                    ['09', '0e', '0b', '0d'],
                    ['0d', '09', '0e', '0b'],
                    ['0b', '0d', '09', '0e']]

### Functions (from DES)

In [4]:
# Custom print formatting
def cPrint(str, format):
    out = str
    if format == "underline":
        out = "\033[4m"+str+"\033[0m"
    elif format == "bold":
        out = "\033[1m"+str+"\033[0m"
    print(out)

def split_into_blocks(input_str, block_length):
    # Function takes input_str and splits into blocks of length block_length characters
    # eg. split_into_blocks('abc123',3) returns ['abc','123']
    split_blocks = [input_str[i:i+block_length] for i in range(0, len(input_str), block_length)]
    return split_blocks

    
# Binary, Hex, Ascci Conversion
def convert_hex_to_binary(hexdigits):
    # Function takes hex input, converts and outputs binary
    # eg. convert_hex_to_binary('1F') returns '11111'
    binarydigits = ""
    for hexdigit in hexdigits:
        binarydigits += bin(int(hexdigit,16))[2:].zfill(4)
    return binarydigits

def convert_binary_to_hex(binarydigits):
    # Function takes binary input, converts and outputs hex
    # eg. convert_hex_to_binary('11111') returns '1F'
    hexdigits = '%0*X' % ((len(binarydigits) + 3) // 4, int(binarydigits, 2))
    return hexdigits

def convert_ascii_to_hex(ascii_input):
    # Function takes ascii input, converts and outputs hex
    # eg. convert_ascii_to_hex('abc') returns '616263'
    return ''.join([hex(ord(c))[2:].zfill(2) for c in ascii_input])

def convert_hex_to_ascii(hex_input):
    # Function takes hex input, converts and outputs ascii
    # eg. convert_hex_to_ascii('616263') returns 'abc'
    return ''.join([chr(int(''.join(c), 16)) for c in zip(hex_input[0::2],hex_input[1::2])])


def xor(x,y):
    # Function performs manual bitwise exclusive or ie. x ^ y where x and y are binary
    # Example: xor('110001001100011','110010001100101') returns '000011000000110'
    
    result = ""
    for i in range(len(x)):
        if x[i] == y[i]: result += '0'
        else: result += '1'
    return result


def xor_hex(x,y):
    # Function performs manual bitwise exclusive or ie. x ^ y where x and y are hex
    # Example: xor_hex('6263','6465') returns '0606'
    
    binary_x = convert_hex_to_binary(x)
    binary_y = convert_hex_to_binary(y)
    
    result = ""
    for i in range(len(binary_x)):
        if binary_x[i] == binary_y[i]: result += '0'
        else: result += '1'
    return convert_binary_to_hex(result)


def xor_hex_grid(grid1, grid2):
    # Function performs manual bitwise exclusive or ie. x ^ y where x and y are hex and the grid schema are the same
    xor_output = [] 
    for i, column in enumerate(grid1):
        xor_output.append([])
        for j, row in enumerate(column):
            binary_input1 = convert_hex_to_binary(row)
            binary_input2 = convert_hex_to_binary(grid2[i][j])

            result = ""
            for index in range(len(binary_input1)):
                if binary_input1[index] == binary_input2[index]: 
                    result += '0'
                else:
                    result += '1'
            
            xor_output[i].append(convert_binary_to_hex(result))

    return xor_output


def apply_permutation(permuted_choice_table,ini_input):
    # Function takes an initial input (ini_input) and permutes using static permuted_choice_table
    if debug >= 3: print("[DEBUG]\tStarting function apply_permutation\n\tInput:",ini_input)
        
    permutated_out = ""
    
    for hex_char in split_into_blocks(ini_input, 2):
        
        lookup_row_index = int(str(hex_char[:1]), 16) # Convert the first hex bit into decimal for row lookup
        lookup_col_index = int(str(hex_char[1:]), 16) # Convert the second hex bit into decimal for column lookup
        
        sbox_value = permuted_choice_table[lookup_row_index][lookup_col_index]
    
        permutated_out += sbox_value
        
        # Print if debug level >= 3
        if debug >= 3: print("\tLookup hex:",hex_char, "Returned value:",sbox_value, "( Lookup Ref (Row x Column)",lookup_row_index,"x",lookup_col_index,")")
    
    if debug >= 3: print("[DEBUG]\tPermutated output :",permutated_out)
            
    return permutated_out

def list_to_string(list):
    # Function takes a grid list and converts to string to string
    
    cipher_string = ''
    for i, row in enumerate(list):
        for j, column in enumerate(row):
            cipher_string += column
    return cipher_string

## Key Expansion (Generating Round Key )

### Functions for Generating Round Key 

In [5]:
# Function to rotate bits
def RotWord(input_bits): # Improvement; Could use left shift function from DES
    # Rotate = Circular left shift 2 bits
    if debug >= 3: print("[DEBUG]\tStarting function RotWord")
    
    rotated_bits = input_bits[2:] + input_bits[:2]
    if debug >= 3: print("\tInput:",input_bits,"Rotated output:",rotated_bits)
        
    return rotated_bits


# Wrapper for apply_permutation
def SubWord(input_hex):
    if debug >= 3: print("[DEBUG]\tStarting function SubWord")
    return apply_permutation(aes_sbox,input_hex)


def gen_rcon(round):
    # https://crypto.stackexchange.com/questions/2418/how-to-use-rcon-in-key-expansion-of-128-bit-advanced-encryption-standard
    
    if debug >= 3: print("[DEBUG]\tStarting function gen_rcon\n\tround:",round)
    rcon = 0x8d
    for i in range(0, round):
        rcon = (rcon<<1) ^ (0x11b & -(rcon>>7)) #& 0xff
    return hex(rcon)[2:].zfill(2)+'000000'


### Generate Round Keys

## Cipher

In [6]:
def generate_round_keys(k, Nk, Nb, Nr):
    
    if debug >= 2: print("\n[DEBUG]\tStarting function generate_round_keys with variables:\n\tk:",k,"Nk:",Nk,"Nb:",Nb,"Nr:",Nr)

    # Split key (k) into 4 keys (8 bit each) 
    split_key = split_into_blocks(k, 8)
        
    if debug >= 2: print("[DEBUG]\tSplit key Array:", split_key)

    # Seed last split key into expander
    state = (split_key[-1])
    
    #key_num = 0
    gen_key = [split_key[i * 4:(i + 1) * 4] for i in range((len(split_key) + 4 - 1) // 4 )]
    key_num = len(gen_key) -1
#     print(gen_key)
#     sys.exit()
    #= [split_key[i * n:(i + 1) * n] for i in range((len(split_key) + n - 1) // n )]
    
    if debug >= 2: print("[DEBUG]\tStarting loop range",(Nr+1)*4)
        
    for i in range(Nk, (Nr+1)*Nb):
        if (i % Nb) == 0:
            key_num += 1
            if debug >= 2: print("[DEBUG]\tStarting new key", key_num,"on iteration:",i,"with new Rcon:",gen_rcon(int(i/Nk)))
            gen_key.append([])
            
        if (i % Nk) == 0: # was Nk

            state = xor_hex(SubWord(RotWord(state)),gen_rcon(int(i/Nk)))
            if debug >= 2: print("[DEBUG]\tAfter XOR with Rcon\t", state)
            
        if Nk > 6 and i % Nk == 4:
            state = SubWord(state)
            if debug >= 2: print("After SubWord\t\t", state)
            

        if debug >= 2: print("[DEBUG]\tw[i–Nk]\t\t\t", split_key[i-Nk])

        next_key_section = xor_hex(state,split_key[i-Nk])
        if debug >= 2: print("[DEBUG]\tw[i]= temp XOR w[i-Nk]\t", next_key_section)

        split_key.append(next_key_section)  
       
        gen_key[key_num].append(next_key_section)
        
        state = next_key_section

    return gen_key


### Functions for Cipher

In [7]:
def SubBytes(sbox, round_grid):
    # Function looks up a round grid (128 bits), 2 bits (1 hex byte) at a time.
    # eg. SubBytes(aes_sbox, round_grid)
    if debug >= 3: print("[DEBUG]\tStarting function SubBytes")
    
    # Initialise the subbed round_grid list 
    round_grid_subbed = []
    
    # For each row in round_grid
    for i, row in enumerate(round_grid):
        
        # Create a new 'sub-list' inside the round_grid_subbed list
        round_grid_subbed.append([])
        
        # For each column in each round_grid row
        for j, column in enumerate(row):
            
            # Add sbox lookup value to round_grid_subbed list
            round_grid_subbed[i].append(apply_permutation(sbox,column))
            
            # debug output
            if debug >= 3: print("\tProcessing bit:", column, "SubBytes out:", apply_permutation(sbox,column))
    
    return round_grid_subbed


def ShiftRows(round_grid_subbed, direction):
    # Function takes a round grid (128 bits, 8x8), 2 bits (1 hex byte) at a time and shifts columns by the row number (base 0)
    # eg. ShiftRows(ShiftRows, 'left')
    
    # Transpose the round_grid_subbed for ease of list manipulation 
    round_grid_transposed = list(map(list, zip(*round_grid_subbed)))
    
    # Initialise the shifed transposed list 
    round_grid_transposed_shifted = []
    
    # For each sub-list in list peform the shift based on the index of the sub-list (which prior to transposition is the row number)
    for i, row in enumerate(round_grid_transposed):
        
        # If direction = right then make i negative by *-1
        shift = (i*-1) if direction == 'right' else i
        
        # Peform shift and add to list
        round_grid_transposed_shifted.append(row[shift:] + row[:shift])
        
    # Reverse transposition
    round_grid_shifted = list(map(list, zip(*round_grid_transposed_shifted)))

    # Debug printing
    if debug >= 3: print("[DEBUG]\tround_grid_transposed:\n", round_grid_transposed), print("[DEBUG]\tround_grid_shifted:\n", round_grid_shifted)
        
    return round_grid_shifted


def mx2(hex_input):
    # Function peforms multiplication of a single hex byte (two characters eg. d4) by {02} 
    # Deduced from explanation of finite field arithmetic: https://crypto.stackexchange.com/questions/2402/how-to-solve-mixcolumns/2403
    # eg. mx2("d4") = b3
    
    # Convert hex to binary
    binary_input = convert_hex_to_binary(hex_input)
    
    # Get left most bit of binary value
    first_bit = binary_input[0]
    
    # Left shift by 1 digit, add 0 to right
    binary_left_shifted = binary_input[1:] + "0"
    
    # If the left most bit (before left shift) [first_bit] is 1 then xor with hex '1B' (bin 00011011) else do nothing
    output_binary = xor(binary_left_shifted,"00011011") if first_bit == "1" else binary_left_shifted

    return convert_binary_to_hex(output_binary) # return as hex


def MixColumns(round_grid_shifted,mix_matrix):
    
    if debug >= 3: print("[DEBUG]\tStarting function MixColumns")
        
    round_grid_mixed = []
    for i, column in enumerate(round_grid_shifted):
        round_grid_mixed.append([])
        for j, row in enumerate(column):
            mix_coef = mix_matrix[j]
            temp_mix_stack = None
            if debug >= 3: print("\tCell Input:", column[j])

            for ci in range(len(mix_coef)):
                ci_calc_out = None

                if mix_coef[ci] == '01':
                    ci_calc_out = column[ci]
                elif mix_coef[ci] == '02':
                    ci_calc_out = mx2(column[ci])
                elif mix_coef[ci] == '03':
                    ci_calc_out = xor_hex(mx2(column[ci]),column[ci])


                if debug >= 3: print("\t\t", column[ci],"x", mix_coef[ci],"=", ci_calc_out)

                # xor output with previous 
                if temp_mix_stack:
                    next_temp_mix_stack = xor_hex(temp_mix_stack,ci_calc_out)
                else:
                    next_temp_mix_stack = ci_calc_out

                # set next to current for ci+1    
                temp_mix_stack = next_temp_mix_stack

            if debug >= 3: print("\tCell Output=", temp_mix_stack)

            round_grid_mixed[i].append(temp_mix_stack)

    return round_grid_mixed


def InvMixColumns(round_grid_shifted,mix_matrix):
    # https://crypto.stackexchange.com/questions/2569/how-does-one-implement-the-inverse-of-aes-mixcolumns

    if debug >= 3: print("[DEBUG]\tCalculating Mix Columns")
        
    round_grid_mixed = []
        
    for i, column in enumerate(round_grid_shifted):
        round_grid_mixed.append([])
        for j, row in enumerate(column):
            mix_coef = mix_matrix[j]
            temp_mix_stack = None
            if debug >= 3: print("\n[DEBUG]\tCell:", column[j])

            for ci in range(len(mix_coef)):
                ci_calc_out = None

                if mix_coef[ci] == '09': # 09=9
                    # 𝑥x9 = (((𝑥x2)x2)x2)+𝑥   (where + is xor)
                    ci_calc_out = xor_hex(mx2(mx2(mx2(column[ci]))),column[ci])
                    
                elif mix_coef[ci] == '0b': # 0b=11
                    # 𝑥x11 = ((((𝑥x2)x2)+𝑥)x2)+𝑥
                    ci_calc_out = xor_hex(mx2(xor_hex(mx2(mx2(column[ci])),column[ci])),column[ci])
                    
                elif mix_coef[ci] == '0d': # 0d=13
                    # 𝑥x13 = ((((𝑥x2)+𝑥)x2)x2)+𝑥
                    ci_calc_out = xor_hex(mx2(mx2(xor_hex(mx2(column[ci]),column[ci]))),column[ci])
                    
                elif mix_coef[ci] == '0e': # 0e=14
                    # 𝑥x14 = ((((𝑥x2)+𝑥)x2)+𝑥)x2
                    ci_calc_out = mx2(xor_hex(mx2(xor_hex(mx2(column[ci]),column[ci])),column[ci]))

                if debug >= 3: print("[DEBUG]\tVal:", column[ci],"\tCoef:", mix_coef[ci],"\tOut:", ci_calc_out)

                # xor output with previous 
                if temp_mix_stack:
                    next_temp_mix_stack = xor_hex(temp_mix_stack,ci_calc_out)
                else:
                    next_temp_mix_stack = ci_calc_out

                # set next to current for ci+1    
                temp_mix_stack = next_temp_mix_stack

            if debug >= 3: print("[DEBUG]\tMixedCell=",temp_mix_stack)

            round_grid_mixed[i].append(temp_mix_stack)

    return round_grid_mixed


### Encryption

In [8]:
def aes_encrypt(k,m, bit_mode):
    # Function takes key (k), hex message (m) and bit mode integer (bit_mode) and encrypts using AES.
    # eg. aes_encrypt('2b7e151628aed2a6abf7158809cf4f3c','3243f6a8885a308d313198a2e0370734', 128)

    
    # Check if the bit_mode is supported
    if not(bit_mode == 128 or bit_mode == 256): print("[ERROR] Bit Mode is not supported. Use '128' or '256'"), sys.exit()
        
    # Set Key, Block, Round combinations based on bit_mode
    if bit_mode == 128:    Nk = 4; Nb = 4; Nr = 10
    elif bit_mode == 256:  Nk = 8; Nb = 4; Nr = 14
        
        
    # A. Prepare Round Keys  
    ## A1. Key Expansion: Generate 'round keys'
    round_keys = generate_round_keys(k, Nk, Nb, Nr)
    if debug >= 2: print("\n[DEBUG]\tround_keys:"), print(round_keys)
    
    ## A2. Convert the round keys into a grid with 2 bits in each cell
    round_keys_grid = []
    for i in range(len(round_keys)):
        round_keys_grid.append([])
        for j in range(len(round_keys[i])):
            round_keys_grid[i].append(split_into_blocks(round_keys[i][j],2))
    if debug >= 2: print("\n[DEBUG]\t round_keys_grid:"), print(round_keys_grid)
  
        
    # B. Prepare Message  
    ## B1. Split 128 bit message into 8 bit blocks
    split_m = split_into_blocks(m,8)
    if debug >= 2: print("\n[DEBUG]\t split_m:"), print(split_m)

    ## B2. Convert the message blocks into a grid with 2 bits in each cell
    round_grid = []
    for i in range(len(split_m)):
        round_grid.append(split_into_blocks(split_m[i],2))
    if debug >= 2: print("\n[DEBUG]\tround_grid:\t",round_grid)

        
    # C. Interations
    # Round 1: Just xor round key (no SubBytes, ShiftRows or MixColumns)
    # Round 2 to (Nr-1): 1. SubBytes, 2. ShiftRows, 3. MixColumns and 4. xor round key
    # Round Nr: 1. SubBytes, 2. ShiftRows and 4. xor round key (no MixColumns)

    for i in range(Nr+1): # +1 as base 0
        if debug >= 2: print("\n\n[DEBUG]\tRound:",i)
        
        # Initialise the state 
        state = round_grid
        
        if i > 0: # If the round is *not* round 1
            # 1. SubBytes
            state = SubBytes(aes_sbox, round_grid)
            if debug >= 2: print("[DEBUG]\tsub_grid:\t",state)

            # 2. ShiftRows
            state = ShiftRows(state, 'left')
            if debug >= 2: print("[DEBUG]\tshift_grid:\t",state)

            if i != Nr: # If the round is *not* the Nr round
                # 3. MixColumns
                state = MixColumns(state,mix_matrix)
                if debug >= 2: print("[DEBUG]\tmix_grid:\t",state)

        # 4. xor round key
        round_grid = xor_hex_grid(state, round_keys_grid[i])
        if debug >= 2: print("[DEBUG]\tround_key:\t",round_keys_grid[i]), print("[DEBUG]\txor:\t\t",round_grid)
        
    return round_grid

### Decryption

In [9]:
def aes_decrypt(k,c, bit_mode):
    # Function takes key (k), hex cipher (c) and bit mode integer (bit_mode) and decrypts using AES.
    # eg. aes_decrypt('2b7e151628aed2a6abf7158809cf4f3c','3925841D02DC09FBDC118597196A0B32', 128)
    
    # Check if the bit_mode is supported
    if not(bit_mode == 128 or bit_mode == 256): print("[ERROR] Bit Mode is not supported. Use '128' or '256'"), sys.exit()
        
    # Set Key, Block, Round combinations based on bit_mode
    if bit_mode == 128:    Nk = 4; Nb = 4; Nr = 10
    elif bit_mode == 256:  Nk = 8; Nb = 4; Nr = 14

        
    # A. Prepare Round Keys  
    ## A1. Key Expansion: Generate 'round keys'
    round_keys = generate_round_keys(k, Nk, Nb, Nr)
    if debug >= 2: print("[DEBUG]\t round_keys:",round_keys)

    ## A2. Convert the round keys into a grid with 2 bits in each cell
    round_keys_grid = []
    for i in range(len(round_keys)):
        round_keys_grid.append([])
        for j in range(len(round_keys[i])):
            round_keys_grid[i].append(split_into_blocks(round_keys[i][j],2))
    if debug >= 2: print("\n[DEBUG]\t round_keys_grid:\t",round_keys_grid)

        
    # B. Prepare Cipher  
    ## B1. Split 128 bit cipher into 8 bit blocks
    split_c = split_into_blocks(c,8)
    if debug >= 2: print("[DEBUG]\t round_grid:\t",split_c)

    ## B2. Convert the cipher blocks into a grid with 2 bits in each cell
    round_grid = []
    for i in range(len(split_c)):
        round_grid.append(split_into_blocks(split_c[i],2))
    if debug >= 2: print("[DEBUG]\t round_grid:\t",round_grid)

        
    # C. Interations (reversed)
    # Round Nr: 1. xor round key, 2.ShiftRows and 4. SubBytes (no MixColumns)
    # Round (Nr-1) to 2: 1. xor round key, 2. MixColumns, 3. ShiftRows and 4. SubBytes 
    # Round 1: Just xor round key (no SubBytes, ShiftRows or MixColumns)
    
    for i in reversed(range(Nr+1)): # + 1 as base 0
        if debug >= 2: print("\n\n[DEBUG]\t Round:",i)
        state = round_grid
        
        # 1. xor round key
        state = xor_hex_grid(state, round_keys_grid[i])
        if debug >= 2: print("[DEBUG]\t round_key:\t",round_keys_grid[i]), print("[DEBUG]\t xor:\t\t",state)
        
        if i > 0: # If the round is *not* round 1
            if i != Nr: # If the round is *not* the Nr round
                # 2. Inverse MixColumns
                state = InvMixColumns(state,inv_mix_matrix)
                if debug >= 2: print("[DEBUG]\t mix_grid:\t",state)
                   
            # 3. ShiftRows (right)
            state = ShiftRows(state, 'right')
            if debug >= 2: print("[DEBUG]\t shift_grid:\t",state)

            # 4. Inverse SubBytes
            state = SubBytes(inv_aes_sbox, state)
            if debug >= 2: print("[DEBUG]\t sub_grid:\t",state)

        round_grid = state
        
    return round_grid


# Modes of Operation

## Padding Functions

In [10]:
def apply_padding_s1(hex_m):
    hex_m_padded = hex_m
    rem = len(hex_m) % 32
    
    if rem > 0:
        # add padding
        padding_len = 32-rem
        if padding_len == 2:
            hex_m_padded = hex_m+"80" # 80 = hex for 128
        elif padding_len > 2:
            no_of_zeros = (padding_len-2)
            hex_m_padded = hex_m+"80"
            for i in range(no_of_zeros):
                hex_m_padded = hex_m_padded+"0"         
                
    return hex_m_padded

def remove_padding_s1(hex_m_padded):
    # NEEDS WORK - how to check- careful not to split a valid 128 dec (80 in hex) characters
    hex_m = hex_m_padded.rsplit('80', 1)
    return hex_m[0]

## Cipher Block Chain (CBC) Mode

### Encryption

In [11]:
# Input m (ascii) & k (hex)

# 128 bit test
# m = '3243f6a8885a308d313198a2e0370734' # Test m1
# m = '00112233445566778899aabbccddeeff' # Test m2
k = '2b7e151628aed2a6abf7158809cf4f3c'

# 256 bit test
# m = '00112233445566778899aabbccddeeff' # Test m1
# k = '000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f'

iv = '3A8E3A8E332ECC7323729FECC73729F3' # hex 8x2 (16) char / 64 bits
m = "Hello World! This is a test of AES encryption." # n lenght ascii


# Begin ECB Mode encryption
if debug >= 1:
    cPrint("Encrypting", "bold")
    print("\r")
    
# Convert ascci m to hex
hex_m = convert_ascii_to_hex(m)

# Add padding scheme 1 to hex m
hex_m_padded= apply_padding_s1(hex_m)

# Split padded hex into blocks of 8 bytes (8 x 2 long (16) hex characters )
hex_m_blocks = split_into_blocks(hex_m_padded, 32)

if debug >= 1:
    cPrint("m", "underline")
    print(m)
    print("\r")
    
    cPrint("m Hex", "underline")
    print(hex_m)
    print("\r")
    
    cPrint("m Hex Blocks (Post Padding)", "underline")
    print(hex_m_blocks)
    print("\r")

# Initialise cipher text c
c=''

transformer = iv

# Loop hex blocks and encrypt
for block in hex_m_blocks:
    if debug >= 3:
        print("\n[DEBUG]\t","Starting new block:",block)
    
    # xor
    transformed_block = xor_hex(transformer,block)

    #Encrypt block 
    c_block = list_to_string(aes_encrypt(k,transformed_block, 128))
    
    # Append block 
    c += c_block
    
    # Set the cipher of current block to be used as transformer in next block
    transformer = c_block
    
    
if debug >= 1:
    cPrint("Cipher text (Hex)", "underline")
    print(c)
    print("\r")


[1mEncrypting[0m

[4mm[0m
Hello World! This is a test of AES encryption.

[4mm Hex[0m
48656c6c6f20576f726c6421205468697320697320612074657374206f662041455320656e6372797074696f6e2e

[4mm Hex Blocks (Post Padding)[0m
['48656c6c6f20576f726c642120546869', '7320697320612074657374206f662041', '455320656e6372797074696f6e2e8000']

[4mCipher text (Hex)[0m
E8471311A31371E60948771851DAFA8D022F80631274480D6E70A4B77645F8460A1C640DCC96AA97E3C926C5A593ECD3



### Decryption

In [12]:
# Begin ECB Mode decryption
if debug >= 1:
    print("\r")
    cPrint("Decrypting", "bold")
    print("\r")

# Split c into blocks for decryption
hex_c_blocks = split_into_blocks(c, 32)

if debug >= 1:
    cPrint("c Hex Blocks", "underline")
    print(hex_c_blocks)
    print("\r")

# Initialise decrypted message
decrypted_m = ''

# Loop hex blocks and encrypt
for block in reversed(hex_c_blocks):
    if debug >= 3:
        print("\n[DEBUG]\t","Starting new block:",block)
        
    if hex_c_blocks.index(block) == 0:
        # use static iv
        transformer = iv
        
    else:
        # use cn-1
        transformer = hex_c_blocks[hex_c_blocks.index(block)-1]
    
    # Decrypt block 
    d_block = list_to_string(aes_decrypt(k,block, 128))
    
    # xor
    transformed_d_block = xor_hex(transformer,d_block)
    
    # Append block 
    decrypted_m = transformed_d_block + decrypted_m
    
    
if debug >= 1 :
    cPrint("Decrypted m (Padded) (Hex)", "underline")
    print(decrypted_m)
    print("\r")
    
    cPrint("Decrypted m (Hex)", "underline")
    print(remove_padding_s1(decrypted_m))
    print("\r")
    
    cPrint("Decrypted m (ascii)", "underline")
    print(convert_hex_to_ascii(remove_padding_s1(decrypted_m)))
    print("\r")


[1mDecrypting[0m

[4mc Hex Blocks[0m
['E8471311A31371E60948771851DAFA8D', '022F80631274480D6E70A4B77645F846', '0A1C640DCC96AA97E3C926C5A593ECD3']

[4mDecrypted m (Padded) (Hex)[0m
48656C6C6F20576F726C6421205468697320697320612074657374206F662041455320656E6372797074696F6E2E8000

[4mDecrypted m (Hex)[0m
48656C6C6F20576F726C6421205468697320697320612074657374206F662041455320656E6372797074696F6E2E

[4mDecrypted m (ascii)[0m
Hello World! This is a test of AES encryption.

