In [3]:
#import library to generate sufficiently random initial key value, have to pip install pycryptodome
from Crypto.Random import get_random_bytes

#S-box provided in lesson slides for encryption 
sbox = {'0000':'1001', '0001': '0100', '0010': '1010', '0011': '1011', '0100': '1101',
         '0101': '0001', '0110': '1000', '0111': '0101', '1000': '0110', '1001': '0010',
          '1010': '0000', '1011': '0011', '1100': '1100', '1101': '1110', '1110': '1111',
          '1111': '0111'}

#inverse S-box for decryption
inversesbox = {'1001': '0000', '0100': '0001', '1010': '0010', '1011': '0011', '1101': '0100',
               '0001': '0101', '1000': '0110', '0101': '0111', '0110': '1000', '0010': '1001',
               '0000':'1010', '0011': '1011', '1100':'1100', '1110': '1101', '1111': '1110',
               '0111':'1111'}

#round constants for key expansion: binary representations of x**(i+2)/x**4+x+1 + 4 0's to make 8 bits
rconstants =  ['10000000', '00110000', '11000000', '00110000', '0001000', '00100000', '01000000',
               '10000000', '00110000', '11000000', '00110000', '00010000', '00100000', '01000000',
               '10000000']

#list for storing user-provided plaintexts
plaintexts_list = []
testing_plaintexts = ['lions', 'tigers', 'bears', 'walruses', 'deer',
                       'giraffes', 'llamas', 'ostriches', 'wolves', 'whales']
testing_plaintexts2 = ["soccer", "basketball", "baseball", "tennis", "golf", 
                       "hockey", "rugby", "cricket", "volleyball", "swimming"]

#max number of plaintexts to store
numofplaintexts = 10

#take 10 user inputs and store them as strings and binary strings
for i in range(numofplaintexts):
    userinput = input("Please enter the plaintext: ")
    plaintexts_list.append(userinput)

#converts plaintext values into a list of strings representing binary values
def converttobinarystrings(input_plaintext):
    binarystrings = []
    for i in range(len(input_plaintext)):
        binarystring = ''.join(format(byte, '08b') for byte in bytearray(input_plaintext[i], encoding='utf-8'))
        binarystrings.append(binarystring)
    return(binarystrings)

#converting the encrypted binaries back into characters using latin-1 encoding
def convertbinstringstotexts(binary_string_list):
    res = []
    for bs in binary_string_list:
        # Make sure the length is a multiple of 8
        num_bytes = len(bs) // 8
        # Convert the binary string to an integer, then to bytes
        b = int(bs, 2).to_bytes(num_bytes, byteorder='big')
        # Decode using latin-1 to preserve every byte exactly
        res.append(b.decode('latin-1'))
    return(res)

#to extend any plaintext to a block size of 16
def pkcs7_pad(text, blocksize=2):
    pad_len = blocksize - (len(text) % blocksize)
    return text + chr(pad_len) * pad_len

#to remove same amount of padding that was added to plaintext to complete a block
def pkcs7_unpad(text):
    pad_len = ord(text[-1])
    return text[:-pad_len]

#divide strings of '1's and '0's into 16 bit blocks 
def split_to_blocks(binarystring):
    return [binarystring[i:i+16] for i in range(0, len(binarystring), 16)]

#divide binary string into 4 bit nibbles
def makenibbles(input_binary):
    nibbles = []
    for i in range(len(input_binary)):
        nibbles.append([input_binary[i][j:j+4] for j in range(0, len(input_binary[i]), 4)])   
    return(nibbles)

key0 = get_random_bytes(2) # 2 bytes * 8 = 16 bits 
binaryKey0 = ''.join(format(byte, '08b') for byte in key0)
binaryKey1 = None
binaryKey2 = None

#performs key expansion to derive key values for rounds 1 and 2 of SAES encryption
def expand_key(init_key_val, roundconstant):
    word0 = init_key_val[0:8]
    word1 = init_key_val[8:16]
    #breaking bytes into 4 bit "nibs"
    word0nib0, word0nib1, word1nib0, word1nib1 = word0[0:4], word0[4:8], word1[0:4], word1[4:8]
    
    #apply rotation
    rotatednibs = word1nib1 + word1nib0
    
    #apply substition from provided sbox 
    sub_rotnib1 = sbox[rotatednibs[:4]] + sbox[rotatednibs[4:]]
    
    #XOR the rotated and substituted word1 with rconstant0, pad with any necessary '0's
    xored_subrotword1 = bin(int(roundconstant, 2) ^ int(sub_rotnib1, 2))[2:].rjust(8, '0')
    
    #XOR result with word0 to get word2, pad with any necessary '0's
    word2 = bin(int(word0, 2) ^ int(xored_subrotword1, 2))[2:].rjust(8, '0')
    
    #XOR word2 with word1 to get word3, pad with any necessary '0's
    word3 = bin(int(word2, 2) ^ int(word1, 2))[2:].rjust(8, '0')
    
    return(word2 + word3)
    
binaryKey1 = expand_key(binaryKey0, rconstants[0])
binaryKey2 = expand_key(binaryKey1, rconstants[1])

#carry out first step of SAES encryption - substitution, with input value a string of the binary digits representing plaintext inputs
def substitute_nibbles(plaintext_binary):
    #nibbles = makenibbles(plaintext_binary)
    subbed_nibbled_words = []
    #replace plaintext nibble with corresponding Sbox value
    for i in range(len(plaintext_binary)):
        subbed_nibbles = []
        for j in range(len(plaintext_binary[i])):
            subbed_nibbles.append(sbox[plaintext_binary[i][j]]) 
        subbed_nibbled_words.append(subbed_nibbles)
    return(subbed_nibbled_words)

#uses the inverse SBOX to perform substitution for decryption
def inverse_substitute_nibbles(input_binary):
    subbed_nibbled_words = []
    #replace plaintext nibble with corresponding inverse Sbox value
    for i in range(len(input_binary)):
        subbed_nibbles = []
        for j in range(len(input_binary[i])):
            subbed_nibbles.append(inversesbox[input_binary[i][j]]) 
        subbed_nibbled_words.append(subbed_nibbles)
    return(subbed_nibbled_words)

#carry out second step of SAES encryption - shift rows, with input value a string of the binary digits representing plaintext inputs that have been substituted
def shift_rows(subbed_binary):
    #shift rows in a 2x2 matrix is equivalent to swapping every second and fourth nibbles 
    for i in range(len(subbed_binary)):
        for j in range(0, len(subbed_binary[i]), 4):
            temp_value = subbed_binary[i][j+1]
            subbed_binary[i][j+1] = subbed_binary[i][j+3]
            subbed_binary[i][j+3] = temp_value
    return(subbed_binary)

#shifting rows back and removing any '0000' padding for decryption function
def inverse_shift_rows(input_binary):
    for i in range(len(input_binary)):
        for j in range(0, len(input_binary[i]), 4):
            temp_value = input_binary[i][j+3]
            input_binary[i][j+3] = input_binary[i][j+1]
            input_binary[i][j+1] = temp_value
    return(input_binary)
        
#carry out third step of SAES encryption for first round only - mix columns, with a string of binary digits representing inputs
# that have been substituted and shifted by rows    
def mix_columns(shifted_binary):
    for i in range(len(shifted_binary)):
        for j in range(0, len(shifted_binary[i]), 2):
            #creating a list of the 8 bits of the column of two nibbles used for each round of column shifting 
            mix_list = [shifted_binary[i][j][0], shifted_binary[i][j][1], shifted_binary[i][j][2], shifted_binary[i][j][3],
                        shifted_binary[i][j+1][0],shifted_binary[i][j+1][1],shifted_binary[i][j+1][2],shifted_binary[i][j+1][3] ]
            #XORing mix_list values according to mix column table provided in lesson slides
            mix_list[0] = format(int(shifted_binary[i][j][0], 2) ^ int(shifted_binary[i][j+1][2], 2), '01b')
            mix_list[1] = format(int(shifted_binary[i][j][1], 2) ^ int(shifted_binary[i][j+1][0], 2) ^ int(shifted_binary[i][j+1][3], 2), '01b')
            mix_list[2] = format(int(shifted_binary[i][j][2], 2) ^ int(shifted_binary[i][j+1][0], 2) ^ int(shifted_binary[i][j+1][1], 2), '01b')
            mix_list[3] = format(int(shifted_binary[i][j][3], 2) ^ int(shifted_binary[i][j+1][1], 2), '01b')
            mix_list[4] = format(int(shifted_binary[i][j+1][0], 2) ^ int(shifted_binary[i][j][2], 2), '01b')
            mix_list[5] = format(int(shifted_binary[i][j+1][1], 2) ^ int(shifted_binary[i][j][0], 2) ^ int(shifted_binary[i][j][3], 2), '01b')
            mix_list[6] = format(int(shifted_binary[i][j+1][2], 2) ^ int(shifted_binary[i][j][0], 2) ^ int(shifted_binary[i][j][1], 2), '01b')
            mix_list[7] = format(int(shifted_binary[i][j+1][3], 2) ^ int(shifted_binary[i][j][1], 2), '01b')
            #concatenating new nibble values and placing them in string to be returned by function
            shifted_binary[i][j] = mix_list[0]+ mix_list[1] + mix_list[2] + mix_list[3]
            shifted_binary[i][j+1] = mix_list[4] + mix_list[5] + mix_list[6] + mix_list[7]
    return(shifted_binary)

#inverse column mixing for decryption
def inverse_mix_columns(input_binary):
    for i in range(len(input_binary)):
        for j in range(0, len(input_binary[i]), 2):
            #creating a list of the 8 bits of the column of two nibbles used for each round of column shifting 
            mix_list = [input_binary[i][j][0], input_binary[i][j][1], input_binary[i][j][2], input_binary[i][j][3],
                       input_binary[i][j+1][0], input_binary[i][j+1][1], input_binary[i][j+1][2], input_binary[i][j+1][3] ]
            #XORing mix_list values according to inverse mix column table provided in lesson slides
            mix_list[0] = format(int(input_binary[i][j][3], 2) ^ int(input_binary[i][j+1][1], 2), '01b')
            mix_list[1] = format(int(input_binary[i][j][0], 2) ^ int(input_binary[i][j+1][2], 2), '01b')
            mix_list[2] = format(int(input_binary[i][j][1], 2) ^ int(input_binary[i][j+1][0], 2) ^ int(input_binary[i][j+1][3], 2), '01b')
            mix_list[3] = format(int(input_binary[i][j][2], 2) ^ int(input_binary[i][j][3], 2) ^ int(input_binary[i][j+1][0], 2),  '01b')
            mix_list[4] = format(int(input_binary[i][j][1], 2) ^ int(input_binary[i][j+1][3], 2), '01b')
            mix_list[5] = format(int(input_binary[i][j][2], 2) ^ int(input_binary[i][j+1][0], 2), '01b')
            mix_list[6] = format(int(input_binary[i][j][0], 2) ^ int(input_binary[i][j][3], 2) ^ int(input_binary[i][j+1][1], 2), '01b')
            mix_list[7] = format(int(input_binary[i][j][0], 2) ^ int(input_binary[i][j+1][2], 2) ^ int(input_binary[i][j+1][3], 2), '01b')
            #concatenating new nibble values and placing them in string to be returned by function
            input_binary[i][j] = mix_list[0]+ mix_list[1] + mix_list[2] + mix_list[3]
            input_binary[i][j+1] = mix_list[4] + mix_list[5] + mix_list[6] + mix_list[7]
    return(input_binary)

#XORing the list of nibbles from either the mixing column or shifting rows steps with round key
def add_round_key(mixed_binary, roundkey):
    roundkey_nibs = [roundkey[0:4], roundkey[4:8], roundkey[8:12], roundkey[12:16]]
    for i in range(len(mixed_binary)):
        for j in range(0, len(mixed_binary[i]), 4):
            mixed_binary[i][j] = format(int(mixed_binary[i][j], 2) ^ int(roundkey_nibs[0], 2), '04b')
            mixed_binary[i][j+1] = format(int(mixed_binary[i][j+1], 2) ^ int(roundkey_nibs[1], 2), '04b')
            mixed_binary[i][j+2] = format(int(mixed_binary[i][j+2], 2) ^ int(roundkey_nibs[2], 2), '04b')
            mixed_binary[i][j+3] = format(int(mixed_binary[i][j+3], 2) ^ int(roundkey_nibs[3], 2), '04b')
    return(mixed_binary)

#combining the 4 steps: subsituting nibbles, shifting rows, mixing columns, and XORing the first round key
def encrypt_round1(plaintext_binary, round_key):
    subbed_nibbles = substitute_nibbles(plaintext_binary)
    shifted_rows = shift_rows(subbed_nibbles)
    mixed_cols = mix_columns(shifted_rows)
    roundkey_added = add_round_key(mixed_cols, round_key)
    return(roundkey_added)

#second round of SAES encryption: substitution, shifting rows, and XORing the second round key (no mixing columns)
def encrypt_round2(intermediate_binary, round_key):
    subbed_nibbles = substitute_nibbles(intermediate_binary)
    shifted_rows = shift_rows(subbed_nibbles)
    roundkey_added = add_round_key(shifted_rows, round_key)
    return(roundkey_added)

#first round of SAES decryption: only reversing roundkey addition, shifted rows, and nibble substition
def decrypt_round2(input_binaries, roundkey):
    roundkey_add = add_round_key(input_binaries, roundkey)
    inverse_shifted = inverse_shift_rows(roundkey_add)
    inverse_subbed = inverse_substitute_nibbles(inverse_shifted)
    return(inverse_subbed)

#second round of SAES  decryption: reversing all four steps: round key addition, mixing columns, shifted rows, and nibble substitution
def decrypt_round1(intermediate_binaries, roundkey):
    roundkey_add = add_round_key(intermediate_binaries, roundkey)
    inverse_mixed = inverse_mix_columns(roundkey_add)
    inverse_shifted = inverse_shift_rows(inverse_mixed)
    inverse_subbed = inverse_substitute_nibbles(inverse_shifted)
    plaintext_binary = [''.join(nibbles) for nibbles in inverse_subbed]
    return(plaintext_binary)

def reassemble_ciphertexts(cipher_nibbles, padded_binaries):
    
    #Groups the flat list of cipher blocks (cipher_nibbles) back into a list of binary strings.
    #Each binary string will have the same length as the corresponding padded plaintext's binary form.
    
    ciphertexts = []
    index = 0
    for binary_text in padded_binaries:
        # Determine how many 16-bit blocks were used for this plaintext.
        blocks = split_to_blocks(binary_text)
        num_blocks = len(blocks)
        # Get the corresponding encrypted blocks.
        group = cipher_nibbles[index:index+num_blocks]
        index += num_blocks
        # For each block, join its 4-bit nibbles to form the 16-bit encrypted block, then join all blocks.
        ciphertext_binary = ''.join(''.join(block) for block in group)
        ciphertexts.append(ciphertext_binary)
    return(ciphertexts)

def ciphertext_to_nibbles(ciphertext):
     # Use latin-1 to get back the exact original bytes
    binary_str = ''.join(format(byte, '08b') for byte in bytearray(ciphertext, 'latin-1'))
    remainder = len(binary_str) % 16
    if remainder != 0:
        binary_str = binary_str.rjust(len(binary_str) + (16 - remainder), '0')
    blocks = split_to_blocks(binary_str)
    nibbles = makenibbles(blocks)
    return nibbles

 #Given the list of decrypted blocks (each 16 bits) and the list of padded binary strings from encryption
def group_decrypted_blocks(decrypted_blocks, padded_binaries):
        
    grouped = []
    index = 0
    for binary_text in padded_binaries:
        blocks = split_to_blocks(binary_text)
        num_blocks = len(blocks)
        group = decrypted_blocks[index:index+num_blocks]
        index += num_blocks
        # Reassemble the group into one binary string:
        plaintext_binary = ''.join(group)
        grouped.append(plaintext_binary)
    return grouped

def SAES_encrypt(plaintexts, roundkey1, roundkey2):
    padded_plaintexts = []
    padded_binaries = []
    all_nibbles = []
        
    # Process each plaintext individually:
    for plaintext in plaintexts:
        padded = pkcs7_pad(plaintext)
        padded_plaintexts.append(padded)
        binary_text = ''.join(format(byte, '08b') for byte in bytearray(padded, 'utf-8'))
        padded_binaries.append(binary_text)
        blocks = split_to_blocks(binary_text)  # Each block is a 16-bit string.
        # For each plaintext, makenibbles returns a list of blocks, each as a list of 4-bit strings.
        all_nibbles.extend(makenibbles(blocks))
    
    # Encrypt all blocks (all_nibbles is a flat list of blocks)
    intermediate_val = encrypt_round1(all_nibbles, roundkey1)
    cipher_nibbles = encrypt_round2(intermediate_val, roundkey2)
    
    # Reassembles the cipher_nibbles back into ciphertexts that have the same block structure as padded_binaries.
    cipherbinaries= reassemble_ciphertexts(cipher_nibbles, padded_binaries)
    ciphertexts = convertbinstringstotexts(cipherbinaries)
    return(ciphertexts, padded_binaries)

def SAES_decrypt(ciphertexts, roundkey1, roundkey2, padded_binaries):
    all_nibbles = []
    for ciphertext in ciphertexts:
        # Convert each ciphertext back into its nibble blocks and append to all_nibbles.
        nibbles = ciphertext_to_nibbles(ciphertext)
        all_nibbles.extend(nibbles)
    intermediate_val = decrypt_round2(all_nibbles, roundkey2)
    plaintext_blocks = decrypt_round1(intermediate_val, roundkey1)
    grouped_binaries = group_decrypted_blocks(plaintext_blocks, padded_binaries)
    plaintexts = convertbinstringstotexts(grouped_binaries)
    original_plaintexts = [pkcs7_unpad(pt) for pt in plaintexts]
    return(original_plaintexts)

test_cipher1 = SAES_encrypt(testing_plaintexts, binaryKey1, binaryKey2)
test_cipher2 = SAES_encrypt(testing_plaintexts2, binaryKey1, binaryKey2)
print('First set of test ciphertexts:')
print(test_cipher1[0])
print('Second set of test ciphertexts:')
print(test_cipher2[0])
print('\n')

testdecrypt1 = SAES_decrypt(test_cipher1[0], binaryKey1, binaryKey2, test_cipher1[1])
testdecrypt2 = SAES_decrypt(test_cipher2[0], binaryKey1, binaryKey2, test_cipher2[1])
print('First set of test decrypted texts:')
print(testdecrypt1)
print('Second set of test decrypted texts:')
print(testdecrypt2)
print('\n')

userencryption = SAES_encrypt(plaintexts_list, binaryKey1, binaryKey2)
userdecryption = SAES_decrypt(userencryption[0], binaryKey1, binaryKey2, userencryption[1])

print('User ciphertexts:')
for index, string in enumerate(userencryption[0]):
    print(f"Ciphertext {index}: {string}", end = ", ")
print('\n')
print('Decrypted user texts:')
for index, string in enumerate(userdecryption):
    print(f"Plaintext {index}: {string}", end = "  ")


Please enter the plaintext:  run
Please enter the plaintext:  dog
Please enter the plaintext:  boat
Please enter the plaintext:  seven
Please enter the plaintext:  hurricane
Please enter the plaintext:  nowhere
Please enter the plaintext:  key
Please enter the plaintext:  fork
Please enter the plaintext:  tide
Please enter the plaintext:  bark


First set of test ciphertexts:
['\x9d\x9a\x10\x0fÏF', 'LgVò{X\x82\x0c', 'WÒÆØÏF', 'ÆöÌÈq\x18\x01\x10\x82\x0c', '\\bÁ\x18\x82\x0c', '\x96úÇÖ£±\x01\x10\x82\x0c', 'M\x9då\\\x06Ð\x82\x0c', '\x04p= \x08ÀTBÏF', 'öô¬Á\x01\x10\x82\x0c', '¦õE]\x01\x10\x82\x0c']
Second set of test ciphertexts:
['û\x84\x0b\x80Á\x18\x82\x0c', '7Ûk\x8aa\x197ÛM\x9d\x82\x0c', '7Û\x1b\x8e7ÛM\x9d\x82\x0c', '\x1cn\x1f\x1f\n\x90\x82\x0c', '\x86õ\xad\x91\x82\x0c', '\x84EÛ\x87\x91\x1a\x82\x0c', '\x1b^ÆøÄ\x16', 'Çø\x08ÀZ"Îæ', 'ó´M\x9d\x91\x1a7ÛM\x9d\x82\x0c', '\x07ûèÌ\x92:\x7f\x16\x82\x0c']


First set of test decrypted texts:
['lions', 'tigers', 'bears', 'walruses', 'deer', 'giraffes', 'llamas', 'ostriches', 'wolves', 'whales']
Second set of test decrypted texts:
['soccer', 'basketball', 'baseball', 'tennis', 'golf', 'hockey', 'rugby', 'cricket', 'volleyball', 'swimming']


User ciphertexts:
Ciphertext 0: ^;Ë, Ciphertext 1: e9», Ciphertext 2: ÕfÙ, Ciphertext 3: ¾;Ë, Ciphertext 4: P²;À_6, Ciphertext 5: ¦õÁ6, C