In [139]:
import sys

#round Constants
rCon1 = 0b10000000
rCon2 = 0b00110000
 
# S-Box
sBox  = [0x9, 0x4, 0xa, 0xb, 0xd, 0x1, 0x8, 0x5,
		 0x6, 0x2, 0x0, 0x3, 0xc, 0xe, 0xf, 0x7]
 
# Inverse S-Box
sBoxI = [0xa, 0x5, 0x9, 0xb, 0x1, 0x7, 0x8, 0xf,
		 0x6, 0x0, 0x2, 0x3, 0xc, 0x4, 0xd, 0xe]
 
# Round keys: K0 = w0 + w1; K1 = w2 + w3; K2 = w4 + w5
w = [None] * 6

keys = [None] * 3

In [140]:
#w will be 8 bits long
def rotate_word(w):
    return (((w & 0x0f) << 4) + ((w & 0xf0) >> 4))

In [141]:
#w will be 8 bits long
def sub_nib(w):
    return (sBox[(w & 0xf0) >> 4] << 4) + sBox[(w & 0x0f)]

In [142]:
def key_generation(key):
    w[0] = (key & 0xff00) >> 8
    w[1] = (key & 0x00ff)
    
    keys[0] = (w[0] << 8) + w[1]
    
    w[2] = w[0] ^ rCon1 ^ sub_nib(rotate_word(w[1])) 
    w[3] = w[1] ^ w[2]
    
    keys[1] = (w[2] << 8) + w[3]
    
    w[4] = w[2] ^ rCon2 ^ sub_nib(rotate_word(w[3])) 
    w[5] = w[3] ^ w[4]
    
    keys[2] = (w[4] << 8) + w[5]
    
    print(keys)

In [143]:
#n is 16 bit

#0010 1110 1010 1111
# 0010 1010
# 1110 1111
def plain_to_state_array(n):
    return [n >> 12, ( n >> 4) & 0x000f , ( n >> 8) & 0x000f ,  n & 0x000f]

def state_array_to_plain(m):
    return (m[0] << 12) + (m[2] << 8) + (m[1] << 4) + m[3]

In [144]:
def add_round_key(state,key):
    return [i^j  for i,j in zip(state,key)]

In [145]:
def nibble_sub(state):
    return [sBox[i] for i in state]

def inv_nibble_sub(state):
    return [sBoxI[i] for i in state]

In [146]:
#shift_rows and inv_shift_rows are same
def shift_rows(state):
    return [state[0], state[1], state[3], state[2]]

In [147]:
def mult(p1, p2):
    #Multiply two polynomials in (GF)(2^4)/x^4 + x + 1
    p = 0
    while p2:
        if p2 & 0b1:
            p ^= p1
        p1 <<= 1
        if p1 & 0b10000:
            p1 ^= 0b11
        p2 >>= 1
    return p & 0b1111

def mix_columns(s):
    return [s[0] ^ mult(4, s[2]), s[1] ^ mult(4, s[3]),
    s[2] ^ mult(4, s[0]), s[3] ^ mult(4, s[1])] 

def inv_mix_columns(s):
    return [mult(9, s[0]) ^ mult(2, s[2]), mult(9, s[1]) ^ mult(2, s[3]),
        mult(9, s[2]) ^ mult(2, s[0]), mult(9, s[3]) ^ mult(2, s[1])]

In [148]:
def encrpyt(pt):
    state = plain_to_state_array(pt) 
    state = add_round_key(state,plain_to_state_array(keys[0]))
    
    #round 1
    state = nibble_sub(state)
    state = shift_rows(state)
    state = mix_columns(state)
    state = add_round_key(state,plain_to_state_array(keys[1]))
    
    #round 2
    state = nibble_sub(state)
    state = shift_rows(state)
    state = add_round_key(state,plain_to_state_array(keys[2]))
    
    return state_array_to_plain(state)
    

In [154]:
def decrpyt(pt):
    state = plain_to_state_array(pt) 
    state = add_round_key(state,plain_to_state_array(keys[2]))
    
    #round 1
    state = shift_rows(state)
    state = inv_nibble_sub(state)
    state = add_round_key(state,plain_to_state_array(keys[1]))
    state = inv_mix_columns(state)
    
    
    #round 2
    state = shift_rows(state)
    state = inv_nibble_sub(state)
    state = add_round_key(state,plain_to_state_array(keys[0]))
    
    return state_array_to_plain(state)

In [155]:
plain_text = int(input("Enter plaintext (Numeric value < 65536): "))
key = int(input("Enter key (Numeric value): "))
key_generation(key)
cipher_text = encrpyt(plain_text)
print("Cipher: ",cipher_text)
print("Plain: ",decrpyt(cipher_text))

Enter plaintext (Numeric value < 65536): 35
Enter key (Numeric value): 34
[34, 10760, 29563]
Cipher:  6215
Plain:  35
