In [1]:
import torch
from pathlib import Path
import os
import hydra
import numpy as np
import matplotlib.pyplot as plt
import pickle

os.chdir('/workspace')
import src

In [2]:
def get_cfg(pos):
    with hydra.initialize(config_path="../conf", version_base='1.1'):
        cfg = hydra.compose(
            config_name='config',
            overrides=[
                "model=KAN2h",
                "model.model.width.0=4",
                "model.model.width.1=5",
                "model.model.width.2=1",
                "model.train_params.steps=7000",
                "dataset@train=ASCADf_profiling",
                "dataset@test=ASCADf_attack",
                "trace_transforms=set_poi",
                "trace_transforms.transforms.0.pois=[[1,5,1]]",
                "label_transforms=bit",
                f"label_transforms.transforms.3.pos={pos}",
                f"save_path=/workspace/notebook/results/exp3/{pos}",
                "n_attack_traces=2000"
                ]
            )
    return cfg

In [3]:
def train(cfg):
    device = hydra.utils.instantiate(cfg.device)
    cpu = torch.device('cpu')

    # Dataset
    ds_prof = hydra.utils.instantiate(cfg.train.dataset)
    ds_test = hydra.utils.instantiate(cfg.test.dataset)

    target_byte = 2
    profiling_dataset = src.datasets.ASCAD_sim.Dataset(
        ds_prof.plaintext,
        ds_prof.key,
        ds_prof.masks,
        target_byte=target_byte,
        trace_transforms=hydra.utils.instantiate(
            cfg.train.dataset.trace_transforms),
        label_transforms=hydra.utils.instantiate(
            cfg.train.dataset.label_transforms)
        )

    test_dataset = src.datasets.ASCAD_sim.Dataset(
        ds_test.plaintext,
        ds_test.key,
        ds_test.masks,
        target_byte=target_byte,
            trace_transforms=hydra.utils.instantiate(
            cfg.train.dataset.trace_transforms),
        label_transforms=hydra.utils.instantiate(
            cfg.train.dataset.label_transforms)
        )

    train_dataloader = torch.utils.data.DataLoader(
        profiling_dataset, batch_size=cfg.train.batch_size, shuffle=True
    )
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset, batch_size=cfg.test.batch_size, shuffle=False
    )
    KANds = src.utils.to_KAN_dataset(
        train_dataloader, test_dataloader,
        device=device)

    # Train
    model = hydra.utils.instantiate(cfg.model.model)
    if not Path(cfg.save_path, cfg.model_name+'.ckpt').exists():
        model = model.to(device)
        _ = model.train(
            KANds,
            **hydra.utils.instantiate(cfg.model.train_params)
            )
        Path(cfg.save_path).mkdir(exist_ok=True, parents=True)
        model.to(cpu).save_ckpt(cfg.model_name+'.ckpt', cfg.save_path)
    else:
        model.load_ckpt(cfg.model_name+'.ckpt', cfg.save_path)
    model = model.to(device)

    # Test
    preds, labels, th = src.utils.make_prediction(
    model, test_dataloader, device,
    cfg.label_transforms.one_hot)
    preds_class = np.argmax(preds, axis=1)
    accuracy = np.mean(labels == preds_class)
    print('Accuracy: ', accuracy)

In [4]:
for i in range(8):
    cfg = get_cfg(i)
    train(cfg)

train loss: 8.22e-01 | test loss: 8.28e-01 | reg: 2.17e+00 : 100%|█| 7000/7000 [08:13<00:00, 14.18it


save this model to /workspace/notebook/results/exp3/0/KAN-2Hidden.ckpt
Accuracy:  0.561


train loss: 8.32e-01 | test loss: 8.17e-01 | reg: 2.26e+00 : 100%|█| 7000/7000 [08:13<00:00, 14.19it


save this model to /workspace/notebook/results/exp3/1/KAN-2Hidden.ckpt
Accuracy:  0.5548


train loss: 8.16e-01 | test loss: 8.27e-01 | reg: 2.25e+00 : 100%|█| 7000/7000 [08:18<00:00, 14.06it


save this model to /workspace/notebook/results/exp3/2/KAN-2Hidden.ckpt
Accuracy:  0.5631


train loss: 8.31e-01 | test loss: 8.31e-01 | reg: 2.21e+00 : 100%|█| 7000/7000 [08:07<00:00, 14.37it


save this model to /workspace/notebook/results/exp3/3/KAN-2Hidden.ckpt
Accuracy:  0.5711


train loss: 8.29e-01 | test loss: 8.20e-01 | reg: 2.23e+00 : 100%|█| 7000/7000 [08:09<00:00, 14.31it


save this model to /workspace/notebook/results/exp3/4/KAN-2Hidden.ckpt
Accuracy:  0.5633


train loss: 8.17e-01 | test loss: 8.31e-01 | reg: 2.22e+00 : 100%|█| 7000/7000 [08:08<00:00, 14.32it


save this model to /workspace/notebook/results/exp3/5/KAN-2Hidden.ckpt
Accuracy:  0.5504


train loss: 8.21e-01 | test loss: 8.21e-01 | reg: 2.28e+00 : 100%|█| 7000/7000 [08:14<00:00, 14.15it


save this model to /workspace/notebook/results/exp3/6/KAN-2Hidden.ckpt
Accuracy:  0.554


train loss: 8.16e-01 | test loss: 8.22e-01 | reg: 2.25e+00 : 100%|█| 7000/7000 [08:19<00:00, 14.02it


save this model to /workspace/notebook/results/exp3/7/KAN-2Hidden.ckpt
Accuracy:  0.5557


Accuracy:  0.561  
Accuracy:  0.5548  
Accuracy:  0.5631  
Accuracy:  0.5711  
Accuracy:  0.5633  
Accuracy:  0.5504  
Accuracy:  0.554  
Accuracy:  0.5557  