In [None]:
class SimplifiedAES(object):
    """Simplified AES is a simplified version of AES algorithm"""

    # S-Box
    sBox = [
        0x9, 0x4, 0xA, 0xB,
        0xD, 0x1, 0x8, 0x5,
        0x6, 0x2, 0x0, 0x3,
        0xC, 0xE, 0xF, 0x7,
    ]

    # Inverse S-Box
    sBoxI = [
        0xA, 0x5, 0x9, 0xB,
        0x1, 0x7, 0x8, 0xF,
        0x6, 0x0, 0x2, 0x3,
        0xC, 0x4, 0xD, 0xE,
    ]

    def __init__(self, key):
        # Round keys: K0 = w0 + w1; K1 = w2 + w3; K2 = w4 + w5
        self.pre_round_key, self.round1_key, self.round2_key = self.key_expansion(key)

    def sub_word(self, word):
        """ Substitute word """
        return (self.sBox[(word >> 4)] << 4) + self.sBox[word & 0x0F]

    def rot_word(self, word):
        """ Rotate word """
        return ((word & 0x0F) << 4) + ((word & 0xF0) >> 4)

    def key_expansion(self, key):
        """Key expansion"""
        Rcon1 = 0x80
        Rcon2 = 0x30

        w = [None] * 6
        w[0] = (key & 0xFF00) >> 8
        w[1] = key & 0x00FF
        w[2] = w[0] ^ (self.sub_word(self.rot_word(w[1])) ^ Rcon1)
        w[3] = w[2] ^ w[1]
        w[4] = w[2] ^ (self.sub_word(self.rot_word(w[3])) ^ Rcon2)
        w[5] = w[4] ^ w[3]

        return (
            self.int_to_state((w[0] << 8) + w[1]),  # Pre-Round key
            self.int_to_state((w[2] << 8) + w[3]),  # Round 1 key
            self.int_to_state((w[4] << 8) + w[5]),  # Round 2 key
        )

    def gf_mult(self, a, b):
        """Galois field multiplication of a and b in GF(2^4) / x^4 + x + 1"""
        product = 0
        a = a & 0x0F
        b = b & 0x0F

        while a and b:
            if b & 1:
                product ^= a
            a <<= 1
            if a & (1 << 4):
                a ^= 0b10011
            b >>= 1

        return product

    def int_to_state(self, n):
        """Convert a 2-byte integer into a 4-element vector (state matrix)"""
        return [n >> 12 & 0xF, (n >> 4) & 0xF, (n >> 8) & 0xF, n & 0xF]

    def state_to_int(self, m):
        """Convert a 4-element vector (state matrix) into 2-byte integer"""
        return (m[0] << 12) + (m[2] << 8) + (m[1] << 4) + m[3]

    def add_round_key(self, s1, s2):
        """Add round keys in GF(2^4)"""
        return [i ^ j for i, j in zip(s1, s2)]

    def sub_nibbles(self, sbox, state):
        """Nibble substitution"""
        return [sbox[nibble] for nibble in state]

    def shift_rows(self, state):
        """Shift rows and inverse shift rows of state matrix"""
        return [state[0], state[1], state[3], state[2]]

    def mix_columns(self, state):
        """Mix columns transformation on state matrix"""
        return [
            state[0] ^ self.gf_mult(4, state[2]),
            state[1] ^ self.gf_mult(4, state[3]),
            state[2] ^ self.gf_mult(4, state[0]),
            state[3] ^ self.gf_mult(4, state[1]),
        ]

    def inverse_mix_columns(self, state):
        """Inverse mix columns transformation on state matrix"""
        return [
            self.gf_mult(9, state[0]) ^ self.gf_mult(2, state[2]),
            self.gf_mult(9, state[1]) ^ self.gf_mult(2, state[3]),
            self.gf_mult(9, state[2]) ^ self.gf_mult(2, state[0]),
            self.gf_mult(9, state[3]) ^ self.gf_mult(2, state[1]),
        ]

    def encrypt(self, plaintext):
        """Encrypt plaintext with given key"""
        state = self.add_round_key(self.pre_round_key, self.int_to_state(plaintext))
        state = self.mix_columns(self.shift_rows(self.sub_nibbles(self.sBox, state)))
        state = self.add_round_key(self.round1_key, state)
        state = self.shift_rows(self.sub_nibbles(self.sBox, state))
        state = self.add_round_key(self.round2_key, state)
        return self.state_to_int(state)

    def decrypt(self, ciphertext):
        """Decrypt ciphertext with given key"""
        state = self.add_round_key(self.round2_key, self.int_to_state(ciphertext))
        state = self.sub_nibbles(self.sBoxI, self.shift_rows(state))
        state = self.inverse_mix_columns(self.add_round_key(self.round1_key, state))
        state = self.sub_nibbles(self.sBoxI, self.shift_rows(state))
        state = self.add_round_key(self.pre_round_key, state)
        return self.state_to_int(state)

# Example usage
key = 0b0100101011110101
s_aes = SimplifiedAES(key)
plaintext = 0b1101011100111000

print("Plaintext: 0b1101011100111000")
ciphertext = s_aes.encrypt(plaintext)
print("Ciphertext:", bin(ciphertext))

decrypted_plaintext = s_aes.decrypt(ciphertext)
print("Decrypted Plaintext:", bin(decrypted_plaintext))


Plaintext: 0b1101011100111000
Ciphertext: 0b10110000011100
Decrypted Plaintext: 0b1101011100111000
