# Exp.3 Train KAN with simulation data

- Dataset: ASCADv1 simulation dataset  
    Plain/ciphertext, key, and mask are same as ASCADv1 fixed-key dataset.  
    Each sample point is derived from HW($snr$).
- PoI: 1~4 pts, combination of $snr2$ to $snr5$
- #traces: 50,000 (profiling), 20,000 (attack)
- Preprocess: min-max mormalization to [0.0, 1.0]
- Labeling: LSB of unmasked SBox output (two class classification)
- Loss function: Softmax cross entropy (torch.nn.BCEWithLogitsLoss)
- Model architecture: [1~4(Input), 5, 1, 2(Softmax)]  

| $snr$ # | Leakage                            |                                         |
| ---     | ---                                | ---                                     |
| 2       | Masked Sbox Output                 | $SBox(p[3] \oplus k[3]) \oplus r_{out}$ |
| 3       | Mask of Sbox output                | $r_{out}$                               |
| 4       | Masked sbox output in linear part  | $SBox(p[3] \oplus k[3]) \oplus r[3]$    |
| 5       | Mask of Sbox output in linear part | $r[3]$                                  |

In [1]:
import torch
from pathlib import Path
import os
import hydra
import numpy as np
import matplotlib.pyplot as plt
import pickle

os.chdir('/workspace')
import src

In [2]:
def get_cfg(pois, output_size, name):
    with hydra.initialize(config_path="../conf", version_base='1.1'):
        cfg = hydra.compose(
            config_name='config',
            overrides=[
                "model=KAN2h",
                "model.model.width.1=5",
                "model.model.width.2=1",
                "model.train_params.steps=7000",
                "model.name=KAN1h_"+name,
                "dataset@train=ASCADf_profiling",
                "dataset@test=ASCADf_attack",
                "trace_transforms=set_poi",
                "trace_transforms.transforms.0.pois="+pois,
                "trace_transforms.output_size="+output_size,
                "label_transforms=bit",
                "label_transforms.transforms.3.pos=0",
                "save_path=/workspace/notebook/results/exp3/1",
                "n_attack_traces=2000"
                ]
            )
    return cfg

In [3]:
def train(cfg):
    device = hydra.utils.instantiate(cfg.device)
    cpu = torch.device('cpu')

    # Dataset
    ds_prof = hydra.utils.instantiate(cfg.train.dataset)
    ds_test = hydra.utils.instantiate(cfg.test.dataset)

    target_byte = 2
    profiling_dataset = src.datasets.ASCAD_sim.Dataset(
        ds_prof.plaintext,
        ds_prof.key,
        ds_prof.masks,
        target_byte=target_byte,
        trace_transforms=hydra.utils.instantiate(
            cfg.train.dataset.trace_transforms),
        label_transforms=hydra.utils.instantiate(
            cfg.train.dataset.label_transforms)
        )

    test_dataset = src.datasets.ASCAD_sim.Dataset(
        ds_test.plaintext,
        ds_test.key,
        ds_test.masks,
        target_byte=target_byte,
            trace_transforms=hydra.utils.instantiate(
            cfg.train.dataset.trace_transforms),
        label_transforms=hydra.utils.instantiate(
            cfg.train.dataset.label_transforms)
        )

    train_dataloader = torch.utils.data.DataLoader(
        profiling_dataset, batch_size=cfg.train.batch_size, shuffle=True
    )
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset, batch_size=cfg.test.batch_size, shuffle=False
    )
    KANds = src.utils.to_KAN_dataset(
        train_dataloader, test_dataloader,
        device=device)

    # Train
    model = hydra.utils.instantiate(cfg.model.model)
    if not Path(cfg.save_path, cfg.model_name+'.ckpt').exists():
        model = model.to(device)
        _ = model.train(
            KANds,
            **hydra.utils.instantiate(cfg.model.train_params)
            )
        Path(cfg.save_path).mkdir(exist_ok=True, parents=True)
        model.to(cpu).save_ckpt(cfg.model_name+'.ckpt', cfg.save_path)
    else:
        model.load_ckpt(cfg.model_name+'.ckpt', cfg.save_path)
    model = model.to(device)

    # Test
    preds, labels, th = src.utils.make_prediction(
    model, test_dataloader, device,
    cfg.label_transforms.one_hot)
    preds_class = np.argmax(preds, axis=1)
    accuracy = np.mean(labels == preds_class)
    print('Accuracy: ', accuracy)

In [4]:
import itertools
for i in range(1,5):
    poi_candidates = [v for v in itertools.combinations(range(1,5), i)]
    for p in poi_candidates:
        pois = ""
        name = ""
        for v in p:
            pois += f"[{v},{v+1},1],"
            name += str(v+1)
        cfg = get_cfg(f"[{pois[:-1]}]", f"{len(p)}", name)
        print(f"POIs: [{pois[:-1]}]")
        train(cfg)

POIs: [[1,2,1]]


train loss: 8.32e-01 | test loss: 8.33e-01 | reg: 1.45e+00 : 100%|█| 7000/7000 [06:30<00:00, 17.91it


save this model to /workspace/notebook/results/exp3//KAN1h_2.ckpt
Accuracy:  0.5031
POIs: [[2,3,1]]


train loss: 8.33e-01 | test loss: 8.32e-01 | reg: 1.45e+00 : 100%|█| 7000/7000 [06:26<00:00, 18.12it


save this model to /workspace/notebook/results/exp3//KAN1h_3.ckpt
Accuracy:  0.4965
POIs: [[3,4,1]]


train loss: 8.33e-01 | test loss: 8.32e-01 | reg: 1.35e+00 : 100%|█| 7000/7000 [06:19<00:00, 18.43it


save this model to /workspace/notebook/results/exp3//KAN1h_4.ckpt
Accuracy:  0.507
POIs: [[4,5,1]]


train loss: 8.33e-01 | test loss: 8.33e-01 | reg: 1.43e+00 : 100%|█| 7000/7000 [06:11<00:00, 18.85it


save this model to /workspace/notebook/results/exp3//KAN1h_5.ckpt
Accuracy:  0.498
POIs: [[1,2,1],[2,3,1]]


train loss: 8.31e-01 | test loss: 8.30e-01 | reg: 1.96e+00 : 100%|█| 7000/7000 [06:49<00:00, 17.10it


save this model to /workspace/notebook/results/exp3//KAN1h_23.ckpt
Accuracy:  0.5424
POIs: [[1,2,1],[3,4,1]]


train loss: 8.32e-01 | test loss: 8.33e-01 | reg: 1.67e+00 : 100%|█| 7000/7000 [06:53<00:00, 16.93it


save this model to /workspace/notebook/results/exp3//KAN1h_24.ckpt
Accuracy:  0.5069
POIs: [[1,2,1],[4,5,1]]


train loss: 8.33e-01 | test loss: 8.33e-01 | reg: 1.68e+00 : 100%|█| 7000/7000 [06:46<00:00, 17.23it


save this model to /workspace/notebook/results/exp3//KAN1h_25.ckpt
Accuracy:  0.4978
POIs: [[2,3,1],[3,4,1]]


train loss: 8.31e-01 | test loss: 8.33e-01 | reg: 1.73e+00 : 100%|█| 7000/7000 [06:39<00:00, 17.53it


save this model to /workspace/notebook/results/exp3//KAN1h_34.ckpt
Accuracy:  0.4987
POIs: [[2,3,1],[4,5,1]]


train loss: 8.32e-01 | test loss: 8.33e-01 | reg: 1.72e+00 : 100%|█| 7000/7000 [06:39<00:00, 17.51it


save this model to /workspace/notebook/results/exp3//KAN1h_35.ckpt
Accuracy:  0.4942
POIs: [[3,4,1],[4,5,1]]


train loss: 8.31e-01 | test loss: 8.30e-01 | reg: 1.87e+00 : 100%|█| 7000/7000 [06:43<00:00, 17.34it


save this model to /workspace/notebook/results/exp3//KAN1h_45.ckpt
Accuracy:  0.5373
POIs: [[1,2,1],[2,3,1],[3,4,1]]


train loss: 8.29e-01 | test loss: 8.30e-01 | reg: 1.93e+00 : 100%|█| 7000/7000 [07:26<00:00, 15.66it


save this model to /workspace/notebook/results/exp3//KAN1h_234.ckpt
Accuracy:  0.5383
POIs: [[1,2,1],[2,3,1],[4,5,1]]


train loss: 8.27e-01 | test loss: 8.31e-01 | reg: 2.02e+00 : 100%|█| 7000/7000 [07:23<00:00, 15.78it


save this model to /workspace/notebook/results/exp3//KAN1h_235.ckpt
Accuracy:  0.5407
POIs: [[1,2,1],[3,4,1],[4,5,1]]


train loss: 8.31e-01 | test loss: 8.32e-01 | reg: 1.90e+00 : 100%|█| 7000/7000 [07:23<00:00, 15.78it


save this model to /workspace/notebook/results/exp3//KAN1h_245.ckpt
Accuracy:  0.5394
POIs: [[2,3,1],[3,4,1],[4,5,1]]


train loss: 8.31e-01 | test loss: 8.29e-01 | reg: 1.81e+00 : 100%|█| 7000/7000 [07:23<00:00, 15.78it


save this model to /workspace/notebook/results/exp3//KAN1h_345.ckpt
Accuracy:  0.5425
POIs: [[1,2,1],[2,3,1],[3,4,1],[4,5,1]]


train loss: 8.24e-01 | test loss: 8.29e-01 | reg: 2.20e+00 : 100%|█| 7000/7000 [08:31<00:00, 13.70it


save this model to /workspace/notebook/results/exp3//KAN1h_2345.ckpt
Accuracy:  0.5586


Simulation results

| PoIs    | Accuracy |
| ---     | ---      |
| 2       | 0.5031   |
| 3       | 0.4965   |
| 4       | 0.507   |
| 5       | 0.498   | 
| 2,3     | 0.5424   |
| 2,4     | 0.5069   |
| 2,5     | 0.4978   |
| 3,4     | 0.4987   |
| 3,5     | 0.4942   |
| 4,5     | 0.5373   |
| 2,3,4   | 0.5383   |
| 2,3,5   | 0.5407   |
| 2,4,5   | 0.5394   |
| 3,4,5   | 0.5425   |
| 2,3,4,5 | 0.5586   |