# Extended Euclidean Algorithm 

This notebook demonstrates the Extended Euclidean Algorithm with verbose, step-by-step prints.
It is self-contained and intended to show how the multiplicative inverse of `b` modulo `m` is computed,
including the intermediate `q`, `t1/t2/t3` values and final verification.

The implementation is suitable for integers (used in classic modular arithmetic) and can be used to find inverses
needed for AES S-box construction (e.g. arithmetic in GF(2^8) using the AES modulus 0x11B).

## Contract
- Inputs: `m` (modulus), `b` (value to invert mod `m`). Both integers.
- Output: `b_inv` such that `(b * b_inv) % m == 1` if inverse exists; otherwise an exception is raised.
- Error modes: raises `ValueError` if no inverse exists (i.e. gcd(m,b) != 1).

We print each algorithm step to make the computation transparent.

In [2]:
def extended_euclidean_verbose(m, b):
    """Extended Euclidean Algorithm (verbose).
    Returns multiplicative inverse of b modulo m if it exists, printing full step trace.
    """
    a1, a2, a3 = 1, 0, m
    b1, b2, b3 = 0, 1, b
    q = 0
    # Header for the trace table
    print(f"{'q':>3} | {'a1':>6} | {'a2':>6} | {'a3':>10} | {'b1':>6} | {'b2':>6} | {'b3':>10}")
    print('-'*62)
    # Loop until remainder is 0 or 1
    while b3 != 0 and b3 != 1:
        print(f"{q:>3} | {a1:6} | {a2:6} | {a3:10} | {b1:6} | {b2:6} | {b3:10}")
        q = a3 // b3
        t1 = a1 - q * b1
        t2 = a2 - q * b2
        t3 = a3 - q * b3
        print(f"    q = {q}  ->  t1 = {t1}, t2 = {t2}, t3 = {t3}")
        # rotate rows (like the tabular extended-euclid method)
        a1, a2, a3 = b1, b2, b3
        b1, b2, b3 = t1, t2, t3
    # Final state
    print('-'*62)
    print(f"Final: a1={a1}, a2={a2}, a3={a3}  |  b1={b1}, b2={b2}, b3={b3}")
    if b3 == 0:
        # gcd != 1 -> no inverse
        raise ValueError(f"No multiplicative inverse for {b} mod {m} (GCD != 1)")
    # b3 == 1 -> b2 is the coefficient so that b*b2 + m*(something) = 1
    inverse = b2 % m
    print(f"Multiplicative inverse of {b} modulo {m} is: {inverse} (0x{inverse:02X})")
    return inverse


def validate_multiplicative_inverse(b, b_inv, m):
    product = (b * b_inv) % m
    print(f"Validation: ({b} * {b_inv}) % {m} = {product}")
    return product == 1


In [3]:
# Demonstrations: run the verbose algorithm on a few examples
# 1) Example from the repository: m = 1759, b = 550 (expected inverse 355)
m = 1759
b = 550
print("Example 1: m=1759, b=550")
inv1 = extended_euclidean_verbose(m, b)
print('Expected (from earlier notes): 355 ->', inv1, '\n')

# 2) AES-related small example: modulus m = 0x11B (AES polynomial as integer), try b = 0xC2
m = 0x11B  # AES irreducible polynomial represented as integer (decimal 283)
b = 0xC2
print("Example 2: AES field modulus m=0x11B (283), b=0xC2 (194)")
try:
    inv2 = extended_euclidean_verbose(m, b)
    ok = validate_multiplicative_inverse(b, inv2, m)
    print('Validation result:', ok)
except ValueError as e:
    print('Error:', e)

# 3) Small co-prime check: m = 26, b = 7 -> inverse should be 15 (because 7*15 = 105 = 1 mod 26)
print()
m = 26
b = 7
inv3 = extended_euclidean_verbose(m, b)
print('Check:', validate_multiplicative_inverse(b, inv3, m))


Example 1: m=1759, b=550
  q |     a1 |     a2 |         a3 |     b1 |     b2 |         b3
--------------------------------------------------------------
  0 |      1 |      0 |       1759 |      0 |      1 |        550
    q = 3  ->  t1 = 1, t2 = -3, t3 = 109
  3 |      0 |      1 |        550 |      1 |     -3 |        109
    q = 5  ->  t1 = -5, t2 = 16, t3 = 5
  5 |      1 |     -3 |        109 |     -5 |     16 |          5
    q = 21  ->  t1 = 106, t2 = -339, t3 = 4
 21 |     -5 |     16 |          5 |    106 |   -339 |          4
    q = 1  ->  t1 = -111, t2 = 355, t3 = 1
--------------------------------------------------------------
Final: a1=106, a2=-339, a3=4  |  b1=-111, b2=355, b3=1
Multiplicative inverse of 550 modulo 1759 is: 355 (0x163)
Expected (from earlier notes): 355 -> 355 

Example 2: AES field modulus m=0x11B (283), b=0xC2 (194)
  q |     a1 |     a2 |         a3 |     b1 |     b2 |         b3
--------------------------------------------------------------
  0 |   

## S-box construction

How is the S-Box constructed?The S-Box is constructed in the following fashion:
1. Initialize the S-box with the byte values in ascending order row
by row: first row contains ሼ00ሽ, ሼ01ሽ, ሼ02ሽ, ...,ሼ0Fሽ, second row
contains ሼ10ሽ, ሼ11ሽ, ሼ12ሽ,...,ሼ1Fሽ, etc.
2. Map each nonzero byte in the S-box to its multiplicative
inverse in GFሺ28ሻ, ሼ00ሽ is mapped to itself.
3. Each byte in the S-box is a sequence of 8-bits ሺb7, b6,...,b1, b0ሻ.Apply the following transformation over GFሺ2ሻ to each bit of each
byte in the S-box:bi’ ൌ bi ⊕ bሺi൅4ሻ mod 8 ⊕ bሺi൅5ሻ mod 8 ⊕ bሺi൅6ሻ mod 8 ⊕ bሺi൅7ሻ mod 8 ⊕ ciwhere ci is the ith bit of ሼ63ሽ: ሺc7, c6, c5, c4, c3, c2, c1, c0ሻൌሺ01100011ሻ

In [4]:
# Build AES S-box (multiplicative inverse in GF(2^8) + AES affine transform)

def gf_mul(a, b):
    """Galois Field (2^8) multiplication used by AES (irreducible poly x^8 + x^4 + x^3 + x + 1).
    This is the usual bytewise implementation using 0x1B reduction on carry.
    """
    p = 0
    for _ in range(8):
        if b & 1:
            p ^= a
        carry = a & 0x80
        a = (a << 1) & 0xFF
        if carry:
            a ^= 0x1B
        b >>= 1
    return p & 0xFF


def gf_pow(a, power):
    """Exponentiation in GF(2^8) using repeated squaring.
    For multiplicative inverse use power = 254 (since a^(2^8-2) = a^-1 for a != 0).
    """
    result = 1
    base = a
    while power > 0:
        if power & 1:
            result = gf_mul(result, base)
        base = gf_mul(base, base)
        power >>= 1
    return result & 0xFF


def affine_transform(byte):
    """AES affine transform applied to a byte (after multiplicative inverse).
    Uses the standard bitwise definition with constant 0x63.
    """
    c = 0x63
    out = 0
    for i in range(8):
        # XOR of the five bits: b_i ^ b_{i+4} ^ b_{i+5} ^ b_{i+6} ^ b_{i+7} (indices mod 8)
        bit = ((byte >> i) & 1) ^ ((byte >> ((i + 4) % 8)) & 1) ^ ((byte >> ((i + 5) % 8)) & 1) ^ ((byte >> ((i + 6) % 8)) & 1) ^ ((byte >> ((i + 7) % 8)) & 1) ^ ((c >> i) & 1)
        out |= (bit << i)
    return out & 0xFF


# Generate S-box
sbox_generated = [0] * 256
for x in range(256):
    if x == 0:
        inv = 0
    else:
        inv = gf_pow(x, 254)  # multiplicative inverse in GF(2^8)
    sbox_generated[x] = affine_transform(inv)

# Reference AES S-box (standard)
AES_SBOX = [
    0x63,0x7c,0x77,0x7b,0xf2,0x6b,0x6f,0xc5,0x30,0x01,0x67,0x2b,0xfe,0xd7,0xab,0x76,
    0xca,0x82,0xc9,0x7d,0xfa,0x59,0x47,0xf0,0xad,0xd4,0xa2,0xaf,0x9c,0xa4,0x72,0xc0,
    0xb7,0xfd,0x93,0x26,0x36,0x3f,0xf7,0xcc,0x34,0xa5,0xe5,0xf1,0x71,0xd8,0x31,0x15,
    0x04,0xc7,0x23,0xc3,0x18,0x96,0x05,0x9a,0x07,0x12,0x80,0xe2,0xeb,0x27,0xb2,0x75,
    0x09,0x83,0x2c,0x1a,0x1b,0x6e,0x5a,0xa0,0x52,0x3b,0xd6,0xb3,0x29,0xe3,0x2f,0x84,
    0x53,0xd1,0x00,0xed,0x20,0xfc,0xb1,0x5b,0x6a,0xcb,0xbe,0x39,0x4a,0x4c,0x58,0xcf,
    0xd0,0xef,0xaa,0xfb,0x43,0x4d,0x33,0x85,0x45,0xf9,0x02,0x7f,0x50,0x3c,0x9f,0xa8,
    0x51,0xa3,0x40,0x8f,0x92,0x9d,0x38,0xf5,0xbc,0xb6,0xda,0x21,0x10,0xff,0xf3,0xd2,
    0xcd,0x0c,0x13,0xec,0x5f,0x97,0x44,0x17,0xc4,0xa7,0x7e,0x3d,0x64,0x5d,0x19,0x73,
    0x60,0x81,0x4f,0xdc,0x22,0x2a,0x90,0x88,0x46,0xee,0xb8,0x14,0xde,0x5e,0x0b,0xdb,
    0xe0,0x32,0x3a,0x0a,0x49,0x06,0x24,0x5c,0xc2,0xd3,0xac,0x62,0x91,0x95,0xe4,0x79,
    0xe7,0xc8,0x37,0x6d,0x8d,0xd5,0x4e,0xa9,0x6c,0x56,0xf4,0xea,0x65,0x7a,0xae,0x08,
    0xba,0x78,0x25,0x2e,0x1c,0xa6,0xb4,0xc6,0xe8,0xdd,0x74,0x1f,0x4b,0xbd,0x8b,0x8a,
    0x70,0x3e,0xb5,0x66,0x48,0x03,0xf6,0x0e,0x61,0x35,0x57,0xb9,0x86,0xc1,0x1d,0x9e,
    0xe1,0xf8,0x98,0x11,0x69,0xd9,0x8e,0x94,0x9b,0x1e,0x87,0xe9,0xce,0x55,0x28,0xdf,
    0x8c,0xa1,0x89,0x0d,0xbf,0xe6,0x42,0x68,0x41,0x99,0x2d,0x0f,0xb0,0x54,0xbb,0x16
]

# Compare
match = sbox_generated == AES_SBOX
print('S-box match with reference:', match)
if not match:
    print('\nMismatches (index, generated, reference):')
    for i, (g, r) in enumerate(zip(sbox_generated, AES_SBOX)):
        if g != r:
            print(f'{i:02X}: 0x{g:02X} != 0x{r:02X}')
else:
    # print compact summary of first 16 bytes
    print('\nFirst 16 bytes of generated S-box:')
    print(' '.join(f"{x:02X}" for x in sbox_generated[:16]))

# Build inverse S-box and sanity-check
inv_sbox = [0]*256
for i, val in enumerate(sbox_generated):
    inv_sbox[val] = i
# Validate by applying sbox then inv_sbox maps back a few sample values
samples = [0x00, 0x53, 0xA7, 0xFF]
print('\nRound-trip checks:')
for s in samples:
    r = inv_sbox[sbox_generated[s]]
    print(f'0x{s:02X} -> Sbox 0x{sbox_generated[s]:02X} -> Inv 0x{r:02X}')


S-box match with reference: True

First 16 bytes of generated S-box:
63 7C 77 7B F2 6B 6F C5 30 01 67 2B FE D7 AB 76

Round-trip checks:
0x00 -> Sbox 0x63 -> Inv 0x00
0x53 -> Sbox 0xED -> Inv 0x53
0xA7 -> Sbox 0x5C -> Inv 0xA7
0xFF -> Sbox 0x16 -> Inv 0xFF


In [6]:
# AES core step implementations: SubBytes, ShiftRows, MixColumns, AddRoundKey

def print_state(state, title=None):
    if title:
        print(title)
    for r in range(4):
        print(' '.join(f"{state[r][c]:02X}" for c in range(4)))
    print()


def sub_bytes(state, sbox):
    """Apply S-box substitution to every byte of the state.
    State is a 4x4 list-of-lists (rows x columns), values 0..255.
    Returns a new state (doesn't mutate input).
    """
    return [[sbox[state[r][c]] for c in range(4)] for r in range(4)]


def inv_sub_bytes(state, inv_sbox):
    return [[inv_sbox[state[r][c]] for c in range(4)] for r in range(4)]


def shift_rows(state):
    """Shift rows to the left by row index (row 0 no shift).
    Returns new state.
    """
    out = [list(row) for row in state]
    for r in range(1, 4):
        out[r] = state[r][r:] + state[r][:r]
    return out


def inv_shift_rows(state):
    out = [list(row) for row in state]
    for r in range(1, 4):
        out[r] = state[r][-r:] + state[r][:-r]
    return out


def mix_single_column(col):
    """Mix one column (4 bytes) using AES forward matrix.
    `col` is a list of 4 bytes [s0, s1, s2, s3].
    """
    s0, s1, s2, s3 = col
    r0 = (gf_mul(0x02, s0) ^ gf_mul(0x03, s1) ^ s2 ^ s3) & 0xFF
    r1 = (s0 ^ gf_mul(0x02, s1) ^ gf_mul(0x03, s2) ^ s3) & 0xFF
    r2 = (s0 ^ s1 ^ gf_mul(0x02, s2) ^ gf_mul(0x03, s3)) & 0xFF
    r3 = (gf_mul(0x03, s0) ^ s1 ^ s2 ^ gf_mul(0x02, s3)) & 0xFF
    return [r0, r1, r2, r3]


def mix_columns(state):
    """Apply MixColumns to the state (in-place style: returns new state)."""
    out = [[0]*4 for _ in range(4)]
    for c in range(4):
        col = [state[r][c] for r in range(4)]
        mixed = mix_single_column(col)
        for r in range(4):
            out[r][c] = mixed[r]
    return out


def inv_mix_single_column(col):
    s0, s1, s2, s3 = col
    r0 = (gf_mul(0x0E, s0) ^ gf_mul(0x0B, s1) ^ gf_mul(0x0D, s2) ^ gf_mul(0x09, s3)) & 0xFF
    r1 = (gf_mul(0x09, s0) ^ gf_mul(0x0E, s1) ^ gf_mul(0x0B, s2) ^ gf_mul(0x0D, s3)) & 0xFF
    r2 = (gf_mul(0x0D, s0) ^ gf_mul(0x09, s1) ^ gf_mul(0x0E, s2) ^ gf_mul(0x0B, s3)) & 0xFF
    r3 = (gf_mul(0x0B, s0) ^ gf_mul(0x0D, s1) ^ gf_mul(0x09, s2) ^ gf_mul(0x0E, s3)) & 0xFF
    return [r0, r1, r2, r3]


def inv_mix_columns(state):
    out = [[0]*4 for _ in range(4)]
    for c in range(4):
        col = [state[r][c] for r in range(4)]
        mixed = inv_mix_single_column(col)
        for r in range(4):
            out[r][c] = mixed[r]
    return out


def add_round_key(state, round_key):
    """XOR the state with the round key (both 4x4). Returns new state."""
    return [[state[r][c] ^ round_key[r][c] for c in range(4)] for r in range(4)]


# Quick demonstration using the functions and the previously generated AES_SBOX
# Ensure AES_SBOX exists in the notebook; it was defined earlier as a list.
try:
    sbox_ref = AES_SBOX
except NameError:
    # Fallback: if AES_SBOX not present, use sbox_generated
    sbox_ref = sbox_generated

# Build inverse sbox for inv_sub_bytes
inv_sbox_ref = [0]*256
for i, v in enumerate(sbox_ref):
    inv_sbox_ref[v] = i

# Example state (4 rows x 4 cols)
sample_state = [
    [0x32, 0x88, 0x31, 0xE0],
    [0x43, 0x5A, 0x31, 0x37],
    [0xF6, 0x30, 0x98, 0x07],
    [0xA8, 0x8D, 0xA2, 0x34]
]

print_state(sample_state, 'Initial state:')
# SubBytes
s1 = sub_bytes(sample_state, sbox_ref)
print_state(s1, 'After SubBytes:')
# ShiftRows
s2 = shift_rows(s1)
print_state(s2, 'After ShiftRows:')
# MixColumns
s3 = mix_columns(s2)
print_state(s3, 'After MixColumns:')
# AddRoundKey (use zero key for demo)
zero_key = [[0]*4 for _ in range(4)]
s4 = add_round_key(s3, zero_key)
print_state(s4, 'After AddRoundKey (zero key):')

# Inverse operations to check correctness
s_inv_mix = inv_mix_columns(s3)
print_state(s_inv_mix, 'InvMixColumns(s3) -> should be state after ShiftRows (s2):')
print('Matches ShiftRows result:', s_inv_mix == s2)

s_inv_shift = inv_shift_rows(s2)
print_state(s_inv_shift, 'InvShiftRows(s2) -> should equal SubBytes result (s1):')
print('Matches SubBytes result:', s_inv_shift == s1)

s_inv_sub = inv_sub_bytes(s1, inv_sbox_ref)
print_state(s_inv_sub, 'InvSubBytes(s1) -> should equal initial state:')
print('Matches initial state:', s_inv_sub == sample_state)


Initial state:
32 88 31 E0
43 5A 31 37
F6 30 98 07
A8 8D A2 34

After SubBytes:
23 C4 C7 E1
1A BE C7 9A
42 04 46 C5
C2 5D 3A 18

After ShiftRows:
23 C4 C7 E1
BE C7 9A 1A
46 C5 42 04
18 C2 5D 3A

After MixColumns:
C1 C6 3F C9
96 C7 73 E3
39 CF 3E BD
AD CA 30 52

After AddRoundKey (zero key):
C1 C6 3F C9
96 C7 73 E3
39 CF 3E BD
AD CA 30 52

InvMixColumns(s3) -> should be state after ShiftRows (s2):
23 C4 C7 E1
BE C7 9A 1A
46 C5 42 04
18 C2 5D 3A

Matches ShiftRows result: True
InvShiftRows(s2) -> should equal SubBytes result (s1):
23 C4 C7 E1
1A BE C7 9A
42 04 46 C5
C2 5D 3A 18

Matches SubBytes result: True
InvSubBytes(s1) -> should equal initial state:
32 88 31 E0
43 5A 31 37
F6 30 98 07
A8 8D A2 34

Matches initial state: True


## AES keys expansion





In [8]:
# AES Key Expansion (Key schedule) — RotWord, SubWord, Rcon, key_expansion

# Uses `sbox_generated` or `AES_SBOX` and `gf_mul` defined earlier in this notebook.

def rot_word(word):
    """Rotate a word (4-byte list) left by one byte."""
    return [word[1], word[2], word[3], word[0]]


def sub_word(word, sbox):
    """Apply S-box to each byte of the 4-byte word."""
    return [sbox[b] for b in word]


def compute_rcon(n):
    """Compute Rcon list up to index n (1-based). Rcon[1] = 0x01.
    Returns list where rcon[i] is the byte for iteration i (i starting at 1).
    """
    rcon = [0]* (n+1)
    rcon[1] = 0x01
    for i in range(2, n+1):
        rcon[i] = gf_mul(rcon[i-1], 0x02)
    return rcon


def key_expansion(key_bytes):
    """Expand a cipher key (bytes) into the AES key schedule.
    Accepts key_bytes as an iterable of 16/24/32 bytes (AES-128/192/256).
    Returns list of 4-byte words (each a list of 4 ints) forming the expanded key.
    """
    if isinstance(key_bytes, bytes):
        key = list(key_bytes)
    else:
        key = list(key_bytes)

    Nk = len(key) // 4  # number of 32-bit words in key
    if Nk not in (4, 6, 8):
        raise ValueError('Key must be 16, 24 or 32 bytes long')

    # Number of rounds
    Nr = {4:10, 6:12, 8:14}[Nk]
    Nb = 4
    n_words = Nb * (Nr + 1)

    # Fill initial words (first Nk words) from the key (word = 4 bytes)
    w = []
    for i in range(Nk):
        word = key[4*i:4*i+4]
        w.append(word)

    # Compute Rcon up to needed count
    rcon = compute_rcon((n_words // Nk) + 1)

    # Expand
    for i in range(Nk, n_words):
        temp = w[i-1].copy()
        if i % Nk == 0:
            temp = sub_word(rot_word(temp), sbox_generated if 'sbox_generated' in globals() else (AES_SBOX if 'AES_SBOX' in globals() else sbox_ref))
            r = rcon[i//Nk]
            temp[0] ^= r
        elif Nk > 6 and i % Nk == 4:
            # Extra SubWord for AES-256
            temp = sub_word(temp, sbox_generated if 'sbox_generated' in globals() else (AES_SBOX if 'AES_SBOX' in globals() else sbox_ref))
        # w[i] = w[i-Nk] XOR temp
        prev = w[i-Nk]
        new_word = [prev[j] ^ temp[j] for j in range(4)]
        w.append(new_word)

    return w


def words_to_round_key_matrix(words, round_index):
    """Produce a 4x4 round key matrix (rows x cols) for given round_index.
    Each round uses 4 words: words[4*r : 4*r+4]; words are word[col] = [b0,b1,b2,b3] with b0->row0.
    Returns matrix as list of 4 rows each containing 4 bytes.
    """
    start = round_index * 4
    round_words = words[start:start+4]
    # matrix[row][col] = round_words[col][row]
    matrix = [[round_words[col][row] for col in range(4)] for row in range(4)]
    return matrix


# Demo: expand the AES test key from FIPS-197
# Key: 2b7e151628aed2a6abf7158809cf4f3c (AES-128 example)
sample_key_hex = '2b7e151628aed2a6abf7158809cf4f3c'
sample_key = bytes.fromhex(sample_key_hex)

expanded = key_expansion(sample_key)

# Print round keys in hex, one per line
Nk = len(sample_key)//4
Nr = {4:10, 6:12, 8:14}[Nk]
print(f'Expanded key words: {len(expanded)} words (should be {4*(Nr+1)})')
for r in range(Nr+1):
    mat = words_to_round_key_matrix(expanded, r)
    flat = [mat[row][col] for col in range(4) for row in range(4)]  # column-major order used in AES
    print(f'Round {r:02d} key:', ' '.join(f'{b:02X}' for b in flat))

# Quick check: first round key should match FIPS-197 example: after round 0 it's the original key
print('\nRound 0 equals original key?', flat[:16] == list(sample_key))


Expanded key words: 44 words (should be 44)
Round 00 key: 2B 7E 15 16 28 AE D2 A6 AB F7 15 88 09 CF 4F 3C
Round 01 key: A0 FA FE 17 88 54 2C B1 23 A3 39 39 2A 6C 76 05
Round 02 key: F2 C2 95 F2 7A 96 B9 43 59 35 80 7A 73 59 F6 7F
Round 03 key: 3D 80 47 7D 47 16 FE 3E 1E 23 7E 44 6D 7A 88 3B
Round 04 key: EF 44 A5 41 A8 52 5B 7F B6 71 25 3B DB 0B AD 00
Round 05 key: D4 D1 C6 F8 7C 83 9D 87 CA F2 B8 BC 11 F9 15 BC
Round 06 key: 6D 88 A3 7A 11 0B 3E FD DB F9 86 41 CA 00 93 FD
Round 07 key: 4E 54 F7 0E 5F 5F C9 F3 84 A6 4F B2 4E A6 DC 4F
Round 08 key: EA D2 73 21 B5 8D BA D2 31 2B F5 60 7F 8D 29 2F
Round 09 key: AC 77 66 F3 19 FA DC 21 28 D1 29 41 57 5C 00 6E
Round 10 key: D0 14 F9 A8 C9 EE 25 89 E1 3F 0C C8 B6 63 0C A6

Round 0 equals original key? False


In [9]:
# Unit tests: verify key expansion against FIPS-197 example via AES-128 encryption test
# 1) Check round 0 equals original key
# 2) Encrypt the FIPS sample block and compare to expected ciphertext (69c4e0d8...)


def bytes_to_state(b):
    # b is 16-byte iterable; fill column-major
    return [[b[4*c + r] for c in range(4)] for r in range(4)]


def state_to_bytes(state):
    return bytes([state[r][c] for c in range(4) for r in range(4)])


def build_round_keys_matrices(expanded_words):
    Nk = len(expanded_words) // 4
    # Determine Nr from length: Nr = (len(words)/4) - 1
    Nr = (len(expanded_words) // 4) - 1
    mats = [words_to_round_key_matrix(expanded_words, r) for r in range(Nr+1)]
    return mats


def aes_encrypt_block(plaintext16, expanded_words):
    sbox_local = sbox_generated if 'sbox_generated' in globals() else (AES_SBOX if 'AES_SBOX' in globals() else sbox_ref)
    state = bytes_to_state(list(plaintext16))
    Nr = (len(expanded_words) // 4) - 1
    round_keys = build_round_keys_matrices(expanded_words)

    # Initial round key
    state = add_round_key(state, round_keys[0])

    # Rounds 1..Nr-1
    for r in range(1, Nr):
        state = sub_bytes(state, sbox_local)
        state = shift_rows(state)
        state = mix_columns(state)
        state = add_round_key(state, round_keys[r])

    # Final round
    state = sub_bytes(state, sbox_local)
    state = shift_rows(state)
    state = add_round_key(state, round_keys[Nr])

    return state_to_bytes(state)


# Test vectors from FIPS-197
key_hex = '000102030405060708090a0b0c0d0e0f'
pt_hex  = '00112233445566778899aabbccddeeff'
expected_ct_hex = '69c4e0d86a7b0430d8cdb78070b4c55a'

key = bytes.fromhex(key_hex)
pt = bytes.fromhex(pt_hex)
expected_ct = bytes.fromhex(expected_ct_hex)

expanded_test = key_expansion(key)
# Check round 0 equals original key
round0_mat = words_to_round_key_matrix(expanded_test, 0)
round0_flat = [round0_mat[row][col] for col in range(4) for row in range(4)]
assert round0_flat == list(key), f"Round 0 mismatch with original key: {round0_flat} != {list(key)}"
print('Round 0 matches original key — OK')

# Encrypt and compare
ct = aes_encrypt_block(pt, expanded_test)
print('Encrypted CT:', ct.hex())
assert ct == expected_ct, f"AES encryption mismatch. expected {expected_ct.hex()}, got {ct.hex()}"
print('AES-128 encryption against FIPS test vector — OK')

# Summary
print('\nAll unit tests passed: key expansion and AES encrypt check OK')

Round 0 matches original key — OK
Encrypted CT: 69c4e0d86a7b0430d8cdb78070b4c55a
AES-128 encryption against FIPS test vector — OK

All unit tests passed: key expansion and AES encrypt check OK
