## CIIC5018 / ICOM5018
## Network Security and Cryptography
## Project 5: AES Implementation
### Francis Jose Patron Fidalgo (802180833)
### sec: 060
### oct/16/2022

## Addition of two numbers

In [280]:
def add(a: int, b: int) -> int:
    return a ^ b 

## mult of two numbers

In [281]:
def mult(a: int, b: int) -> int:
    modulo = 0b100011011    # x^8+x^4+x^3+x+1       
    prod= 0
    for i in range(8):
        prod = prod << 1 
        if prod & 0b100000000:
            prod = prod ^ modulo
        if b & 0b010000000:
            prod = prod ^ a
        b = b << 1
    return prod

## Inverse of a number

In [282]:
def inverse(x: int) -> int: # x^254 = 1/x
    p=mult(x,x)   # p = x^2
    x=mult(p,p)   # x = x^4
    p=mult(p,x)   # p = x^(2+4)
    x=mult(x,x)   # x = x^8
    p=mult(p,x)   # p = x^(2+4+8)
    x=mult(x,x)   # x = x^16
    p=mult(p,x)   # p = x^(2+4+8+16)
    x=mult(x,x)   # x = x^32
    p=mult(p,x)   # p = x^(2+4+8+16+32)
    x=mult(x,x)   # x = x^64
    p=mult(p,x)   # p = x^(2+4+8+16+32+64)
    x=mult(x,x)   # x = x^128
    p=mult(p,x)   # p = x^(2+4+8+16+32+64+128)
    return p

## Generation of S-Box and InvS-Box

In [283]:
def transform_sbox(byte: hex) -> str:
    c = 0x63
    new_num = 0x00
    for i in range(8):
        new_num = new_num | (((byte >> i) & 0x01) 
            ^ ((byte >> ((i + 4) % 8 )) & 0x01) 
            ^ ((byte >> ((i + 5) % 8 )) & 0x01) 
            ^ ((byte >> ((i + 6) % 8 )) & 0x01)
            ^ ((byte >> ((i + 7) % 8 )) & 0x01)
            ^ ((c >> i) & 0x01)) << i
    return new_num

In [284]:
def transform_isbox(byte: hex) -> str:
    c = 0x05
    new_num = 0x00
    for i in range(8):
        new_num = new_num | (((byte >> ((i + 2) % 8 )) & 0x01) 
            ^ ((byte >> ((i + 5) % 8 )) & 0x01) 
            ^ ((byte >> ((i + 7) % 8 )) & 0x01)
            ^ ((c >> i) & 0x01)) << i
    return new_num

In [285]:
def get_sbox():
    sbox = []
    # Initialize the S-box with the byte values in ascending sequence row by row
    for i in range(16):
        row = []
        for j in range(16):
            row.append(int(hex(i)+hex(j)[2:], 16))
        sbox.append(row)
    # Map each byte in the S-box to its multiplicative inverse in the finite field
    # GF(^8); the value {00} is mapped to itself.
    for row in sbox:
        for i in range(len(row)):
            row[i] = inverse(row[i])
            # Apply the transformation to each bit of each byte in the S-box
            row[i] = '{:02x}'.format(transform_sbox(row[i]))
    return sbox

In [286]:
def get_isbox():
    sbox = []
    # Initialize the S-box with the byte values in ascending sequence row by row
    for i in range(16):
        row = []
        for j in range(16):
            row.append(int(hex(i)+hex(j)[2:], 16))
        sbox.append(row)
    for row in sbox:
        for i in range(len(row)):
            # Apply the transformation to each bit of each byte in the iS-box
            row[i] = transform_isbox(row[i])
            # Map each byte in the S-box to its multiplicative inverse in the finite field
            # GF(^8); the value {00} is mapped to itself.
            row[i] = '{:02x}'.format(inverse(row[i]))
    return sbox

In [287]:
SBOX = get_sbox()
ISBOX = get_isbox()

# I. function to convert a sequence of 16 bytes to a 4x4 square

In [288]:
def gen4x4Square(bytes):
    # take in 16 bytes, return 4x4 matrix
    # support for string keys
    if type(bytes) is str:
        bytes = [bytes[i:i+2] for i in range(0, len(bytes), 2)]
    return [[b for b in bytes[i*4:i*4+4]] for i in range(4)]

# II. function to convert a 4x4 square to a sequence of 16 bytes

In [289]:
def gen16Bytes(matrix):
    bytes = []
    [[bytes.append(matrix[row][col]) for col in range(4)] for row in range(4)]
    return ''.join(bytes)

# III. A function to print the current state

In [290]:
def print_state(state, round=None, rkey=None, exp_keys=None):
    print('-'*35)
    if not round is None:
        print(f'{round}')
        print('-'*35)
    key = [['']*4]*4
    kname = ''
    if not rkey is None:
        key = exp_keys[rkey:rkey+4]
        kname = 'round key:'
    print(f'state:                 {kname}')
    for i in range(4):
        print(f'{state[0][i]} {state[1][i]} {state[2][i]} {state[3][i]}            {key[0][i]} {key[1][i]} {key[2][i]} {key[3][i]}')
    print('')

# IV. SubByte

In [291]:
def sub_byte(state):
    new_state = [[],[],[],[]]
    for row in range(4):
        for col, byte in enumerate(state[row]):
            byte = '{:02x}'.format(int(byte, 16))
            new_state[row].append(SBOX[int(byte[0], 16)][int(byte[1], 16)])
    return new_state

# V. InvSubByte

In [292]:
def inv_sub_byte(state):
    new_state = [[],[],[],[]]
    for row in range(4):
        for col, byte in enumerate(state[row]):
            byte = '{:02x}'.format(int(byte, 16))
            new_state[row].append(ISBOX[int(byte[0], 16)][int(byte[1], 16)])
    return new_state

# VI. ShiftRows

In [293]:
def shift_row(state):
    rows = []
    # switch to list of columns
    for i in range(4):
        rows.append([int(col[i], 16) for col in state])
    new_state = rows
    # left shift second row by 1
    new_state[1].append(new_state[1].pop(0))
    # left shift third row by 2
    new_state[2].append(new_state[2].pop(0))
    new_state[2].append(new_state[2].pop(0))
    # left shift third row by 3
    new_state[3].append(new_state[3].pop(0))
    new_state[3].append(new_state[3].pop(0))
    new_state[3].append(new_state[3].pop(0))
    # switch back
    return [['{:02x}'.format(row[i]) for row in new_state] for i in range(4)]

# VII. InvShiftRows

In [294]:
def inv_shift_row(state):
    rows = []
    # switch to list of columns
    for i in range(4):
        rows.append([int(col[i], 16) for col in state])
    new_state = rows
    # left shift second row by 1
    new_state[1].insert(0, new_state[1].pop())
    # left shift third row by 2
    new_state[2].insert(0, new_state[2].pop())
    new_state[2].insert(0, new_state[2].pop())
    # left shift third row by 3
    new_state[3].insert(0, new_state[3].pop())
    new_state[3].insert(0, new_state[3].pop())
    new_state[3].insert(0, new_state[3].pop())
    # switch back
    return [['{:02x}'.format(row[i]) for row in new_state] for i in range(4)]

# VIII. MixColumns

In [295]:
def mix(col):
    # col = [a, b, c, d]
    new_col = col.copy()
    new_col[0] = '{:02x}'.format(mult(2, col[0]) ^ mult(3, col[1]) ^ mult(1, col[2]) ^ mult(1, col[3]))
    new_col[1] = '{:02x}'.format(mult(1, col[0]) ^ mult(2, col[1]) ^ mult(3, col[2]) ^ mult(1, col[3]))
    new_col[2] = '{:02x}'.format(mult(1, col[0]) ^ mult(1, col[1]) ^ mult(2, col[2]) ^ mult(3, col[3]))
    new_col[3] = '{:02x}'.format(mult(3, col[0]) ^ mult(1, col[1]) ^ mult(1, col[2]) ^ mult(2, col[3]))
    return new_col
def imix(col):
    # col = [a, b, c, d]
    new_col = col.copy()
    new_col[0] = '{:02x}'.format(mult(0x0E, col[0]) ^ mult(0x0B, col[1]) ^ mult(0x0D, col[2]) ^ mult(0x09, col[3]))
    new_col[1] = '{:02x}'.format(mult(0x09, col[0]) ^ mult(0x0E, col[1]) ^ mult(0x0B, col[2]) ^ mult(0x0D, col[3]))
    new_col[2] = '{:02x}'.format(mult(0x0D, col[0]) ^ mult(0x09, col[1]) ^ mult(0x0E, col[2]) ^ mult(0x0B, col[3]))
    new_col[3] = '{:02x}'.format(mult(0x0B, col[0]) ^ mult(0x0D, col[1]) ^ mult(0x09, col[2]) ^ mult(0x0E, col[3]))
    return new_col

In [296]:
def mix_columns(state):
    new_state = [[int(j, 16) for j in i] for i in state]
    for i, col in enumerate(new_state):
        new_state[i] = mix(col)
    return new_state

# IX. InvMicColumns

In [297]:
def inv_mix_columns(state):
    new_state = [[int(j, 16) for j in i] for i in state]
    for i, col in enumerate(new_state):
        new_state[i] = imix(col)
    return new_state

# X. AddRoundKey

In [298]:
def add_round_key(state, round_key):
    new_state = state.copy()
    for row in range(4):
        for col in range(4):
            new_state[row][col] = '{:02x}'.format(int(state[row][col], 16) ^ int(round_key[row][col], 16))
    return new_state

# XI. AES key expansion

In [299]:
get_rcon = [['00', '00', '00', '00'], 
            ['01', '00', '00', '00'], 
            ['02', '00', '00', '00'], 
            ['04', '00', '00', '00'], 
            ['08', '00', '00', '00'], 
            ['10', '00', '00', '00'], 
            ['20', '00', '00', '00'], 
            ['40', '00', '00', '00'], 
            ['80', '00', '00', '00'],
            ['1b', '00', '00', '00'],
            ['36', '00', '00', '00'],]

In [300]:
def rot_word(w):
    # performs a one-byte circular left shift on a word. This means that an
    # input word [B 0, B 1, B 2, B 3] is transformed into [B 1, B 2, B 3, B 0
    new_w = w.copy()
    new_w.append(new_w.pop(0))
    return new_w

In [301]:
def sub_word(w):
    # performs a byte substitution on each byte of its input word, using the S-box
    new_w = []
    for b in w:
        new_w.append(SBOX[int(b[0], 16)][int(b[1], 16)])
    return new_w

In [302]:
def word_xor(w1, w2):
	new_w = []
	for i in range(4):
		new_w.append('{:02x}'.format(int(w1[i], 16) ^ int(w2[i], 16)))
	return new_w

In [303]:
def key_expansion(key):
	# support for string keys
	if type(key) is str:
		key = [key[i:i+2] for i in range(0, len(key), 2)]
	w = [[]]*44
	for i in range(4):
		w[i] = [key[4*i], key[4*i+1], key[4*i+2], key[4*i+3]]
	for i in range(4, 44):
		tmp = w[i-1]
		if i % 4 == 0:
			tmp = word_xor(sub_word(rot_word(tmp)), get_rcon[int(i/4)]) 
		w[i] = word_xor(w[i-4], tmp)		
	return w

# XII. AES encryption & decryption

In [304]:
def encryption_round(state, rkey, expanded_keys):
    state = sub_byte(state)
    state = shift_row(state)
    state = mix_columns(state)
    state = add_round_key(state, expanded_keys[rkey:rkey+4])
    return state

In [305]:
def decryption_round(state, rkey, expanded_keys):
    state = inv_sub_byte(state)
    state = inv_shift_row(state)
    state = inv_mix_columns(state)
    state = add_round_key(state, inv_mix_columns(expanded_keys[rkey-4:rkey]))
    return state

In [306]:
def encrypt(plain_text, key, print=False, all_states=None):
    # create input state
    init_state = gen4x4Square(plain_text)
    expanded_keys = key_expansion(key)
    rkey = 0
    # add states to an array (optional)
    if not all_states is None:
        all_states.append(init_state)
    # check if we want to print states
    if print:
        print_state(init_state, 'input_state', rkey, expanded_keys)
    # initial transformation
    state = add_round_key(init_state, expanded_keys[rkey:rkey+4])
    if print:
        print_state(state, 'initial_trans_state', rkey, expanded_keys)
    rkey+=4
    if not all_states is None:
        all_states.append(state)
    # go for 9 rounds
    for i in range(9):
        state = encryption_round(state, rkey, expanded_keys)
        if not all_states is None:
            all_states.append(state)
        if print:
            print_state(state, i+1, rkey, expanded_keys)
        rkey+=4
    # last round = no mix columns
    state = add_round_key(shift_row(sub_byte(state)), expanded_keys[rkey:rkey+4])
    if print:
        print_state(state, 10)
    if not all_states is None:
        all_states.append(state)
    return gen16Bytes(state)

In [307]:
def decrypt(cypher_txt, key):
    # create input state
    init_state = gen4x4Square(cypher_txt)
    expanded_keys = key_expansion(key)
    rkey = 44
    # initial transformation
    state = add_round_key(init_state, expanded_keys[rkey-4:rkey])
    rkey-=4
    # go for 9 rounds
    for i in range(9):
        state = decryption_round(state, rkey, expanded_keys)
        rkey-=4
    # last round = no mix columns
    state = add_round_key(inv_shift_row(inv_sub_byte(state)), expanded_keys[rkey-4:rkey])
    return gen16Bytes(state)

# Testing Procedure

### 1. Key Expansion    

In [308]:
key = '0f1571c947d9e8590cb7add6af7f6798'
expanded_key = key_expansion(key)
for i, row in enumerate(expanded_key):
	print(f'w{i} = {row[0]} {row[1]} {row[2]} {row[3]}')

w0 = 0f 15 71 c9
w1 = 47 d9 e8 59
w2 = 0c b7 ad d6
w3 = af 7f 67 98
w4 = dc 90 37 b0
w5 = 9b 49 df e9
w6 = 97 fe 72 3f
w7 = 38 81 15 a7
w8 = d2 c9 6b b7
w9 = 49 80 b4 5e
w10 = de 7e c6 61
w11 = e6 ff d3 c6
w12 = c0 af df 39
w13 = 89 2f 6b 67
w14 = 57 51 ad 06
w15 = b1 ae 7e c0
w16 = 2c 5c 65 f1
w17 = a5 73 0e 96
w18 = f2 22 a3 90
w19 = 43 8c dd 50
w20 = 58 9d 36 eb
w21 = fd ee 38 7d
w22 = 0f cc 9b ed
w23 = 4c 40 46 bd
w24 = 71 c7 4c c2
w25 = 8c 29 74 bf
w26 = 83 e5 ef 52
w27 = cf a5 a9 ef
w28 = 37 14 93 48
w29 = bb 3d e7 f7
w30 = 38 d8 08 a5
w31 = f7 7d a1 4a
w32 = 48 26 45 20
w33 = f3 1b a2 d7
w34 = cb c3 aa 72
w35 = 3c be 0b 38
w36 = fd 0d 42 cb
w37 = 0e 16 e0 1c
w38 = c5 d5 4a 6e
w39 = f9 6b 41 56
w40 = b4 8e f3 52
w41 = ba 98 13 4e
w42 = 7f 4d 59 20
w43 = 86 26 18 76


### 2. AES Encryption & Decryption

### Encrypt

In [309]:
plaintxt = '0123456789abcdeffedcba9876543210'
key = '0f1571c947d9e8590cb7add6af7f6798'
cypher_txt = encrypt(plaintxt, key, True)

-----------------------------------
input_state
-----------------------------------
state:                 round key:
01 89 fe 76            0f 47 0c af
23 ab dc 54            15 d9 b7 7f
45 cd ba 32            71 e8 ad 67
67 ef 98 10            c9 59 d6 98

-----------------------------------
initial_trans_state
-----------------------------------
state:                 round key:
0e ce f2 d9            0f 47 0c af
36 72 6b 2b            15 d9 b7 7f
34 25 17 55            71 e8 ad 67
ae b6 4e 88            c9 59 d6 98

-----------------------------------
1
-----------------------------------
state:                 round key:
65 0f c0 4d            dc 9b 97 38
74 c7 e8 d0            90 49 fe 81
70 ff e8 2a            37 df 72 15
75 3f ca 9c            b0 e9 3f a7

-----------------------------------
2
-----------------------------------
state:                 round key:
5c 6b 05 f4            d2 49 de e6
7b 72 a2 6d            c9 80 7e ff
b4 34 31 12            6b b4 c6 d3
9a 9b 7f 94 

### Decrypt

In [310]:
decrypt(cypher_txt, key)

'0123456789abcdeffedcba9876543210'

### 3. The Avalanche Effect

In [311]:
original_key_states = []
modified_key_states = []
encrypt('0123456789abcdeffedcba9876543210', key, False, original_key_states)
encrypt('0023456789abcdeffedcba9876543210', key, False, modified_key_states)
for i in range(len(original_key_states)):
    print('------------------------------------------')
    print(f'original: {gen16Bytes(original_key_states[i])}')
    print(f'modified: {gen16Bytes(modified_key_states[i])}')

------------------------------------------
original: 0e3634aece7225b6f26b174ed92b5588
modified: 0f3634aece7225b6f26b174ed92b5588
------------------------------------------
original: 0e3634aece7225b6f26b174ed92b5588
modified: 0f3634aece7225b6f26b174ed92b5588
------------------------------------------
original: 657470750fc7ff3fc0e8e8ca4dd02a9c
modified: c4a9ad090fc7ff3fc0e8e8ca4dd02a9c
------------------------------------------
original: 5c7bb49a6b72349b05a2317ff46d1294
modified: fe2ae569f7ee8bb8c1f5a2bb37ef53d5
------------------------------------------
original: 7115262448dc747e5cdac7227da9bd9c
modified: ec093dfb7c45343d689017507d485e62
------------------------------------------
original: f867aee8b437a5210c24c1974cffeabc
modified: 43efdb697244df808e8d9364ee0ae6f5
------------------------------------------
original: 721eb200ba06206dcbd4bce704fa654e
modified: 7b28a5d5ed643287e006c099bb375302
------------------------------------------
original: 0ad9d85689f9f77bc1c5f71185e5fb14
modified: 3

In [312]:

original_key_states = []
modified_key_states = []
encrypt(plaintxt, '0f1571c947d9e8590cb7add6af7f6798', False, original_key_states)
encrypt(plaintxt, '0e1571c947d9e8590cb7add6af7f6798', False, modified_key_states)
for i in range(len(original_key_states)):
    print('------------------------------------------')
    print(f'original: {gen16Bytes(original_key_states[i])}')
    print(f'modified: {gen16Bytes(modified_key_states[i])}')

------------------------------------------
original: 0e3634aece7225b6f26b174ed92b5588
modified: 0f3634aece7225b6f26b174ed92b5588
------------------------------------------
original: 0e3634aece7225b6f26b174ed92b5588
modified: 0f3634aece7225b6f26b174ed92b5588
------------------------------------------
original: 657470750fc7ff3fc0e8e8ca4dd02a9c
modified: c5a9ad090ec7ff3fc1e8e8ca4cd02a9c
------------------------------------------
original: 5c7bb49a6b72349b05a2317ff46d1294
modified: 90905fa9563356d15f3760f3b8259985
------------------------------------------
original: 7115262448dc747e5cdac7227da9bd9c
modified: 18aeb7aa794b3b66629448d575c7cebf
------------------------------------------
original: f867aee8b437a5210c24c1974cffeabc
modified: f81015f993c978a876ae017cb49e7eec
------------------------------------------
original: 721eb200ba06206dcbd4bce704fa654e
modified: 5955c91b4e769f3cb4a94768e98d5267
------------------------------------------
original: 0ad9d85689f9f77bc1c5f71185e5fb14
modified: d