In [None]:
import numpy as np
from scipy import stats
import math
import os
import matplotlib.pyplot as plt
from itertools import chain


In [None]:
# FOLDER = "ChaCha-100-000-Random-Nonce-STM-SHIELDED"
# FOLDER = "ChaCha-100-000-Random-Nonce-STM"
FOLDER = "ChaCha-100-000-Random-Nonce-STM-2"

In [None]:
# FOLDER = "ChaCha-100-000-Random-Nonce-XMEGA"
# FOLDER = "ChaCha-100-000-Random-Nonce-XMEGA-2"

In [None]:
CPA_OUTPUT_FOLDER = "_CPA_STM_2_S1_COL"

# Load the content of files 

In [None]:
NONCE_LEN_BYTES = 12
TRACE_CNT = None
TRACE_LEN = None
TRACE_RANDOM_CNT = None
CHUNK_SIZE = None
CHUNKS_CNT = None
LAST_CHUNK_SIZE = None
def read_info(FOLDER):
    """
    Read the info file to get the number of traces, chunk size, and trace length.
    """
    global CHUNK_SIZE, LAST_CHUNK_SIZE
    global CHUNKS_CNT
    with open(f"{FOLDER}/info.txt", 'r') as file:
        global TRACE_CNT
        global TRACE_LEN
        global TRACE_RANDOM_CNT
        TRACE_CNT = int(file.readline())
        TRACE_RANDOM_CNT = TRACE_CNT
        CHUNK_SIZE = int(file.readline())
        TRACE_LEN = int(file.readline())
        
    CHUNKS_CNT = math.ceil(TRACE_CNT / CHUNK_SIZE)
    LAST_CHUNK_SIZE = TRACE_CNT - (CHUNKS_CNT - 1)* CHUNK_SIZE
    print(f"TRACE_CNT = {TRACE_CNT}")   
    print(f"CHUNK_SIZE = {CHUNK_SIZE}")   
    print(f"LAST_CHUNK_SIZE = {LAST_CHUNK_SIZE}")   
    print(f"CHUNKS_CNT = {CHUNKS_CNT}")   
    print(f"TRACE_RANDOM_CNT = {TRACE_RANDOM_CNT}")   
    print(f"TRACE_LEN = {TRACE_LEN}")
    
    
read_info(FOLDER)

In [None]:
# TRACE_CNT = 500_000
# CHUNKS_CNT = 50

In [None]:
NONCES = None
def read_nonces(FOLDER):
    """
    Read the nonces from the binary files in the specified folder.
    The nonces are stored in chunks, and the function concatenates them into a single array.
    """
    global NONCES
    global CHUNK_SIZE, LAST_CHUNK_SIZE
    global CHUNKS_CNT
    global TRACE_CNT
    # Calculate the number of chunks
    
    nonces_list = []
    for chunk_index in range(CHUNKS_CNT):
        chunk_folder = os.path.join(FOLDER, f"chunk_{chunk_index}")
        chunk_file = os.path.join(chunk_folder, "nonces_random.bin")
        
        if os.path.exists(chunk_file):
            with open(chunk_file, 'rb') as file:
                byte_array = file.read()
            
            if chunk_index != CHUNKS_CNT-1:
                chunk = np.frombuffer(byte_array, dtype=np.uint8).reshape((CHUNK_SIZE, NONCE_LEN_BYTES))
            else:
                chunk = np.frombuffer(byte_array, dtype=np.uint8).reshape((LAST_CHUNK_SIZE, NONCE_LEN_BYTES))

            nonces_list.append(chunk)
        else:
            print(f"Chunk file {chunk_file} does not exist.")
    
    # Concatenate all chunks
    NONCES = np.vstack(nonces_list)
        
    # Desired shape
    desired_shape = (TRACE_CNT, NONCE_LEN_BYTES)

    # Check if the array has the desired shape
    if NONCES.shape != desired_shape:
        # Trim the array to the desired shape
        print(f"NONCES.shape != desired_shape - {NONCES.shape} != {desired_shape}")
        NONCES = NONCES[:desired_shape[0], :desired_shape[1]]
        print(f"Trimming NONCES to {NONCES.shape}")
        
    nonce = NONCES[0]

    # Convert each byte to its hexadecimal representation and join them
    hex_representation = ''.join(f'{byte:02x}' for byte in nonce[:12])

    # Print the result
    print(f"Nonce[0] = {hex_representation}")
    
    
read_nonces(FOLDER)

In [None]:
#load the correct key
CORRECT_KEY = np.fromfile(f"{FOLDER}/key.bin", dtype=np.uint8)
hex_key = ''.join(f'{byte:02x}' for byte in CORRECT_KEY)
print(f"Correct Key: {hex_key}")


In [None]:
TRACES_RANDOM = None

def load_traces_random(folder):
    """
    Load traces from the specified folder. 
    The traces are stored in binary files (chunks, defined in info.txt file),
    and the function concatenates them into a single array.
    """
    global TRACES_RANDOM, CHUNK_SIZE, LAST_CHUNK_SIZE, CHUNKS_CNT, TRACE_CNT

    traces_list = []
    for chunk_index in range(CHUNKS_CNT):
        chunk_folder = os.path.join(folder, f"chunk_{chunk_index}")
        chunk_file = os.path.join(chunk_folder, "traces_random.bin")

        if os.path.exists(chunk_file):
            with open(chunk_file, 'rb') as file:
                byte_array = file.read()

            if chunk_index != CHUNKS_CNT-1:
                chunk_traces = np.frombuffer(byte_array, dtype=np.uint16).reshape((CHUNK_SIZE, TRACE_LEN))
            else:
                chunk_traces = np.frombuffer(byte_array, dtype=np.uint16).reshape((LAST_CHUNK_SIZE, TRACE_LEN))

            traces_list.append(chunk_traces)
        else:
            print(f"Chunk file {chunk_file} does not exist.")

    # Concatenate all chunks
    TRACES_RANDOM = np.vstack(traces_list)

load_traces_random(FOLDER)

# Adjust size of traces if needed

In [None]:
# # Crop every trace (row) to sample range 0..1499
# TRACE_CNT = 5000
# TRACES_RANDOM = TRACES_RANDOM[:TRACE_CNT, :]
# # TRACES_RANDOM = TRACES_RANDOM[:, 0:8_000]

# NONCES= NONCES[:TRACE_CNT, :]

In [None]:
print(f"TRACES shape = {TRACES_RANDOM.shape}")

# Helper functions and constants setup

In [None]:
CHACHA_KEY_LEN_BYTES = 32
VALUES_IN_BYTE = 256 #0...255

# constant sigma - 16B - "expand 32-byte k" transformed to bytes
fixed_sigma = bytearray("expand 32-byte k", "utf-8")
print(hex(fixed_sigma[4]))

# Hamming weight table for 8-bit values
# This table is used to calculate the Hamming weight of a byte value.
hamming_weight = [
   0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 
   2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 1, 2, 2, 3, 2, 3, 3, 4, 
   2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 
   5, 6, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 
   3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 
   4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 
   1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 
   4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 
   4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 2, 3, 
   3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 
   4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 
   6, 7, 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8]

def corr2_coeff(A, B):
    """
    Calculate the correlation coefficient between two matrices A and B.
    Each row of A and B is treated as a separate variable.
    The function returns a matrix of correlation coefficients.
    """
    # Rowwise mean of input arrays & subtract from input arrays themeselves
    A_mA = A - A.mean(1)[:, None]
    B_mB = B - B.mean(1)[:, None]

    # Sum of squares across rows
    ssA = (A_mA**2).sum(1)
    ssB = (B_mB**2).sum(1)
    
        # Locate rows with zero sum of squares
    zero_ssA_indices = np.where(ssA == 0)[0]
    zero_ssB_indices = np.where(ssB == 0)[0]
    
    # Avoid division by zero by setting problematic rows to NaN
    ssA[zero_ssA_indices] = np.nan
    ssB[zero_ssB_indices] = np.nan
    
    one_part = np.sqrt(np.dot(ssA[:, None],ssB[None]))

    # Finally get corr coeff
    return np.dot(A_mA, B_mB.T) / one_part

def leftRotate16(n):
    return ((n << 16) | (n >> 16)) & 0xFFFFFFFF

# Obtain the S[0], S[4], S[8], and S[12] based on already obtained K[4-15] and K[20-31]
- potential only, when the recovery of all K[4-15] and K[20-31] bytes not successful
- correct key used

In [None]:
import importlib
import _ChaCha20_source

# Reload the module to apply changes
importlib.reload(_ChaCha20_source)

# Import the updated functions
from _ChaCha20_source import QUARTERROUND

In [None]:
# Global dictionary to store the correct state words
# The words are the byte positions (0, 4, 8, 12), corresponding state words to state matrix
CORRECT_STATE_1_COL_0 = {
    0: None,
    4: None,
    8: None,
    12: None
}

# Global dictionary to store the found state words
FOUND_STATE_1_COL_0 = {
    0: None,
    4: None,
    8: None,
    12: None
}

# Global dictionary to store the calculated state words (S_1[x]) in column 1 - 3 (key bytes used to compute the columns 1 - 3)
S1_COLUMNS = None

In [None]:
def calculate_s1_column_0():
    """
    Calculate the first column of the ChaCha20 state using the correct key and fixed sigma.
    This function uses the QUARTERROUND function to perform the necessary operations.
    """
    global CORRECT_STATE_1_COL_0
    col_0 = np.empty(4, dtype=np.uint32)  # 1 column - 4 words each (32-bit)

    # Column 1: Fixed Sigma [4:7] Key[4:7] Key [20:23] Nonce [0:3]
    col_0[0] = fixed_sigma[0] | (fixed_sigma[1] << 8) | (fixed_sigma[2] << 16) | (fixed_sigma[3] << 24)
    col_0[1] = CORRECT_KEY[0]  | (CORRECT_KEY[1] << 8)  | (CORRECT_KEY[2] << 16)  | (CORRECT_KEY[3] << 24)
    col_0[2] = CORRECT_KEY[16] | (CORRECT_KEY[17] << 8) | (CORRECT_KEY[18] << 16) | (CORRECT_KEY[19] << 24)
    col_0[3] = 0x00_00_00_00

    QUARTERROUND(col_0, 0, 1, 2, 3)
    
    CORRECT_STATE_1_COL_0[0]  = col_0[0]
    CORRECT_STATE_1_COL_0[4]  = col_0[1]    
    CORRECT_STATE_1_COL_0[8]  = col_0[2]    
    CORRECT_STATE_1_COL_0[12] = col_0[3]    

In [None]:
calculate_s1_column_0()

In [None]:
# Print the correct state for each column
print("Correct State 1 Column 0:")
for i in (0, 4, 8, 12):
    if CORRECT_STATE_1_COL_0[i]:
        print(f"[{i:02}] {hex(CORRECT_STATE_1_COL_0[i])}")
    else:
        print(f"[{i:02}] None")

In [None]:
def insert_found_byte_into_s1_column_0(byte, word_index, byte_index):
    """
    Updates the FOUND_STATE_1_COL_0 array by placing the given byte
    into the correct position based on the word_index and byte_index.

    Parameters:
        byte (int): The byte value to be placed (0-255).
        word_index (int): The index of the 32-bit word (0-3).
        byte_index (int): The index of the byte within the 32-bit word (0-3).
    """
    global FOUND_STATE_1_COL_0

    # Ensure the byte is within the valid range
    if not (0 <= byte <= 255):
        raise ValueError("Byte must be in the range 0-255.")
    
    # Ensure the word_index is within the valid range
    if word_index not in (0, 4, 8, 12):
        raise ValueError("Word index must be 0, 4, 8, or 12")
    
    # Ensure the byte_index is within the valid range
    if not (0 <= byte_index <= 3):
        raise ValueError("Byte index must be in the range 0-3.")

    # Ensure FOUND_STATE_1_COL_0 is initialized
    if FOUND_STATE_1_COL_0[word_index] is None:
        FOUND_STATE_1_COL_0[word_index] = 0x00_00_00_00

    # Clear the target byte position in the 32-bit word
    FOUND_STATE_1_COL_0[word_index] &= ~(0xFF << (byte_index * 8))

    # Insert the new byte into the correct position
    FOUND_STATE_1_COL_0[word_index] |= (byte << (byte_index * 8))

In [None]:
# Print the initial state of FOUND_STATE_1_COL_0
print("Initial FOUND_STATE_1_COL_0:")
for i in (0, 4, 8, 12):
    if FOUND_STATE_1_COL_0[i]:
        print(f"[{i}] {hex(FOUND_STATE_1_COL_0[i])}")
    else:
        print(f"[{i}] None")

In [None]:
def calculate_s1_columns_1_to_3(key_bytes_array):
    """
    Calculate the S1_COLUMNS for every trace.
    S1_COLUMNS will have the shape (TRACE_CNT, 3, 4), where each row contains
    the 3 columns of the ChaCha20 state (excluding the first constant column).

    Parameters:
        key_bytes_array (np.ndarray): Array of key bytes (shape: [32]).
    """
    global S1_COLUMNS
    S1_COLUMNS = np.empty((TRACE_CNT, 3, 4), dtype=np.uint32)  # 3 columns, 4 words each (32-bit)

    # Iterate over each trace
    for trace_idx in range(TRACE_CNT):
        # Extract the nonce for the current trace
        nonce = NONCES[trace_idx]

        # Column 1: Fixed Sigma [4:7] Key[4:7] Key [20:23] Nonce [0:3]
        S1_COLUMNS[trace_idx][0] = (
            (fixed_sigma[4] | (fixed_sigma[5] << 8) | (fixed_sigma[6] << 16) | (fixed_sigma[7] << 24)),
            (key_bytes_array[4] | (key_bytes_array[5] << 8) | (key_bytes_array[6] << 16) | (key_bytes_array[7] << 24)),
            (key_bytes_array[20] | (key_bytes_array[21] << 8) | (key_bytes_array[22] << 16) | (key_bytes_array[23] << 24)),
            (nonce[0] | (nonce[1] << 8) | (nonce[2] << 16) | (nonce[3] << 24))
        )
        # Column 1: Fixed Sigma [8:11] Key[8:11] Key [24:27] Nonce [4:7]
        S1_COLUMNS[trace_idx][1] = (
            (fixed_sigma[8] | (fixed_sigma[9] << 8) | (fixed_sigma[10] << 16) | (fixed_sigma[11] << 24)),
            (key_bytes_array[8] | (key_bytes_array[9] << 8) | (key_bytes_array[10] << 16) | (key_bytes_array[11] << 24)),
            (key_bytes_array[24] | (key_bytes_array[25] << 8) | (key_bytes_array[26] << 16) | (key_bytes_array[27] << 24)),
            (nonce[4] | (nonce[5] << 8) | (nonce[6] << 16) | (nonce[7] << 24))
        )

        # Column 1: Fixed Sigma [12:15] Key[12:15] Key [28:31] Nonce [8:11]
        S1_COLUMNS[trace_idx][2] = (
            (fixed_sigma[12] | (fixed_sigma[13] << 8) | (fixed_sigma[14] << 16) | (fixed_sigma[15] << 24)),
            (key_bytes_array[12] | (key_bytes_array[13] << 8) | (key_bytes_array[14] << 16) | (key_bytes_array[15] << 24)),
            (key_bytes_array[28] | (key_bytes_array[29] << 8) | (key_bytes_array[30] << 16) | (key_bytes_array[31] << 24)),
            (nonce[8] | (nonce[9] << 8) | (nonce[10] << 16) | (nonce[11] << 24))
        )

        QUARTERROUND(S1_COLUMNS[trace_idx][0], 0, 1, 2, 3)
        QUARTERROUND(S1_COLUMNS[trace_idx][1], 0, 1, 2, 3)
        QUARTERROUND(S1_COLUMNS[trace_idx][2], 0, 1, 2, 3)


    print(f"S1_COLUMNS calculated with shape: {S1_COLUMNS.shape}")

In [None]:
calculate_s1_columns_1_to_3(CORRECT_KEY)

In [None]:
# # Assuming S1_COLUMNS is a NumPy array with shape (TRACE_CNT, 3, 4)
# for trace_idx in range(S1_COLUMNS.shape[0]):  # Iterate over traces
#     print(f"Trace {trace_idx}:")
#     for col_idx in range(S1_COLUMNS.shape[1]):  # Iterate over columns
#         column = S1_COLUMNS[trace_idx][col_idx]
#         hex_values = [hex(value) for value in column]  # Convert each value to hex
#         print(f"  Column {col_idx}: {hex_values}")

In [None]:
def calculate_row_of_intermediate_values_for_s1_0_4_12(s1_columns,
                                                       s1_index,
                                                       byte_index,
                                                       obtained_S1_col_0=None):
    """
    Calculate one row of intermediate values for a specific S1 byte (0, 4, or 12) 
    based on the S1 columns and fixed sigma.

    Parameters:
        s1_columns (np.ndarray): The S1 columns 1-3 (shape: [3, 4]). (column 1 = [0], column 2 = [1], column 3 = [2])
        s1_index (int): The index of the S1 byte (0, 4, or 12).
        byte_index (int): The byte index within the word (0-3).
        obtained_S1_col_0 (dict): A dictionary containing the obtained S1 column 0 values (0, 4, 8, 12).

    Returns:
        np.ndarray: Intermediate values for all possible S1 byte values (shape: [256]).
    """
    global hamming_weight
    
    if obtained_S1_col_0 is None:
        obtained_S1_col_0  = {0:0 , 4:0 , 8:0 , 12:0} # 4 bytes - 32 bits

    Sx_word = None
    Sy_byte = None
    Sx_byte_prev = None

    Sy_word = None
    Sy_byte = None
    Sy_byte_prev = None

    Sz_word = None
    Sz_byte = None

    # Extract the S1 columns for the specified index
    # For S1_8 separate function, since leakage function is different
    if s1_index == 0:    # Sx = S1_0 - target, Sy = S1_5,          Sz = S1_15 
        Sx_word = obtained_S1_col_0[0]
        Sy_word = s1_columns[0][1]
        Sz_word = s1_columns[2][3]
    elif s1_index == 4:  # Sx = S1_3,          Sy = S1_4 - target, Sz = S1_14 
        Sx_word = s1_columns[2][0]
        Sy_word = obtained_S1_col_0[4]
        Sz_word = s1_columns[1][3]
    elif s1_index == 12: # Sx = S1_2,          Sy = S1_6,          Sz = S1_12 - target
        Sx_word = s1_columns[0][0]
        Sy_word = s1_columns[1][1]
        Sz_word = obtained_S1_col_0[12]
    else:
        raise ValueError("s1_index must be one of 0, 4, or 12.")

    carry = 0
    # Determine if carry is needed
    if byte_index % 4 != 0:
        # Extract the previous byte from the 32-bit word
        Sx_byte_prev = Sx_word >> (8 * (byte_index - 1)) & 0xFF
        Sy_byte_prev = Sy_word >> (8 * (byte_index - 1)) & 0xFF

        sum_bytes = np.uint16(Sx_byte_prev) + np.uint16(Sy_byte_prev)
        carry = 1 if (sum_bytes) > 0xFF else 0

    # Extract the current bytes
    Sx_byte = Sx_word >> (8 * byte_index) & 0xFF
    Sy_byte = Sy_word >> (8 * byte_index) & 0xFF
    Sz_byte = Sz_word >> (8 * byte_index) & 0xFF
    
    # Calculate possible values based on the leakage function
    possible_values_vector = None

    if s1_index == 0: 
        possible_values_vector = np.array([
        Sz_byte ^ ((candidate + Sy_byte + carry) & 0xFF)
        for candidate in range(256)
    ], dtype=np.uint8)
    elif s1_index == 4:
        possible_values_vector = np.array([
        Sz_byte ^ ((Sx_byte + candidate + carry) & 0xFF)
        for candidate in range(256)
    ], dtype=np.uint8)
    elif s1_index == 12:
        possible_values_vector = np.array([
        candidate ^ ((Sx_byte + Sy_byte + carry) & 0xFF)
        for candidate in range(256)
    ], dtype=np.uint8)

    # Convert to hamming weights (or other leakage model)
    intermediate_values = np.array([hamming_weight[value] for value in possible_values_vector], dtype=np.uint16)
   
    return intermediate_values

In [None]:
def calculate_rotated_previous_intermediate_values_from_s1():
    """
    Calculate the rotated previous intermediate values from S1 columns 1, 2, and 3. (needed form carry bit calculation)
    This function uses the S1_COLUMNS to compute the intermediate values.
    """
    global S1_COLUMNS

    intermediate_values = np.empty(TRACE_CNT, dtype=np.uint32)  # Store 32-bit intermediate values

    for i in range(TRACE_CNT):
        # (((S_1[2]_i + S_1[7]_i) XOR S_1[13]_i) <<< 16)) 
        intermediate_values[i] = S1_COLUMNS[i][0][3] ^ (S1_COLUMNS[i][1][0] + S1_COLUMNS[i][2][1])
        #Rotate the intermediate value left by 16 bits
        intermediate_values[i] = leftRotate16(intermediate_values[i])

    # Convert the intermediate values to numpy array of 4 x uint8
    intermediate_values = np.array(
    [
        (
            (intermediate_values[i]         & 0xFF),
            ((intermediate_values[i] >>  8) & 0xFF),
            ((intermediate_values[i] >> 16) & 0xFF),
            ((intermediate_values[i] >> 24) & 0xFF)
        ) for i in range(TRACE_CNT)
    ], dtype=np.uint8)
    
    #Print first intermediate value in HEX after conversion (4 bytes)
    print(f"Intermediate value[0] = {hex(intermediate_values[0][0])} {hex(intermediate_values[0][1])} {hex(intermediate_values[0][2])} {hex(intermediate_values[0][3])}")
    return intermediate_values

In [None]:
def calculate_row_of_intermediate_values_for_s1_8(s1_columns, 
                                                  byte_index, 
                                                  prev_intermediate_values, 
                                                  obtained_S1_col_0 = None,
                                                  include_XOR = True):
    """
    Calculate one row of intermediate values for a specific S1 byte (16-19) based on the S1 columns and fixed sigma.    

    Parameters:
        s1_columns (np.ndarray): The S1 columns (shape: [3, 4]).
        byte_index (int): Index of the S1 byte (16-19).
        prev_intermediate_values (np.ndarray): The previous intermediate values (shape: [4]).
        obtained_S1_col_0 (dict): A dictionary containing the obtained S1 column 0 values (0, 4, 8, 12).
        include_XOR (bool): If True, include XOR operation in the calculation.
    Returns:
        np.ndarray: Intermediate values for all possible S1 byte values (shape: [256]).
    """
    global S1_COLUMNS, CORRECT_KEY, hamming_weight

    carry = 0

    if 0 <= byte_index <= 3:

        s1_8_word = obtained_S1_col_0[8]
        s1_7_word = s1_columns[2][1]
        
        # Determine if carry is needed
        if byte_index != 0:
            if s1_8_word is None:
                raise ValueError("s1_8_word must be provided for carry calculation.")
            # extract the byte from the 32-bit word - (byte_index - 1)th byte
            prev_intermediate_value_byte_prev = prev_intermediate_values[byte_index - 1]

            s1_8_byte_prev = s1_8_word >> (8 * (byte_index - 1)) & 0xFF
            sum_bytes = s1_8_byte_prev + prev_intermediate_value_byte_prev.astype(np.uint16)
            carry = 1 if (sum_bytes) > 0xFF else 0

        s1_7_byte = s1_7_word >> (8 * (byte_index)) & 0xFF
        prev_intermediate_value_byte = prev_intermediate_values[byte_index]
        # Calculate possible values based on the nonce byte and fixed sigma

        possible_values_vector = None
        if include_XOR:   
            possible_values_vector = np.array([
                s1_7_byte ^ ((s1_byte_candidate + prev_intermediate_value_byte + carry) & 0xFF)
                for s1_byte_candidate in range(256)
            ], dtype=np.uint8)
        else:
            possible_values_vector = np.array([
                ((s1_byte_candidate + prev_intermediate_value_byte + carry) & 0xFF)
                for s1_byte_candidate in range(256)
            ], dtype=np.uint8)

    else:
        raise ValueError("Byte index must be in the range 0-3.")

    # Convert to hamming weights (or other leakage model)
    intermediate_values = np.array([hamming_weight[value] for value in possible_values_vector], dtype=np.uint16)

    return intermediate_values

In [None]:
def calculate_correlation_matrix_s1(s1_word_index,
                                    byte_index, 
                                    num_traces, 
                                    include_XOR,
                                    trace_shift=0, 
                                    obtained_s1_col_0 = CORRECT_STATE_1_COL_0,
                                    previous_intermediate_values = None):
    """
    Calculate the correlation matrix for a specific byte index based on the S1 columns and traces.

    Parameters:
        s1_col0_word_index (int): Index of the S1 word (0, 4, 8, 12).
        byte_index (int): Index of the S1 byte (0-3)
        num_traces (int): Number of traces to use for the calculation.
        include_XOR (bool): If True, include XOR operation in the calculation for S1_8.  
        trace_shift (int): Number of traces to shift for the calculation. Default is 0.
        obtained_s1_col_0 (np.ndarray): The obtained S1 byte from column 0 (shape: [4]).
        previous_intermediate_values (np.ndarray): The previous intermediate values (shape: [4]).

    Returns:
        np.ndarray: The correlation matrix (shape: [num_traces, 256]).
    """
    global TRACES_RANDOM, S1_COLUMNS, CORRECT_KEY, NONCES, CORRECT_STATE_1_COL_0, FOUND_STATE_1_COL_0
    global hamming_weight

    # Ensure the number of traces does not exceed available traces
    max_traces = TRACES_RANDOM.shape[0]
    if trace_shift >= max_traces:
        raise ValueError(f"trace_shift ({trace_shift}) exceeds the number of available traces ({max_traces}).")
    
    num_traces = min(num_traces, max_traces - trace_shift)

    # Slice the traces and nonces to the specified number of traces
    traces = TRACES_RANDOM[trace_shift:num_traces + trace_shift, :]
    s1_columns = S1_COLUMNS[trace_shift:num_traces + trace_shift, :]

    # Placeholder for intermediate values (to be calculated later)
    L_matrix = np.empty((num_traces, 256), dtype=np.uint16) #TODO: unify?

    # Check if previous intermediate values are needed and if calculated
    if s1_word_index == 8 and previous_intermediate_values is None:
        raise ValueError("Previous intermediate values are required for S1_8 word.")
    
    if  byte_index not in range(0, 4):
        raise ValueError("Byte index must be in the range 0-3.")
    
    # Calculate intermediate values for each trace
    for trace_idx in range(num_traces):
        if s1_word_index == 0 or s1_word_index == 4 or s1_word_index == 12:
            L_matrix[trace_idx, :] = calculate_row_of_intermediate_values_for_s1_0_4_12(
                s1_columns[trace_idx],
                s1_word_index,
                byte_index, 
                obtained_s1_col_0
            )
        elif s1_word_index == 8:
            L_matrix[trace_idx, :] = calculate_row_of_intermediate_values_for_s1_8(
                s1_columns[trace_idx],
                byte_index,
                previous_intermediate_values[trace_idx + trace_shift],
                obtained_s1_col_0,
                include_XOR
            )
        else:
            raise ValueError("Byte index must be in the correct range")


    # Calculate the correlation matrix
    correlation_matrix = corr2_coeff(L_matrix.T, traces.T)

    return correlation_matrix

In [None]:
def calculate_correlation_matrix_for_s1_words(
    s1_word_indices, 
    byte_indices, 
    num_traces,
    include_XOR,
    trace_shift=0, 
    obtained_s1_col_0=CORRECT_STATE_1_COL_0, 
):
    """
    Calculate the correlation matrix for specified S1 word indices and byte indices.

    Parameters:
        s1_word_indices (list): A list of S1 word indices (e.g., [0, 4, 8, 12]).
        byte_indices (list): A list of byte indices (0-3) for each word.
        num_traces (int): Number of traces to use for the calculation.
        include_XOR (bool): If True, include XOR operation in the calculation for S1_8.
        trace_shift (int): Number of traces to shift for the calculation. Default is 0.
        obtained_s1_col_0 (np.ndarray): Array of obtained S1 column 0 bytes (shape: [4]).
    """
    global CORRELATION_MATRIXES_S1, PREVIOUS_INTERMEDIATE_VALUES_S1

    for s1_word_index in s1_word_indices:
        for byte_index in byte_indices:
            print(f"Calculating correlation matrix for S1 word {s1_word_index}, byte {byte_index}...")

            if s1_word_index in [0, 4, 12]:
                # Use the function for S1 words 0, 4, and 12
                CORRELATION_MATRIXES_S1[(s1_word_index, byte_index)] = calculate_correlation_matrix_s1(
                    s1_word_index=s1_word_index,
                    byte_index=byte_index,
                    num_traces=num_traces,
                    include_XOR=include_XOR,
                    trace_shift=trace_shift,
                    obtained_s1_col_0=obtained_s1_col_0
                )
            elif s1_word_index == 8:
                # Use the function for S1 word 8, which requires previous intermediate values
                if PREVIOUS_INTERMEDIATE_VALUES_S1 is None:
                    raise ValueError("Previous intermediate values are required for S1 word 8.")
                
                CORRELATION_MATRIXES_S1[(s1_word_index, byte_index)] = calculate_correlation_matrix_s1(
                    s1_word_index=s1_word_index,
                    byte_index=byte_index,
                    num_traces=num_traces,
                    include_XOR=include_XOR,
                    trace_shift=trace_shift,
                    obtained_s1_col_0=obtained_s1_col_0,
                    previous_intermediate_values=PREVIOUS_INTERMEDIATE_VALUES_S1
                )
            else:
                print(f"Skipping unsupported S1 word {s1_word_index}.")

    print("Correlation matrices calculation complete.")

# Correlation matrixes processing (Plot, (Sub)Key retrieval)

In [None]:
# Define the number of traces to use for the calculation
# This can be adjusted based on the available data and requirements
TRACES_USED = 5000

In [None]:
#Correlation matrices for S1 words 0, 4, 8, and 12 (byte indices 0-3)
CORRELATION_MATRIXES_S1 = {}

In [None]:
# Calculate the previous intermediate values, needed for S1_8 acquisition
# Previous intermediate values are calculated from S1 columns 1, 2, and 3 (previous knowledge of key bytes is needed. Simulated using the correct key)
# The function calculates the intermediate values for all traces and stores them in a global variable.
PREVIOUS_INTERMEDIATE_VALUES_S1 = calculate_rotated_previous_intermediate_values_from_s1()

In [None]:
# Save the S1 byte values
CORRECT_STATE_1_COL_0_BYTES = {
    (0, 0):  (CORRECT_STATE_1_COL_0[0] & 0xFF),
    (0, 1): ((CORRECT_STATE_1_COL_0[0] >> 8) & 0xFF),
    (0, 2): ((CORRECT_STATE_1_COL_0[0] >> 16) & 0xFF),
    (0, 3): ((CORRECT_STATE_1_COL_0[0] >> 24) & 0xFF),
    
    (4, 0):  (CORRECT_STATE_1_COL_0[4] & 0xFF),
    (4, 1): ((CORRECT_STATE_1_COL_0[4] >> 8) & 0xFF),
    (4, 2): ((CORRECT_STATE_1_COL_0[4] >> 16) & 0xFF),
    (4, 3): ((CORRECT_STATE_1_COL_0[4] >> 24) & 0xFF),
    
    (8, 0):  (CORRECT_STATE_1_COL_0[8] & 0xFF),
    (8, 1): ((CORRECT_STATE_1_COL_0[8] >> 8) & 0xFF),
    (8, 2): ((CORRECT_STATE_1_COL_0[8] >> 16) & 0xFF),
    (8, 3): ((CORRECT_STATE_1_COL_0[8] >> 24) & 0xFF),
    
    (12, 0):  (CORRECT_STATE_1_COL_0[12] & 0xFF),
    (12, 1): ((CORRECT_STATE_1_COL_0[12] >> 8) & 0xFF),
    (12, 2): ((CORRECT_STATE_1_COL_0[12] >> 16) & 0xFF),
    (12, 3): ((CORRECT_STATE_1_COL_0[12] >> 24) & 0xFF),
}

In [None]:
# Loop through the dictionary and print the hex values
for (word_index, byte_index), byte_value in CORRECT_STATE_1_COL_0_BYTES.items():
    print(f"S1[{word_index}, {byte_index}] = {hex(byte_value)}")

In [None]:
def process_and_plot_correlation_matrix_s1(
    s1_word_index,
    byte_index, 
    show_plots=True, 
    save_plots=False, 
    zoom_plot=False,
    list_correlations=False,
    correct_s1_bytes=CORRECT_STATE_1_COL_0_BYTES,
    output_folder="plots"
):
    """
    Process the correlation matrix for a given key byte index, optionally display and save the plots,
    and save the found key byte value to a global array.

    Parameters:
        s1_word_index (int): The S1 word index being processed (0, 4, 8, or 12).
        byte_index (int): The key byte index being processed. (0-3)
        show_plots (bool): Whether to display the plot.
        save_plots (bool): Whether to save the plot as an image.
        zoom_plot (bool): Whether to zoom into the range defined by `byte_ranges`.
        list_correlations (bool): Whether to list the correlations for all key candidates.
        correct_s1_bytes (np.ndarray): The correct S1 bytes in 4 x 4 matrix
        output_folder (str): Folder to save the plots if `save_plots` is True.
    """
    global FOUND_STATE_1_COL_0, CORRELATION_MATRIXES_S1, BYTE_RANGES_S1, VALUES_IN_BYTE

    if byte_index < 0 or byte_index > 3:
        raise ValueError("Byte index must be in the range 0-3.")
    
    if s1_word_index not in [0, 4, 8, 12]:
        raise ValueError("S1 word index must be one of 0, 4, 8, or 12.")

    # Determine the correct key byte value for the given byte index
    correct_s1_byte = correct_s1_bytes[(s1_word_index,byte_index)]

    # Get the correlation matrix cut-out for the current byte index
    correlation_matrix_cutout = CORRELATION_MATRIXES_S1[(s1_word_index,byte_index)][:, BYTE_RANGES_S1[(s1_word_index,byte_index)][0]:BYTE_RANGES_S1[(s1_word_index,byte_index)][1]]

    # Find the minimum correlation value and its corresponding key candidate
    corr_min = np.unravel_index(correlation_matrix_cutout.argmin(), correlation_matrix_cutout.shape)
    obtained_s1_min = corr_min[0]
    min_value = correlation_matrix_cutout[corr_min]

    insert_found_byte_into_s1_column_0(obtained_s1_min, s1_word_index, byte_index)
    # Find the correlation of the correct key byte
    correct_s1_correlation = np.min(correlation_matrix_cutout[correct_s1_byte])

    output = f"$S_1[{s1_word_index}][{byte_index}]$ & {BYTE_RANGES_S1[(s1_word_index,byte_index)]} & \\texttt{{{format(obtained_s1_min, '#04X').replace('0X', '0x')}}} ({min_value:.5f}) & \\texttt{{{format(correct_s1_byte, '#04X').replace('0X', '0x')}}} ({correct_s1_correlation:.5f})"
    print(output)

    if list_correlations:
        # Calculate the minimum correlation value for each key candidate
        min_correlations = []
        for candidate in range(VALUES_IN_BYTE):
            min_corr = np.min(correlation_matrix_cutout[candidate])
            min_correlations.append((min_corr, candidate))
        
        # Sort the candidates based on their minimum correlation values
        min_correlations.sort()  # Sort by the first element of the tuple (min_corr)
        
        # Print the top 10 sorted candidates and their correlation values
        print(f"Top 10 sorted candidates for S1 byte {byte_index}:")
        for i in range(min(10, len(min_correlations))):  # Ensure we don't exceed the list length
            print(f"Candidate: {hex(min_correlations[i][1])}, Correlation: {min_correlations[i][0]}")
        
        os.makedirs(output_folder, exist_ok=True)  # Ensure the folder exists
        output_file = os.path.join(output_folder, f"correlation_results_s1_{s1_word_index}_{byte_index}.txt")

        with open(output_file, "w") as file:
            for i in range(0, len(min_correlations)):
                file.write(f"Candidate: {hex(min_correlations[i][1])}, Correlation: {min_correlations[i][0]}\n")
        
    if not show_plots and not save_plots:
        return

    # Prepare the plot
    plt.figure(figsize=(12, 6))
    # Plot all key candidates
    for corr_i in range(VALUES_IN_BYTE):
        if corr_i != correct_s1_byte and corr_i != obtained_s1_min:
            plt.plot(CORRELATION_MATRIXES_S1[(s1_word_index,byte_index)][corr_i], 'k.', alpha=0.3)

    # Highlight the correct key byte and the guessed key byte
    plt.plot(CORRELATION_MATRIXES_S1[(s1_word_index,byte_index)][obtained_s1_min], 'r.', label=f"Guessed byte = {hex(obtained_s1_min)}")
    plt.plot(CORRELATION_MATRIXES_S1[(s1_word_index,byte_index)][correct_s1_byte], 'g.', label=f"Correct byte = {hex(correct_s1_byte)}")

    plt.ylim(-1, 1)
    plt.xlabel("Trace Sample")
    plt.ylabel("Correlation")
    plt.legend(loc="upper right")

    start_col, end_col = BYTE_RANGES_S1[(s1_word_index,byte_index)]
    plt.axvline(x=start_col, color='b', linestyle='--', label='Search Range')
    plt.axvline(x=end_col, color='b', linestyle='--')
    
    # Save the plot if required
    if save_plots:
        os.makedirs(output_folder, exist_ok=True)
        plot_path = os.path.join(output_folder, f"s1_{s1_word_index:02}_byte_{byte_index:02}.png")
        plt.savefig(plot_path, bbox_inches='tight', dpi=600)
        print(f"Plot saved to {plot_path}")

    # Show the plot
    if show_plots:
        plt.show()
    
    plt.close()
    
    # Add zoomed range if enabled
    if zoom_plot:
        # Prepare the plot
        plt.figure(figsize=(12, 6))
        # Plot all key candidates
        for corr_i in range(VALUES_IN_BYTE):
            if corr_i != correct_s1_byte and corr_i != obtained_s1_min:
                plt.plot(CORRELATION_MATRIXES_S1[(s1_word_index,byte_index)][corr_i], 'k.', alpha=0.3)

        # Highlight the correct key byte and the guessed key byte
        plt.plot(CORRELATION_MATRIXES_S1[(s1_word_index,byte_index)][obtained_s1_min], 'r.', label=f"Guessed byte = {hex(obtained_s1_min)}")
        plt.plot(CORRELATION_MATRIXES_S1[(s1_word_index,byte_index)][correct_s1_byte], 'g.', label=f"Correct byte = {hex(correct_s1_byte)}")

        plt.ylim(-1, 1)
        plt.xlabel("Trace Sample")
        plt.ylabel("Correlation")
        plt.legend(loc="upper right")

        plt.xlim(start_col, end_col)

        # Show the plot
        if show_plots:
            plt.show()

        # Save the plot if required
        if save_plots:
            os.makedirs(output_folder, exist_ok=True)
            plot_path = os.path.join(output_folder, f"s1_{s1_word_index:02}_byte_{byte_index:02}_zoomed.png")
            plt.savefig(plot_path, bbox_inches='tight', dpi=600)
            print(f"Plot saved to {plot_path}")

        # Close the plot to free memory
        plt.close()

## Execution of the anaysis

In [None]:
INCLUDE_XOR = "STM" in FOLDER
print(f"INCLUDE_XOR = {INCLUDE_XOR}")

In [None]:
# Calculate the correlation matrix for the specified byte range
# This will calculate the correlation matrix for state 1 words (0, 4, 8, 12) and bytes (0, 1, 2, 3)
# and store them in the CORRELATION_MATRIXES_S1 dictionary.
calculate_correlation_matrix_for_s1_words(
    s1_word_indices=[0, 4, 8, 12],
    byte_indices=[0, 1, 2, 3],
    num_traces=TRACES_USED,
    include_XOR=INCLUDE_XOR,
    trace_shift=0, 
    obtained_s1_col_0=CORRECT_STATE_1_COL_0
)
# calculate_correlation_matrix_for_s1_words(
#     s1_word_indices=[8],
#     byte_indices=[1],
#     num_traces=TRACES_USED, 
#     include_XOR=INCLUDE_XOR,
#     trace_shift=0, 
#     obtained_s1_col_0=CORRECT_STATE_1_COL_0
# )

In [None]:
# Define the byte ranges for each s1 word and byte index within the S_1 column 0.
# The ranges are defined based on the analysis of the correlation matrices.
# The ranges are adjusted to focus on the most relevant samples for each key byte.
# The ranges are defined as (start, end) tuples, where start and end are the sample indices.
BYTE_RANGES_S1_STM = {
  (0,0): (0, TRACE_LEN), # S1_0
  (0,1): (0, TRACE_LEN),
  (0,2): (0, TRACE_LEN),
  (0,3): (0, TRACE_LEN),
  (4,0): (0, TRACE_LEN), # S1_4
  (4,1): (0, TRACE_LEN),
  (4,2): (0, TRACE_LEN),
  (4,3): (0, TRACE_LEN),
  (8,0): (0, TRACE_LEN), # S1_8
  (8,1): (0, TRACE_LEN),
  (8,2): (0, TRACE_LEN),
  (8,3): (0, TRACE_LEN),
  (12,0): (0, TRACE_LEN), # S1_12
  (12,1): (0, TRACE_LEN),
  (12,2): (0, TRACE_LEN),
  (12,3): (0, TRACE_LEN),
}

BYTE_RANGES_S1 = BYTE_RANGES_S1_STM

In [None]:
# # Define the byte ranges for each s1 word and byte index within the S_1 column 0.
# # The ranges are defined based on the analysis of the correlation matrices.
# # The ranges are adjusted to focus on the most relevant samples for each key byte.
# # The ranges are defined as (start, end) tuples, where start and end are the sample indices.
# BYTE_RANGES_S1_XMEGA = {
#   (0,0):  (11500, 13000), # S1_0
#   (0,1):  (11500, 13000),
#   (0,2):  (11500, 13000),
#   (0,3):  (11500, 13000),
#   (4,0):  (18500, 20500), # S1_4
#   (4,1):  (18500, 20500),
#   (4,2):  (18500, 20500),
#   (4,3):  (18500, 20500),
#   (8,0):  (16500, 17500), # S1_8
#   (8,1):  (16500, 17500),
#   (8,2):  (16500, 17500),
#   (8,3):  (16500, 17500),
#   (12,0): (14000, 15500), # S1_12
#   (12,1): (14000, 15500),
#   (12,2): (14500, 15500),
#   (12,3): (14000, 15500),
# }

# BYTE_RANGES_S1 = BYTE_RANGES_S1_XMEGA

In [None]:
# Iterate over the specified word indices (0, 4, 8, 12)
for word_index in [0, 4, 8, 12]:
    # Iterate over the byte indices (0-3)
    for byte_index in range(4):
        # Call the function for each combination of word and byte index
        process_and_plot_correlation_matrix_s1(
            s1_word_index=word_index,
            byte_index=byte_index,
            show_plots=False,
            save_plots=False,
            zoom_plot=False,
            list_correlations=False,
            correct_s1_bytes= CORRECT_STATE_1_COL_0_BYTES,
            output_folder=None
        )

In [None]:
print(f"S1_8_XOR_INCLUDED_{not INCLUDE_XOR}".upper())

In [None]:
# Iterate over the specified word indices (0, 4, 8, 12)
for word_index in [0, 4, 8, 12]:
    # Iterate over the byte indices (0-3)
    for byte_index in range(4):
        # Call the function for each combination of word and byte index
        process_and_plot_correlation_matrix_s1(
            s1_word_index=word_index,
            byte_index=byte_index,
            show_plots=False,
            save_plots=True,
            zoom_plot=True,
            list_correlations=True,
            correct_s1_bytes= CORRECT_STATE_1_COL_0_BYTES,
            output_folder=CPA_OUTPUT_FOLDER
        )

## Success rate in bits

In [None]:
# Define the ranges of S1 state bytes to analyze
S1_BYTE_RANGES = [0, 4, 8, 12]

# Initialize counters for correct bytes and correct bits
correct_s1_bytes = 0
correct_s1_bits = 0

# Iterate over the specified S1 byte ranges
for word_index in S1_BYTE_RANGES:
    for byte_index in range(4):  # Each word has 4 bytes
        # Get the obtained S1 byte and the correct S1 byte
        obtained_byte = (FOUND_STATE_1_COL_0[word_index] >> (8 * byte_index)) & 0xFF
        correct_byte = (CORRECT_STATE_1_COL_0[word_index] >> (8 * byte_index)) & 0xFF
        
        # Check if the entire byte matches
        if obtained_byte == correct_byte:
            correct_s1_bytes += 1
        
        # Count the number of matching bits in the byte
        print(f"S1 Byte (Word {word_index}, Byte {byte_index}): Obtained = {hex(obtained_byte)}, Correct = {hex(correct_byte)}")
        # Count the number of matching bits using XOR and bit counting
        matching_bits = bin(~(obtained_byte ^ correct_byte) & 0xFF).count('1')
        correct_s1_bits += matching_bits

# Print the results
total_s1_bytes = len(S1_BYTE_RANGES) * 4  # 4 bytes per word
print(f"Correct S1 Bytes: {correct_s1_bytes} / {total_s1_bytes}")
print(f"Correct S1 Bits: {correct_s1_bits} / {total_s1_bytes * 8}")

## Experiment with S1_8 attack
Usage of XOR in the LF influences the outcome

In [None]:
BYTE_RANGES_S1_XMEGA[(8,0)] = (16500, 18500)
BYTE_RANGES_S1_XMEGA[(8,1)] = (16500, 18500)
BYTE_RANGES_S1_XMEGA[(8,2)] = (16500, 18500)
BYTE_RANGES_S1_XMEGA[(8,3)] = (16500, 18500)

In [None]:
CPA_OUTPUT_FOLDER_EXPERIMENT = os.path.join(CPA_OUTPUT_FOLDER, f"S1_8_XOR_INCLUDED_{not INCLUDE_XOR}".upper())
os.makedirs(CPA_OUTPUT_FOLDER_EXPERIMENT, exist_ok=True)  # Ensure the folder exists
# Print the folder path for the output
print(f"Output folder for S1_8_EXPERIMENT: {CPA_OUTPUT_FOLDER_EXPERIMENT}")

In [None]:
xor_included = not INCLUDE_XOR
print(f"INCLUDE_XOR = {xor_included}")

calculate_correlation_matrix_for_s1_words(
    s1_word_indices=[8],
    byte_indices=[0, 1, 2, 3],
    num_traces=TRACES_USED,
    include_XOR=xor_included,
    trace_shift=0, 
    obtained_s1_col_0=CORRECT_STATE_1_COL_0
)

for word_index in [8]:
    # Iterate over the byte indices (0-3)
    for byte_index in range(4):
        # Call the function for each combination of word and byte index
        process_and_plot_correlation_matrix_s1(
            s1_word_index=word_index,
            byte_index=byte_index,
            show_plots=False,
            save_plots=False,
            zoom_plot=False,
            list_correlations=False,
            correct_s1_bytes= CORRECT_STATE_1_COL_0_BYTES,
            output_folder=None
        )

for word_index in [8]:
    # Iterate over the byte indices (0-3)
    for byte_index in range(4):
        # Call the function for each combination of word and byte index
        process_and_plot_correlation_matrix_s1(
            s1_word_index=word_index,
            byte_index=byte_index,
            show_plots=False,
            save_plots=True,
            zoom_plot=True,
            list_correlations=True,
            correct_s1_bytes= CORRECT_STATE_1_COL_0_BYTES,
            output_folder=CPA_OUTPUT_FOLDER_EXPERIMENT
        )