In [None]:
import numpy as np
from scipy import stats
import math
import os

import plotly.graph_objs as go
import plotly.express as px
from plotly.subplots import make_subplots

In [None]:
# FOLDER = "ChaCha-100-000-Fixed-Nonce-XMEGA"
FOLDER = "ChaCha-100-000-Fixed-Nonce-STM"

In [None]:
CIPHERTEXT_LEN_BYTES = 64
PLAINTEXT_LEN_BYTES  = CIPHERTEXT_LEN_BYTES
TRACE_CNT = None
TRACE_LEN = None
TRACE_RANDOM_CNT = None
CHUNK_SIZE = None
CHUNKS_CNT = None
LAST_CHUNK_SIZE = None
def read_info(folder):
    """
    Read the info.txt file to get the number of traces, chunk size, and trace length.
    The info.txt file should contain the following lines:
    - Number of traces
    - Chunk size
    - Trace length
    """
    global CHUNK_SIZE, LAST_CHUNK_SIZE
    global CHUNKS_CNT
    with open(f"{folder}/info.txt", 'r') as file:
        global TRACE_CNT
        global TRACE_LEN
        global TRACE_RANDOM_CNT
        TRACE_CNT = int(file.readline())
        TRACE_RANDOM_CNT = TRACE_CNT
        CHUNK_SIZE = int(file.readline())
        TRACE_LEN = int(file.readline())
        
    CHUNKS_CNT = math.ceil(TRACE_CNT / CHUNK_SIZE)
    LAST_CHUNK_SIZE = TRACE_CNT - (CHUNKS_CNT - 1)* CHUNK_SIZE
    print(f"TRACE_CNT = {TRACE_CNT}")   
    print(f"CHUNK_SIZE = {CHUNK_SIZE}")   
    print(f"LAST_CHUNK_SIZE = {LAST_CHUNK_SIZE}")   
    print(f"CHUNKS_CNT = {CHUNKS_CNT}")   
    print(f"TRACE_RANDOM_CNT = {TRACE_RANDOM_CNT}")   
    print(f"TRACE_LEN = {TRACE_LEN}")
    
    
read_info(FOLDER)

In [None]:
CIPHERTEXTS = None
def read_ciphertexts(folder):
    """
    Read the ciphertexts from the binary files in the specified folder.
    The ciphertexts are stored in chunks, and each chunk is a binary file.
    The chunk files are named "ciphertexts_random.bin" and are stored in folders named "chunk_0", "chunk_1", etc.
    """
    global CIPHERTEXTS
    global CHUNK_SIZE, LAST_CHUNK_SIZE
    global CHUNKS_CNT
    global TRACE_CNT
    # Calculate the number of chunks
    
    ciphertexts_list = []
    for chunk_index in range(CHUNKS_CNT):
        chunk_folder = os.path.join(folder, f"chunk_{chunk_index}")
        chunk_file = os.path.join(chunk_folder, "ciphertexts_random.bin")
        
        if os.path.exists(chunk_file):
            with open(chunk_file, 'rb') as file:
                byte_array = file.read()
            
            if chunk_index != CHUNKS_CNT-1:
                chunk_ciphertexts = np.frombuffer(byte_array, dtype=np.uint8).reshape((CHUNK_SIZE, CIPHERTEXT_LEN_BYTES))
            else:
                chunk_ciphertexts = np.frombuffer(byte_array, dtype=np.uint8).reshape((LAST_CHUNK_SIZE, CIPHERTEXT_LEN_BYTES))

            ciphertexts_list.append(chunk_ciphertexts)
        else:
            print(f"Chunk file {chunk_file} does not exist.")
    
    # Concatenate all chunks
    CIPHERTEXTS = np.vstack(ciphertexts_list)
        
    # Desired shape
    desired_shape = (TRACE_CNT, CIPHERTEXT_LEN_BYTES)

    # Check if the array has the desired shape
    if CIPHERTEXTS.shape != desired_shape:
        # Trim the array to the desired shape
        print(f"CIPHERTEXTS.shape != desired_shape - {CIPHERTEXTS.shape} != {desired_shape}")
        CIPHERTEXTS = CIPHERTEXTS[:desired_shape[0], :desired_shape[1]]
        print(f"Trimming CIPHERTEXTS to {CIPHERTEXTS.shape}")
        
    ciphertext = CIPHERTEXTS[0]

    # Convert each byte to its hexadecimal representation and join them
    hex_representation = ''.join(f'{byte:02x}' for byte in ciphertext[:64])

    # Print the result
    print(f"CT[0] = {hex_representation}")
    
    
read_ciphertexts(FOLDER)

In [None]:
PLAINTEXTS = None
def read_plaintexts(folder):
    """
    Read the plaintexts from the binary files in the specified folder.
    The plaintexts are stored in chunks, and each chunk is a binary file.
    The chunk files are named "plaintexts_random.bin" and are stored in folders named "chunk_0", "chunk_1", etc.
    """
    global PLAINTEXTS
    global CHUNK_SIZE, LAST_CHUNK_SIZE
    global CHUNKS_CNT
    global TRACE_CNT
    # Calculate the number of chunks
    
    pts_list = []
    for chunk_index in range(CHUNKS_CNT):
        chunk_folder = os.path.join(folder, f"chunk_{chunk_index}")
        chunk_file = os.path.join(chunk_folder, "plaintexts_random.bin")
        
        if os.path.exists(chunk_file):
            with open(chunk_file, 'rb') as file:
                byte_array = file.read()
            
            if chunk_index != CHUNKS_CNT-1:
                chunk_pts = np.frombuffer(byte_array, dtype=np.uint8).reshape((CHUNK_SIZE, PLAINTEXT_LEN_BYTES))
            else:
                chunk_pts = np.frombuffer(byte_array, dtype=np.uint8).reshape((LAST_CHUNK_SIZE, PLAINTEXT_LEN_BYTES))

            pts_list.append(chunk_pts)
        else:
            print(f"Chunk file {chunk_file} does not exist.")
    
    # Concatenate all chunks
    PLAINTEXTS = np.vstack(pts_list)
    
    # Desired shape
    desired_shape = (TRACE_CNT, PLAINTEXT_LEN_BYTES)

    # Check if the array has the desired shape
    if PLAINTEXTS.shape != desired_shape:
        # Trim the array to the desired shape
        PLAINTEXTS = PLAINTEXTS[:desired_shape[0], :desired_shape[1]]
        print(f"Trimming PLAINTEXTS to {PLAINTEXTS.shape}")
        
    pt = PLAINTEXTS[0]

    # Convert each byte to its hexadecimal representation and join them
    hex_representation = ''.join(f'{byte:02x}' for byte in pt[:64])

    # Print the result
    print(f"PT[0] = {hex_representation}")
    
read_plaintexts(FOLDER)

In [None]:
#load the correct key
correct_key = np.fromfile(f"{FOLDER}/key.bin", dtype=np.uint8)
hex_key = ''.join(f'{byte:02x}' for byte in correct_key)
print(f"Correct Key: {hex_key}")

In [None]:
#load the nonce used for encryption
# The nonce is a 12-byte value, so we read it as uint8
# The nonce is stored in a binary file named "nonce.bin" in the specified folder.
nonce = np.fromfile(f"{FOLDER}/nonce.bin", dtype=np.uint8)
hex_nonce = ''.join(f'{byte:02x}' for byte in nonce)
print(f"Nonce: {hex_nonce}")

In [None]:



# The keystream is a hexadecimal string.
# It is a sequence of hexadecimal digits that represent the keystream generated by the ChaCha20 algorithm.
# The keystream is used in the encryption process to combine with the plaintext to produce the ciphertext.
# The keystream is generated based on the key and nonce, and it is the same length as the plaintext or ciphertext.

# STM keystream
keystream_str = "73e96f51a6b7eeb730563e7db6f7ee22a95997abc498a52c141a1941769b3734805bcb9f529d93fcfbb46752889d4e560538c5e7d0a7cbce6e66da115323097b"

# XMEGA keystream
# keystream_str = "70b602649c3f6fc355e0cd92c77945e825eaf2e0c1bb90e9e09fa8635a786b62b897af7729b9c76894a9a864af4245e722963a355be562c80314b9a2562694b4"

# Step 1: Split the string into pairs of hexadecimal digits
hex_pairs = [keystream_str[i:i+2] for i in range(0, len(keystream_str), 2)]

# Step 2: Convert each pair into an integer
byte_values = [int(hex_pair, 16) for hex_pair in hex_pairs]

# Step 3: Create a NumPy array from these integers
keystream = np.array(byte_values, dtype=np.uint8)

hex_keystream = ''.join(f'{byte:02x}' for byte in keystream)
print(f"Keystream: {hex_keystream}")

In [None]:
TRACES_RANDOM = None

traces_list = []
# Read the traces from the binary files in the specified folder.
# The traces are stored in chunks, and each chunk is a binary file.
# The chunk files are named "traces_random.bin" and are stored in folders named "chunk_0", "chunk_1", etc.
for chunk_index in range(CHUNKS_CNT):
    chunk_folder = os.path.join(FOLDER, f"chunk_{chunk_index}")
    chunk_file = os.path.join(chunk_folder, "traces_random.bin")

    if os.path.exists(chunk_file):
        with open(chunk_file, 'rb') as file:
            byte_array = file.read()

        if chunk_index != CHUNKS_CNT-1:
            chunk_traces = np.frombuffer(byte_array, dtype=np.uint16).reshape((CHUNK_SIZE, TRACE_LEN))
        else:
            chunk_traces = np.frombuffer(byte_array, dtype=np.uint16).reshape((LAST_CHUNK_SIZE, TRACE_LEN))

        traces_list.append(chunk_traces)
    else:
        print(f"Chunk file {chunk_file} does not exist.")

# Concatenate all chunks
TRACES_RANDOM = np.vstack(traces_list)


# Desired shape
desired_shape = (TRACE_CNT, TRACE_LEN)

# Check if the array has the desired shape
if TRACES_RANDOM.shape != desired_shape:
    # Trim the array to the desired shape
    TRACES_RANDOM = TRACES_RANDOM[:desired_shape[0], :desired_shape[1]]
    print(f"Trimming TRACES_RANDOM to {TRACES_RANDOM.shape}")

In [None]:
%matplotlib notebook
import matplotlib.pylab as plt
plt.figure()
plt.title(f'First trace from set using constant PT - CHACHA20')
plt.plot(TRACES_RANDOM[0], 'g')
plt.show()


In [None]:
CHACHA_KEY_LEN_BYTES = 64
VALUES_IN_BYTE = 256 #0...255

In [None]:
# Hamming weight table for 8-bit values
# This table is used to calculate the Hamming weight of a byte value.
hamming_weight = [
   0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 
   2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 1, 2, 2, 3, 2, 3, 3, 4, 
   2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 
   5, 6, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 
   3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 
   4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 
   1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 
   4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 
   4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 2, 3, 
   3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 
   4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 
   6, 7, 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8]

def corr2_coeff(A, B):
    """
    Calculate the correlation coefficient between two matrices A and B.
    Each row of A and B is treated as a separate variable.
    The function returns a matrix of correlation coefficients.
    """
    # Rowwise mean of input arrays & subtract from input arrays themeselves
    A_mA = A - A.mean(1)[:, None]
    B_mB = B - B.mean(1)[:, None]

    # Sum of squares across rows
    ssA = (A_mA**2).sum(1)
    ssB = (B_mB**2).sum(1)

    # Finally get corr coeff
    return np.dot(A_mA, B_mB.T) / np.sqrt(np.dot(ssA[:, None],ssB[None]))

def check_obtained_key(correct_key, obtained_key_range):
    """
    Check the obtained key against the keystream generated using the correct key and print the result.
    The function compares each byte of the obtained key with the correct key and counts the number of correct bytes.
    It also prints a summary of the comparison.
    """
    correct_count = 0
    comparison_result = []

    for i in range(len(correct_key)):
        if correct_key[i] == obtained_key_range[i]:
            correct_count += 1
            comparison_result.append('✔️')
        else:
            comparison_result.append('❌')

    # Print the result
    print(f"{correct_count} / {len(correct_key)} bytes correct:")
    print(' '.join(comparison_result))


In [None]:
def print_array_in_hex(array):
    # Define a function to convert a single number to hexadecimal
    def to_hex(x):
        return hex(x)
    
    # Vectorize the to_hex function
    vectorized_to_hex = np.vectorize(to_hex)
    
    # Apply the vectorized function to the array
    hex_array = vectorized_to_hex(array)
    
    # Print the resulting array
    print(hex_array)
    
def print_array_in_binary(array):
    # Define a function to convert a single number to binary
    def to_binary(x):
        return bin(x)
    
    # Vectorize the to_binary function
    vectorized_to_binary = np.vectorize(to_binary)
    
    # Apply the vectorized function to the array
    binary_array = vectorized_to_binary(array)
    
    # Print the resulting array
    print(binary_array)

In [None]:
def get_correlations_for_byte(byte_number):
    """
    Calculate the correlation matrix for a specific byte of the ciphertext.
    The function computes the Hamming weight for each possible value of the byte and calculates the correlation with the traces.
    """
    global TRACE_CNT,VALUES_IN_BYTE, PLAINTEXTS, TRACES_RANDOM
    L_v1_v2 = np.empty((TRACE_CNT,VALUES_IN_BYTE), dtype=np.uint16)
    
    for j in range(TRACE_CNT):
        # i-th byte of CT matching the j-th trace
        ct_byte = PLAINTEXTS[j][byte_number] 
        # XOR CT[j][i] with 0...255
        v1_v2_line = np.array([ct_byte ^ key_val for key_val in range(VALUES_IN_BYTE)], dtype=np.uint16)
      
        L_v1_v2_line = np.array([hamming_weight[v1] for v1 in v1_v2_line], dtype=np.uint16)
        L_v1_v2[j, :] = L_v1_v2_line

    correlations = corr2_coeff(L_v1_v2.T, TRACES_RANDOM.T)
    return correlations

In [None]:
# STM32 - Trim TRACES_RANDOM to have only the first 23 traces
TRACE_CNT = 23
TRACES_RANDOM = TRACES_RANDOM[:TRACE_CNT, :]


In [None]:
# XMEGA - Trim TRACES_RANDOM 
OFFSET = 0
TRACE_CNT = 2650
TRACES_RANDOM = TRACES_RANDOM[OFFSET:OFFSET+TRACE_CNT, :]


In [None]:
correlations = []
for i in range(CHACHA_KEY_LEN_BYTES):
    correlations.append(get_correlations_for_byte(i))
print("Done")

In [None]:
if correlations:
    print(f"Size of the first correlation matrix: {correlations[0].shape}")
    print(f"Size of the first TRACES_RANDOM: {len(TRACES_RANDOM)}")

In [None]:
CORR_TRESHOLD = 0.5
INDENTATION = 10

## XMEGA CPA

In [None]:
def perform_CPA_in_range(list_of_correlations = [], start = 0, end = CHACHA_KEY_LEN_BYTES, graph = False):
    """
    Perform Correlation Power Analysis (CPA) in a specified range of bytes.
    The function calculates the correlation coefficients for each byte in the specified range and identifies the key byte with the highest correlation (min correlation in a row).
    It also plots the correlation coefficients for visualization.

    Parameters:
    - list_of_correlations: List of correlation matrices for each byte.
    - start: Starting index of the byte range to analyze.
    - end: Ending index of the byte range to analyze.
    - graph: Boolean flag to indicate whether to plot the correlation coefficients.

    Returns:
    - obtained_key_range: Array of obtained key bytes in the specified range.
    """

    global REMOVED_CHUNK_CORR_TRESHOLD, REMOVED_CHUNK_SIZE

    obtained_key_range = np.zeros(CHACHA_KEY_LEN_BYTES, dtype=np.ubyte)
    bytes_with_max_sum_index = np.zeros(CHACHA_KEY_LEN_BYTES, dtype=np.ubyte)
    for ANALYZED_BYTE in range(start, end):
        correlations = []
        if not list_of_correlations:
#             print("The list of correlations is empty")
            correlations = get_correlations_for_byte(ANALYZED_BYTE)
        else:
#             print("The list of correlations is not empty")
            correlations = list_of_correlations[ANALYZED_BYTE].copy()  # Create a copy

        start_col =0    
        # Iterate through each row to find the first column from the right with abs(correlation) >= 0.5
        col_index_max = -1
        for row_index, row in enumerate(correlations):
            for col_index in range(len(row) - 1, -1, -1):
                if abs(row[col_index]) >= CORR_TRESHOLD and col_index_max < col_index:
                    print ("Col = ", col_index)
                    col_index_max = col_index
                    break  # Stop after finding the first column from the right
                    
        start_col = max(0, col_index_max - INDENTATION)
        end_col = min(len(correlations[0]), col_index_max + INDENTATION)
        
        correlations = correlations[:, start_col:end_col]
        corr_min = np.unravel_index(correlations.argmin(), correlations.shape)
        obtained_key = corr_min[0]
        print(f"Obtained Expanded Key[{ANALYZED_BYTE}] = {hex(corr_min[0])} with min. correlation {correlations[corr_min[0]][corr_min[1]]} at position {corr_min[1]}")
        
        if graph:
            plt.figure(figsize=(7, 3))
            for corr_i in range(VALUES_IN_BYTE):
                if corr_i != keystream[0] and corr_i != corr_min[0]:
                    plt.plot(list_of_correlations[ANALYZED_BYTE][corr_i], 'k')
#             plt.plot(list_of_correlations[ANALYZED_BYTE][keystream[ANALYZED_BYTE]], 'g', alpha=0.6, label="Correct key")
            plt.plot(list_of_correlations[ANALYZED_BYTE][obtained_key], 'r', label="Correct key")
            plt.ylim(-1, 1)
            # Highlight the lowest value in yellow
            plt.scatter(corr_min[1]+ start_col, list_of_correlations[ANALYZED_BYTE][corr_min[0], corr_min[1] + start_col], color='yellow', zorder=5, label=f'The lowest value in a range ({start_col}, {end_col})')
#             plt.axvline(x=start_col, color='green', linestyle='--')
#             plt.axvline(x=end_col, color='green', linestyle='--')
            plt.ylabel('Correlation')
            plt.xlabel('Trace sample')
#             plt.xlim(400, 600)  # Limit the x-axis range to 0-1000
            plt.legend(loc='upper left')
            plt.show()
        
        obtained_key_range[ANALYZED_BYTE] = obtained_key
    return obtained_key_range

In [None]:
obtained_key_range = perform_CPA_in_range(correlations, 63, 64, True)

In [None]:
obtained_key_range = perform_CPA_in_range(correlations, 0, 64)

In [None]:
check_obtained_key(keystream, obtained_key_range)

## STM CPA

In [None]:
def perform_CPA_in_range_STM(list_of_correlations = [], start = 0, end = CHACHA_KEY_LEN_BYTES, graph = False):
    """
    Perform Correlation Power Analysis (CPA) in a specified range of bytes, modified for STM platform.
    The function calculates the correlation coefficients for each byte in the specified range and identifies the key byte with the highest correlation (min correlation in a row).
    It also plots the correlation coefficients for visualization.

    Parameters:
    - list_of_correlations: List of correlation matrices for each byte.
    - start: Starting index of the byte range to analyze.
    - end: Ending index of the byte range to analyze.
    - graph: Boolean flag to indicate whether to plot the correlation coefficients.

    Returns:
    - obtained_key_range: Array of obtained key bytes in the specified range.
    """
    global REMOVED_CHUNK_CORR_TRESHOLD, REMOVED_CHUNK_SIZE
    obtained_key_range = np.zeros(CHACHA_KEY_LEN_BYTES, dtype=np.ubyte)
    bytes_with_max_sum_index = np.zeros(CHACHA_KEY_LEN_BYTES, dtype=np.ubyte)
    
    min_guess_correct_cnt = 0
    
    for ANALYZED_BYTE in range(start, end):
        correlations = []
        if not list_of_correlations:
            print("The list of correlations is empty")
            correlations = get_correlations_for_byte(ANALYZED_BYTE)
        else:
            print("The list of correlations is not empty")
            correlations = list_of_correlations[ANALYZED_BYTE].copy()  # Create a copy

        corr_min = np.unravel_index(correlations.argmin(), correlations.shape)
        obtained_key = corr_min[0]
        print(f"Obtained Expanded Key [{ANALYZED_BYTE}] = {hex(corr_min[0])} with min. correlation {correlations[corr_min[0]][corr_min[1]]} at position {corr_min[1]}")
  
        start = max(0, corr_min[1] - 100)
        end = min(len(correlations[0]), corr_min[1] + 100)
        result = ""
        
        if keystream[ANALYZED_BYTE] == corr_min[0]:
            min_guess_correct_cnt += 1
#             print(f"min_guess_correct_cnt - Key[{ANALYZED_BYTE}] = {hex(corr_min[0])} with min. correlation {correlations[corr_min[0]][corr_min[1]]} at position {corr_min[1]}")

        if keystream[ANALYZED_BYTE] == obtained_key:
            result = "="
        else:
            result = "!"
            print(f"Wrong Key[{ANALYZED_BYTE}]")

        if graph:
            plt.figure(figsize=(7, 3))
            for corr_i in range(VALUES_IN_BYTE):
                if corr_i != keystream[0] and corr_i != corr_min[0]:
                    plt.plot(correlations[corr_i], 'k')
#             plt.plot(correlations[keystream[ANALYZED_BYTE]], 'g.', alpha=0.6, label=f'keystream[{ANALYZED_BYTE}] = {hex(keystream[ANALYZED_BYTE])}')
            plt.plot(correlations[obtained_key], 'r')
            plt.ylim(-1, 1)
            # Highlight the lowest value in yellow
            plt.scatter(corr_min[1], correlations[corr_min[0], corr_min[1]], color='yellow', zorder=5, label='Lowest value')

            plt.ylabel('Correlation')
            plt.xlabel('Trace sample')
            plt.show()

        obtained_key_range[ANALYZED_BYTE] = obtained_key
    print (f"min_guess_correct_cnt = {min_guess_correct_cnt}")
    return obtained_key_range

In [None]:
obtained_key_range = perform_CPA_in_range_STM(correlations, 0, 64)

In [None]:
check_obtained_key(keystream, obtained_key_range)
    

In [None]:
obtained_key_range = perform_CPA_in_range_STM(correlations, 0, 1, True)