In [52]:
##################Fault Injection Simulation at the beginning of (r-2)th round of QARMAv2-64 ##################


import numpy as np

# ---- Fault propagation patterns ----
FAULT_PATTERNS = {
    0:  [4,6,7,8,9,11,12,13,14],
    1:  [0,1,3,4,5,6,12,14,15],
    2:  [0,1,2,4,5,7,8,10,11],
    3:  [0,2,3,8,9,10,12,13,15],
    4:  [4,5,7,8,10,11,13,14,15],
    5:  [0,2,3,5,6,7,12,13,15],
    6:  [1,2,3,4,6,7,8,9,11],
    7:  [0,1,3,9,10,11,12,14,15],
    8:  [5,6,7,8,9,10,12,13,15],
    9:  [0,1,2,4,5,7,13,14,15],
    10: [0,1,3,4,5,6,9,10,11],
    11: [1,2,3,8,9,11,12,13,14],
    12: [4,5,6,9,10,11,12,14,15],
    13: [1,2,3,4,6,7,12,13,14],
    14: [0,2,3,5,6,7,8,9,10],
    15: [0,1,2,8,10,11,13,14,15],
}

# ---- Substitution, permutation, and XOR ----

SBOX = [4, 7, 9, 11, 12, 6, 14, 15, 0, 5, 1, 13, 8, 3, 2, 10]
SBOX_INV = [8, 10, 14, 13, 0, 9, 5, 1, 12, 2, 15, 3, 4, 11, 6, 7]
TAU = [0,11,6,13,10,1,12,7,5,14,3,8,15,4,9,2]
TAU_INV = [0,5,15,10,13,8,2,7,11,14,4,1,6,3,9,12]

def sub_bytes(state):
    return [SBOX[n] for n in state]

def sub_bytes_inv(state):
    return [SBOX_INV[n] for n in state]

def shuffle_tau(state):
    temp = [0] * 16
    for i, t in enumerate(TAU):
        temp[i] = state[t]
    return temp

def shuffle_tau_inv(state):
    temp = [0] * 16
    for i, t in enumerate(TAU_INV):
        temp[i] = state[t]
    return temp

def xor_layer(state, layer):
    return [s ^ l for s, l in zip(state, layer)]

def inject_fault(state, pos, fault_val=0x1):
    faulty = state.copy()
    faulty[pos] ^= fault_val
    return faulty

def get_active_nibbles(diff):
    return [i for i, v in enumerate(diff) if v != 0]

# ---- Bit-level MixColumn using row-major layout ----

def hex_to_binvec(nibble):
    return [int(b) for b in format(nibble, '04b')]

def binvec_to_hex(vec):
    return int(''.join(map(str, vec)), 2)

def rho_vec(vec):
    return vec[1:] + vec[:1]

def rho_pow(vec, power):
    for _ in range(power % 4):
        vec = rho_vec(vec)
    return vec

def xor_bits(a, b):
    return [x ^ y for x, y in zip(a, b)]

# Define the diffusion matrix M
diffusion_matrix = [
    [lambda v: [0]*4,        lambda v: rho_pow(v, 1), lambda v: rho_pow(v, 2), lambda v: rho_pow(v, 3)],
    [lambda v: rho_pow(v, 3), lambda v: [0]*4,        lambda v: rho_pow(v, 1), lambda v: rho_pow(v, 2)],
    [lambda v: rho_pow(v, 2), lambda v: rho_pow(v, 3), lambda v: [0]*4,        lambda v: rho_pow(v, 1)],
    [lambda v: rho_pow(v, 1), lambda v: rho_pow(v, 2), lambda v: rho_pow(v, 3), lambda v: [0]*4]
]

def mix_columns(state):
    # Convert 1D state list into 4x4 matrix of 4-bit vectors (row-major)
    state_matrix = [[hex_to_binvec(state[4 * row + col]) for col in range(4)] for row in range(4)]

    mixed_matrix = [[None for _ in range(4)] for _ in range(4)]
    for i in range(4):  # row
        for j in range(4):  # col
            acc = [0] * 4
            for k in range(4):
                acc = xor_bits(acc, diffusion_matrix[i][k](state_matrix[k][j]))
            mixed_matrix[i][j] = acc

    # Flatten back to 1D list (row-major)
    mixed_state = [binvec_to_hex(mixed_matrix[row][col]) for row in range(4) for col in range(4)]
    return mixed_state

# ---- Encryption function ----

def full_encrypt(state, L1, L0):
    
    state = sub_bytes_inv(state)
    state = mix_columns(state)
    state = shuffle_tau_inv(state)
    state = xor_layer(state, L0)
    
    state = sub_bytes_inv(state)
    state = mix_columns(state)
    state = shuffle_tau_inv(state)
    state = xor_layer(state, L1)

    state = sub_bytes_inv(state)
    state = mix_columns(state)
    state = shuffle_tau_inv(state)
    state = xor_layer(state, L0)

    state = sub_bytes_inv(state)
    state = xor_layer(state, L1)
    return state

def partial_decrypt(state, L1):
    
    state = xor_layer(state, L1)
    state = sub_bytes(state)
    return state
    
def backward_diff_prop(state):

    state = shuffle_tau(state)
    state = mix_columns(state)
    return state

# ---- DFA Matching Engine ----

def find_matching_fault(target_fault_pos, max_trials= 50):
    for trial in range(max_trials):
        plaintext = [np.random.randint(0, 16) for _ in range(16)]
        L1 = [np.random.randint(0, 16) for _ in range(16)]
        L0 = [np.random.randint(0, 16) for _ in range(16)]

        correct_ct = full_encrypt(plaintext.copy(), L1, L0)

        fault_pos = np.random.randint(0, 16)
        faulty_plaintext = inject_fault(plaintext, fault_pos)
        faulty_ct = full_encrypt(faulty_plaintext.copy(), L1, L0)
        
        correct_ct1 = partial_decrypt(correct_ct.copy(), L1)
        faulty_ct1 = partial_decrypt(faulty_ct.copy(), L1)
        
        diff1 = [c ^ f for c, f in zip(correct_ct1, faulty_ct1)]
        
        diff = backward_diff_prop(diff1)

        
        active = get_active_nibbles(diff)
        expected = FAULT_PATTERNS[target_fault_pos]

        if sorted(active) == sorted(expected):
            print("\n✅ Matching pattern found!")
            print(f"→ Fault was injected at position : {fault_pos}")
            print(f"→ Target fault position matched  : {target_fault_pos}")
            print(f"→ Active nibbles after 2 rounds  : {active}")
            print(f"→ Expected pattern               : {expected}")
            print(f"→ Original Ciphertext            : {correct_ct}")
            print(f"→ Faulty Ciphertext              : {faulty_ct}")
            return {
                "fault_pos": fault_pos,
                "plaintext": plaintext,
                "correct": correct_ct,
                "faulty": faulty_ct,
                "difference": diff,
                "active": active
            }

    print("\n❌ No match found in trials.")
    return None

# ---- Run the simulation ----

if __name__ == "__main__":
    target_position = 11
    result = find_matching_fault(target_position)



✅ Matching pattern found!
→ Fault was injected at position : 11
→ Target fault position matched  : 11
→ Active nibbles after 2 rounds  : [1, 2, 3, 8, 9, 11, 12, 13, 14]
→ Expected pattern               : [1, 2, 3, 8, 9, 11, 12, 13, 14]
→ Original Ciphertext            : [9, 7, 15, 9, 5, 3, 4, 15, 8, 7, 7, 11, 0, 5, 3, 11]
→ Faulty Ciphertext              : [1, 13, 1, 9, 8, 0, 0, 6, 7, 0, 11, 4, 0, 14, 10, 7]
