# Exp.3 Train MLP with simulation data

- Dataset: ASCADv1 simulation dataset  
    Plain/ciphertext, key, and mask are same as ASCADv1 fixed-key dataset.  
    Each sample point is derived from HW($snr$).
- PoI: 1~4 pts, combination of $snr2$ to $snr5$
- #traces: 50,000 (profiling), 20,000 (attack)
- Preprocess: min-max mormalization to [0.0, 1.0]
- Labeling: LSB of unmasked SBox output (two class classification)
- Loss function: Softmax cross entropy (torch.nn.BCEWithLogitsLoss)
- Model architecture: [1~4(Input), 70(ReLU), 50(ReLU), 2(Softmax)]  
    This is used on DDLA paper (Timon)

| $snr$ # | Leakage                            |                                         |
| ---     | ---                                | ---                                     |
| 2       | Masked Sbox Output                 | $SBox(p[3] \oplus k[3]) \oplus r_{out}$ |
| 3       | Mask of Sbox output                | $r_{out}$                               |
| 4       | Masked sbox output in linear part  | $SBox(p[3] \oplus k[3]) \oplus r[3]$    |
| 5       | Mask of Sbox output in linear part | $r[3]$                                  |

In [1]:
import torch
from pathlib import Path
import os
import hydra
import numpy as np
import matplotlib.pyplot as plt
import pickle

os.chdir('/workspace')
import src

In [2]:
def get_cfg(pois, output_size, name):
    with hydra.initialize(config_path="../conf", version_base='1.1'):
        cfg = hydra.compose(
            config_name='config',
            overrides=[
                "model=MLP_DDLAsim",
                "model.train_params.steps=2000",
                "model.name=MLP_DDLAsim"+name,
                "model.train_params.batch=1000",
                "dataset@train=ASCADf_profiling",
                "dataset@test=ASCADf_attack",
                "trace_transforms=set_poi",
                "trace_transforms.transforms.0.pois="+pois,
                "trace_transforms.output_size="+output_size,
                "label_transforms=bit",
                "label_transforms.transforms.3.pos=0",
                "save_path=/workspace/notebook/results/exp3/3",
                "n_attack_traces=2000"
                ]
            )
    return cfg

In [3]:
def train(cfg):
    device = hydra.utils.instantiate(cfg.device)
    cpu = torch.device('cpu')

    # Dataset
    ds_prof = hydra.utils.instantiate(cfg.train.dataset)
    ds_test = hydra.utils.instantiate(cfg.test.dataset)

    target_byte = 2
    profiling_dataset = src.datasets.ASCAD_sim.Dataset(
        ds_prof.plaintext,
        ds_prof.key,
        ds_prof.masks,
        target_byte=target_byte,
        trace_transforms=hydra.utils.instantiate(
            cfg.train.dataset.trace_transforms),
        label_transforms=hydra.utils.instantiate(
            cfg.train.dataset.label_transforms)
        )

    test_dataset = src.datasets.ASCAD_sim.Dataset(
        ds_test.plaintext,
        ds_test.key,
        ds_test.masks,
        target_byte=target_byte,
            trace_transforms=hydra.utils.instantiate(
            cfg.train.dataset.trace_transforms),
        label_transforms=hydra.utils.instantiate(
            cfg.train.dataset.label_transforms)
        )

    train_dataloader = torch.utils.data.DataLoader(
        profiling_dataset, batch_size=cfg.train.batch_size, shuffle=True
    )
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset, batch_size=cfg.test.batch_size, shuffle=False
    )

    # Train
    model = hydra.utils.instantiate(cfg.model.model)
    if not Path(cfg.save_path, cfg.model_name+'.ckpt').exists():
        model = model.to(device)
        train_kwargs = hydra.utils.instantiate(cfg.model.train_params)
        train_kwargs.opt = train_kwargs.opt(model.parameters())
        _ = src.trainDNN.train(
            model,
            train_dataloader,
            test_dataloader,
            **train_kwargs
            )
        model = model.to(cpu)
        Path(cfg.save_path).mkdir(exist_ok=True, parents=True)
        torch.save(model.state_dict(), Path(cfg.save_path, cfg.model_name+'.pt'))
    else:
        model.load_state_dict(torch.load(
            Path(cfg.save_path, cfg.model_name+'.pt')))
    model = model.to(device)

    # Test
    preds, labels, th = src.utils.make_prediction(
    model, test_dataloader, device,
    cfg.label_transforms.one_hot)
    preds_class = np.argmax(preds, axis=1)
    accuracy = np.mean(labels == preds_class)
    print('Accuracy: ', accuracy)

In [4]:
import itertools
for i in range(1,5):
    poi_candidates = [v for v in itertools.combinations(range(1,5), i)]
    for p in poi_candidates:
        pois = ""
        name = ""
        for v in p:
            pois += f"[{v},{v+1},1],"
            name += str(v+1)
        cfg = get_cfg(f"[{pois[:-1]}]", f"{len(p)}", name)
        print(f"POIs: [{pois[:-1]}]")
        train(cfg)

POIs: [[1,2,1]]


100%|██████████| 2000/2000 [10:37<00:00,  3.14it/s]


Accuracy:  0.5031
POIs: [[2,3,1]]


100%|██████████| 2000/2000 [10:32<00:00,  3.16it/s]


Accuracy:  0.499
POIs: [[3,4,1]]


100%|██████████| 2000/2000 [10:22<00:00,  3.21it/s]


Accuracy:  0.4989
POIs: [[4,5,1]]


100%|██████████| 2000/2000 [10:24<00:00,  3.20it/s]


Accuracy:  0.4949
POIs: [[1,2,1],[2,3,1]]


100%|██████████| 2000/2000 [13:52<00:00,  2.40it/s]


Accuracy:  0.5404
POIs: [[1,2,1],[3,4,1]]


100%|██████████| 2000/2000 [13:41<00:00,  2.43it/s]


Accuracy:  0.5031
POIs: [[1,2,1],[4,5,1]]


100%|██████████| 2000/2000 [13:35<00:00,  2.45it/s]


Accuracy:  0.4939
POIs: [[2,3,1],[3,4,1]]


100%|██████████| 2000/2000 [13:22<00:00,  2.49it/s]


Accuracy:  0.5056
POIs: [[2,3,1],[4,5,1]]


100%|██████████| 2000/2000 [13:31<00:00,  2.47it/s]


Accuracy:  0.4955
POIs: [[3,4,1],[4,5,1]]


100%|██████████| 2000/2000 [13:34<00:00,  2.46it/s]


Accuracy:  0.5373
POIs: [[1,2,1],[2,3,1],[3,4,1]]


100%|██████████| 2000/2000 [16:43<00:00,  1.99it/s]


Accuracy:  0.5414
POIs: [[1,2,1],[2,3,1],[4,5,1]]


100%|██████████| 2000/2000 [16:53<00:00,  1.97it/s]


Accuracy:  0.5417
POIs: [[1,2,1],[3,4,1],[4,5,1]]


100%|██████████| 2000/2000 [16:57<00:00,  1.97it/s]


Accuracy:  0.5347
POIs: [[2,3,1],[3,4,1],[4,5,1]]


100%|██████████| 2000/2000 [16:59<00:00,  1.96it/s]


Accuracy:  0.5381
POIs: [[1,2,1],[2,3,1],[3,4,1],[4,5,1]]


100%|██████████| 2000/2000 [19:43<00:00,  1.69it/s]


Accuracy:  0.5627


Simulation results

| PoIs    | Accuracy |
| ---     | ---      |
| 2       | 0.5031   |
| 3       | 0.4990   |
| 4       | 0.4989   |
| 5       | 0.4949   |
| 2,3     | 0.5404   |
| 2,4     | 0.5031   |
| 2,5     | 0.4939   |
| 3,4     | 0.5056   |
| 3,5     | 0.4955   |
| 4,5     | 0.5373   |
| 2,3,4   | 0.5414   |
| 2,3,5   | 0.5417   |
| 2,4,5   | 0.5347   |
| 3,4,5   | 0.5381   |
| 2,3,4,5 | 0.5627   |