# Mini Block Cipher + MITM Attack

This notebook implements the mini block cipher (16-bit block, 16-bit key) and a meet-in-the-middle attack. It follows the project specification and demonstrates how MITM recovers the key. The notebook also runs a small Monte Carlo experiment to estimate how many `(P,C)` pairs are typically needed to uniquely identify the key.

## 1. Cipher primitives and round functions

This cell defines the S-box, inverse S-box, nibble operations, `shift_rows`, `mix_columns`, `add_round_key`, and the round functions `encrypt_round1`, `encrypt_round2`, `decrypt_round2`, and `decrypt_round1`.

In [None]:
# --- Cipher primitives and functions ---
SBOX = {
    0x0:0x9, 0x1:0x4, 0x2:0xA, 0x3:0xB,
    0x4:0xD, 0x5:0x1, 0x6:0x8, 0x7:0x5,
    0x8:0x6, 0x9:0x2, 0xA:0x0, 0xB:0x3,
    0xC:0xC, 0xD:0xE, 0xE:0xF, 0xF:0x7
}
INV_SBOX = {v:k for k,v in SBOX.items()}

def nibble_get(state, idx):
    shift = (3-idx)*4
    return (state >> shift) & 0xF

def nibble_set(state, idx, val):
    shift = (3-idx)*4
    mask = ~(0xF << shift) & 0xFFFF
    return (state & mask) | ((val & 0xF) << shift)

def sub_nibbles(state):
    out = 0
    for i in range(4):
        v = nibble_get(state, i)
        out = nibble_set(out, i, SBOX[v])
    return out

def inv_sub_nibbles(state):
    out = 0
    for i in range(4):
        v = nibble_get(state, i)
        out = nibble_set(out, i, INV_SBOX[v])
    return out

def shift_rows(state):
    s0 = nibble_get(state,0)
    s1 = nibble_get(state,1)
    s2 = nibble_get(state,2)
    s3 = nibble_get(state,3)
    new = 0
    new = nibble_set(new, 0, s0)
    new = nibble_set(new, 1, s1)
    new = nibble_set(new, 2, s3)
    new = nibble_set(new, 3, s2)
    return new

def inv_shift_rows(state):
    s0 = nibble_get(state,0)
    s1 = nibble_get(state,1)
    s2 = nibble_get(state,2)
    s3 = nibble_get(state,3)
    new = 0
    new = nibble_set(new, 0, s0)
    new = nibble_set(new, 1, s1)
    new = nibble_set(new, 2, s3)
    new = nibble_set(new, 3, s2)
    return new

def mix_columns(state):
    a0 = nibble_get(state,0)
    a1 = nibble_get(state,2)
    b0 = nibble_get(state,1)
    b1 = nibble_get(state,3)
    col0_top = a0 ^ a1
    col0_bot = a0
    col1_top = b0 ^ b1
    col1_bot = b0
    new = 0
    new = nibble_set(new, 0, col0_top)
    new = nibble_set(new, 2, col0_bot)
    new = nibble_set(new, 1, col1_top)
    new = nibble_set(new, 3, col1_bot)
    return new

def inv_mix_columns(state):
    x0 = nibble_get(state,0)
    y0 = nibble_get(state,2)
    x1 = nibble_get(state,1)
    y1 = nibble_get(state,3)
    a0 = y0
    a1 = x0 ^ y0
    b0 = y1
    b1 = x1 ^ y1
    new = 0
    new = nibble_set(new, 0, a0)
    new = nibble_set(new, 2, a1)
    new = nibble_set(new, 1, b0)
    new = nibble_set(new, 3, b1)
    return new

def add_round_key(state, round_key):
    return state ^ (round_key & 0xFFFF)

def key_expansion(key16):
    k1_8 = (key16 >> 8) & 0xFF
    k2_8 = key16 & 0xFF
    k1 = (k1_8 << 8) | k1_8
    k2 = (k2_8 << 8) | k2_8
    return k1, k2, k1_8, k2_8

def encrypt_round1(state, round_key):
    state = sub_nibbles(state)
    state = shift_rows(state)
    state = mix_columns(state)
    state = add_round_key(state, round_key)
    return state

def encrypt_round2(state, round_key):
    state = sub_nibbles(state)
    state = shift_rows(state)
    state = add_round_key(state, round_key)
    return state

def decrypt_round2(state, round_key):
    state = add_round_key(state, round_key)
    state = inv_shift_rows(state)
    state = inv_sub_nibbles(state)
    return state

def decrypt_round1(state, round_key):
    state = add_round_key(state, round_key)
    state = inv_mix_columns(state)
    state = inv_shift_rows(state)
    state = inv_sub_nibbles(state)
    return state

def encrypt(plaintext16, key16):
    k1, k2, _, _ = key_expansion(key16)
    state = plaintext16 & 0xFFFF
    state = encrypt_round1(state, k1)
    state = encrypt_round2(state, k2)
    return state & 0xFFFF

def decrypt(ciphertext16, key16):
    k1, k2, _, _ = key_expansion(key16)
    state = ciphertext16 & 0xFFFF
    state = decrypt_round2(state, k2)
    state = decrypt_round1(state, k1)
    return state & 0xFFFF

print('Cipher primitives loaded.')

## 2. Generate (P,C) pairs

Use a secret 16-bit key, then produce 10 plaintext-ciphertext pairs for demonstration.

In [None]:
import random, pandas as pd
random.seed(42)
true_key = random.randint(0, 0xFFFF)
pairs = []
for _ in range(10):
    p = random.randint(0, 0xFFFF)
    c = encrypt(p, true_key)
    pairs.append((p, c))

df = pd.DataFrame([{'plaintext': f'0x{p:04X}', 'ciphertext': f'0x{c:04X}'} for p,c in pairs])
print(f"Secret key (example): 0x{true_key:04X}")
df

## 3. MITM attack implementation

This cell implements the MITM attack (build map from `encrypt_round1(P,K1)` and `decrypt_round2(C,K2)`, match, then prune using more pairs).

In [None]:
from collections import defaultdict

def mitm_recover_key(pairs, verbose=False):
    subkey_space = range(256)
    P0, C0 = pairs[0]
    X_to_k1 = defaultdict(list)
    for k1_8 in subkey_space:
        k1 = (k1_8 << 8) | k1_8
        X_to_k1[encrypt_round1(P0, k1)].append(k1_8)
    Xp_to_k2 = defaultdict(list)
    for k2_8 in subkey_space:
        k2 = (k2_8 << 8) | k2_8
        Xp_to_k2[decrypt_round2(C0, k2)].append(k2_8)
    candidate_pairs = set()
    for X, k1_list in X_to_k1.items():
        if X in Xp_to_k2:
            for k1 in k1_list:
                for k2 in Xp_to_k2[X]:
                    candidate_pairs.add((k1, k2))
    pruning = [len(candidate_pairs)]
    if verbose:
        print(f"Candidates after first pair: {pruning[-1]}")
    for i in range(1, len(pairs)):
        P, C = pairs[i]
        newset = set()
        for (k1_8, k2_8) in candidate_pairs:
            k1 = (k1_8 << 8) | k1_8
            k2 = (k2_8 << 8) | k2_8
            if encrypt_round1(P, k1) == decrypt_round2(C, k2):
                newset.add((k1_8, k2_8))
        candidate_pairs = newset
        pruning.append(len(candidate_pairs))
        if verbose:
            print(f"After pair {i+1}: {len(candidate_pairs)}")
        if len(candidate_pairs) <= 1:
            break
    verified = set()
    for (k1_8, k2_8) in candidate_pairs:
        key16 = (k1_8 << 8) | k2_8
        if all(encrypt(P, key16) == C for P,C in pairs):
            verified.add((k1_8, k2_8))
    return verified, pruning

# Run on pairs from previous cell (if present)
try:
    recovered, pruning = mitm_recover_key(pairs, verbose=True)
    print('Recovered:', recovered)
    print('Pruning history:', pruning)
except NameError:
    print('Run the generation cell above to produce `pairs`.')

## 4. Monte Carlo: How many (P,C) pairs are needed on average?

We'll repeat the experiment for many random keys and measure how many plaintext/ciphertext pairs are required (from the generated list of random plaintexts) until the MITM attack yields a unique candidate. We'll also measure runtime.

In [None]:
import time, statistics, random, pandas as pd

def pairs_needed_for_key(true_key, max_pairs=16):
    pool = [random.randint(0, 0xFFFF) for _ in range(max_pairs)]
    pcs = [(p, encrypt(p, true_key)) for p in pool]
    for n in range(1, max_pairs+1):
        subset = pcs[:n]
        recovered, pruning = mitm_recover_key(subset)
        if len(recovered) == 1:
            return n, pruning
    return None, pruning

# Run Monte Carlo
random.seed(0)
trials = 100
max_pairs = 16
results = []
times = []
for t in range(trials):
    k = random.randint(0, 0xFFFF)
    start = time.time()
    need, prune = pairs_needed_for_key(k, max_pairs=max_pairs)
    elapsed = time.time() - start
    results.append(need if need is not None else max_pairs+1)
    times.append(elapsed)

results_series = pd.Series(results)
print(f"Trials: {trials}, mean pairs needed: {results_series.mean():.2f}, median: {results_series.median():.2f}")
print(f"Fraction solved within {max_pairs} pairs: {(results_series <= max_pairs).mean():.2%}")

summary_df = pd.DataFrame({
    'trial': list(range(1, len(results)+1)),
    'pairs_needed': results,
    'time_s': times
})
summary_df.head(20)

### 4.1 Histogram of pairs needed

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(6,4))
plt.hist(results, bins=range(1, max_pairs+3), align='left')
plt.xlabel('Number of (P,C) pairs needed')
plt.ylabel('Frequency')
plt.title('Histogram: pairs needed to uniquely recover key (Monte Carlo)')
plt.grid(True)
plt.show()

### 4.2 Timing vs pairs needed

In [None]:
plt.figure(figsize=(6,4))
plt.scatter(results, times)
plt.xlabel('Pairs needed to recover key')
plt.ylabel('Time (s) for MITM run on that key')
plt.title('Runtime vs pairs needed')
plt.grid(True)
plt.show()

## 5. Save and download the notebook

The notebook file is saved alongside this environment. You can download it using the link provided after running this cell.