# AES

In [125]:
from typing import List
from scripts.utils import split_words, join_words

|Name|Tamaño de llave|Tamaño de bloque|Rondas|
|------|------------|-------------|------|
|AES-128| 128 | 128 | 10 |
|AES-192| 192 | 128 | 12 |
|AES-256| 256 | 128 | 14 |

Entrada: un bloque de 128 bits como una matrix 4 x 4 (cada entrada es un byte)

In [126]:
s_box: List[int] = [
    0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
    0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
    0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15,
    0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75,
    0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84,
    0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF,
    0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8,
    0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2,
    0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73,
    0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB,
    0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79,
    0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08,
    0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A,
    0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E,
    0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
    0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16
]

s_box_inv: List[int] = [
    0x52, 0x09, 0x6A, 0xD5, 0x30, 0x36, 0xA5, 0x38, 0xBF, 0x40, 0xA3, 0x9E, 0x81, 0xF3, 0xD7, 0xFB,
    0x7C, 0xE3, 0x39, 0x82, 0x9B, 0x2F, 0xFF, 0x87, 0x34, 0x8E, 0x43, 0x44, 0xC4, 0xDE, 0xE9, 0xCB,
    0x54, 0x7B, 0x94, 0x32, 0xA6, 0xC2, 0x23, 0x3D, 0xEE, 0x4C, 0x95, 0x0B, 0x42, 0xFA, 0xC3, 0x4E,
    0x08, 0x2E, 0xA1, 0x66, 0x28, 0xD9, 0x24, 0xB2, 0x76, 0x5B, 0xA2, 0x49, 0x6D, 0x8B, 0xD1, 0x25,
    0x72, 0xF8, 0xF6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xD4, 0xA4, 0x5C, 0xCC, 0x5D, 0x65, 0xB6, 0x92,
    0x6C, 0x70, 0x48, 0x50, 0xFD, 0xED, 0xB9, 0xDA, 0x5E, 0x15, 0x46, 0x57, 0xA7, 0x8D, 0x9D, 0x84,
    0x90, 0xD8, 0xAB, 0x00, 0x8C, 0xBC, 0xD3, 0x0A, 0xF7, 0xE4, 0x58, 0x05, 0xB8, 0xB3, 0x45, 0x06,
    0xD0, 0x2C, 0x1E, 0x8F, 0xCA, 0x3F, 0x0F, 0x02, 0xC1, 0xAF, 0xBD, 0x03, 0x01, 0x13, 0x8A, 0x6B,
    0x3A, 0x91, 0x11, 0x41, 0x4F, 0x67, 0xDC, 0xEA, 0x97, 0xF2, 0xCF, 0xCE, 0xF0, 0xB4, 0xE6, 0x73,
    0x96, 0xAC, 0x74, 0x22, 0xE7, 0xAD, 0x35, 0x85, 0xE2, 0xF9, 0x37, 0xE8, 0x1C, 0x75, 0xDF, 0x6E,
    0x47, 0xF1, 0x1A, 0x71, 0x1D, 0x29, 0xC5, 0x89, 0x6F, 0xB7, 0x62, 0x0E, 0xAA, 0x18, 0xBE, 0x1B,
    0xFC, 0x56, 0x3E, 0x4B, 0xC6, 0xD2, 0x79, 0x20, 0x9A, 0xDB, 0xC0, 0xFE, 0x78, 0xCD, 0x5A, 0xF4,
    0x1F, 0xDD, 0xA8, 0x33, 0x88, 0x07, 0xC7, 0x31, 0xB1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xEC, 0x5F,
    0x60, 0x51, 0x7F, 0xA9, 0x19, 0xB5, 0x4A, 0x0D, 0x2D, 0xE5, 0x7A, 0x9F, 0x93, 0xC9, 0x9C, 0xEF,
    0xA0, 0xE0, 0x3B, 0x4D, 0xAE, 0x2A, 0xF5, 0xB0, 0xC8, 0xEB, 0xBB, 0x3C, 0x83, 0x53, 0x99, 0x61,
    0x17, 0x2B, 0x04, 0x7E, 0xBA, 0x77, 0xD6, 0x26, 0xE1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0C, 0x7D
]

In [127]:
A: List[int] = [
    [0x00, 0x01, 0x02, 0x03],
    [0x04, 0x05, 0x06, 0x07],
    [0x08, 0x09, 0x0A, 0x0B],
    [0x0C, 0x0D, 0x0E, 0x0F],
]

## Byte Substitution

Una función invertible que repesenta la parte no linea de AES

$$
ByteSub(A_i) + ByteSub(A_j) \neq ByteSub(A_i + A_j)
$$

Utiliza la tablita de busqueda s-box

In [128]:
def byte_sub(state: List[List[int]]) -> List[List[int]]: 
    return [ [s_box[x] for x in row] for row in state ] 

assert(byte_sub([[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]]) == [[124, 119, 123, 242], [107, 111, 197, 48], [1, 103, 43, 254], [215, 171, 118, 202]])

In [129]:
def byte_sub_inv(state: List[List[int]]) -> List[List[int]]: 
    return [[s_box_inv[x] for x in row] for row in state]  

assert(byte_sub_inv([[124, 119, 123, 242], [107, 111, 197, 48], [1, 103, 43, 254], [215, 171, 118, 202]]) == [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])

## Shift rows

Transformación lineal de las matrix

In [130]:
def shift(seq: List[any], n: int) -> List[any]:
    n = n % len(seq)
    return seq[n:] + seq[:n]

In [131]:
def shift_rows(B: List[List[int]]) -> List[List[int]]: 
    return [
        shift(B[0], 0),
        shift(B[1], 1),
        shift(B[2], 2),
        shift(B[3], 3),
    ]


assert(shift_rows([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]) == [[1, 2, 3, 4], [6, 7, 8, 5], [11, 12, 9, 10], [16, 13, 14, 15]])

In [132]:
def shift_rows_inv(B: List[List[int]]) -> List[List[int]]: 
    return [
        shift(B[0], 0),
        shift(B[1], -1),
        shift(B[2], -2),
        shift(B[3], -3),
    ]

assert(shift_rows_inv([[1, 2, 3, 4], [6, 7, 8, 5], [11, 12, 9, 10], [16, 13, 14, 15]]) == [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])

## Mix Collumn

Transformación lineal usando matematicas del campos finitos $F_2^8$ con el polinomio $x^8 + x^4 + x^3 + x + 1$

In [133]:
M: List[List[int]]= [
    [0x02, 0x03, 0x01, 0x01],
    [0x01, 0x02, 0x03, 0x01],
    [0x01, 0x01, 0x02, 0x03],
    [0x03, 0x01, 0x01, 0x02],
]

In [134]:
M_inv: List[List[int]] = [
    [0x0e, 0x0b, 0x0d, 0x09],
    [0x09, 0x0e, 0x0b, 0x0d],
    [0x0d, 0x09, 0x0e, 0x0b],
    [0x0b, 0x0d, 0x09, 0x0e],
]

### Using finite fields math

In [135]:
from galois import GF, Poly

f2_8 = GF(2, 8) 
G = Poly([1, 0, 0, 0, 1, 1, 0, 1, 1], field=f2_8) # x^8 + x^4 + x^3 + x + 1

In [136]:
def poly_to_int(poly: Poly) -> int:
    return int("".join(map(str, poly.coefficients().tolist())), 2)

assert(poly_to_int(G) == 0b100011011)

In [137]:
def int_to_poly(integer: int) -> Poly:
    return Poly([int(x) for x in bin(integer)[2:]], field=f2_8)

assert(int_to_poly(0b100011011) == G)

In [138]:
def int_matrix_to_poly(matrix: List[List[int]]) -> List[List[Poly]]:
    return [[int_to_poly(col) for col in row] for row in matrix] 

def poly_matrix_to_int(matrix: List[List[Poly]]) -> List[List[int]]:
    return [[poly_to_int(col) for col in row] for row in matrix] 

In [139]:
M_poly: List[List[Poly]] = int_matrix_to_poly(M)

def mix_collumns(state: List[List[int]]) -> List[List[int]]: 
    C: List[List[Poly]]= int_matrix_to_poly(
        [[0 for _ in range(4)] for _ in range(4)])
    
    _state: List[List[Poly]] = int_matrix_to_poly(state)

     
    for i in range(4):
        for j in range(4):
            for k in range(4):
                C[i][j] = (C[i][j] + M_poly[i][k] * _state[k][j]) % G
    
    return poly_matrix_to_int(C)

In [140]:
assert(mix_collumns([
    [0x33, 0x3b, 0x61, 0x50],
    [0x1c, 0x4f, 0xea, 0xaf],
    [0x38, 0xb7, 0x21, 0xa9],
    [0xc9, 0x4d, 0x44, 0x2d]
]) == [
    [0xb3, 0x5d, 0x82, 0xce],
    [0x8a, 0x2a, 0x89, 0xd8],
    [0x1f, 0xd6, 0x5,  0xc1],
    [0xf8, 0x2f, 0xe0, 0xac],
])

In [141]:
M_inv_poly: List[List[Poly]] = int_matrix_to_poly(M_inv)

def mix_collumns_inv(state: List[List[int]]) -> List[List[int]]: 
    C: List[List[Poly]]= int_matrix_to_poly(
        [[0 for _ in range(4)] for _ in range(4)])
    
    _state: List[List[Poly]] = int_matrix_to_poly(state)

     
    for i in range(4):
        for j in range(4):
            for k in range(4):
                C[i][j] = (C[i][j] + M_inv_poly[i][k] * _state[k][j]) % G
    
    return poly_matrix_to_int(C)

In [142]:
assert(mix_collumns_inv([
    [0xb3, 0x5d, 0x82, 0xce],
    [0x8a, 0x2a, 0x89, 0xd8],
    [0x1f, 0xd6, 0x5,  0xc1],
    [0xf8, 0x2f, 0xe0, 0xac],
]) == [
    [0x33, 0x3b, 0x61, 0x50],
    [0x1c, 0x4f, 0xea, 0xaf],
    [0x38, 0xb7, 0x21, 0xa9],
    [0xc9, 0x4d, 0x44, 0x2d]
])

## Key addition

Entrada: Matriz estado 4x4 (cadena de 128bits) y llave de ronda (128 bits)

Salida: $C \oplus k_i$

In [143]:
def key_addition(state: List[List[int]], k: List[List[int]]) -> List[List[int]]: 

    assert(len(state) == len(k)), "State and k must have same size"
    
    _state = state.copy()

    for row1, row2 in zip(_state, k):
        assert(len(row1) == len(row2)), "Rows must have same size"

        for i in range(len(row1)):
            row1[i] ^= row2[i]
    
    return _state

In [144]:
assert(key_addition([
    [0x4d, 0x61, 0x73, 0x65],
    [0x65, 0x6a, 0x65, 0x74],
    [0x6e, 0x65, 0x63, 0x6f],
    [0x73, 0x20, 0x72, 0x2e],
], [
    [0x2b, 0x28, 0xab, 0x9],
    [0x7e, 0xae, 0xf7, 0xcf],
    [0x15, 0xd2, 0x15, 0x4f],
    [0x16, 0xa6, 0x88, 0x3c]
]) == [
    [0x66, 0x49, 0xd8, 0x6c],
    [0x1b, 0xc4, 0x92, 0xbb],
    [0x7b, 0xb7, 0x76, 0x20],
    [0x65, 0x86, 0xfa, 0x12],
])

## Key expansion

Input: Una llave (4 palabras)

Outout: llaves de ronda

In [145]:
RC = [
    0b00000001,
    0b00000010,
    0b00000100,
    0b00001000,
    0b00010000,
    0b00100000,
    0b01000000,
    0b10000000,
    0b00011011,
    0b00110110
]

### Función G

![G funciton](./img/AES_g_function.png)

In [146]:
def g(w: int, i: int) -> int: 
    """
    Args:
        w (int): 16 bit palabra
        i (int): número de ronda
         
    Returns:
        int: 16 bit palabra
    """

    v = split_words(w, 1, 4)

    v = [v[1], v[2], v[3], v[0]]
    
    # byte subtituion 
    v = [s_box[vi] for vi in v]

    
    # XOR
    v[0] ^= RC[i]
    
    return join_words(v, 1)

assert(g(0x1A38B5EE, 5) == 0x27d528A2)
assert(g(0x09CF4F3C, 0) == 0x8B84EB01)

### Expansion key AES-128

- Rondas: 10
- Llaves: 11 (de 128 bits -> 4 palabras de 32bits)
- Total palabras: 44 (de 32 bits)
- Iteraciones: 10

![expansion key 128](./img/expansion_key128.png)

In [147]:
def expansion_key128(k: List[int]):   
    assert(len(k) == 4)
    
    w = [0]*(44)

    w[0] = k[0]
    w[1] = k[1]
    w[2] = k[2]
    w[3] = k[3]
    
    for i in range(1, 10 + 1):
        w[4*i] = w[4 * (i - 1)] ^ g(w[4 * i - 1], i - 1)
        w[4*i + 1] = w[4 * (i - 1) + 1] ^ w[4 * i]
        w[4*i + 2] = w[4 * (i - 1) + 2] ^ w[4 * i + 1]
        w[4*i + 3] = w[4 * (i - 1) + 3] ^ w[4 * i + 2]    
    return w


In [148]:
k0 = split_words(0x2b7e151628aed2a6abf7158809cf4f3c, 4, 4)
w = expansion_key128(k0)
rks = [join_words(w[i:i+4], 4) for i in range(0, len(w), 4)]

assert(rks[1] == 0xa0fafe1788542cb123a339392a6c7605), "Key 1 is wrong"
assert(rks[2] == 0xf2c295f27a96b9435935807a7359f67f), "Key 2 is wrong"
assert(rks[3] == 0x3d80477d4716fe3e1e237e446d7a883b), "Key 3 is wrong"
assert(rks[4] == 0xef44a541a8525b7fb671253bdb0bad00), "Key 4 is wrong"
assert(rks[5] == 0xd4d1c6f87c839d87caf2b8bc11f915bc), "Key 5 is wrong"
assert(rks[6] == 0x6d88a37a110b3efddbf98641ca0093fd), "Key 6 is wrong"
assert(rks[7] == 0x4e54f70e5f5fc9f384a64fb24ea6dc4f), "Key 7 is wrong"
assert(rks[8] == 0xead27321b58dbad2312bf5607f8d292f), "Key 8 is wrong"
assert(rks[9] == 0xac7766f319fadc2128d12941575c006e), "Key 9 is wrong"
assert(rks[10] ==0xd014f9a8c9ee2589e13f0cc8b6630ca6), "Key 10 is wrong"


In [149]:
k0 = "hello"
k0 = [ord(x) for x in k0]

# Complete with '0'
if (len(k0) < 16):
    k0 += [ord("0")] * (16 - len(k0))

k0 = [join_words(k0[i:i+4], 1) for i in range(0, len(k0), 4)]

w = expansion_key128(k0)
rks = [join_words(w[i:i+4], 4) for i in range(0, len(w), 4)]

assert(rks[0] == 0x68656c6c6f3030303030303030303030), "Key 0 is wrong"
assert(rks[1] == 0x6d616868025158583261686802515858), "Key 1 is wrong"
assert(rks[2] == 0xbe0b021fbc5a5a478e3b322f8c6a6a77), "Key 2 is wrong"
assert(rks[3] == 0xb809f77b0453ad3c8a689f130602f564), "Key 3 is wrong"
assert(rks[4] == 0xc7efb414c3bc192849d4863b4fd6735f), "Key 4 is wrong"
assert(rks[5] == 0x21607b90e2dc62b8ab08e483e4de97dc), "Key 5 is wrong"
assert(rks[6] == 0x1ce8fdf9fe349f41553c7bc2b1e2ec1e), "Key 6 is wrong"
assert(rks[7] == 0xc4268f313a1210706f2e6bb2decc87ac), "Key 7 is wrong"
assert(rks[8] == 0x0f311e2c35230e5c5a0d65ee84c1e242), "Key 8 is wrong"
assert(rks[9] == 0x6ca93273598a3c2f038759c18746bb83), "Key 9 is wrong"
assert(rks[10] == 0x0043de6459c9e24b5a4ebb8add080009), "Key 10 is wrong"

### Expansion key AES-192

- rondas: 12
- llaves: 13 (de 128 bits -> 4 palabras of 32bits)
- Total palabras: 52 (de 32 bits)
- Iteraciones: 8

![expansion key 192](./img/expansion_key192.png)

In [150]:
def expansion_key192(k: List[int]):   
    w = [0]* 52

    w[0] = k[0]
    w[1] = k[1]
    w[2] = k[2]
    w[3] = k[3]
    w[4] = k[4]
    w[5] = k[5]

    for i in range(1, 8 + 1):    
        w[6*i] = w[6 * (i - 1)] ^ g(w[6 * i - 1], i - 1)
        w[6*i + 1] = w[6 * (i - 1) + 1] ^ w[6 * i]
        w[6*i + 2] = w[6 * (i - 1) + 2] ^ w[6 * i + 1]
        w[6*i + 3] = w[6 * (i - 1) + 3] ^ w[6 * i + 2]

        if (i == 8): break # pass last two

        w[6*i + 4] = w[6 * (i - 1) + 4] ^ w[6 * i + 3]
        w[6*i + 5] = w[6 * (i - 1) + 5] ^ w[6 * i + 4]

    return w

In [151]:
k = split_words(0x8e73b0f7da0e6452c810f32b809079e562f8ead2522c6b7b, 4, 6)
w = expansion_key192(k)
rks = [join_words(w[i:i+6], 4) for i in range(0, len(w), 6)]


assert(rks[1] == 0xfe0c91f72402f5a5ec12068e6c827f6b0e7a95b95c56fec2), "Key 1 is wrong"
assert(rks[2] == 0x4db7b4bd69b5411885a74796e92538fde75fad44bb095386), "Key 2 is wrong"
assert(rks[3] == 0x485af05721efb14fa448f6d94d6dce24aa326360113b30e6), "Key 3 is wrong"
assert(rks[4] == 0xa25e7ed583b1cf9a27f939436a94f767c0a69407d19da4e1), "Key 4 is wrong"
assert(rks[5] == 0xec1786eb6fa64971485f703222cb8755e26d135233f0b7b3), "Key 5 is wrong"
assert(rks[6] == 0x40beeb282f18a2596747d26b458c553ea7e1466c9411f1df), "Key 6 is wrong"
assert(rks[7] == 0x821f750aad07d753ca4005388fcc5006282d166abc3ce7b5), "Key 7 is wrong"
assert(rks[8] == 0xe98ba06f448c773c8ecc720401002202), "Key 8 is wrong"

### Expansion key AES-256

- rondas: 14
- llaves: 15 (de 128 bits -> 4 palabras of 32bits)
- palabras: 60 (de 32 bits)
- iteraciones: 7

![expansion key 256](./img/expansion_key256.png)

In [152]:
def h(w: int) -> int: 
    """h

    Args:
        w (int): palabra 16 bit
        i (int): numero de roneda

    Returns:
        int: palabra 16 bit
    """

    v = split_words(w, 1, 4)
    
    # byte subtituion 
    v = [s_box[vi] for vi in v]
    
    return join_words(v, 1)

In [153]:
def expansion_key256(k: List[int]):   
    w = [0]* 60

    w[0] = k[0]
    w[1] = k[1]
    w[2] = k[2]
    w[3] = k[3]
    w[4] = k[4]
    w[5] = k[5]
    w[6] = k[6]
    w[7] = k[7]

    for i in range(1, 7 + 1):    
        w[8*i] = w[8 * (i - 1)] ^ g(w[8 * i - 1], i - 1)
        w[8*i + 1] = w[8 * (i - 1) + 1] ^ w[8 * i]
        w[8*i + 2] = w[8 * (i - 1) + 2] ^ w[8 * i + 1]
        w[8*i + 3] = w[8 * (i - 1) + 3] ^ w[8 * i + 2]

        if (i == 7): break # skip last two

        w[8*i + 4] = w[8 * (i - 1) + 4] ^ h(w[8 * i + 3])
        w[8*i + 5] = w[8 * (i - 1) + 5] ^ w[8 * i + 4]
        w[8*i + 6] = w[8 * (i - 1) + 6] ^ w[8 * i + 5]
        w[8*i + 7] = w[8 * (i - 1) + 7] ^ w[8 * i + 6]

    return w

In [154]:
k = split_words(0x603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4, 4, 8)
w = expansion_key256(k)
rks = [join_words(w[i:i+8], 4) for i in range(0, len(w), 8)]

assert(rks[1] == 0x9ba354118e6925afa51a8b5f2067fcdea8b09c1a93d194cdbe49846eb75d5b9a), "Key 1 is wrong"
assert(rks[2] == 0xd59aecb85bf3c917fee94248de8ebe96b5a9328a2678a647983122292f6c79b3), "Key 2 is wrong"
assert(rks[3] == 0x812c81addadf48ba24360af2fab8b46498c5bfc9bebd198e268c3ba709e04214), "Key 3 is wrong"
assert(rks[4] == 0x68007bacb2df331696e939e46c518d80c814e20476a9fb8a5025c02d59c58239), "Key 4 is wrong"
assert(rks[5] == 0xde1369676ccc5a71fa2563959674ee155886ca5d2e2f31d77e0af1fa27cf73c3), "Key 5 is wrong"
assert(rks[6] == 0x749c47ab18501ddae2757e4f7401905acafaaae3e4d59b349adf6acebd10190d), "Key 6 is wrong"
assert(rks[7] == 0xfe4890d1e6188d0b046df344706c631e), "Key 7 is wrong"

## Cifrado y decifrado

In [155]:
def print_hex_matrix(m):
    print(*[[hex(x) for x in row] for row in m], sep="\n")

In [156]:
def to_matrix(k: int):
    k = split_words(k, 1, 4*4)
    return [
        [k[0], k[4], k[8], k[12]],
        [k[1], k[5], k[9], k[13]],
        [k[2], k[6], k[10], k[14]],
        [k[3], k[7], k[11], k[15]]
    ]    

assert(to_matrix(0x2b7e151628aed2a6abf7158809cf4f3c) == [[0x2b, 0x28, 0xab, 0x09],
[0x7e, 0xae, 0xf7, 0xcf],
[0x15, 0xd2, 0x15, 0x4f],
[0x16, 0xa6, 0x88, 0x3c]])

In [157]:
def AES_encrypt(state, k):
    assert(len(bin(k)) <= 256), "Key must be 256 bits or less"

    if (len(bin(k)) <= 128):
        w = expansion_key128(split_words(k, 4, 4))
        rounds = 10
    elif (len(bin(k)) <= 192):
        w = expansion_key192(split_words(k, 4, 6))
        rounds = 12
    else:
        w = expansion_key256(split_words(k, 4, 8))
        rounds = 15

    keys = [join_words(w[i:i+4], 4) for i in range(0, len(w), 4)]
    
    s = to_matrix(state)
    k = to_matrix(keys[0])
    s = key_addition(s, k)

    for i in range(1, rounds + 1):
        s = byte_sub(s)
        
        s = shift_rows(s)
        

        if (i != rounds):
            s = mix_collumns(s)        

        k = to_matrix(keys[i])
        s = key_addition(s, k)
    
    return s


assert(AES_encrypt(0x3243f6a8885a308d313198a2e0370734, 0x2b7e151628aed2a6abf7158809cf4f3c) == [
    [0x39, 0x02, 0xdc, 0x19],
    [0x25, 0xdc, 0x11, 0x6a],
    [0x84, 0x09, 0x85, 0x0b],
    [0x1d, 0xfb, 0x97, 0x32]])

In [158]:
def AES_decrypt(state, k):
    assert(len(bin(k)) <= 256), "Key must be 256 bits or less"

    if (len(bin(k)) <= 128):
        w = expansion_key128(split_words(k, 4, 4))
        rounds = 10
    elif (len(bin(k)) <= 192):
        w = expansion_key192(split_words(k, 4, 6))
        rounds = 12
    else:
        w = expansion_key256(split_words(k, 4, 8))
        rounds = 15

    keys = [join_words(w[i:i+4], 4) for i in range(0, len(w), 4)]
    
    s = to_matrix(state)

    for i in range(rounds, 0, -1):
        k = to_matrix(keys[i])
        s = key_addition(s, k)

        if (i != rounds):
            s = mix_collumns_inv(s)        

        s = shift_rows_inv(s)
        s = byte_sub_inv(s)
    
    
    k = to_matrix(keys[0])
    s = key_addition(s, k)
    return s

assert(AES_decrypt(0x3925841d02dc09fbdc118597196a0b32, 0x2b7e151628aed2a6abf7158809cf4f3c) == [[0x32, 0x88, 0x31, 0xe0],
[0x43, 0x5a, 0x31, 0x37],
[0xf6, 0x30, 0x98, 0x07],
[0xa8, 0x8d, 0xa2, 0x34]])

## Referencias

- Notas de clase
- [Advanced Encryption Standard (AES) - Proceso de cifrado (encrypt)](https://www.teoria.com/jra/aes/encrypt.html)
- [Lecture 8: AES: The Advanced Encryption Standard (PDF file)](https://engineering.purdue.edu/kak/compsec/NewLectures/Lecture8.pdf)
- [FIPS PUB 197: the official AES standard (PDF file)](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.197.pdf)