### AES Round Demonstration (AES-128 Simplified)

This code demonstrates a **single round of AES-128 encryption** in a simplified way.  
It shows the main AES operations on a single 16-byte block:

1. **State conversion**  
2. **AddRoundKey**  
3. **SubBytes**  
4. **ShiftRows**  
5. (MixColumns is omitted for simplicity)

---

#### 1. Convert plaintext and key to AES state matrices

```python
state = bytes_to_matrix(plaintext)
key_matrix = bytes_to_matrix(key)


* AES operates on a 4x4 state matrix of bytes (column-major order).

* bytes_to_matrix converts 16 bytes into a 4x4 matrix for internal AES operations.

* matrix_to_bytes converts the matrix back to 16 bytes after transformations.

In [1]:
# Import required libraries
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
import numpy as np
from os import urandom
import sympy as sp
from IPython.display import display
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
import os

In [2]:
# AES Round Demonstration (AES-128 Simplified)
# Shows SubBytes, ShiftRows, MixColumns, AddRoundKey for a single round

import os
from IPython.display import display, Math

# --- State conversion (AES uses column-major!) ---
def bytes_to_matrix(b):
    """Converts 16 bytes into a 4x4 AES state matrix (column-major)."""
    return [list(b[i::4]) for i in range(4)]

def matrix_to_bytes(m):
    """Converts a 4x4 AES state matrix back to 16 bytes."""
    return bytes(sum(zip(*m), ()))  # flatten column-major

# --- Pretty printing ---
def show_matrix(state, title="State"):
    rows = [" & ".join(f"{b:02X}" for b in row) for row in state]
    body = r"\\ ".join(rows)
    latex_str = rf"{title} = \begin{{bmatrix}} {body} \end{{bmatrix}}"
    display(Math(latex_str))

# --- AES S-Box ---
S_BOX = [
    # 0     1      2      3     4      5      6      7     8      9      A      B     C      D      E      F
    0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, # 0
    0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, # 1
    0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, # 2
    0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, # 3
    0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, # 4
    0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, # 5
    0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, # 6
    0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, # 7
    0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, # 8
    0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, # 9
    0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, # A
    0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, # B
    0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, # C
    0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, # D
    0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, # E
    0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16  # F
]

def sub_bytes(state):
    return [[S_BOX[b] for b in row] for row in state]

def shift_rows(state):
    return [
        state[0],
        state[1][1:] + state[1][:1],
        state[2][2:] + state[2][:2],
        state[3][3:] + state[3][:3]
    ]

def add_round_key(state, key_matrix):
    return [[b ^ k for b, k in zip(row, key_row)] for row, key_row in zip(state, key_matrix)]

# --- Demo with AES state ---
plaintext = b"HelloAESRound1!!"  # exactly 16 bytes
key = os.urandom(16)

state = bytes_to_matrix(plaintext)
key_matrix = bytes_to_matrix(key)

print("Initial State:")
show_matrix(state, "Initial State")

# AddRoundKey
state = add_round_key(state, key_matrix)
print("\nAfter AddRoundKey:")
show_matrix(state, "State")

# SubBytes
state = sub_bytes(state)
print("\nAfter SubBytes:")
show_matrix(state, "State")

# ShiftRows
state = shift_rows(state)
print("\nAfter ShiftRows:")
show_matrix(state, "State")

# Final (no MixColumns for now)
final_bytes = matrix_to_bytes(state)
print("\nFinal State as bytes:", final_bytes)
print("Final Hex:", final_bytes.hex())


Initial State:


<IPython.core.display.Math object>


After AddRoundKey:


<IPython.core.display.Math object>


After SubBytes:


<IPython.core.display.Math object>


After ShiftRows:


<IPython.core.display.Math object>


Final State as bytes: b'Qj%\xa4\xbd+\xbde\xf0;\\\xeb\xd8\xfbQ.'
Final Hex: 516a25a4bd2bbd65f03b5cebd8fb512e
