In [4]:
import sys
import os
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(parent_dir)

from sklearn import preprocessing
import numpy as np
import pickle as pkl
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow import keras
from utils.dataset_utils import batch_generator 
from DANN_trainer import DANN
from utils.LoadData import load_CW_Source,load_CW_Target_validation
from utils.DANN_config import config


In [5]:
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [6]:
def run_main():
    """Main function for side-channel analysis using domain adaptation"""
    # Initialize power trace parameters
    power_input_shape = (800, 1)  # Original trace length and channels
    trace_length = 800           # Renamed from image_size
    init_learning_rate = 5e-4
    momentum_rate = 0.9
    batch_size = 2000
    epochs = 60
    pre_trained_path = None
    output_path = 'E:\DL_result\DANN\XMEGA_MDS_CT/'
    profiling_Data_path='../Dataset/AES_device1/'
    Target_Data_path='../Dataset/SM4_device1/'

    for byte_num in range(0, 1):
        print(f"================ Processing Byte {byte_num} ================")
        
        # Configure paths for SCA results
        checkpoints_dir = os.path.abspath(output_path + f"/models/byte={byte_num}")
        

        # Create configuration object
        sca_config=config(
            pre_trained_path=pre_trained_path,
            checkpoints_dir=checkpoints_dir,
            power_input_shape=power_input_shape,
            trace_length=trace_length,
            init_learning_rate=init_learning_rate,
            momentum_rate=momentum_rate,
            batch_size=batch_size,
            epochs=epochs
        )

        # Load source device power traces
        src_traces_train, src_labels_train,src_traces_val, src_labels_val,_,_ = load_CW_Source(
            in_file=profiling_Data_path,
            sec=18000,  # Fixed security parameter from original implementation
            byte=byte_num
        )
        
        # Convert labels for SCA classification (256 classes for byte values)
        src_labels_train = to_categorical(src_labels_train).astype(np.float32)
        src_labels_val = to_categorical(src_labels_val).astype(np.float32)

        # Preprocess power traces
        scaler = preprocessing.StandardScaler()
        src_traces_train = scaler.fit_transform(src_traces_train)
        src_traces_val = scaler.transform(src_traces_val)

        # Load target device traces
        # Note: Target domain labels are only used for validation accuracy, not for gradient updates
        tgt_traces_train, tgt_labels_train, tgt_traces_val, tgt_labels_val, _,_ = load_CW_Target_validation(
            in_file=Target_Data_path,
            sec=18000,  # Fixed security parameter from original implementation
            byte=byte_num
        )
       

        # Process target device data
        tgt_labels_train = to_categorical(tgt_labels_train).astype(np.float32)
        tgt_labels_val = to_categorical(tgt_labels_val).astype(np.float32)
        tgt_traces_train = scaler.transform(tgt_traces_train)
        tgt_traces_val = scaler.transform(tgt_traces_val)

        # Reshape traces for CNN input
        src_traces_train = src_traces_train.reshape((-1, power_input_shape[0], 1))
        tgt_traces_train = tgt_traces_train.reshape((-1, power_input_shape[0], 1))
        tgt_traces_val = tgt_traces_val.reshape((-1, power_input_shape[0], 1))

        # Create trace data generators
        train_source_gen = batch_generator(
            [src_traces_train, src_labels_train], 
            batch_size=batch_size//2
        )
        train_target_gen = batch_generator(
            [tgt_traces_train, tgt_labels_train],
            batch_size=batch_size//2,
            shuffle=False
        )
        val_target_gen = batch_generator(
            [tgt_traces_val, tgt_labels_val],
            batch_size=batch_size
        )

        # Calculate training iterations
        train_iters = max(
            len(src_traces_train) // (batch_size//2),
            len(tgt_traces_train) // (batch_size//2)
        )
        val_iters = len(tgt_traces_val) // batch_size

        # Initialize fresh model session
        tf.keras.backend.clear_session()

        # Initialize and train SCA model
        sca_model = DANN(sca_config)
        sca_model.train(train_source_gen,train_target_gen,val_target_gen,train_iters,val_iters
        )

In [1]:
if __name__ == '__main__':
    run_main()