In [3]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# ===============================
# Constants and AES Parameters
# ===============================

# Inverse AES S-Box (used for last round CPA attack)
INV_SBOX = np.array([
    0x52,0x09,0x6A,0xD5,0x30,0x36,0xA5,0x38,0xBF,0x40,0xA3,0x9E,0x81,0xF3,0xD7,0xFB,
    0x7C,0xE3,0x39,0x82,0x9B,0x2F,0xFF,0x87,0x34,0x8E,0x43,0x44,0xC4,0xDE,0xE9,0xCB,
    0x54,0x7B,0x94,0x32,0xA6,0xC2,0x23,0x3D,0xEE,0x4C,0x95,0x0B,0x42,0xFA,0xC3,0x4E,
    0x08,0x2E,0xA1,0x66,0x28,0xD9,0x24,0xB2,0x76,0x5B,0xA2,0x49,0x6D,0x8B,0xD1,0x25,
    0x72,0xF8,0xF6,0x64,0x86,0x68,0x98,0x16,0xD4,0xA4,0x5C,0xCC,0x5D,0x65,0xB6,0x92,
    0x6C,0x70,0x48,0x50,0xFD,0xED,0xB9,0xDA,0x5E,0x15,0x46,0x57,0xA7,0x8D,0x9D,0x84,
    0x90,0xD8,0xAB,0x00,0x8C,0xBC,0xD3,0x0A,0xF7,0xE4,0x58,0x05,0xB8,0xB3,0x45,0x06,
    0xD0,0x2C,0x1E,0x8F,0xCA,0x3F,0x0F,0x02,0xC1,0xAF,0xBD,0x03,0x01,0x13,0x8A,0x6B,
    0x3A,0x91,0x11,0x41,0x4F,0x67,0xDC,0xEA,0x97,0xF2,0xCF,0xCE,0xF0,0xB4,0xE6,0x73,
    0x96,0xAC,0x74,0x22,0xE7,0xAD,0x35,0x85,0xE2,0xF9,0x37,0xE8,0x1C,0x75,0xDF,0x6E,
    0x47,0xF1,0x1A,0x71,0x1D,0x29,0xC5,0x89,0x6F,0xB7,0x62,0x0E,0xAA,0x18,0xBE,0x1B,
    0xFC,0x56,0x3E,0x4B,0xC6,0xD2,0x79,0x20,0x9A,0xDB,0xC0,0xFE,0x78,0xCD,0x5A,0xF4,
    0x1F,0xDD,0xA8,0x33,0x88,0x07,0xC7,0x31,0xB1,0x12,0x10,0x59,0x27,0x80,0xEC,0x5F,
    0x60,0x51,0x7F,0xA9,0x19,0xB5,0x4A,0x0D,0x2D,0xE5,0x7A,0x9F,0x93,0xC9,0x9C,0xEF,
    0xA0,0xE0,0x3B,0x4D,0xAE,0x2A,0xF5,0xB0,0xC8,0xEB,0xBB,0x3C,0x83,0x53,0x99,0x61,
    0x17,0x2B,0x04,0x7E,0xBA,0x77,0xD6,0x26,0xE1,0x69,0x14,0x63,0x55,0x21,0x0C,0x7D
], dtype=np.uint8)

# AES RCON (used for reversing key schedule)
RCON = np.array([0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36], dtype=np.uint8)

# Inverse shift rows index mapping
INV_SHIFTROWS_MAP = [
    0, 1, 2, 3,
    7, 4, 5, 6,
   10,11, 8, 9,
   13,14,15,12
]

# ===============================
# Leakage Models
# ===============================

def hamming_weight(val):
    """Returns the Hamming weight (number of 1 bits) of a byte."""
    return bin(val).count('1')

def msb(val):
    """Returns the Most Significant Bit (bit 7) of a byte."""
    return (val >> 7) & 1

def hamming_distance(val1, val2):
    """Returns Hamming Distance between two bytes."""
    return hamming_weight(val1 ^ val2)

def leakage_model(byte_val, prev_val, model="HW"):
    """Selects and applies leakage model."""
    if model == "HW":
        return hamming_weight(byte_val)
    elif model == "MSB":
        return msb(byte_val)
    elif model == "HD":
        return hamming_distance(byte_val, prev_val)
    else:
        raise ValueError("Unsupported leakage model")

# ===============================
# CPA Attack
# ===============================

def cpa_byte(traces, ciphertexts, byte_idx, model="HW"):
    """
    Performs CPA on a single byte position using specified leakage model.
    Returns best key guess and full correlation matrix.
    """
    corrected_idx = INV_SHIFTROWS_MAP[byte_idx]
    num_traces, num_samples = traces.shape
    correlations = np.zeros((256, num_samples))

    for kguess in range(256):
        hypothetical = np.array([
            leakage_model(INV_SBOX[ciphertext[corrected_idx] ^ kguess], ciphertext[corrected_idx], model)
            for ciphertext in ciphertexts
        ])
        hypothetical = (hypothetical - np.mean(hypothetical)) / np.std(hypothetical)

        for t in range(num_samples):
            trace_point = (traces[:, t] - np.mean(traces[:, t])) / np.std(traces[:, t])
            correlations[kguess, t] = np.corrcoef(hypothetical, trace_point)[0, 1]

    best_guess = np.argmax(np.max(np.abs(correlations), axis=1))
    return best_guess, correlations

def recover_last_round_key(traces, ciphertexts, model="HW"):
    """Performs CPA for all 16 bytes of last round key."""
    key = []
    all_corrs = []

    for i in range(16):
        k, c = cpa_byte(traces, ciphertexts, i, model=model)
        key.append(k)
        all_corrs.append(c)
        print(f"[{model}] Byte {i:02}: 0x{k:02x}")

    return bytes(key), all_corrs

# ===============================
# Key Schedule Reversal
# ===============================

def reverse_key_schedule(last_round_key):
    """Reverses AES-128 key schedule to recover the original encryption key."""
    round_keys = [bytearray(last_round_key)]
    for i in range(9, -1, -1):
        prev = bytearray(16)
        for j in range(4):
            if j == 0:
                t = round_keys[0][13:] + round_keys[0][:1]
                t = [INV_SBOX[b] for b in t]
                t[0] ^= RCON[i + 1]
                for k in range(4):
                    prev[k] = round_keys[0][k] ^ t[k]
            else:
                for k in range(4):
                    prev[4 * j + k] = round_keys[0][4 * (j - 1) + k] ^ round_keys[0][4 * j + k]
        round_keys.insert(0, prev)
    return bytes(round_keys[0])

# ===============================
# Evaluation & Visualization
# ===============================

def compute_snr(traces, leakages):
    """Calculates SNR profile across traces using leakage class."""
    unique = np.unique(leakages)
    means = []
    for val in unique:
        class_traces = traces[leakages == val]
        means.append(np.mean(class_traces, axis=0))
    signal = np.var(np.array(means), axis=0)
    noise = np.mean(np.var(traces, axis=0))
    return signal / (noise + 1e-9)

def plot_snr(snr):
    """Plots SNR curve."""
    plt.figure(figsize=(10, 4))
    plt.plot(snr)
    plt.title("Signal-to-Noise Ratio (SNR)")
    plt.xlabel("Time Sample")
    plt.ylabel("SNR")
    plt.grid()
    plt.tight_layout()
    plt.show()

def plot_correlation(corr_matrix, title="CPA Correlation"):
    """Visualizes CPA correlation matrix for a key byte."""
    plt.figure(figsize=(10, 5))
    plt.imshow(np.abs(corr_matrix), cmap='hot', aspect='auto')
    plt.colorbar(label="|Correlation|")
    plt.title(title)
    plt.xlabel("Time Sample")
    plt.ylabel("Key Guess (0-255)")
    plt.tight_layout()
    plt.show()

# ===============================
# Main Execution
# ===============================

if __name__ == "__main__":
    # Load trace file
    data = np.load("data/traces_1.npz")
    traces = data["wave"]
    ciphertexts = data["dut_io_computed_data"]

    # Choose leakage model: "HW", "HD", or "MSB"
    attack_model = "MSB"

    print(f"Running CPA attack using model: {attack_model}")
    last_round_key, corrs = recover_last_round_key(traces, ciphertexts, model=attack_model)

    # Recover original AES key
    original_key = reverse_key_schedule(last_round_key)
    print("\nRecovered original key:      ", original_key.hex())

    # Compare with known correct key
    true_key = bytes.fromhex("10a58869d74be5a374cf867cfb473859")
    print("Expected correct key:        ", true_key.hex())
    print("Key match:", original_key == true_key)

    # SNR Visualization (for first byte)
    corrected_idx = INV_SHIFTROWS_MAP[0]
    leaks = np.array([leakage_model(INV_SBOX[c[corrected_idx] ^ last_round_key[0]], c[corrected_idx], attack_model)
                      for c in ciphertexts])
    snr = compute_snr(traces, leaks)
    plot_snr(snr)

    # Correlation Plot for Byte 0
    plot_correlation(corrs[0], title=f"Correlation (Byte 0, {attack_model})")


Running CPA attack using model: MSB
[MSB] Byte 00: 0x02
[MSB] Byte 01: 0xbb
[MSB] Byte 02: 0x27


KeyboardInterrupt: 