In [2]:
# SageMath code for a 4–round AES (3 full + 1 final round) and its integral attack

# AES S-box and Inverse S-box (standard definitions)
Sbox = [
 0x63,0x7c,0x77,0x7b,0xf2,0x6b,0x6f,0xc5,0x30,0x01,0x67,0x2b,0xfe,0xd7,0xab,0x76,
 0xca,0x82,0xc9,0x7d,0xfa,0x59,0x47,0xf0,0xad,0xd4,0xa2,0xaf,0x9c,0xa4,0x72,0xc0,
 0xb7,0xfd,0x93,0x26,0x36,0x3f,0xf7,0xcc,0x34,0xa5,0xe5,0xf1,0x71,0xd8,0x31,0x15,
 0x04,0xc7,0x23,0xc3,0x18,0x96,0x05,0x9a,0x07,0x12,0x80,0xe2,0xeb,0x27,0xb2,0x75,
 0x09,0x83,0x2c,0x1a,0x1b,0x6e,0x5a,0xa0,0x52,0x3b,0xd6,0xb3,0x29,0xe3,0x2f,0x84,
 0x53,0xd1,0x00,0xed,0x20,0xfc,0xb1,0x5b,0x6a,0xcb,0xbe,0x39,0x4a,0x4c,0x58,0xcf,
 0xd0,0xef,0xaa,0xfb,0x43,0x4d,0x33,0x85,0x45,0xf9,0x02,0x7f,0x50,0x3c,0x9f,0xa8,
 0x51,0xa3,0x40,0x8f,0x92,0x9d,0x38,0xf5,0xbc,0xb6,0xda,0x21,0x10,0xff,0xf3,0xd2,
 0xcd,0x0c,0x13,0xec,0x5f,0x97,0x44,0x17,0xc4,0xa7,0x7e,0x3d,0x64,0x5d,0x19,0x73,
 0x60,0x81,0x4f,0xdc,0x22,0x2a,0x90,0x88,0x46,0xee,0xb8,0x14,0xde,0x5e,0x0b,0xdb,
 0xe0,0x32,0x3a,0x0a,0x49,0x06,0x24,0x5c,0xc2,0xd3,0xac,0x62,0x91,0x95,0xe4,0x79,
 0xe7,0xc8,0x37,0x6d,0x8d,0xd5,0x4e,0xa9,0x6c,0x56,0xf4,0xea,0x65,0x7a,0xae,0x08,
 0xba,0x78,0x25,0x2e,0x1c,0xa6,0xb4,0xc6,0xe8,0xdd,0x74,0x1f,0x4b,0xbd,0x8b,0x8a,
 0x70,0x3e,0xb5,0x66,0x48,0x03,0xf6,0x0e,0x61,0x35,0x57,0xb9,0x86,0xc1,0x1d,0x9e,
 0xe1,0xf8,0x98,0x11,0x69,0xd9,0x8e,0x94,0x9b,0x1e,0x87,0xe9,0xce,0x55,0x28,0xdf,
 0x8c,0xa1,0x89,0x0d,0xbf,0xe6,0x42,0x68,0x41,0x99,0x2d,0x0f,0xb0,0x54,0xbb,0x16
]

InvSbox = [
 0x52,0x09,0x6a,0xd5,0x30,0x36,0xa5,0x38,0xbf,0x40,0xa3,0x9e,0x81,0xf3,0xd7,0xfb,
 0x7c,0xe3,0x39,0x82,0x9b,0x2f,0xff,0x87,0x34,0x8e,0x43,0x44,0xc4,0xde,0xe9,0xcb,
 0x54,0x7b,0x94,0x32,0xa6,0xc2,0x23,0x3d,0xee,0x4c,0x95,0x0b,0x42,0xfa,0xc3,0x4e,
 0x08,0x2e,0xa1,0x66,0x28,0xd9,0x24,0xb2,0x76,0x5b,0xa2,0x49,0x6d,0x8b,0xd1,0x25,
 0x72,0xf8,0xf6,0x64,0x86,0x68,0x98,0x16,0xd4,0xa4,0x5c,0xcc,0x5d,0x65,0xb6,0x92,
 0x6c,0x70,0x48,0x50,0xfd,0xed,0xb9,0xda,0x5e,0x15,0x46,0x57,0xa7,0x8d,0x9d,0x84,
 0x90,0xd8,0xab,0x00,0x8c,0xbc,0xd3,0x0a,0xf7,0xe4,0x58,0x05,0xb8,0xb3,0x45,0x06,
 0xd0,0x2c,0x1e,0x8f,0xca,0x3f,0x0f,0x02,0xc1,0xaf,0xbd,0x03,0x01,0x13,0x8a,0x6b,
 0x3a,0x91,0x11,0x41,0x4f,0x67,0xdc,0xea,0x97,0xf2,0xcf,0xce,0xf0,0xb4,0xe6,0x73,
 0x96,0xac,0x74,0x22,0xe7,0xad,0x35,0x85,0xe2,0xf9,0x37,0xe8,0x1c,0x75,0xdf,0x6e,
 0x47,0xf1,0x1a,0x71,0x1d,0x29,0xc5,0x89,0x6f,0xb7,0x62,0x0e,0xaa,0x18,0xbe,0x1b,
 0xfc,0x56,0x3e,0x4b,0xc6,0xd2,0x79,0x20,0x9a,0xdb,0xc0,0xfe,0x78,0xcd,0x5a,0xf4,
 0x1f,0xdd,0xa8,0x33,0x88,0x07,0xc7,0x31,0xb1,0x12,0x10,0x59,0x27,0x80,0xec,0x5f,
 0x60,0x51,0x7f,0xa9,0x19,0xb5,0x4a,0x0d,0x2d,0xe5,0x7a,0x9f,0x93,0xc9,0x9c,0xef,
 0xa0,0xe0,0x3b,0x4d,0xae,0x2a,0xf5,0xb0,0xc8,0xeb,0xbb,0x3c,0x83,0x53,0x99,0x61,
 0x17,0x2b,0x04,0x7e,0xba,0x77,0xd6,0x26,0xe1,0x69,0x14,0x63,0x55,0x21,0x0c,0x7d
]

# --- GF(2^8) multiplication functions (with irreducible polynomial 0x11b) ---
def xtime(a):
    a <<= 1
    if a & 0x100:
        a ^= 0x11b
    return a & 0xFF

def gf_mul(a, b):
    result = 0
    for i in range(8):
        if b & 1:
            result ^= a
        a = xtime(a)
        b >>= 1
    return result

# --- AES round functions ---

def sub_bytes(state):
    """Apply Sbox substitution to each byte of the 4x4 state matrix."""
    return [[Sbox[byte] for byte in row] for row in state]

def inv_sub_bytes(state):
    """Apply inverse Sbox substitution to each byte of the 4x4 state matrix."""
    return [[InvSbox[byte] for byte in row] for row in state]

def shift_rows(state):
    """Shift rows: row i is cyclically shifted left by i."""
    new_state = []
    for i, row in enumerate(state):
        new_state.append(row[i:] + row[:i])
    return new_state

def inv_shift_rows(state):
    """Inverse shift rows: row i is cyclically shifted right by i."""
    new_state = []
    for i, row in enumerate(state):
        new_state.append(row[-i:] + row[:-i])
    return new_state

def mix_single_column(col):
    """Mix one column (list of 4 bytes) using AES mixcolumns matrix."""
    # The matrix is:
    # [2 3 1 1]
    # [1 2 3 1]
    # [1 1 2 3]
    # [3 1 1 2]
    a0, a1, a2, a3 = col
    return [ gf_mul(a0,2) ^ gf_mul(a1,3) ^ gf_mul(a2,1) ^ gf_mul(a3,1),
             gf_mul(a0,1) ^ gf_mul(a1,2) ^ gf_mul(a2,3) ^ gf_mul(a3,1),
             gf_mul(a0,1) ^ gf_mul(a1,1) ^ gf_mul(a2,2) ^ gf_mul(a3,3),
             gf_mul(a0,3) ^ gf_mul(a1,1) ^ gf_mul(a2,1) ^ gf_mul(a3,2) ]

def mix_columns(state):
    """Apply MixColumns on state (a 4x4 matrix where state is arranged row-wise)."""
    # First, reorganize state columns: state[r][c] gives the byte at row r, column c.
    new_state = [ [0]*4 for _ in range(4) ]
    for c in range(4):
        col = [state[r][c] for r in range(4)]
        mixed = mix_single_column(col)
        for r in range(4):
            new_state[r][c] = mixed[r]
    return new_state

def add_round_key(state, round_key):
    """XOR state with round key (both are 4x4 matrices)."""
    return [[ state[r][c] ^ round_key[r][c] for c in range(4) ] for r in range(4)]

# --- Key schedule for 128-bit AES but for only 4 rounds (we need 5 round keys) ---
def key_expansion(key):
    """
    Expand a 16-byte key into round keys.
    For 4 rounds, we need 5 round keys (each 4x4 matrix).
    The key and round keys are represented as lists of 4 lists (rows) with 4 bytes each.
    """
    # Split key into 4 words (each word is a list of 4 bytes).
    w = [ list(key[4*i:4*i+4]) for i in range(4) ]
    # rcon for AES (we need 4 rounds): note rcon[0]=0x01, rcon[1]=0x02, rcon[2]=0x04, rcon[3]=0x08.
    rcon = [0x01, 0x02, 0x04, 0x08]
    # We need a total of 4*(4+1)=20 words.
    for i in range(4, 20):
        temp = w[i-1].copy()
        if i % 4 == 0:
            # Rotate word (left shift)
            temp = temp[1:] + temp[:1]
            # Apply S-box to each byte
            temp = [ Sbox[b] for b in temp ]
            # XOR the first byte with rcon value (index i/4 - 1)
            temp[0] ^= rcon[(i//4) - 1]
        # XOR with word 4 positions earlier
        new_word = [ w[i-4][j] ^ temp[j] for j in range(4) ]
        w.append(new_word)
    # Now group every 4 words into a round key (as a 4x4 matrix in row order)
    round_keys = []
    for i in range(5):  # rounds 0 to 4
        # Each round key is the concatenation of words 4*i to 4*i+3.
        round_key_words = w[4*i:4*i+4]
        # The round key is arranged as a 4x4 matrix in row-major order.
        # Each word is one row.
        round_keys.append(round_key_words)
    return round_keys

# --- AES encryption (4 rounds: 3 full rounds + final round without MixColumns) ---
def aes_encrypt(plaintext, round_keys):
    """
    plaintext: a 4x4 matrix (list of 4 lists of 4 bytes each)
    round_keys: list of 5 round keys (each a 4x4 matrix)
    """
    state = add_round_key(plaintext, round_keys[0])
    # Rounds 1, 2, 3 (full rounds)
    for rnd in range(1, len(round_keys)-1):
        state = sub_bytes(state)
        state = shift_rows(state)
        state = mix_columns(state)
        state = add_round_key(state, round_keys[rnd])
    # Final round (round 4): no MixColumns.
    state = sub_bytes(state)
    state = shift_rows(state)
    state = add_round_key(state, round_keys[-1])
    return state

# --- Helper functions for converting between a 16-byte list and a 4x4 state matrix ---
def bytes_to_state(b):
    """
    Convert a list of 16 bytes to a 4x4 state matrix.
    Here we use row-major order: each of the 4 rows gets 4 bytes.
    """
    return [ b[4*i:4*i+4] for i in range(4) ]

def state_to_bytes(state):
    """
    Convert a 4x4 state matrix (list of rows) into a 16-byte list.
    """
    return sum(state, [])

# --- Setup: fixed key, round key expansion, and encryption of chosen plaintexts ---
# Use a fixed 16-byte key; for example, the sequential bytes 0x00,0x01,...,0x0f.
key = [random.randint(0,256) for i in range(16)]
round_keys = key_expansion(key)

# For testing, let’s encrypt a sample plaintext.
# (You can uncomment these lines to test a single encryption.)
# sample_plaintext = bytes_to_state([0x00]*16)
# sample_ciphertext = aes_encrypt(sample_plaintext, round_keys)
# print("Sample ciphertext:", state_to_bytes(sample_ciphertext))

# --- Integral (Square) attack on the (0,0) byte of the final round key ---
# We construct a set of 256 plaintexts where only one byte (position (0,0)) varies,
# and all other bytes are kept fixed (here, chosen as 0).

plaintexts = []
constant = random.randint(0,256)
for a in range(256):
    pt = [constant]*16  # 16-byte plaintext initialized to 0
    pt[0] = a    # vary the first byte (position 0); note our mapping: byte 0 is row0, col0
    plaintexts.append( bytes_to_state(pt) )
# Encrypt all plaintexts
ciphertexts = [ aes_encrypt(pt, round_keys) for pt in plaintexts ]

# Now, the final round last AddRoundKey affects the state byte as:
#   C = OutputByte = f(...)[r][c] XOR key[r][c]  (for our chosen position (0,0))
#
# In the final round, before AddRoundKey the operations are SubBytes and ShiftRows.
# The integral property tells us that if we take the inverse SBox of (C XOR candidate_key_byte)
# and sum (XOR over GF(2^8)) over all ciphertexts in our set, then for the correct candidate
# this sum will be 0.
#
# We now try all 256 possible values for the key byte at position (0,0) in the final round key.
attack_candidates = []
for candidate in range(256):
    acc = 0
    for ct in ciphertexts:
        # ct is a 4x4 state; the targeted byte is at row 0, col 0.
        # Remove candidate key byte: candidate_guess XOR ciphertext byte:
        val = ct[0][0] ^ candidate
        # Apply inverse Sbox:
        inv_val = InvSbox[val]
        # XOR accumulate (note: XOR is the addition in GF(2^8))
        acc ^= inv_val
    if acc == 0:
        attack_candidates.append(candidate)

print("Possible candidates for final round key byte at position (0,0):")
print(attack_candidates)
print("Actual final round key byte (position (0,0)): %d" % round_keys[4][0][0])


Possible candidates for final round key byte at position (0,0):
[78, 111]
Actual final round key byte (position (0,0)): 78


In [3]:
from functools import reduce

plaintexts = []
constant = random.randint(0,256)
for a in range(256):
    pt = [constant]*16  # 16-byte plaintext initialized to 0
    pt[0] = a    # vary the first byte (position 0); note our mapping: byte 0 is row0, col0
    plaintexts.append( bytes_to_state(pt) )

# Encrypt all plaintexts
ciphertexts = [ aes_encrypt(pt, round_keys[:]) for pt in plaintexts ]

# start acc = 0, then for each ct do acc ^ ct[0][0]
total = reduce(lambda acc, ct: acc ^ ct[0][0],
               ciphertexts,
               0)
print("XOR‐sum of all ct[0][0] bytes is", total)

XOR‐sum of all ct[0][0] bytes is 207


In [6]:
import random
from collections import defaultdict

def integral_attack_stats(num_keys=10000):
    """
    Runs `num_keys` random‐key experiments of the 4‐round AES integral attack
    on state‐byte (0,0).  Returns a dict: { number_of_candidates: frequency }.
    Warns if, in any trial, the true key‐byte is *not* among the candidates.
    """
    dist = defaultdict(int)

    
    # 0) random base plaintext (15 fixed random bytes; byte #0 will vary)
    base = [ random.randrange(256) for _ in range(16) ]
        
        
    for trial in range(num_keys):
        # 1) random key & expand
        key = [ random.randrange(256) for _ in range(16) ]
        round_keys = key_expansion(key)


        # 3) build the 256–set of plaintexts
        plaintexts = []
        for a in range(256):
            pt = list(base)
            pt[0] = a
            plaintexts.append( bytes_to_state(pt) )

        # 4) encrypt them all
        ciphertexts = [ aes_encrypt(pt, round_keys) for pt in plaintexts ]

        # 5) attack on byte (0,0)
        candidates = []
        
        key_list = (list(range(256))).remove(round_keys[-1][0][0])
        for guess in range(256):
            acc = 0
            for ct in ciphertexts:
                acc ^= InvSbox[ ct[0][0] ^ guess ]
            if acc == 0:
                candidates.append(guess)

        kcount = len(candidates)
        dist[kcount] += 1

        # 6) verify correct byte is in candidates
        true_byte = round_keys[-1][0][0]
        if true_byte not in candidates:
            print(f"⚠️  Trial {trial}: true key‐byte 0x{true_byte:02x} missing!")

    return dict(dist)


trials_number = 1000
# Run the experiment with random keys
dist = integral_attack_stats(trials_number)

# Display the resulting histogram
print(f'# candidates → frequency over {trials_number} trials')
for k in sorted(dist):
    print(f"{k:3d} → {dist[k]:5d}")


KeyboardInterrupt: 

In [43]:
import random
import collections

def approximate_distribution(num_probes=256, extra_true_prob=1/256, runs=100_000):
    """
    Simulate `runs` experiments. In each:
      - One probe is always True.
      - The remaining num_probes-1 probes are independently True with probability extra_true_prob.
    Returns a dict mapping total True‐counts → empirical probability.
    """
    counts = []
    for _ in range(runs):
        # guaranteed one True, plus binomially many from the others
        true_count = 1 + sum(random.random() < extra_true_prob for _ in range(num_probes - 1))
        counts.append(true_count)

    freq = collections.Counter(counts)
    # convert to probabilities
    return {k: v / runs for k, v in sorted(freq.items())}

if __name__ == "__main__":
    dist = approximate_distribution()
    print("Total True\tProbability")
    for total_true, prob in dist.items():
        print(f"{total_true:>10}\t{prob:.4f}")


Total True	Probability
         1	0.3652
         2	0.3708
         3	0.1857
         4	0.0600
         5	0.0148
         6	0.0029
         7	0.0004
         8	0.0001
         9	0.0000


In [None]:
import numpy as np
from scipy.stats import chisquare
import matplotlib.pyplot as plt

# 1) Run or reload your two experiments
num_trials = 1000
obs_dist = integral_attack_stats(num_trials)            # {k: freq} from your 1 000 integral-attack trials
theo_probs = approximate_distribution(runs=num_trials)  # {k: probability} from 1 + Binomial(255,1/256)

# 2) Align support and build arrays
ks = sorted(set(obs_dist) | set(theo_probs))
obs = np.array([obs_dist.get(k, 0) for k in ks])
exp = np.array([theo_probs.get(k, 0) * num_trials for k in ks])

# 3) (Optional) merge any tail bins with exp<5 for chi-square validity:
mask = exp >= 5
if not mask.all():
    # find first small-exp index
    i0 = np.where(~mask)[0][0]
    obs = np.concatenate([obs[:i0], [obs[i0:].sum()]])
    exp = np.concatenate([exp[:i0], [exp[i0:].sum()]])
    ks  = ks[:i0] + [f'>={ks[i0]}']

# 4) Chi-square test
chi2, pval = chisquare(f_obs=obs, f_exp=exp)
print(f"χ² = {chi2:.2f}, p-value = {pval:.4f}")

# 5) Plot observed vs. theoretical
plt.bar(ks, obs/num_trials, alpha=0.6, label='Observed')
plt.plot(ks, [theo_probs.get(k,0) for k in ks], marker='o', linestyle='-', label='Theory')
plt.xlabel('Number of Candidates / “True” Count')
plt.ylabel('Probability')
plt.title('Observed vs. Shifted-Binomial Distribution')
plt.legend()
plt.tight_layout()
plt.show()
