# SYMMETRIC CRYPTOGRAPHY

## OVERVIEW

- Are algorithms that use the same key for encryption and decryption.

- **Goal**: Use short keys to securely and efficiently send long messages.

- **Example**: (AES) Advanced Encryption Standard

- **Types**: Block ciphers and Stream Ciphers

### 1. Block Ciphers
- Break up a plaintext into fixed-length blocks and send each block through an encryption function together with a secret key. AES is an example of such a cipher.

### 2. Stream Ciphers
- Enrypt one byte of plaintext at a time, by XORing a pseudo-random keystream with the data.




## KEYED PERMUTATIONS

- AES performs a keyed permutation, meaning that it maps every possible input block to a unique output block with a key determining which permutation to perform.

- Using the same key, the permutation can be performed in reverse thus mapping output to input block.

- **NOTE: There should be a 1-1 correspoondence between input and output blocks i.e reversible mapping.**

- The mathematical term for a one-to-one correspondence is 
<span style='color:yellow'> bijection </span>

- Conversly, a function can be injective or surjective as shown below:

![image.png](attachment:image.png)

1. <span style='color:yellow'> Injective </span>: we won't have 2 or more As pointing to the same B. Many to One and One to Many is _not okay_, but we can have a B without a matching A. Is said to be **One to One**

2. <span style='color:yellow'> Surjective </span>: every B has __at least__ one matching A (maybe more than one), no B is left out. Can be **Many to One**

3. <span style='color:yellow'> Bijective </span>: means both injective and surjective as in a perfect pairing between the sets, each has a partner and no one is left out i.e uniqueness. Is called **One to One Correspondence**

## RESISTING BRUTE FORCE

- If a block cipher is secure there should be no way for an attacker to distinguish the output of the AES from a random permutation of bits and there should be no better way to undo the permutation than to simply bruteforce every possible key.

- A cipher is considered **broken** if there is an attack that takes less steps than bruteforce.

- The best single key attack against AES is called 
<span style='color:yellow'> **biclique attacks** </span> 
which exploits weaknesses that occur when keys are relatedin a specific way.

## STRUCTURE OF AES

- At a high level, *AES-128* begins with a _key schedule_ and then runs _10 rounds_ over a state.

- The starting state is just the plaintext block we want to encrypt, represented as a 4 x 4 matrix of bytes which is the passed through 10 rounds of invertible transformations.

### Phases of AES Encryption
![image.png](attachment:image.png)

[Video Explanation of the AES Implementation](https://youtu.be/gP4PqVGudtg)

#### 1. <span style='color:yellow'> Key Expansion/Schedule </span>
From the 128 bit key, 11 separate 128 bit _round keys_ are derived where one is to be used in each _AddRoundKey Step_

#### 2. <span style='color:yellow'> Initial Key Addition </span>
_AddRoundKey Step_ - the bytes of the 1st round key are XORed with the bytes of the state

#### 3. <span style='color:yellow'> Round </span>
This phase is looped 10 times, for 9 main rounds and 1 final round.

1. SubBytes - each byte of the state is substituted for a different byte according to the _S-box table_

2. ShiftRows - the last three rows of the state matrix are transposed/ shifted over the row number i.e row 0 by 0, row 1 by 1,...

3. MixColumns - matrix multiplication is performed on the columns of the state, combining the 4 bytes in each column. _Skipped in final round_

4. AddRoundKey - the bytes of the current round key are XORed with the bytes of the state


 Write a matrix2bytes function to turn that matrix back into bytes, and submit the resulting plaintext as the flag.

In [16]:
matrix = [
    [99, 114, 121, 112],
    [116, 111, 123, 105]
]

OneD = [n for sub in matrix for n in sub]
OneD

[99, 114, 121, 112, 116, 111, 123, 105]

In [17]:
import numpy as np
matrix = np.array([
    [99, 114, 121, 112],
    [116, 111, 123, 105]
])
OneD_flatten = matrix.flatten()
OneD_flatten

array([ 99, 114, 121, 112, 116, 111, 123, 105])

In [18]:
def bytes2matrix(text):
    """ Converts a 16-byte array into a 4x4 matrix.  """
    return [list(text[i:i+4]) for i in range(0, len(text), 4)]

def matrix2bytes(matrix):
    """ Converts a 4x4 matrix into a 16-byte array.  """
    return ''.join([chr(element) for sub in matrix for element in sub])
    


matrix = [
    [99, 114, 121, 112],
    [116, 111, 123, 105],
    [110, 109, 97, 116],
    [114, 105, 120, 125],
]

print(matrix2bytes(matrix))


crypto{inmatrix}


### Round Keys

- The general idea is that it takes in our 16 byte key and produces 11 4x4 matrices called round keys derived from our initial key. These allow AES to get extra mileage out of the single key provided.

- The _AddRoundKey Step_ XORs the current state with the current round key.

![image.png](attachment:image.png)

- This step also occurs as the final step of each round. This factor that makes it a "keyed permutation" is what makes it such an effective cipher as it jumbles the plaintext.

Complete the `add_round_key` function, then use the `matrix2bytes` function to get your next flag.


In [19]:
state = [
    [206, 243, 61, 34],
    [171, 11, 93, 31],
    [16, 200, 91, 108],
    [150, 3, 194, 51],
]

round_key = [
    [173, 129, 68, 82],
    [223, 100, 38, 109],
    [32, 189, 53, 8],
    [253, 48, 187, 78],
]
    
def add_round_key(s, k):
    return [[ s[i][j] ^ k[i][j] for j in range(4)] for i in range(4)]

print(matrix2bytes(add_round_key(state, round_key)))



crypto{r0undk3y}


### Confusion Through Substitution

The first step of each AES round is _SubBytes_ which involves taking each byte of the state matrix and subtituting it for a different byte in a preset 16x16 lookup table called <span style='color:yellow'> Substitution Box or S-Box</span>. It is implemented to ensure the **Confusion Property**: relationship between the ciphertext and key should be as complex as possible.

**Purpose of the S-Box**: transform the input in a way that is resistant to being approximated by linear or algebraic methods i.e high non-linearity 

![image.png](attachment:image.png)

The simplest way to express the function is trhough the following high degree polynomial:

![image-2.png](attachment:image-2.png)



Implement sub_bytes, send the state matrix through the inverse S-box and then convert it to bytes to get the flag.

In [20]:
s_box = (
    0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
    0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
    0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15,
    0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75,
    0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84,
    0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF,
    0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8,
    0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2,
    0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73,
    0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB,
    0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79,
    0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08,
    0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A,
    0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E,
    0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
    0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16,
)

inv_s_box = (
    0x52, 0x09, 0x6A, 0xD5, 0x30, 0x36, 0xA5, 0x38, 0xBF, 0x40, 0xA3, 0x9E, 0x81, 0xF3, 0xD7, 0xFB,
    0x7C, 0xE3, 0x39, 0x82, 0x9B, 0x2F, 0xFF, 0x87, 0x34, 0x8E, 0x43, 0x44, 0xC4, 0xDE, 0xE9, 0xCB,
    0x54, 0x7B, 0x94, 0x32, 0xA6, 0xC2, 0x23, 0x3D, 0xEE, 0x4C, 0x95, 0x0B, 0x42, 0xFA, 0xC3, 0x4E,
    0x08, 0x2E, 0xA1, 0x66, 0x28, 0xD9, 0x24, 0xB2, 0x76, 0x5B, 0xA2, 0x49, 0x6D, 0x8B, 0xD1, 0x25,
    0x72, 0xF8, 0xF6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xD4, 0xA4, 0x5C, 0xCC, 0x5D, 0x65, 0xB6, 0x92,
    0x6C, 0x70, 0x48, 0x50, 0xFD, 0xED, 0xB9, 0xDA, 0x5E, 0x15, 0x46, 0x57, 0xA7, 0x8D, 0x9D, 0x84,
    0x90, 0xD8, 0xAB, 0x00, 0x8C, 0xBC, 0xD3, 0x0A, 0xF7, 0xE4, 0x58, 0x05, 0xB8, 0xB3, 0x45, 0x06,
    0xD0, 0x2C, 0x1E, 0x8F, 0xCA, 0x3F, 0x0F, 0x02, 0xC1, 0xAF, 0xBD, 0x03, 0x01, 0x13, 0x8A, 0x6B,
    0x3A, 0x91, 0x11, 0x41, 0x4F, 0x67, 0xDC, 0xEA, 0x97, 0xF2, 0xCF, 0xCE, 0xF0, 0xB4, 0xE6, 0x73,
    0x96, 0xAC, 0x74, 0x22, 0xE7, 0xAD, 0x35, 0x85, 0xE2, 0xF9, 0x37, 0xE8, 0x1C, 0x75, 0xDF, 0x6E,
    0x47, 0xF1, 0x1A, 0x71, 0x1D, 0x29, 0xC5, 0x89, 0x6F, 0xB7, 0x62, 0x0E, 0xAA, 0x18, 0xBE, 0x1B,
    0xFC, 0x56, 0x3E, 0x4B, 0xC6, 0xD2, 0x79, 0x20, 0x9A, 0xDB, 0xC0, 0xFE, 0x78, 0xCD, 0x5A, 0xF4,
    0x1F, 0xDD, 0xA8, 0x33, 0x88, 0x07, 0xC7, 0x31, 0xB1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xEC, 0x5F,
    0x60, 0x51, 0x7F, 0xA9, 0x19, 0xB5, 0x4A, 0x0D, 0x2D, 0xE5, 0x7A, 0x9F, 0x93, 0xC9, 0x9C, 0xEF,
    0xA0, 0xE0, 0x3B, 0x4D, 0xAE, 0x2A, 0xF5, 0xB0, 0xC8, 0xEB, 0xBB, 0x3C, 0x83, 0x53, 0x99, 0x61,
    0x17, 0x2B, 0x04, 0x7E, 0xBA, 0x77, 0xD6, 0x26, 0xE1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0C, 0x7D,
)

state = [
    [251, 64, 182, 81],
    [146, 168, 33, 80],
    [199, 159, 195, 24],
    [64, 80, 182, 255],
]

#  Iterate over each element in the state matrix.
# Replace each element with its corresponding value in the inv_s_box.
# Convert the result into bytes to obtain the flag.

def sub_bytes(s, sbox=s_box):
    for i in range(4):
        for j in range(4): 
            print(chr(sbox[s[i][j]]), end="")

print(sub_bytes(state, sbox=inv_s_box))



crypto{l1n34rly}None


### Diffusion through Permutation

- This discusses the significance of **Diffusion Property** in symmetric cryptography. The interplay of these operations helps create a more complex and secure cipher by spreading the influence of individual bytes throughout the entire state.

- We need to alternate substitutions by scrambling the state in an invertible way so that substitutions applied on one byte influence all other bytes in the state.

- The _ShiftRows_ and _MixColumns_ steps combine to achieve this.

1. <span style='color:yellow'> ShiftRows</span>
- The importance of this step is to avoid the columns being encrypted independently, in which case AES degenerates into four independent block ciphers.

![image.png](attachment:image.png)

2. <span style='color:yellow'> MixColumns</span>
- This step performs matrix multiplication in Rijndaels's Galios field between the columns of the state matrix and a preset matrix. Each single byte of each column thus affects all the bytes of the resulting column.

![image-2.png](attachment:image-2.png)

#### AES Galios Field
- In Rijndael's Galios field, all mathematical operations result in an 8-bit number thus addition, subtraction, multiplication and division are redefined.
- Addition and subtraction are performed by the exclusive or (XOR) operation.
- Multiplication is a bit more complicated:
    1. Take 2 8-bit numbers and an 8-bit product p.
    2. Set p to 0
    3. Make a copy of `a` and `b`
    4. Run the following loop 8 times
        4.1 If the low bit of `b` is set, `p` XOR `a`
        4.2 Keep track of whether the MSB of `a` is set to 1
        4.3 Rotate `a` one bit to the left, discarding the MSB, making the LSB have a value of zero
        4.4 If `a`'s MSB had a value of 1 prior to this rotation, `a` XOR `0x1b`
        4.5 Rotate `b` one bit to the right, discarding the LSB, making the MSB have a value of zero
    5. The product `p` now has the the product of `a` and `b`

    - Example:
    - `p` = 0, `a` = 7, `b` = 3
    Iteration 1
    1. LSB of `b` is 1, `p` = `p` XOR `a` = 7
    2. `a` is rotated one bit to the left, `a` = 14
    3. `a` MSB != 1 thus no XOR with `0x1b`
    4. `b` is rotated one bit to the right, `b` = 1

    Iteration 2
    1. LSB of `b` is 1, `p` = `p` XOR `a` = 9
    2. `a` is rotated one bit to the left, `a` = 28
    3. `a` MSB != 1 thus no XOR with `0x1b`
    4. `b` is rotated one bit to the right, `b` = 0

    Repeat this 6 more times, in this case no more changes occur as b is 0 thus product won't change again.

    Sample Code:

    `unsigned char gmul(unsigned char a, unsigned char b) {
	unsigned char p = 0;
	unsigned char counter;
	unsigned char hi_bit_set;
        for(counter = 0; counter < 8; counter++) {
            if((b & 1) == 1) 
                p ^= a;
            hi_bit_set = (a & 0x80);
            a <<= 1;
            if(hi_bit_set == 0x80) 
                a ^= 0x1b;		
            b >>= 1;
        }
        return p;
    }` 

![Matrix Representation of MixColumns](attachment:image-3.png)






We've provided code to perform MixColumns and the forward ShiftRows operation. After implementing `inv_shift_rows`, take the state, run `inv_mix_columns` on it, then `inv_shift_rows`, convert to bytes and you will have your flag.

In [21]:
def shift_rows(s):
    s[0][1], s[1][1], s[2][1], s[3][1] = s[1][1], s[2][1], s[3][1], s[0][1]
    s[0][2], s[1][2], s[2][2], s[3][2] = s[2][2], s[3][2], s[0][2], s[1][2]
    s[0][3], s[1][3], s[2][3], s[3][3] = s[3][3], s[0][3], s[1][3], s[2][3]


def inv_shift_rows(s):
    s[0][1], s[1][1], s[2][1], s[3][1] = s[3][1], s[0][1], s[1][1], s[2][1]
    s[0][2], s[1][2], s[2][2], s[3][2] = s[2][2], s[3][2], s[0][2], s[1][2]
    s[0][3], s[1][3], s[2][3], s[3][3] = s[1][3], s[2][3], s[3][3], s[0][3]


# learned from http://cs.ucsb.edu/~koc/cs178/projects/JT/aes.c
xtime = lambda a: (((a << 1) ^ 0x1B) & 0xFF) if (a & 0x80) else (a << 1)


def mix_single_column(a):
    # see Sec 4.1.2 in The Design of Rijndael
    t = a[0] ^ a[1] ^ a[2] ^ a[3]
    u = a[0]
    a[0] ^= t ^ xtime(a[0] ^ a[1])
    a[1] ^= t ^ xtime(a[1] ^ a[2])
    a[2] ^= t ^ xtime(a[2] ^ a[3])
    a[3] ^= t ^ xtime(a[3] ^ u)


def mix_columns(s):
    for i in range(4):
        mix_single_column(s[i])


def inv_mix_columns(s):
    # see Sec 4.1.3 in The Design of Rijndael
    for i in range(4):
        u = xtime(xtime(s[i][0] ^ s[i][2]))
        v = xtime(xtime(s[i][1] ^ s[i][3]))
        s[i][0] ^= u
        s[i][1] ^= v
        s[i][2] ^= u
        s[i][3] ^= v

    mix_columns(s)


state = [
    [108, 106, 71, 86],
    [96, 62, 38, 72],
    [42, 184, 92, 209],
    [94, 79, 8, 54],
]

print(state)

inv_mix_columns(state)
print(state)

inv_shift_rows(state)
print(state)

print(matrix2bytes(state))


[[108, 106, 71, 86], [96, 62, 38, 72], [42, 184, 92, 209], [94, 79, 8, 54]]
[[99, 111, 102, 125], [116, 102, 82, 112], [49, 51, 121, 100], [115, 114, 123, 85]]
[[99, 114, 121, 112], [116, 111, 123, 100], [49, 102, 102, 85], [115, 51, 82, 125]]
crypto{d1ffUs3R}


### Bringing It All Together

As we have implemented all the steps from _SubBytes_ (to provide confusion) and _ShiftRows_ and _MixColumns_ (to provide diffusion) sjowing jow these properties work togethre to repeatedly circulate non-linear transformations over the state.

_AddRoundKey_ seeds the key into the network making it a keyed permutation cipher.

Decryption involves performing the steps in reverse while applying the inverse operations. The _KeyExpansion_ still needs to be run first, and the round keys used in reverse order.

We've provided the key expansion code, and ciphertext that's been properly encrypted by AES-128. Copy in all the building blocks you've coded so far, and complete the `decrypt` function that implements the steps shown in the diagram. The decrypted plaintext is the flag

**Explanation of Decryption Steps**

Key Expansion: The key is expanded using expand_key, generating a list of round keys.

Initial Round: The add_round_key operation is performed with the last round key.

Inverse Rounds: For each round:

    1. Inverse ShiftRows: Reverses the row shifting operation.

    2. Inverse SubBytes: Reverses the byte substitution using the inverse S-box.

    3. AddRoundKey: XORs the state with the corresponding round key.

    4. Inverse MixColumns: Reverses the column mixing operation (skipped in the last round).
    
Final Round: The last round only involves InverseShiftRows, InverseSubBytes, and AddRoundKey (no InverseMixColumns).

In [22]:
N_ROUNDS = 10

def expand_key(master_key):
    """
    Expands and returns a list of key matrices for the given master_key.
    """

    # Round constants https://en.wikipedia.org/wiki/AES_key_schedule#Round_constants
    r_con = (
        0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40,
        0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D, 0x9A,
        0x2F, 0x5E, 0xBC, 0x63, 0xC6, 0x97, 0x35, 0x6A,
        0xD4, 0xB3, 0x7D, 0xFA, 0xEF, 0xC5, 0x91, 0x39,
    )

    # Initialize round keys with raw key material.
    key_columns = bytes2matrix(master_key)
    iteration_size = len(master_key) // 4

    # Each iteration has exactly as many columns as the key material.
    i = 1
    while len(key_columns) < (N_ROUNDS + 1) * 4:
        # Copy previous word.
        word = list(key_columns[-1])

        # Perform schedule_core once every "row".
        if len(key_columns) % iteration_size == 0:
            # Circular shift.
            word.append(word.pop(0))
            # Map to S-BOX.
            word = [s_box[b] for b in word]
            # XOR with first byte of R-CON, since the others bytes of R-CON are 0.
            word[0] ^= r_con[i]
            i += 1
        elif len(master_key) == 32 and len(key_columns) % iteration_size == 4:
            # Run word through S-box in the fourth iteration when using a
            # 256-bit key.
            word = [s_box[b] for b in word]

        # XOR with equivalent word from previous iteration.
        word = bytes(i^j for i, j in zip(word, key_columns[-iteration_size]))
        key_columns.append(word)

    # Group key words in 4x4 byte matrices.
    expanded_keys = [key_columns[4*i : 4*(i+1)] for i in range(len(key_columns) // 4)]
    # print(f"Expanded keys: {expanded_keys}") 

    return expanded_keys

def add_round_key(s, k):
    return [[ s[i][j] ^ k[i][j] for j in range(4)] for i in range(4)]

def sub_bytes(s, s_box):
    return [[s_box[byte] for byte in row] for row in s]

def mix_single_column(a):
    # see Sec 4.1.2 in The Design of Rijndael
    t = a[0] ^ a[1] ^ a[2] ^ a[3]
    u = a[0]
    a[0] ^= t ^ xtime(a[0] ^ a[1])
    a[1] ^= t ^ xtime(a[1] ^ a[2])
    a[2] ^= t ^ xtime(a[2] ^ a[3])
    a[3] ^= t ^ xtime(a[3] ^ u)


def mix_columns(s):
    for i in range(4):
        mix_single_column(s[i])


def inv_mix_columns(s):
    # see Sec 4.1.3 in The Design of Rijndael
    for i in range(4):
        u = xtime(xtime(s[i][0] ^ s[i][2]))
        v = xtime(xtime(s[i][1] ^ s[i][3]))
        s[i][0] ^= u
        s[i][1] ^= v
        s[i][2] ^= u
        s[i][3] ^= v

    mix_columns(s)

def encrypt(key, plaintext):
    round_keys = expand_key(key)

    pt_matrix = bytes2matrix(plaintext)
    pt_matrix = [[ord(pt_matrix[i][j]) for i in range(4)] for j in range(4)]
    # print(f"Initial matrix: {pt_matrix}")

    pt_matrix = add_round_key(pt_matrix, round_keys[0])
    # print(f"After initial AddRoundKey: {pt_matrix}")


    for i in range(0, N_ROUNDS - 1):
        pt_matrix = sub_bytes(pt_matrix, s_box)
        shift_rows(pt_matrix)
        mix_columns(pt_matrix)
        pt_matrix = add_round_key(pt_matrix, round_keys[i])

    pt_matrix = sub_bytes(pt_matrix, s_box)
    shift_rows(pt_matrix)
    pt_matrix = add_round_key(pt_matrix, round_keys[10])

    ciphertext = matrix2bytes(pt_matrix)

    return ciphertext
    

def decrypt(key, ciphertext):
    round_keys = expand_key(key) # Remember to start from the last round key and work backwards through them when decrypting
    # print(f"Round keys: {round_keys}")

    # Convert ciphertext to state matrix
    c_matrix = bytes2matrix(ciphertext)
    # print(f"Initial matrix: {c_matrix}")

    # Initial add round key step
    c_matrix = add_round_key(c_matrix, round_keys[N_ROUNDS])
    # print(f"After initial AddRoundKey: {c_matrix}")

    for i in range(N_ROUNDS - 1, 0, -1):
        inv_shift_rows(c_matrix)
        # print(f"After InvShiftRows: {c_matrix}")

        c_matrix = sub_bytes(c_matrix, inv_s_box)
        # print(f"After InvSubBytes: {c_matrix}")

        c_matrix = add_round_key(c_matrix, round_keys[i])
        # print(f"After AddRoundKey: {c_matrix}")

        inv_mix_columns(c_matrix)
        # print(f"After InvMixColumns: {c_matrix}")

    # Run final round (skips the InvMixColumns step)
    inv_shift_rows(c_matrix)
    c_matrix = sub_bytes(c_matrix, inv_s_box)
    c_matrix = add_round_key(c_matrix, round_keys[0])

    # Convert state matrix to plaintext
    plaintext = matrix2bytes(c_matrix)

    return plaintext


key        = b'\xc3,\\\xa6\xb5\x80^\x0c\xdb\x8d\xa5z*\xb6\xfe\\'
ciphertext = b'\xd1O\x14j\xa4+O\xb6\xa1\xc4\x08B)\x8f\x12\xdd'
plaintext = decrypt(key, ciphertext)

print(decrypt(key, ciphertext))
print(encrypt(key, plaintext))

crypto{MYAES128}
D-2Î~Ø)B*yÙ0
