Plot trained KAN and set linear function on input activations (and fine-tune)

In [None]:
# Choose dataset to plot/symbolic fitting
# ASCADf | ASCADv
ds = 'ASCADf'

----- Program codes ----

In [1]:
import torch
from pathlib import Path
import os
import hydra
import numpy as np
import matplotlib.pyplot as plt
import pickle

os.chdir('/workspace')
import src

In [2]:
if ds == 'ASCADf':
    pt = "[[156, 157, 1],[517, 518, 1]]"
elif ds == 'ASCADv':
    pt = "[[188, 189, 1],[1071, 1072, 1]]"

In [3]:
with hydra.initialize(config_path="../../conf", version_base='1.1'):
    cfg = hydra.compose(
        config_name='config',
        overrides=[
            "model=KAN1h",
            "model.model.width.1=1",
            "model.model.grid=3",
            "model.model.k=3",
            "model.train_params.steps=3000",
            f"dataset@train={ds}_profiling",
            f"dataset@test={ds}_attack",
            "trace_transforms=set_poi",
            f"trace_transforms.transforms.0.pois={pt}",
            "trace_transforms.output_size=2",
            "label_transforms=bit",
            "label_transforms.transforms.3.pos=0",
            f"save_path=/workspace/results/ascad/KAN_{ds}_snr/0",
            "n_attack_traces=2000"
            ]
        )

In [4]:
device = hydra.utils.instantiate(cfg.device)
cpu = torch.device('cpu')

Setup dataset

In [5]:
profiling_dataset = hydra.utils.instantiate(cfg.train.dataset)
test_dataset = hydra.utils.instantiate(cfg.test.dataset)

train_dataloader = torch.utils.data.DataLoader(
    profiling_dataset, batch_size=cfg.train.batch_size, shuffle=True
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=cfg.test.batch_size, shuffle=False
)
KANds = src.utils.to_KAN_dataset(
    train_dataloader, test_dataloader,
    device=device)

Train KAN

In [6]:
model = hydra.utils.instantiate(cfg.model.model)
model.load_ckpt(cfg.model_name+'.ckpt', cfg.save_path)
model = model.to(device)

RuntimeError: Error(s) in loading state_dict for KAN:
	size mismatch for biases.0.weight: copying a param with shape torch.Size([1, 2]) from checkpoint, the shape in current model is torch.Size([1, 1]).
	size mismatch for act_fun.0.grid: copying a param with shape torch.Size([4, 4]) from checkpoint, the shape in current model is torch.Size([2, 4]).
	size mismatch for act_fun.0.coef: copying a param with shape torch.Size([4, 6]) from checkpoint, the shape in current model is torch.Size([2, 6]).
	size mismatch for act_fun.0.scale_base: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([2]).
	size mismatch for act_fun.0.scale_sp: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([2]).
	size mismatch for act_fun.0.mask: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([2]).
	size mismatch for act_fun.1.grid: copying a param with shape torch.Size([4, 4]) from checkpoint, the shape in current model is torch.Size([2, 4]).
	size mismatch for act_fun.1.coef: copying a param with shape torch.Size([4, 6]) from checkpoint, the shape in current model is torch.Size([2, 6]).
	size mismatch for act_fun.1.scale_base: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([2]).
	size mismatch for act_fun.1.scale_sp: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([2]).
	size mismatch for act_fun.1.mask: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([2]).
	size mismatch for symbolic_fun.0.mask: copying a param with shape torch.Size([2, 2]) from checkpoint, the shape in current model is torch.Size([1, 2]).
	size mismatch for symbolic_fun.0.affine: copying a param with shape torch.Size([2, 2, 4]) from checkpoint, the shape in current model is torch.Size([1, 2, 4]).
	size mismatch for symbolic_fun.1.mask: copying a param with shape torch.Size([2, 2]) from checkpoint, the shape in current model is torch.Size([2, 1]).
	size mismatch for symbolic_fun.1.affine: copying a param with shape torch.Size([2, 2, 4]) from checkpoint, the shape in current model is torch.Size([2, 1, 4]).

Evaluation

In [None]:
preds, labels, th = src.utils.make_prediction(
    model, test_dataloader, device)
preds_class = np.argmax(preds, axis=1)
accuracy = np.mean(labels == preds_class)
print('Accuracy: ', accuracy)

Plot trained KAN

In [None]:
model(KANds['train_input'][:10000])
model.plot(folder=cfg.save_path+'/raw')
plt.savefig(
    Path(cfg.save_path, f'{cfg.model.name}_raw.png'),
    dpi=300, bbox_inches=0)
src.utils.plot_KAN(cfg, model, 'raw')

Set symbolic functions

In [None]:
model(KANds['test_input'][:10000])
model.fix_symbolic(0,0,0, 'x')
model.fix_symbolic(0,1,0, 'x')

Fine-tuning

In [None]:
_ = model.train(
    KANds,
    **hydra.utils.instantiate(cfg.model.train_params)
    )

Evaluate

In [None]:
preds, labels, th = src.utils.make_prediction(
    model, test_dataloader, device)

preds_class = np.argmax(preds, axis=1)
accuracy = np.mean(labels == preds_class)
print('Accuracy: ', accuracy)

In [None]:
model(KANds['train_input'][:10000])
model.plot(folder=cfg.save_path+'/fixed')
plt.savefig(
    Path(cfg.save_path, f'{cfg.model.name}_fixed.png'),
    dpi=300, bbox_inches=0)
src.utils.plot_KAN(cfg, model, 'fixed')

In [None]:
correct_key = test_dataset.key[0][cfg.target_byte]
key_hyposesis = range(256)
if not Path(cfg.save_path, 'label_hyposesis.npy').exists():
    label_hyposesis = src.utils.make_label_hyposesis(
        test_dataset, key_hyposesis)
    np.save(Path(cfg.save_path, 'label_hyposesis.npy'), label_hyposesis)
label_hyposesis = np.load(Path(cfg.save_path, 'label_hyposesis.npy'))

ge = src.sca_utils.calc_guessing_entropy(
    preds, label_hyposesis, correct_key,
    cfg.n_attack_traces, n_trial=cfg.n_trials)
print(np.where(ge==0)[0])

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(ge)
fig.savefig(Path(cfg.save_path, f'{cfg.model.name}_fixed.png'), dpi=300, bbox_inches=0)