In [None]:
import numpy as np
from scipy import stats
import math
import os
import matplotlib.pyplot as plt
from itertools import chain


In [None]:
FOLDER = "ChaCha-100-000-Random-Nonce-STM"
# FOLDER = "ChaCha-100-000-Random-Nonce-STM-2"

In [None]:
FOLDER = "ChaCha-100-000-Random-Nonce-XMEGA"
# FOLDER = "ChaCha-100-000-Random-Nonce-XMEGA-2"

In [None]:
CPA_OUTPUT_FOLDER = "_CPA_XMEGA_1_KEYS_NO_CARRY"

# Load the content of files 

In [None]:
NONCE_LEN_BYTES = 12
TRACE_CNT = None
TRACE_LEN = None
TRACE_RANDOM_CNT = None
CHUNK_SIZE = None
CHUNKS_CNT = None
LAST_CHUNK_SIZE = None
def read_info(FOLDER):
    """
    Read the info file to get the number of traces, chunk size, and trace length.
    """
    global CHUNK_SIZE, LAST_CHUNK_SIZE
    global CHUNKS_CNT
    with open(f"{FOLDER}/info.txt", 'r') as file:
        global TRACE_CNT
        global TRACE_LEN
        global TRACE_RANDOM_CNT
        TRACE_CNT = int(file.readline())
        TRACE_RANDOM_CNT = TRACE_CNT
        CHUNK_SIZE = int(file.readline())
        TRACE_LEN = int(file.readline())
        
    CHUNKS_CNT = math.ceil(TRACE_CNT / CHUNK_SIZE)
    LAST_CHUNK_SIZE = TRACE_CNT - (CHUNKS_CNT - 1)* CHUNK_SIZE
    print(f"TRACE_CNT = {TRACE_CNT}")   
    print(f"CHUNK_SIZE = {CHUNK_SIZE}")   
    print(f"LAST_CHUNK_SIZE = {LAST_CHUNK_SIZE}")   
    print(f"CHUNKS_CNT = {CHUNKS_CNT}")   
    print(f"TRACE_RANDOM_CNT = {TRACE_RANDOM_CNT}")   
    print(f"TRACE_LEN = {TRACE_LEN}")
    
    
read_info(FOLDER)

In [None]:
# TRACE_CNT = 500_000
# CHUNKS_CNT = 50

In [None]:
NONCES = None
def read_nonces(FOLDER):
    """
    Read the nonces from the binary files in the specified folder.
    The nonces are stored in chunks, and the function concatenates them into a single array.
    """
    global NONCES
    global CHUNK_SIZE, LAST_CHUNK_SIZE
    global CHUNKS_CNT
    global TRACE_CNT
    # Calculate the number of chunks
    
    nonces_list = []
    for chunk_index in range(CHUNKS_CNT):
        chunk_folder = os.path.join(FOLDER, f"chunk_{chunk_index}")
        chunk_file = os.path.join(chunk_folder, "nonces_random.bin")
        
        if os.path.exists(chunk_file):
            with open(chunk_file, 'rb') as file:
                byte_array = file.read()
            
            if chunk_index != CHUNKS_CNT-1:
                chunk = np.frombuffer(byte_array, dtype=np.uint8).reshape((CHUNK_SIZE, NONCE_LEN_BYTES))
            else:
                chunk = np.frombuffer(byte_array, dtype=np.uint8).reshape((LAST_CHUNK_SIZE, NONCE_LEN_BYTES))

            nonces_list.append(chunk)
        else:
            print(f"Chunk file {chunk_file} does not exist.")
    
    # Concatenate all chunks
    NONCES = np.vstack(nonces_list)
        
    # Desired shape
    desired_shape = (TRACE_CNT, NONCE_LEN_BYTES)

    # Check if the array has the desired shape
    if NONCES.shape != desired_shape:
        # Trim the array to the desired shape
        print(f"NONCES.shape != desired_shape - {NONCES.shape} != {desired_shape}")
        NONCES = NONCES[:desired_shape[0], :desired_shape[1]]
        print(f"Trimming NONCES to {NONCES.shape}")
        
    nonce = NONCES[0]

    # Convert each byte to its hexadecimal representation and join them
    hex_representation = ''.join(f'{byte:02x}' for byte in nonce[:12])

    # Print the result
    print(f"Nonce[0] = {hex_representation}")
    
    
read_nonces(FOLDER)

In [None]:
#load the correct key
CORRECT_KEY = np.fromfile(f"{FOLDER}/key.bin", dtype=np.uint8)
hex_key = ''.join(f'{byte:02x}' for byte in CORRECT_KEY)
print(f"Correct Key: {hex_key}")


In [None]:
TRACES_RANDOM = None

def load_traces_random(folder):
    """
    Load traces from the specified folder. 
    The traces are stored in binary files (chunks, defined in info.txt file),
    and the function concatenates them into a single array.
    """
    global TRACES_RANDOM, CHUNK_SIZE, LAST_CHUNK_SIZE, CHUNKS_CNT, TRACE_CNT

    traces_list = []
    for chunk_index in range(CHUNKS_CNT):
        chunk_folder = os.path.join(folder, f"chunk_{chunk_index}")
        chunk_file = os.path.join(chunk_folder, "traces_random.bin")

        if os.path.exists(chunk_file):
            with open(chunk_file, 'rb') as file:
                byte_array = file.read()

            if chunk_index != CHUNKS_CNT-1:
                chunk_traces = np.frombuffer(byte_array, dtype=np.uint16).reshape((CHUNK_SIZE, TRACE_LEN))
            else:
                chunk_traces = np.frombuffer(byte_array, dtype=np.uint16).reshape((LAST_CHUNK_SIZE, TRACE_LEN))

            traces_list.append(chunk_traces)
        else:
            print(f"Chunk file {chunk_file} does not exist.")

    # Concatenate all chunks
    TRACES_RANDOM = np.vstack(traces_list)

load_traces_random(FOLDER)

# Adjust size of traces if needed

In [None]:
# # Crop every trace (row) to sample range 0..1499
# TRACE_CNT = 5000
# TRACES_RANDOM = TRACES_RANDOM[:TRACE_CNT, :]
# # TRACES_RANDOM = TRACES_RANDOM[:, 0:8_000]

# NONCES= NONCES[:TRACE_CNT, :]

In [None]:
print(f"TRACES shape = {TRACES_RANDOM.shape}")

# Helper functions and constants setup

In [None]:
CHACHA_KEY_LEN_BYTES = 32
VALUES_IN_BYTE = 256 #0...255

# constant sigma - 16B - "expand 32-byte k" transformed to bytes
fixed_sigma = bytearray("expand 32-byte k", "utf-8")
print(hex(fixed_sigma[4]))

# Hamming weight table for 8-bit values
# This table is used to calculate the Hamming weight of a byte value.
hamming_weight = [
   0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 
   2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 1, 2, 2, 3, 2, 3, 3, 4, 
   2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 
   5, 6, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 
   3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 
   4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 
   1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 
   4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 
   4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 2, 3, 
   3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 
   4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 
   6, 7, 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8]

def corr2_coeff(A, B):
    """
    Calculate the correlation coefficient between two matrices A and B.
    Each row of A and B is treated as a separate variable.
    The function returns a matrix of correlation coefficients.
    """
    # Rowwise mean of input arrays & subtract from input arrays themeselves
    A_mA = A - A.mean(1)[:, None]
    B_mB = B - B.mean(1)[:, None]

    # Sum of squares across rows
    ssA = (A_mA**2).sum(1)
    ssB = (B_mB**2).sum(1)
    
        # Locate rows with zero sum of squares
    zero_ssA_indices = np.where(ssA == 0)[0]
    zero_ssB_indices = np.where(ssB == 0)[0]
    
    # Avoid division by zero by setting problematic rows to NaN
    ssA[zero_ssA_indices] = np.nan
    ssB[zero_ssB_indices] = np.nan
    
    one_part = np.sqrt(np.dot(ssA[:, None],ssB[None]))

    # Finally get corr coeff
    return np.dot(A_mA, B_mB.T) / one_part

def leftRotate16(n):
    return ((n << 16) | (n >> 16)) & 0xFFFFFFFF

## Functions to obtain the K[4-15] and K[20-31]

In [None]:
def calculate_rotated_previous_intermediate_values_vector(size, 
                                                          trace_shift, 
                                                          column_index, 
                                                          obtained_key_byte_array):
    """
    Calculate the rotated previous intermediate values vector for a given column of the ChaCha20 state.
    This function computes the intermediate values based on the provided key byte array and the fixed sigma value.

    Parameters:
    - size: The number of intermediate values to calculate.
    - trace_shift: The number of traces to shift.
    - column_index: The column index to use for the calculation (must be greater than 0).
    - obtained_key_byte_array: The byte array of the obtained key.

    Returns:
    - A numpy array of the calculated intermediate values.
    """
    if column_index == 0:
        raise ValueError("Column index must be greater than 0.")
        
    if size + trace_shift > len(NONCES):
        raise ValueError("Trace shift must be smaller.")


    intermediate_values = np.empty(size, dtype=np.uint32)  # Store 32-bit intermediate values
    
    # Extract 4 corresponding bytes from the key and form a 32-bit word (LSB first)
    key_word = (
        obtained_key_byte_array[column_index * 4 + 0] |
        (obtained_key_byte_array[column_index * 4 + 1] << 8) |
        (obtained_key_byte_array[column_index * 4 + 2] << 16) |
        (obtained_key_byte_array[column_index * 4 + 3] << 24)
    )

    # Calculate the intermediate value: nonce_word ^ (key_word + sigma_word)
    sigma_word = (
        fixed_sigma[column_index * 4 + 0] |
        (fixed_sigma[column_index * 4 + 1] << 8) |
        (fixed_sigma[column_index * 4 + 2] << 16) |
        (fixed_sigma[column_index * 4 + 3] << 24)
    )
        
    arr_index = 0
    for i in range(trace_shift, trace_shift + size):
        # Extract 4 corresponding bytes from NONCES and form a 32-bit word (LSB first)
        nonce_word = (
            NONCES[i][(column_index - 1) * 4 + 0] |
            (NONCES[i][(column_index - 1) * 4 + 1] << 8) |
            (NONCES[i][(column_index - 1) * 4 + 2] << 16) |
            (NONCES[i][(column_index - 1) * 4 + 3] << 24)
        )
        intermediate_values[arr_index] = nonce_word ^ (key_word + sigma_word)
        intermediate_values[arr_index] = leftRotate16(intermediate_values[arr_index])
        arr_index += 1

    # Convert the intermediate values to numpy array of 4 x uint8
    intermediate_values = np.array(
        [
            (
                (intermediate_values[i]        & 0xFF),
                ((intermediate_values[i] >>  8) & 0xFF),
                ((intermediate_values[i] >> 16) & 0xFF),
                ((intermediate_values[i] >> 24) & 0xFF)
            ) for i in range(size)
        ], dtype=np.uint8)
    
    return intermediate_values

In [None]:
# MODIFICATION: NO CARRY
def calculate_row_of_intermediate_values_for_key_byte_4_to_15(nonce_row, 
                                                              byte_index, 
                                                              obtained_key_byte_array=None):
    """
    Calculate one row of intermediate values for a specific key byte (4-15) based on the nonce row and fixed sigma.

    Parameters:
        nonce_row (np.ndarray): A single row of the NONCES array (shape: [NONCE_LEN_BYTES]).
        byte_index (int): Index of the key byte (4-15).
        prev_key_byte (int or None): Value of the previous key byte for carry calculation. None if not needed.

    Returns:
        np.ndarray: Intermediate values for all possible key byte values (shape: [256]).
    """

    if 4 <= byte_index <= 15:
        nonce_byte = nonce_row[byte_index - 4] # Adjusted index for nonce byte
        
        # Calculate possible values based on the nonce byte and fixed sigma
        possible_values_vector = np.array([
            nonce_byte ^ ((key_candidate + fixed_sigma[byte_index]) & 0xFF)
            for key_candidate in range(256)
        ], dtype=np.uint8)

    else:
        raise ValueError("Byte index must be in the range 4-15 or 20-31.")

    # Convert to hamming weights (or other leakage model)
    intermediate_values = np.array([hamming_weight[value] for value in possible_values_vector], dtype=np.uint16)

    return intermediate_values

In [None]:
# MODIFICATION: NO CARRY
def calculate_row_of_intermediate_values_for_key_byte_20_to_31(byte_index, 
                                                               previous_intermediate_value_array, 
                                                               obtained_key_byte_array=None):
    """
    Calculate one row of intermediate values for a specific key byte (20-31) based on the previous intermediate values and fixed sigma.

    Parameters:
        byte_index (int): Index of the key byte (20-31).
        previous_intermediate_value_array (np.ndarray): Array of previous intermediate values (shape: [4]).
        obtained_key_byte_array (np.ndarray): Array of obtained key bytes (shape: [32]).

    Returns:
        np.ndarray: Intermediate values for all possible key byte values (shape: [256]).
    """
    if 20 <= byte_index <= 31:
        possible_values_vector = np.array([
            (( key_candidate + previous_intermediate_value_array[byte_index % 4]) ^ obtained_key_byte_array[byte_index - 16]) & 0xFF
            for key_candidate in range(256)
        ], dtype=np.uint8)

    else:
        raise ValueError("Byte index must be in the range 20-31.")
    
    # Convert to hamming weights
    intermediate_values = np.array([hamming_weight[value] for value in possible_values_vector], dtype=np.uint16)

    return intermediate_values

In [None]:
def calculate_correlation_matrix_col_1_to_3(byte_index, 
                                            num_traces, 
                                            trace_shift=0, 
                                            obtained_key_array = CORRECT_KEY,
                                            previous_intermediate_values = None):
    """
    Calculate the correlation matrix for a specific key from ChaCha20 state columns (1-3).

    Parameters:
        byte_index (int): Index of the key byte (4-15 or 20-31).
        num_traces (int): Number of traces to use for the calculation.
        trace_shift (int): Number of traces to shift for the calculation. Default is 0.
        obtained_key_array (np.ndarray): Array of obtained key bytes (shape: [32]).
        previous_intermediate_values (np.ndarray): Array of previous intermediate values (shape: [TRACE_LEN]).

    Returns:
        np.ndarray: Correlation matrix of shape (256, TRACE_LEN).
    """
    global NONCES, TRACES_RANDOM
    
    # Ensure the number of traces does not exceed available traces
    max_traces = TRACES_RANDOM.shape[0]
    if trace_shift >= max_traces:
        raise ValueError(f"trace_shift ({trace_shift}) exceeds the number of available traces ({max_traces}).")
    
    num_traces = min(num_traces, max_traces - trace_shift)

    # Slice the traces and nonces to the specified number of traces
    traces = TRACES_RANDOM[trace_shift:num_traces + trace_shift, :]
    nonces = NONCES[trace_shift:num_traces + trace_shift, :]

    # Placeholder for intermediate values (to be calculated later)
    L_matrix = np.empty((num_traces, 256), dtype=np.float32)

    # Check if previous intermediate values are needed and if calculated
    if byte_index >= 20 and previous_intermediate_values is None:
        raise ValueError("Previous intermediate values are required for byte index >= 20.")

    # Calculate intermediate values for each trace
    for trace_idx in range(num_traces):
        if 4 <= byte_index <= 15:
            L_matrix[trace_idx, :] = calculate_row_of_intermediate_values_for_key_byte_4_to_15(
                nonces[trace_idx], byte_index, obtained_key_array
            )
        elif 20 <= byte_index <= 31:
            L_matrix[trace_idx, :] = calculate_row_of_intermediate_values_for_key_byte_20_to_31(
                byte_index, previous_intermediate_values[trace_idx + trace_shift] , obtained_key_array
            )
        else:
            raise ValueError("Byte index must be in the range 4-15 or 20-31.")


    # Calculate the correlation matrix
    correlation_matrix = corr2_coeff(L_matrix.T, traces.T)

    return correlation_matrix

## Correlation matrixes calculation using the correct key

In [None]:
# Global variable to store intermediate values for columns 1, 2, and 3
PREVIOUS_INTERMEDIATE_VALUES = {}

def calculate_all_rotated_previous_intermediate_values(size, trace_shift, obtained_key_byte_array):
    """
    Calculate rotated previous intermediate values for columns 1, 2, and 3
    and store them in a global variable for later use.
    
    Parameters:
        size (int): Number of rows (traces/nonces) to process.
        trace_shift (int): Number of traces to shift for the calculation.
        obtained_key_byte_array (np.ndarray): Array of key bytes (shape: [CHACHA_KEY_LEN_BYTES]).
    """
    global PREVIOUS_INTERMEDIATE_VALUES
    
    for column_index in range(1, 4):  # Columns 1, 2, and 3
        print(f"Calculating rotated previous intermediate values for column {column_index}...")
        PREVIOUS_INTERMEDIATE_VALUES[column_index] = calculate_rotated_previous_intermediate_values_vector(
            size=size,
            trace_shift=trace_shift,
            column_index=column_index,
            obtained_key_byte_array=obtained_key_byte_array
        )
    print("Rotated previous intermediate values for columns 1, 2, and 3 have been calculated and stored.")

In [None]:
def calculate_correlation_matrix_for_byte_range(
    byte_range, 
    num_traces, 
    trace_shift=0, 
    obtained_key_array=CORRECT_KEY, 
):
    """
    Calculate the correlation matrix for a specified range of bytes.

    Parameters:
        byte_range (tuple): A tuple specifying the start and end byte indices (inclusive).
        num_traces (int): Number of traces to use for the calculation.
        trace_shift (int): Number of traces to shift for the calculation. Default is 0.
        obtained_key_array (np.ndarray): Array of obtained key bytes (shape: [32]).
    """
    global CORRELATION_MATRIXES, PREVIOUS_INTERMEDIATE_VALUES
    start_byte, end_byte = byte_range
    
    for byte_index in range(start_byte, end_byte + 1):
        print(f"Calculating correlation matrix for key byte {byte_index}...")

        if byte_index < 4:
            # Key bytes 0–3 are not part of the main recovery process
            print(f"Skipping key byte {byte_index} (not part of recovery).")
            continue

        if 4 <= byte_index <= 15:
            # Calculate correlation matrix for key bytes 4–15
            CORRELATION_MATRIXES[byte_index] = calculate_correlation_matrix_col_1_to_3(
                byte_index=byte_index,
                num_traces=num_traces,
                trace_shift=trace_shift,
                obtained_key_array=obtained_key_array
            )
        elif 20 <= byte_index <= 31:
            # Calculate correlation matrix for key bytes 20–31
            if PREVIOUS_INTERMEDIATE_VALUES is None:
                raise ValueError("Previous intermediate values are required for byte index >= 20.")

            # Determine the column index based on the byte index
            column_index = (byte_index - 20) // 4 + 1  # Maps 20–23 -> 1, 24–27 -> 2, 28–31 -> 3

            # Pass the corresponding intermediate values to the function
            CORRELATION_MATRIXES[byte_index] = calculate_correlation_matrix_col_1_to_3(
                byte_index=byte_index,
                num_traces=num_traces,
                trace_shift=trace_shift,
                obtained_key_array=obtained_key_array,
                previous_intermediate_values=PREVIOUS_INTERMEDIATE_VALUES[column_index]
            )
        else:
            print(f"Skipping unsupported key byte {byte_index}.")

    print("Correlation matrices calculation complete.")

# Correlation matrixes processing (Plot, (Sub)Key retrieval)

In [None]:
# Define the number of traces to use for the calculation
# This can be adjusted based on the available data and requirements
TRACES_USED = 5000

In [None]:
#Correlation matrices for key bytes 0–31
CORRELATION_MATRIXES = {}

# Define the ranges, in which the min. correlation is searched, for all 32 bytes as a list of tuples (start, end)
BYTE_RANGES = []

# Global array to store the found key byte values
FOUND_KEY_BYTES = [None] * 32   

In [None]:
def process_and_plot_correlation_matrix(
    byte_index, 
    show_plots=True, 
    save_plots=False, 
    zoom_plot=False,
    list_correlations=False,
    correct_key_bytes=CORRECT_KEY,
    output_folder="plots"
):
    """
    Process the correlation matrix for a given key byte index, optionally display and save the plots,
    and save the found key byte value to a global array.

    Parameters:
        byte_index (int): The key byte index being processed.
        show_plots (bool): Whether to display the plot.
        save_plots (bool): Whether to save the plot as an image.
        zoom_plot (bool): Whether to zoom into the range defined by `byte_ranges`.
        list_correlations (bool): Whether to list the correlations for all key candidates.
        correct_key_bytes (np.ndarray): The correct key bytes for comparison.
        output_folder (str): Folder to save the plots if `save_plots` is True.
    """
    global FOUND_KEY_BYTES, CORRELATION_MATRIXES, BYTE_RANGES, VALUES_IN_BYTE

    if byte_index < 0 or byte_index >= 32:
        raise ValueError("Byte index must be in the range 0-31.")

    # Determine the correct key byte value for the given byte index
    correct_key_byte = correct_key_bytes[byte_index]

    # Get the correlation matrix cut-out for the current byte index
    correlation_matrix_cutout = CORRELATION_MATRIXES[byte_index][:, BYTE_RANGES[byte_index][0]:BYTE_RANGES[byte_index][1]]

    # Find the minimum correlation value and its corresponding key candidate
    corr_min = np.unravel_index(correlation_matrix_cutout.argmin(), correlation_matrix_cutout.shape)
    obtained_key_min = corr_min[0]
    min_value = correlation_matrix_cutout[corr_min]
    FOUND_KEY_BYTES[byte_index] = obtained_key_min

    # Find the correlation of the correct key byte
    correct_key_correlation = np.min(correlation_matrix_cutout[correct_key_byte])

    output = f"{byte_index:>2} & {BYTE_RANGES[byte_index]} & \\texttt{{{format(obtained_key_min, '#04X').replace('0X', '0x')}}} ({min_value:.5f}) & \\texttt{{{format(correct_key_byte, '#04X').replace('0X', '0x')}}} ({correct_key_correlation:.5f})"
    print(output)
    
    if list_correlations:
        # Calculate the minimum correlation value for each key candidate
        min_correlations = []
        for candidate in range(VALUES_IN_BYTE):
            min_corr = np.min(correlation_matrix_cutout[candidate])
            min_correlations.append((min_corr, candidate))
        
        # Sort the candidates based on their minimum correlation values
        min_correlations.sort()  # Sort by the first element of the tuple (min_corr)
        
        # Print the top 10 sorted candidates and their correlation values
        print(f"Top 10 sorted candidates for key byte {byte_index}:")
        for i in range(min(10, len(min_correlations))):  # Ensure we don't exceed the list length
            print(f"Candidate: {hex(min_correlations[i][1])}, Correlation: {min_correlations[i][0]}")
        
        os.makedirs(output_folder, exist_ok=True)  # Ensure the folder exists
        output_file = os.path.join(output_folder, f"correlation_results_byte_{byte_index}.txt")

        with open(output_file, "w") as file:
            for i in range(0, len(min_correlations)):
                file.write(f"Candidate: {hex(min_correlations[i][1])}, Correlation: {min_correlations[i][0]}\n")
        
    if not show_plots and not save_plots:
        return

    # Prepare the plot
    plt.figure(figsize=(12, 6))
    # Plot all key candidates
    for corr_i in range(VALUES_IN_BYTE):
        if corr_i != correct_key_byte and corr_i != obtained_key_min:
            plt.plot(CORRELATION_MATRIXES[byte_index][corr_i], 'k.', alpha=0.3)

    # Highlight the correct key byte and the guessed key byte
    plt.plot(CORRELATION_MATRIXES[byte_index][obtained_key_min], 'r.', label=f"Guessed Key = {hex(obtained_key_min)}")
    plt.plot(CORRELATION_MATRIXES[byte_index][correct_key_byte], 'g.', label=f"Correct Key = {hex(correct_key_byte)}")

    plt.ylim(-1, 1)
    plt.xlabel("Trace Sample")
    plt.ylabel("Correlation")
    plt.legend(loc="upper right")

    start_col, end_col = BYTE_RANGES[byte_index]
    plt.axvline(x=start_col, color='b', linestyle='--', label='Search Range')
    plt.axvline(x=end_col, color='b', linestyle='--')
    
    # Save the plot if required
    if save_plots:
        os.makedirs(output_folder, exist_ok=True)
        plot_path = os.path.join(output_folder, f"key_byte_{byte_index:02}.png")
        plt.savefig(plot_path, bbox_inches='tight', dpi=600)
        print(f"Plot saved to {plot_path}")

    # Show the plot
    if show_plots:
        plt.show()
    
    plt.close()
    
    # Add zoomed range if enabled
    if zoom_plot:
        # Prepare the plot
        plt.figure(figsize=(12, 6))
        # Plot all key candidates
        for corr_i in range(VALUES_IN_BYTE):
            if corr_i != correct_key_byte and corr_i != obtained_key_min:
                plt.plot(CORRELATION_MATRIXES[byte_index][corr_i], 'k.', alpha=0.3)

        # Highlight the correct key byte and the guessed key byte
        plt.plot(CORRELATION_MATRIXES[byte_index][obtained_key_min], 'r.', label=f"Guessed Key = {hex(obtained_key_min)}")
        plt.plot(CORRELATION_MATRIXES[byte_index][correct_key_byte], 'g.', label=f"Correct Key = {hex(correct_key_byte)}")

        plt.ylim(-1, 1)
        plt.xlabel("Trace Sample")
        plt.ylabel("Correlation")
        plt.legend(loc="upper right")

        plt.xlim(start_col, end_col)

        # Show the plot
        if show_plots:
            plt.show()

        # Save the plot if required
        if save_plots:
            os.makedirs(output_folder, exist_ok=True)
            plot_path = os.path.join(output_folder, f"key_byte_{byte_index:02}_zoomed.png")
            plt.savefig(plot_path, bbox_inches='tight', dpi=600)
            print(f"Plot saved to {plot_path}")

        # Close the plot to free memory
        plt.close()

## Execution of the anaysis

In [None]:
# Call the function to calculate and store intermediate values used for extraction of key bytes in second row
# with more complex leakage function - PREVIOUS_INTERMEDIATE_VALUES
calculate_all_rotated_previous_intermediate_values(
    size=TRACES_USED,
    trace_shift=0,
    obtained_key_byte_array=CORRECT_KEY
)

In [None]:
# Calculate the correlation matrix for the specified byte range
# This will calculate the correlation matrix for key bytes 4 to 15 and 20 to 31
# and store them in the CORRELATION_MATRIXES dictionary
calculate_correlation_matrix_for_byte_range(
    byte_range=(0, 31), 
    num_traces=TRACES_USED, 
    trace_shift=0, 
    obtained_key_array=CORRECT_KEY, 
)

In [None]:
# Define the byte ranges for each key byte (0-31) as a list of tuples (start, end)
# The ranges are defined based on the analysis of the correlation matrices.
# The ranges are adjusted to focus on the most relevant samples for each key byte.
# The ranges are defined as (start, end) tuples, where start and end are the sample indices.
BYTE_RANGES_STM = [
    (0, TRACE_LEN),  # Byte 0
    (0, TRACE_LEN),  # Byte 1
    (0, TRACE_LEN),  # Byte 2
    (0, TRACE_LEN),  # Byte 3
    
    (500, 850),# Byte 4
    (500, 850),# Byte 5
    (500, 850),# Byte 6
    (500, 850),# Byte 7
    
    (570, 920),# Byte 8
    (570, 920),# Byte 9
    (570, 920),# Byte 10
    (570, 920),# Byte 11
    
    (570, 920),# Byte 12
    (570, 920),# Byte 13
    (570, 920),# Byte 14 
    (570, 920),# Byte 15
    
    (0, TRACE_LEN),# Byte 16
    (0, TRACE_LEN),# Byte 17
    (0, TRACE_LEN),# Byte 18
    (0, TRACE_LEN),# Byte 19
    
    (500, 950),# Byte 20
    (500, 950),# Byte 21
    (500, 950),# Byte 22
    (500, 950),# Byte 23
    
    (550, 950),# Byte 24
    (550, 950),# Byte 25
    (550, 950),# Byte 26
    (550, 950),# Byte 27
    
    (550, 950),# Byte 28
    (550, 950),# Byte 29
    (550, 950),# Byte 30
    (550, 950),# Byte 31
]

BYTE_RANGES = BYTE_RANGES_STM

In [None]:
# Define the byte ranges for each key byte (0-31) as a list of tuples (start, end)
# The ranges are defined based on the analysis of the correlation matrices.
# The ranges are adjusted to focus on the most relevant samples for each key byte.
# The ranges are defined as (start, end) tuples, where start and end are the sample indices.
BYTE_RANGES_XMEGA = [
    (0, TRACE_LEN),  # Byte 0
    (0, TRACE_LEN),  # Byte 1
    (0, TRACE_LEN),  # Byte 2
    (0, TRACE_LEN),  # Byte 3
    
    (4700, 5100),# Byte 4
    (4700, 5100),# Byte 5
    (4770, 5100),# Byte 6
    (4700, 5100),# Byte 7
    
    (7100, 7500),# Byte 8 (7100, 7200)
    (7100, 7500),# Byte 9
    (7100, 7500),# Byte 10
    (7100, 7500),# Byte 11 (7100, 7150)
    
    (9450, 9800),# Byte 12
    (9450, 9800),#(9450, 9550),# Byte 13
    (9550, 9600),#(9450, 9550),# Byte 14 (9480, 9515), found somehow
    (9450, 9520),#(9450, 9520),# Byte 15
    
    (0, TRACE_LEN),# Byte 16
    (0, TRACE_LEN),# Byte 17
    (0, TRACE_LEN),# Byte 18
    (0, TRACE_LEN),# Byte 19
    
    (5500, 5760),# Byte 20
    (5500, 5760),# Byte 21 (5000, 6500), - 14 instead 13
    (5500, 5750),# Byte 22
    (5500, 5760),# Byte 23
    
    (7700, 8160),# Byte 24
    (7700, 8160),# Byte 25
    (7700, 8150),# Byte 26
    (7700, 8160),# Byte 27
    
    (9900, 11000),# Byte 28
    (9400, 11000),# Byte 29
    (9400, 11000),# Byte 30
    (9400, 11000),# Byte 31
]

BYTE_RANGES = BYTE_RANGES_XMEGA

In [None]:
# Get the key guesses for bytes 4 to 15 and 20 to 31
# The key guesses are stored in the FOUND_KEY_BYTES global array.
# The correct key bytes are used for comparison and output formatting.
for byte_index in chain(range(4, 16), range(20, 32)):
    process_and_plot_correlation_matrix(
        byte_index=byte_index,
        show_plots=False,
        save_plots=False,
        zoom_plot=False,
        list_correlations=False,
        correct_key_bytes=CORRECT_KEY,
        output_folder=None
    )

In [None]:
# Get the key guesses for bytes 4 to 15 and 20 to 31 together with the plots
# The key guesses are stored in the FOUND_KEY_BYTES global array.
# Plots are saved to the specified output folder.
for byte_index in chain(range(4, 16), range(20, 32)):
    process_and_plot_correlation_matrix(
        byte_index=byte_index,
        show_plots=False,
        save_plots=True,
        zoom_plot=True,
        list_correlations=True,
        correct_key_bytes=CORRECT_KEY,
        output_folder=CPA_OUTPUT_FOLDER
    )

## Success rate in bits

In [None]:
# Define the ranges of key bytes to analyze
KEY_BYTE_RANGES = list(range(4, 16)) + list(range(20, 32))

# Initialize counters for correct bytes and correct bits
correct_key_bytes = 0
correct_key_bits = 0

# Iterate over the specified key byte ranges
for byte_index in KEY_BYTE_RANGES:
    # Get the obtained key byte and the correct key byte
    obtained_byte = np.uint8(FOUND_KEY_BYTES[byte_index])
    correct_byte  = np.uint8(CORRECT_KEY[byte_index])
    
    # Check if the entire byte matches
    if obtained_byte == correct_byte:
        correct_key_bytes += 1
    
    # Count the number of matching bits in the byte
    print(f"Byte {byte_index}: Obtained = {hex(obtained_byte)}, Correct = {hex(correct_byte)}")
    # Count the number of matching bits using XOR and bit counting
    matching_bits = bin(~(obtained_byte ^ correct_byte)).count('1')
    correct_key_bits += matching_bits

# Print the results
print(f"Correct Key Bytes: {correct_key_bytes} / {len(KEY_BYTE_RANGES)}")
print(f"Correct Key Bits: {correct_key_bits} / {len(KEY_BYTE_RANGES) * 8}")