# **Project 3 Group 4 Problem 2**

In [None]:
# Implementing Mini Block Cipher with key size 16 bit and block size 16 bit (Task 1)
# Includes: S-box, ShiftRows, Mix, AddRoundKey, Key expansion,
# encrypt_round1/2, decrypt_round2/1, encrypt, decrypt,generation of plaintext-ciphertext pairs
# MITM helpers, and MITM demo

In [None]:
# =============================
# Cell 1 - Imports and Utilities
# =============================
import random
from collections import defaultdict

def split_words(x):
    "Splitting a 16-bit size word into 4 nibbles"

    return [ (x >> 12) & 0xF, (x >> 8) & 0xF, (x >> 4) & 0xF, x & 0xF ]

def join_nibbles(nibs):
    "Joining 4 nibbles into a 16-bit integer"
    return ((nibs[0] & 0xF) << 12) | ((nibs[1] & 0xF) << 8) | ((nibs[2] & 0xF) << 4) | (nibs[3] & 0xF)

In [None]:
# =============================
# Cell 2 - S-box and inverse S-box
# =============================

# 4-bit S-box (SAES)
SBOX = [0x9,0x4,0xA,0xB,0xD,0x1,0x8,0x5,0x6,0x2,0x0,0x3,0xC,0xE,0xF,0x7]
INV_SBOX = [0]*16
for i,v in enumerate(SBOX):
    INV_SBOX[v] = i

def substitute(state):
    "Applying S-box to each nibble"
    return [SBOX[x & 0xF] for x in state]

def inv_substitute(state):
    "Applying the inverse of S-box to each nibble"
    return [INV_SBOX[x & 0xF] for x in state]


In [None]:
# =============================
# Cell 3 - ShiftRows and inverse ShiftRows
# =============================

def shiftrows(state):
    "Swapping s2 and s3 in order to do the shiftrow"
    return [state[0], state[1], state[3], state[2]]

def inv_shiftrows(state):
    "Applying the inverse of the shiftrows"
    return [state[0], state[1], state[3], state[2]]

In [None]:
# =============================
# Cell 4 - MixColumns, Inverse MixColumns, and GF(16) Multiplication
# =============================

# GF(16) multiplication with modulo x^4 + x + 1 for SAES 
def gf_multiplication(a, b):
    p = 0
    for i in range(4):               
        if (b >> i) & 1:
            p ^= (a << i)
    for i in range(7, 3, -1):        #  reduce modulo 0x13
        if (p >> i) & 1:
            p ^= 0x13 << (i - 4)
    return p & 0xF

# SAES Mix + InvMix (2x2) 
def mix(state):
    "SAES MixColumns - matrix [[1,4],[4,1]] over GF(16)"
    s0, s1, s2, s3 = [x & 0xF for x in state]
    temp0 = (s0 ^ gf_multiplication(4, s2)) & 0xF
    temp2 = (gf_multiplication(4, s0) ^ s2) & 0xF
    temp1 = (s1 ^ gf_multiplication(4, s3)) & 0xF
    temp3 = (gf_multiplication(4, s1) ^ s3) & 0xF
    return [temp0, temp1, temp2, temp3]

def inv_mix(state):
    "Inverse MixColumns - matrix [[1,4],[4,1]] over GF(16)"
    s0, s1, s2, s3 = [x & 0xF for x in state]
    temp0 = (gf_multiplication(9, s0) ^ gf_multiplication(2, s2)) & 0xF
    temp2 = (gf_multiplication(2, s0) ^ gf_multiplication(9, s2)) & 0xF
    temp1 = (gf_multiplication(9, s1) ^ gf_multiplication(2, s3)) & 0xF
    temp3 = (gf_multiplication(2, s1) ^ gf_multiplication(9, s3)) & 0xF
    return [temp0, temp1, temp2, temp3]

In [None]:
# =============================
# Cell 5 - AddRoundKey
# =============================

def add_round_key(state, round_key):
    "XORing the 4 nibbles with the round_key"
    key_nibs = split_words(round_key)
    return [(state[i] ^ key_nibs[i]) & 0xF for i in range(4)]

In [None]:
# =============================
# Cell 6 - Key expansion
# =============================

def key_expansion(K):
    # Splitting 16-bit key into two 8-bit words
    word0 = (K >> 8) & 0xFF
    word1 = K & 0xFF

    # Making sure to round constants
    RCONST1 = 0x80  # 1000 0000
    RCONST2 = 0x30  # 0011 0000

    # Swapping nibbles in an 8-bit word
    def Rotate_Nibbles(b):  
        return ((b << 4) | (b >> 4)) & 0xFF

    # Applying 4-bit SBOX to each of the nibbles
    def Sub_Nibbles(b):  
        high_nibble = SBOX[(b >> 4) & 0xF]
        low_nibble = SBOX[b & 0xF]
        return ((high_nibble << 4) | low_nibble) & 0xFF

    # Using the SAES schedule
    word2 = word0 ^ RCONST1 ^ Sub_Nibbles(Rotate_Nibbles(word1))
    word3 = word2 ^ word1
    word4 = word2 ^ RCONST2 ^ Sub_Nibbles(Rotate_Nibbles(word3))
    word5 = word3 ^ word4

    # Rounding keys for the two round ciphers
    K1 = (word2 << 8) | word3
    K2 = (word4 << 8) | word5
    return K1, K2


In [None]:
# =============================
# Cell 7 - Round functions
# =============================

# Two rounds of encryption to get the ciphertext.
def encrypt_round1(state_bits, K1):
    "encrypt_round1() using K1 -> i. Subsititute(); ii. Shift(); iii. Mix(); iv. AddRoundKey()"
    s = split_words(state_bits)
    s = substitute(s)
    s = shiftrows(s)
    s = mix(s)
    s = add_round_key(s, K1)
    return join_nibbles(s)

def encrypt_round2(state_bits, K2):
    "encrypt_round2() using K2 -> i. Subsititute(); ii. Shift() ; iii. AddRoundKey()"
    s = split_words(state_bits)
    s = substitute(s)
    s = shiftrows(s)
    s = add_round_key(s, K2)
    return join_nibbles(s)

# Two rounds of decryption to recover the plaintext
def decrypt_round2(state_bits, K2):
    "decrypt_round2() using K2 -> i. AddRoundKey(); ii. Shift(); iii. Subsititute()"
    s = split_words(state_bits)
    s = add_round_key(s, K2)
    s = inv_shiftrows(s)
    s = inv_substitute(s)
    return join_nibbles(s)

def decrypt_round1(state_bits, K1):
    "decrypt_round1() using K1 -> i. AddRoundKey(); ii. Mix(); iii. Shift(); iv. Subsititute()"
    s = split_words(state_bits)
    s = add_round_key(s, K1)
    s = inv_mix(s)
    s = inv_shiftrows(s)
    s = inv_substitute(s)
    return join_nibbles(s)



In [None]:
# =============================
# Cell 8 - Full encryption and decryption 
# =============================

def encrypt(plaintext, K):
    "Encrypting a 16-bit plaintext with 16-bit key in order to get a 16-bit ciphertext"
    K1, K2 = key_expansion(K)
    intermediate_ciphertext = encrypt_round1(plaintext, K1)
    ciphertext = encrypt_round2(intermediate_ciphertext, K2)
    return ciphertext

def decrypt(ciphertext, K):
    "Decrypting a 16-bit ciphertext with key in order to get the 16-bit plaintext"
    K1, K2 = key_expansion(K)
    intermediate_plaintext = decrypt_round2(ciphertext, K2)
    plaintext = decrypt_round1(intermediate_plaintext, K1)
    return plaintext



In [None]:

# =============================
# Cell 9 - Create at least 10 plaintext-ciphertext pairs
# =============================

def create_pairs(n=12, seed=None, *, format='hex', ensure_unique=True):
    """Creation of at least ten distinct plaintext–ciphertext pairs using random 16-bit keys.

    Args:
        n - Number of pairs to generate.
        seed - Optional seed for reproducible results.
        format - Representation for the returned values - 'hex', 'int', or 'bin'.
        ensure_unique - Enforce unique (key, plaintext) combinations when True.

    Returns:
        list[dict[str, str | int]]: Generated pairs formatted per the selected representation.
    """
    if n <= 0:
        raise ValueError('n must be a positive integer')
    if format not in {'hex', 'int', 'bin'}:
        raise ValueError("format must be one of {'hex', 'int', 'bin'}")

    rng = random.Random(seed)
    pairs = []
    seen = set()

    # vars for MITM data
    global attack_list
    global attack_list_hex
    global attack_list_ciphertexts
    global attack_list_ciphertexts_hex
    attack_list = []
    attack_list_hex = []
    attack_list_ciphertexts = []
    attack_list_ciphertexts_hex = []

    formatter = {
        'hex': lambda value: f"{value:04X}",
        'bin': lambda value: f"{value:016b}",
        'int': lambda value: value,
    }[format]
    suffix = {'hex': 'hex', 'bin': 'bin', 'int': 'int'}[format]

    K = rng.randrange(0, 0x10000)

    while len(pairs) < n:

        P = rng.randrange(0, 0x10000)

        if ensure_unique:
            if (K, P) in seen:
                continue
            seen.add((K, P))

        C = encrypt(P, K)
        K1, K2 = key_expansion(K)

        attack_list.append(int(formatter(P),16))
        attack_list_hex.append(formatter(P))
        attack_list_ciphertexts.append(int(formatter(C), 16))
        attack_list_ciphertexts_hex.append(formatter(C))

        pairs.append({
            f'key_{suffix}': formatter(K),
            f'K1_{suffix}': formatter(K1),
            f'K2_{suffix}': formatter(K2),
            f'plaintext_{suffix}': formatter(P),
            f'ciphertext_{suffix}': formatter(C),
        })

#    print(pairs)

    # Pull data for use in MITM
    global K_form
    K_form = formatter(K)

    global K1_form
    K1_form = formatter(K1)

    global K2_form
    K2_form = formatter(K2)

    global atl
    global atl_hex
    global atl_ciphertexts
    global atl_ciphertexts_hex
    atl = attack_list
    atl_hex = attack_list_hex
    atl_ciphertexts = attack_list_ciphertexts
    atl_ciphertexts_hex = attack_list_ciphertexts_hex

    return pairs

# def create_pairs_hex(*args, **kwargs):
#     """Compatibility helper that defaults to hexadecimal output.
#
#     This makes it convenient to request the legacy hex format when additional
#     args such as `ensure_unique` are provided.
#     """
#     kwargs.setdefault('format', 'hex')
#     return create_pairs(*args, **kwargs)


In [None]:

# =============================
# Cell 10 - Display created plaintext-ciphertext pairs
# =============================

# Creating 12 distinct plaintext–ciphertext pairs and displaying them in hex format
pairs_hex = create_pairs(12, seed=12345, ensure_unique=True)  # seed in order to keep repeating
for row in pairs_hex:
    print(row)


In [None]:
# =============================
# Cell 11 - Meet-in-the-middle helpers
# =============================

# Helper to run the existing encrypt_round1 with key K1, feed its output into encrypt_round2 with key K2
# and return the resulting ciphertext fragment; this let us treat the two rounds as a single callable during the attack
def double_round_encrypt(plaintext, K1, K2):
    """Encrypt a block using independent round keys."""
    return encrypt_round2(encrypt_round1(plaintext, K1), K2)

# Mirrors the above flow in reverse - it first undoes round 2 with decrypt_round2 and key K2,
# then undoes round 1 with decrypt_round1 and key K1, producing the original plaintext fragment from a ciphertext block
def double_round_decrypt(ciphertext, K1, K2):
    """Decrypt a block using independent round keys."""
    return decrypt_round1(decrypt_round2(ciphertext, K2), K1)

# meet_in_the_middle_attack orchestrates the key search:
#   [1] It validates that at least one plaintext/ciphertext pair is supplied and copies the pairs into a list for re-use
#   [2] Using the first pair, it encrypts the plaintext through round 1 under every possible K1, storing a map from the intermediate state to all keys that produced it (forward_map)
#   [3] It then iterates over all possible K2 values, decrypts the ciphertext back through round 2, and checks whether that intermediate value appeared in the forward map;
#       any match yields one or more (K1, K2) candidates, which are collected in a set
#   [4] For each additional plaintext/ciphertext pair, it filters the current candidate set by re-encrypting the plaintext with each candidate key pair;
#       only those that reproduce the ciphertext survive. The loop stops early if no candidates remain.
#   [5] Finally, it returns a sorted list of the surviving key pairs.
def meet_in_the_middle_attack(pairs):
    """Recover candidate (K1, K2) pairs from plaintext/ciphertext examples.

    Args:
        Plaintext/ciphertext pairs.

    Returns:
        Sorted key-pair candidates consistent with all pairs.
    """
#   [1]
    pairs = list(pairs)
    if not pairs:
        raise ValueError('At least one plaintext/ciphertext pair is required')

#   [2]
    plaintext0, ciphertext0 = pairs[0]
    forward_map = defaultdict(list)
    for K1 in range(0x10000):
        intermediate = encrypt_round1(plaintext0, K1)
        forward_map[intermediate].append(K1)

#   [3]
    candidates = set()
    for K2 in range(0x10000):
        intermediate = decrypt_round2(ciphertext0, K2)
        if intermediate in forward_map:
            for K1 in forward_map[intermediate]:
                candidates.add((K1, K2))

#   [4]
    for plaintext, ciphertext in pairs[1:]:
        candidates = {
            (K1, K2)
            for (K1, K2) in candidates
            if double_round_encrypt(plaintext, K1, K2) == ciphertext
        }
        if not candidates:
            break

    # print('Sorted candidates ', sorted(candidates))

#   [5]
    return sorted(candidates)


In [None]:
# =============================
# Cell 12 - Meet-in-the-middle attack
# =============================

#   [1] Fixes a secret key pair (secret_K1, secret_K2) and a list of sample plaintexts,
#       then builds matching ciphertexts by running each plaintext through the two-round encryption helper
#       creating the test data the attack will try to recover from.

#   [2] It progressively feeds growing subsets of these pairs (first one pair, then two, and so on) into meet_in_the_middle_attack,
#       printing how many candidate key pairs remain after each run and previewing up to five examples;
#       this shows how additional known pairs shrink the search space.

#   [3] After the loop, if at least one candidate remains, it either extracts the unique pair or, if multiple remain, prefers the true secret pair when present;
#       otherwise it falls back to the first candidate for demonstration purposes.

#   [4] To verify success, it encrypts a fresh plaintext (0xFFFF) with the true secret keys, decrypts that ciphertext using the recovered candidate pair,
#       and prints both the keys and the recovered plaintext; an assertion confirms the decrypted value matches the original plaintext.

#   [5] If no candidates survived, the script reports that more known plaintext/ciphertext pairs are needed to make progress.

#   [1]

print('Actual Secret K1: ', K1_form)
print('Actual Secret K2: ', K2_form)
print('Attack list (hex): ', atl_hex)
print('Attack List (dec): ', atl)
print('Attack list ciphertexts (hex): ', atl_ciphertexts_hex)
print('Attack list ciphertexts (dec): ', atl_ciphertexts)

print('Plaintext Samples: ', plaintext_samples)

secret_K1 = int(K1_form, 16)
secret_K2 = int(K2_form, 16)

# Test
# secret_K1 = 0x1234
# secret_K2 = 0xABCD
# plaintext_samples = [0x0123, 0x4567, 0x89AB, 0xFEDC, 0x0F0F ,0xFFAC, 0x02FA, 0x78A8, 0x99A1, 0x11BA, 0xBBBB, 0xCCA1]

pairs_for_attack = [
    (plaintext, double_round_encrypt(plaintext, secret_K1, secret_K2))
    for plaintext in atl
]

print('Pairs for attack: ', pairs_for_attack)

final_candidates = None

#   [2]
for count in range(1, len(pairs_for_attack) + 1):
    subset = pairs_for_attack[:count]
    candidates = meet_in_the_middle_attack(subset)
    final_candidates = candidates
    print(f'Using {count} plaintext/ciphertext pair(s) -> {len(candidates)} candidate key pair(s)')

    preview = candidates if len(candidates) <= 5 else candidates[:5]
    for K1, K2 in preview:
        print(f'  Candidate -> K1={K1:04X}, K2={K2:04X}')
    if len(candidates) > 5:
        print('  ...')

#   [3]
if final_candidates:
    final_count = len(final_candidates)
    if final_count == 1:
        recovered_K1, recovered_K2 = final_candidates[0]
        print('Unique key pair recovered.')
    else:
        recovered_K1, recovered_K2 = next(
            ((K1, K2) for (K1, K2) in final_candidates if (K1, K2) == (secret_K1, secret_K2)),
            final_candidates[0],
        )
        print(f'Multiple candidates remain ({final_count}). Displaying one consistent pair:')

#   [4]
    plaintext_to_recover = 0xFFFF

    ciphertext_example = double_round_encrypt(plaintext_to_recover, secret_K1, secret_K2)

    recovered_plaintext = double_round_decrypt(ciphertext_example, recovered_K1, recovered_K2)

    print(f'Using recovered keys K1={recovered_K1:04X}, K2={recovered_K2:04X}')
    print(f'  Decrypting ciphertext {ciphertext_example:04X} yields plaintext {recovered_plaintext:04X}')
    assert recovered_plaintext == plaintext_to_recover
else:
#   [5]
    print('No key pairs were recovered. Collect additional pairs and retry.')
