 ## Importar paquetes
 
 trsfile interpreta el archivo que contiene las trazas

In [None]:
import trsfile
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import StandardScaler
from tqdm.notebook import tnrange, tqdm

In [None]:
print (trsfile.__version__)

## Definir constantes

In [None]:
trs_dataset_path = r'trace_dataset.trs'
trs_dataset = trsfile.open(trs_dataset_path, mode='r')
plot_size = {'width':15, 'height':4}

AES_Sbox = np.array([
        0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
        0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
        0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15,
        0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75,
        0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84,
        0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF,
        0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8,
        0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2,
        0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73,
        0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB,
        0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79,
        0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08,
        0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A,
        0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E,
        0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
        0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16
        ], dtype=np.uint8)

INV_SBOX = np.array([
    0x52, 0x09, 0x6A, 0xD5, 0x30, 0x36, 0xA5, 0x38, 0xBF, 0x40, 0xA3, 0x9E, 0x81, 0xF3, 0xD7, 0xFB,
    0x7C, 0xE3, 0x39, 0x82, 0x9B, 0x2F, 0xFF, 0x87, 0x34, 0x8E, 0x43, 0x44, 0xC4, 0xDE, 0xE9, 0xCB,
    0x54, 0x7B, 0x94, 0x32, 0xA6, 0xC2, 0x23, 0x3D, 0xEE, 0x4C, 0x95, 0x0B, 0x42, 0xFA, 0xC3, 0x4E,
    0x08, 0x2E, 0xA1, 0x66, 0x28, 0xD9, 0x24, 0xB2, 0x76, 0x5B, 0xA2, 0x49, 0x6D, 0x8B, 0xD1, 0x25,
    0x72, 0xF8, 0xF6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xD4, 0xA4, 0x5C, 0xCC, 0x5D, 0x65, 0xB6, 0x92,
    0x6C, 0x70, 0x48, 0x50, 0xFD, 0xED, 0xB9, 0xDA, 0x5E, 0x15, 0x46, 0x57, 0xA7, 0x8D, 0x9D, 0x84,
    0x90, 0xD8, 0xAB, 0x00, 0x8C, 0xBC, 0xD3, 0x0A, 0xF7, 0xE4, 0x58, 0x05, 0xB8, 0xB3, 0x45, 0x06,
    0xD0, 0x2C, 0x1E, 0x8F, 0xCA, 0x3F, 0x0F, 0x02, 0xC1, 0xAF, 0xBD, 0x03, 0x01, 0x13, 0x8A, 0x6B,
    0x3A, 0x91, 0x11, 0x41, 0x4F, 0x67, 0xDC, 0xEA, 0x97, 0xF2, 0xCF, 0xCE, 0xF0, 0xB4, 0xE6, 0x73,
    0x96, 0xAC, 0x74, 0x22, 0xE7, 0xAD, 0x35, 0x85, 0xE2, 0xF9, 0x37, 0xE8, 0x1C, 0x75, 0xDF, 0x6E,
    0x47, 0xF1, 0x1A, 0x71, 0x1D, 0x29, 0xC5, 0x89, 0x6F, 0xB7, 0x62, 0x0E, 0xAA, 0x18, 0xBE, 0x1B,
    0xFC, 0x56, 0x3E, 0x4B, 0xC6, 0xD2, 0x79, 0x20, 0x9A, 0xDB, 0xC0, 0xFE, 0x78, 0xCD, 0x5A, 0xF4,
    0x1F, 0xDD, 0xA8, 0x33, 0x88, 0x07, 0xC7, 0x31, 0xB1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xEC, 0x5F,
    0x60, 0x51, 0x7F, 0xA9, 0x19, 0xB5, 0x4A, 0x0D, 0x2D, 0xE5, 0x7A, 0x9F, 0x93, 0xC9, 0x9C, 0xEF,
    0xA0, 0xE0, 0x3B, 0x4D, 0xAE, 0x2A, 0xF5, 0xB0, 0xC8, 0xEB, 0xBB, 0x3C, 0x83, 0x53, 0x99, 0x61,
    0x17, 0x2B, 0x04, 0x7E, 0xBA, 0x77, 0xD6, 0x26, 0xE1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0C, 0x7D
    ], dtype=np.uint8)

RCON = (
    0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40,
    0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D, 0x9A,
    0x2F, 0x5E, 0xBC, 0x63, 0xC6, 0x97, 0x35, 0x6A,
    0xD4, 0xB3, 0x7D, 0xFA, 0xEF, 0xC5, 0x91, 0x39,
)

## Caracteristicas del dataset

In [None]:
print ('Número de trazas:', len(trs_dataset))
print ('-----------------------------------------------------------')
print ('Número de muestras (puntos por traza): ', len(trs_dataset[0]))
print ('-----------------------------------------------------------')
data_sample = np.frombuffer(trs_dataset[0].data, dtype=np.uint8)
data_sample2 = np.frombuffer(trs_dataset[1].data, dtype=np.uint8)
print ('Longitud de la metadata:', len(data_sample))
print ('-----------------------------------------------------------')
print ('Metadata de la primera traza:', data_sample)
print ('-----------------------------------------------------------')
print ('Metadata de la segunda traza:', data_sample2)
print ('-----------------------------------------------------------')
print ('plaintext de la primera traza:', data_sample[0:16])
print ('-----------------------------------------------------------')
print ('plaintext de la segunda traza:', data_sample2[0:16])
print ('-----------------------------------------------------------')
print ('ciphertext de la primera traza:', data_sample[16:])
print ('-----------------------------------------------------------')
print ('ciphertext de la segunda traza:', data_sample2[16:])
print ('-----------------------------------------------------------')

## Plot de demostración (traza recopilada durante el proceso de desencriptado)

In [None]:
plt.style.use('./plot_styles/pltstyle.mplstyle')
plt.plot(trs_dataset[0])
plt.show()
plt.close()

## Funciones de [pre-]procesamiento

In [None]:
##--
##
##---------------------------------------------------------------------------------------------------
def compute_mean_std_data_byte(trs_file, number_traces, plaintext_byte):   
    # Create a StandardScaler for each byte position
    trace_mean_std  = StandardScaler()
    byte_scaler = StandardScaler()

    for i in tnrange(number_traces, desc='[INFO]: computing mean and std (byte pos: {})'.format(plaintext_byte)):
        # partially fit the scaler of a byte position
        meta = trs_file[i].data
        byte_scaler.partial_fit(np.array(meta[plaintext_byte]).reshape(1, -1))
        trace_mean_std.partial_fit(np.array(trs_file[i]).reshape(1, -1))

    return ([trace_mean_std.mean_, trace_mean_std.var_], [byte_scaler.mean_, byte_scaler.var_])

##--
##
##---------------------------------------------------------------------------------------------------
def compute_corr(trs_file, number_traces, plaintext_byte):
    # Get group
    n_samples = len(trs_file[0])
    samples_corr    = np.zeros(shape=(n_samples,), dtype=np.float64)

    all_means_std = compute_mean_std_data_byte(trs_file, number_traces, plaintext_byte)
    samples_mean  = all_means_std[0][0]
    samples_std   = np.sqrt(all_means_std[0][1])
    metadata_mean = all_means_std[1][0]
    metadata_std  = np.sqrt(all_means_std[1][1])

    for i in tnrange(number_traces, desc='[INFO]: computing correlation (byte pos: {})'.format(plaintext_byte)):
        meta         = trs_file[i].data
        samples_corr = np.add(samples_corr, (trs_file[i] - samples_mean) * (meta[plaintext_byte] - metadata_mean))

    if np.count_nonzero(metadata_std) == 0 or np.count_nonzero(samples_std) == 0:
        print ('[WARNING]: Metadata or samples standard deviation of AES Sbox plaintext {} and key {} is zero'.format(plaintext_byte, key_byte))
        print ('[INFO]: Returning zero correlation')
    else:
        samples_corr = np.true_divide(samples_corr, ((number_traces - 1) * samples_std * metadata_std))
    return samples_corr

##--
##
##---------------------------------------------------------------------------------------------------
def compute_mean_std_sbox(trs_file, number_traces, plaintext_byte, key_byte):
    # Create a StandardScaler for each byte position
    trace_mean_std  = StandardScaler()
    aes_sbox_scaler = StandardScaler()

    for i in tnrange(number_traces, desc='[INFO]: mean and std sbox(pt pos: {}, key pos: {})'.format(plaintext_byte, key_byte)):
        # partially fit the scaler of a byte position
        meta = trs_file[i]
        aes_sbox_scaler.partial_fit(np.array(AES_Sbox[meta[plaintext_byte] ^ meta[key_byte]]).reshape(1, -1))
        trace_mean_std.partial_fit(np.array(trs_file[i]).reshape(1, -1))

    return ([trace_mean_std.mean_, trace_mean_std.var_], [aes_sbox_scaler.mean_, aes_sbox_scaler.var_])

##--
##
##---------------------------------------------------------------------------------------------------
def compute_corr_sbox(trs_file, number_traces, plaintext_byte, key_byte):
    
    n_samples = len(trs_file[0])
    samples_corr    = np.zeros(shape=(n_samples,), dtype=np.float64)

    all_means_std = compute_mean_std_sbox(trs_file, number_traces, plaintext_byte, key_byte)
    samples_mean  = all_means_std[0][0]
    samples_std   = np.sqrt(all_means_std[0][1])
    metadata_mean = all_means_std[1][0]
    metadata_std  = np.sqrt(all_means_std[1][1])

    for i in tnrange(number_traces, desc='[INFO]: computing corr sbox(pt: {},key: {})'.format(plaintext_byte, key_byte)):
        meta         = trs_file[i]
        samples_corr = np.add(samples_corr, (trs_file[i] - samples_mean) * (AES_Sbox[meta[plaintext_byte] ^ meta[key_byte]] - metadata_mean))

    if np.count_nonzero(metadata_std) == 0 or np.count_nonzero(samples_std) == 0:
        print ('[WARNING]: Metadata or samples standard deviation of AES Sbox plaintext {} and key {} is zero'.format(plaintext_byte, key_byte))
        print ('[INFO]: Returning zero correlation')
    else:
        samples_corr = np.true_divide(samples_corr, ((number_traces - 1) * samples_std * metadata_std))
    return samples_corr

## Analisis de correlación

In [None]:
corr = compute_corr(trs_dataset, 2000, 16)

In [None]:
plt.style.use('./plot_styles/pltstyle.mplstyle')
plt.plot(np.abs(corr))
plt.show()
plt.close()

## Realizar ataque utilizando Diferential Power Analysis (DPA)

### Funciones para back-scheduling

In [None]:
def forward_key_schedule(key, n_rounds):
    round_keys = list(key)
    for i in range(4, 4*(n_rounds+1)):
        a0, a1, a2, a3 = round_keys[(i-1)*4 : i*4]
        if i % 4 == 0:
            a0, a1, a2, a3 = AES_Sbox[a1], AES_Sbox[a2], AES_Sbox[a3], AES_Sbox[a0]
            a0 = a0 ^ RCON[i//4]
        b0, b1, b2, b3 = round_keys[(i-4)*4 : (i-3)*4]
        round_keys.extend([a0^b0, a1^b1, a2^b2, a3^b3])
    return round_keys

def backward_key_schedule(last_round_key, n_rounds):
    round_keys = list(last_round_key)
    for i in range(n_rounds, 0, -1):
        b12 = round_keys[12] ^ round_keys[8]
        b13 = round_keys[13] ^ round_keys[9]
        b14 = round_keys[14] ^ round_keys[10]
        b15 = round_keys[15] ^ round_keys[11]

        b8  = round_keys[8 ] ^ round_keys[4]
        b9  = round_keys[9 ] ^ round_keys[5]
        b10 = round_keys[10] ^ round_keys[6]
        b11 = round_keys[11] ^ round_keys[7]

        b4 = round_keys[4] ^ round_keys[0]
        b5 = round_keys[5] ^ round_keys[1]
        b6 = round_keys[6] ^ round_keys[2]
        b7 = round_keys[7] ^ round_keys[3]

        a0, a1, a2, a3 = AES_Sbox[b13], AES_Sbox[b14], AES_Sbox[b15], AES_Sbox[b12]
        a0 = a0 ^ RCON[i]

        b0 = a0 ^ round_keys[0]
        b1 = a1 ^ round_keys[1]
        b2 = a2 ^ round_keys[2]
        b3 = a3 ^ round_keys[3]
        
        round_keys = [b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15] + round_keys
    return round_keys

### Algoritmo del DPA

In [None]:
databyte_pos_init = 16
databyte_pos_end  = 32
data_length       = 16 #por AES-128
used_points_init  = 0
used_points_end   = 29000
verbose           = True

number_of_bits = 8

ciphertext_array = np.empty(shape=(len(trs_dataset), data_length), dtype=np.uint8)
interval = [0, ciphertext_array.shape[0]]

for i in tnrange(interval[0], interval[1], desc='[INFO]: Getting ciphertext'):
    ciphertext_array[i] = np.frombuffer(trs_dataset[i].data[databyte_pos_init:databyte_pos_end], dtype=np.uint8)

for bit in tnrange(number_of_bits, desc="[INFO]: Iterando por los bits del byte"):
    recovered_key = []
    for byte_pos in tnrange(16, desc="[INFO]: Iterando por los bytes de la posible llave"):
        delta = []
        for key_guess in tnrange(256, desc="[INFO]: Iterando por todas los posibles valores de un byte de la llave", leave=False):
            zero_count = 0
            one_count  = 0
            zero_list  = np.array([0.0] * (used_points_end-used_points_init))
            one_list   = np.array([0.0] * (used_points_end-used_points_init))
            
            #-------------------------------------------------------------------------------------------
            # Una de las partes mas interesante --------------------------------------------------------
            for trace_index in range (interval[0], interval[1]):
                intermediate_value = INV_SBOX[ciphertext_array[trace_index][byte_pos] ^ key_guess]
                target_bit = (intermediate_value >> bit) & 1
                if target_bit == 0:
                    zero_list  += trs_dataset[trace_index][used_points_init:used_points_end]
                    zero_count += 1
                else:
                    one_list   += trs_dataset[trace_index][used_points_init:used_points_end]
                    one_count  += 1
                    
            #-------------------------------------------------------------------------------------------
            
            # Aqui se calcula las medias de los grupos (del grupo 1 y del grupo 0)                        
            mean_delta_accu = np.abs((one_list/one_count) - (zero_list/zero_count))
            delta.append(np.max(mean_delta_accu))
        
        assert len(delta) == 256
        delta = np.array(delta)
        predicted_byte = delta.argmax()
        recovered_key.append(predicted_byte)
    print("Round key:", bytes(recovered_key[:16]).hex())
    round_keys = backward_key_schedule(recovered_key, n_rounds=10)
    print("Possible key:", bytes(round_keys[:16]).hex())


## Probando llave

In [None]:
from Crypto.Cipher import AES

key = bytes.fromhex("<<Insertar la clave>>")


aes128 = AES.new(key, AES.MODE_ECB)
verbose = True
count = 0
bar = tqdm(len(trs_dataset), desc="[INFO]: Probando la llave contra todos los plaintext")
for aTrace in trs_dataset:
    plaintext = aTrace.data[:16]
    ciphertext = aes128.encrypt(plaintext)
    ciphertext_from_trace = aTrace.data[16:]
    if verbose and count == 0:
        print("[INFO]: Plaintext", plaintext.hex())
        print("[INFO]: Ciphertext (plaintext encriptado por el algoritmo):")
        print(ciphertext.hex())
        print("[INFO]: ciphertext contenido en la información de la traza:")
        print(ciphertext_from_trace.hex())
        count += 1
        if ciphertext.hex() == ciphertext_from_trace.hex():
            print("[INFO]: CONTRASEÑA DESCUBIERTA!!")
        else:
            print("[INFO]: TU NO PASARAS!!")
    bar.update()