<a href="https://colab.research.google.com/github/DikshantBadawadagi/Encryption-Algorithms/blob/main/S_AES.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
import numpy as np

# S-Box for SubBytes
sbox = np.array([
    [0x9, 0x4, 0xA, 0xB],
    [0xD, 0x1, 0x8, 0x5],
    [0x6, 0x2, 0x0, 0x3],
    [0xC, 0xE, 0xF, 0x7]
])

# Inverse S-Box for InvSubBytes
inv_sbox = np.array([
    [0xA, 0x5, 0x9, 0xB],
    [0x1, 0x7, 0x8, 0xF],
    [0x6, 0x0, 0x2, 0x3],
    [0xC, 0x4, 0xD, 0xE]
])

# MixColumns constant matrix
mix_columns_matrix = np.array([
    [1, 4],
    [4, 1]
])

def sub_bytes(state):
    return np.array([[sbox[min(s >> 2, 3), s & 0x3] for s in row] for row in state])

def inv_sub_bytes(state):
    return np.array([[inv_sbox[min(s >> 2, 3), s & 0x3] for s in row] for row in state])

def shift_rows(state):
    return np.array([state[0], np.roll(state[1], -1)])

def inv_shift_rows(state):
    return np.array([state[0], np.roll(state[1], 1)])

def mix_columns(state):
    return np.array([
        [(state[0, 0] * mix_columns_matrix[0, 0] ^ state[1, 0] * mix_columns_matrix[0, 1]) % 16,
         (state[0, 1] * mix_columns_matrix[0, 0] ^ state[1, 1] * mix_columns_matrix[0, 1]) % 16],
        [(state[0, 0] * mix_columns_matrix[1, 0] ^ state[1, 0] * mix_columns_matrix[1, 1]) % 16,
         (state[0, 1] * mix_columns_matrix[1, 0] ^ state[1, 1] * mix_columns_matrix[1, 1]) % 16]
    ])

def inv_mix_columns(state):
    inv_mix_columns_matrix = np.array([
        [9, 2],
        [2, 9]
    ])
    return np.array([
        [(state[0, 0] * inv_mix_columns_matrix[0, 0] ^ state[1, 0] * inv_mix_columns_matrix[0, 1]) % 16,
         (state[0, 1] * inv_mix_columns_matrix[0, 0] ^ state[1, 1] * inv_mix_columns_matrix[0, 1]) % 16],
        [(state[0, 0] * inv_mix_columns_matrix[1, 0] ^ state[1, 0] * inv_mix_columns_matrix[1, 1]) % 16,
         (state[0, 1] * inv_mix_columns_matrix[1, 0] ^ state[1, 1] * inv_mix_columns_matrix[1, 1]) % 16]
    ])

def add_round_key(state, round_key):
    return state ^ round_key

def key_expansion(key):
    w = [key[:, 0], key[:, 1]]
    rcon = [0x80, 0x30]
    for i in range(2, 6):
        temp = w[i-1].copy()
        if i % 2 == 0:
            temp = np.roll(temp, -1)
            temp = np.array([sbox[min(t >> 2, 3), t & 0x3] for t in temp])
            temp[0] ^= rcon[i//2 - 1]
        w.append(w[i-2] ^ temp)
    return np.array(w).T.reshape(2, 2, 3)

def encrypt(plaintext, key):
    state = np.array([[ord(c) % 16 for c in plaintext[:2]], [ord(c) % 16 for c in plaintext[2:]]])
    print(f"Debug: Initial state = {state}")  # Debug output
    round_keys = key_expansion(key)
    print(f"Debug: Round keys = {round_keys}")  # Debug output

    state = add_round_key(state, round_keys[:, :, 0])

    for i in range(1, 3):
        print(f"Debug: State before sub_bytes in round {i} = {state}")  # Debug output
        state = sub_bytes(state)
        print(f"Debug: State after sub_bytes in round {i} = {state}")  # Debug output
        state = shift_rows(state)
        if i < 2:
            state = mix_columns(state)
        state = add_round_key(state, round_keys[:, :, i])
        print(f"Debug: State after round {i} = {state}")  # Debug output

    return ''.join([chr(s + ord('0')) for row in state for s in row])

def decrypt(ciphertext, key):
    state = np.array([[ord(c) - ord('0') for c in ciphertext[:2]], [ord(c) - ord('0') for c in ciphertext[2:]]])
    round_keys = key_expansion(key)

    state = add_round_key(state, round_keys[:, :, 2])

    for i in range(1, 3):
        state = inv_shift_rows(state)
        state = inv_sub_bytes(state)
        state = add_round_key(state, round_keys[:, :, 2-i])
        if i < 2:
            state = inv_mix_columns(state)

    return ''.join([chr((s % 16) + ord('A')) for row in state for s in row])

# Example usage
key = np.array([[0, 1], [4, 5]])
plaintext = "ABCD"
print(f"Debug: Key = {key}")  # Debug output
print(f"Debug: Plaintext = {plaintext}")  # Debug output
ciphertext = encrypt(plaintext, key)
decrypted = decrypt(ciphertext, key)

print(f"Plaintext: {plaintext}")
print(f"Ciphertext: {ciphertext}")
print(f"Decrypted: {decrypted}")

Debug: Key = [[0 1]
 [4 5]]
Debug: Plaintext = ABCD
Debug: Initial state = [[1 2]
 [3 4]]
Debug: Round keys = [[[  0   1 129]
  [128 176  48]]

 [[  4   5   0]
  [  5  12   9]]]
Debug: State before sub_bytes in round 1 = [[  1 130]
 [  7   1]]
Debug: State after sub_bytes in round 1 = [[ 4 15]
 [ 5  4]]
Debug: State after round 1 = [[  5 187]
 [  1   5]]
Debug: State before sub_bytes in round 2 = [[  5 187]
 [  1   5]]
Debug: State after sub_bytes in round 2 = [[1 7]
 [4 1]]
Debug: State after round 2 = [[128  55]
 [  1  13]]
Plaintext: ABCD
Ciphertext: °g1=
Decrypted: MHKJ
