In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pywt

In [2]:
import pickle as pic
from itertools import product

In [3]:
from utils import *
from templates import *
from signal_strength import SIGNAL_STRENGTHS_METHODS
from data_loader import EntireTraceIterator

In [4]:
with open("wavelets_haar_2.pic", "rb") as r:
    wavelets = pic.load(r)
with open("labels.pic", "rb") as r:
    rws_perms_labels, round_perms_labels, copy_perms_labels, rws_masks_labels, round_masks_labels = pic.load(r)

In [5]:
traces_path = "..\\acquisition\\739094_maskshuffle_allrounds\\carto_eB4-Rnd-3-WhiteningAndFullFilter-Masking-Shuffling.mat"
key_path = "..\\acquisition\\739094_maskshuffle_allrounds\\carto_eB4-Rnd-3-WhiteningAndFullFilter-Masking-Shuffling.log"

In [7]:
NUM_TRACES = 739_094
TRACE_SIZE = 200_000
PER_DIVISION = 250_000
data_loader = EntireTraceIterator(traces_path, key_path, nr_populations=1, nr_scenarios=1, trace_size=TRACE_SIZE, traces_per_division=PER_DIVISION, parse_output="keyshares+perms")

## Wavelets

In [8]:
for j, (_, traces, _, _) in enumerate(data_loader((0,), (0,))):
    traces = traces[0][0]
    for i in range(0, PER_DIVISION, PER_DIVISION // 10):
        wavelet = pywt.wavedec(traces[i:i + PER_DIVISION // 10], 'haar', level=2)[0]
        with open(f"./wavelets_haar_2/wavelet_{(j+2) * PER_DIVISION + i}_{(j+2) * PER_DIVISION + i + PER_DIVISION // 10}.pic", "wb") as w:
            pic.dump(wavelet, w)
    del traces

In [9]:
wavelets = np.zeros((NUM_TRACES, 50000), dtype=np.float16)
for div in range(0, 750000, 25000):
    filepath = f"./wavelets_haar_2/wavelet_{div}_{div + 25000}.pic"
    with open(filepath, "rb") as r:
        wavelet = pic.load(r)
    wavelets[div:div+wavelet.shape[0]] = np.float16(wavelet)

In [11]:
np.where(np.all(wavelets == 0, axis=1))

(array([499999], dtype=int64),)

In [12]:
wavelets = np.concatenate((wavelets[:499999], wavelets[500000:]), axis=0)


In [13]:
with open("wavelets_haar_2.pic", "wb") as w:
    pic.dump(wavelets, w)

## Labels

In [14]:
rws_perms_labels = np.zeros(NUM_TRACES, dtype=int)
rws_masks_labels = np.zeros((KEYROUND_WIDTH_B4, NR_SHARES, NUM_TRACES), dtype=int)
round_perms_labels = np.zeros(NUM_TRACES, dtype=int)
copy_perms_labels = np.zeros((LATEST_ROUND - EARLIEST_ROUND, NUM_TRACES), dtype=int)
round_masks_labels = np.zeros((LATEST_ROUND - EARLIEST_ROUND, BLOCK_WIDTH_B4, NR_SHARES, NUM_TRACES), dtype=int)

for j, (seeds_sub, key, output_sub) in enumerate(data_loader((0,), (0,), return_traces=False)):
    rws_perms, rws_masks, round_perms, copy_perms, round_masks = get_all_labels(seeds_sub[0][0], key[0], output_sub[0][0])
    rws_perms_labels[..., j * PER_DIVISION:j * PER_DIVISION + rws_perms.shape[-1]] = rws_perms
    rws_masks_labels[..., j * PER_DIVISION:j * PER_DIVISION + rws_masks.shape[-1]] = rws_masks
    round_perms_labels[..., j * PER_DIVISION:j * PER_DIVISION + round_perms.shape[-1]] = round_perms
    copy_perms_labels[..., j * PER_DIVISION:j * PER_DIVISION + copy_perms.shape[-1]] = copy_perms
    round_masks_labels[..., j * PER_DIVISION:j * PER_DIVISION + round_masks.shape[-1]] = round_masks

In [35]:
rws_perms_labels = np.concatenate((rws_perms_labels[..., :499999], rws_perms_labels[..., 500000:]), axis=-1)
rws_masks_labels = np.concatenate((rws_masks_labels[..., :499999], rws_masks_labels[..., 500000:]), axis=-1)
round_perms_labels = np.concatenate((round_perms_labels[..., :499999], round_perms_labels[..., 500000:]), axis=-1)
copy_perms_labels = np.concatenate((copy_perms_labels[..., :499999], copy_perms_labels[..., 500000:]), axis=-1)
round_masks_labels = np.concatenate((round_masks_labels[..., :499999], round_masks_labels[..., 500000:]), axis=-1)

In [37]:
with open("labels.pic", "wb") as w:
    pic.dump((rws_perms_labels, round_perms_labels, copy_perms_labels, rws_masks_labels, round_masks_labels), w)

## RWS permutations

In [None]:
signal_strength = SIGNAL_STRENGTHS_METHODS["SOST"]("./leakage_points_haar_2/f_rws_perms_sost.pic")
signal_strength.fit(wavelets, rws_perms_labels, 0)

## RWS masks

In [None]:
for keyround_idx in range(KEYROUND_WIDTH_B4):
    print(f"{keyround_idx}", end="\r")
    for share_idx in range(NR_SHARES):
        signal_strength = SIGNAL_STRENGTHS_METHODS["SOST"](f"./leakage_points_haar_2/f_rws_masks_{keyround_idx}_{share_idx}_sost.pic")
        signal_strength.fit(wavelets, rws_masks_labels[keyround_idx, share_idx], 0)

## Round permutations

In [None]:
signal_strength = SIGNAL_STRENGTHS_METHODS["SOST"]("./leakage_points_haar_2/f_round_perms_sost.pic")
signal_strength.fit(wavelets, round_perms_labels, 0)

## Copy permutations

In [None]:
for round_idx in range(EARLIEST_ROUND, LATEST_ROUND):
    print(f"{round_idx}", end="\r")
    signal_strength = SIGNAL_STRENGTHS_METHODS["SOST"](f"./leakage_points_haar_2/f_copy_perms_{round_idx}_sost.pic")
    signal_strength.fit(wavelets, copy_perms_labels[round_idx], 0)

## Round masks

In [6]:
NUM_TRACES = 739_094
TRACE_SIZE = 200_000
PER_DIVISION = 320_000
data_loader = EntireTraceIterator(traces_path, key_path, nr_populations=1, nr_scenarios=1, trace_size=TRACE_SIZE, traces_per_division=PER_DIVISION, parse_output="keyshares+perms")
for j, (_, traces, _, _) in enumerate(data_loader((0,), (0,))):
    for round_idx in range(EARLIEST_ROUND, LATEST_ROUND):
        for block_idx in range(BLOCK_WIDTH_B4):
            print(f"{round_idx * BLOCK_WIDTH_B4 + block_idx}", end="\r")
            for share_idx in range(NR_SHARES):
                signal_strength = SIGNAL_STRENGTHS_METHODS["SOST"](f"./leakage_points_bigdata/f_round_masks_{round_idx}_{block_idx}_{share_idx}_sost.pic")
                signal_strength.fit(traces[0][0], round_masks_labels[round_idx, block_idx, share_idx, :PER_DIVISION], 0)
    break

97

In [None]:
for round_idx in range(EARLIEST_ROUND, LATEST_ROUND):
    for block_idx in range(BLOCK_WIDTH_B4):
        print(f"{round_idx * BLOCK_WIDTH_B4 + block_idx}", end="\r")
        for share_idx in range(NR_SHARES):
            signal_strength = SIGNAL_STRENGTHS_METHODS["SOST"](f"./leakage_points_haar_2/f_round_masks_{round_idx}_{block_idx}_{share_idx}_sost.pic")
            signal_strength.fit(wavelets, round_masks_labels[round_idx, block_idx, share_idx], 0)

## Vizualisation

In [4]:
import pickle as pic
import plotly.express as px

folder = "leakage_points_haar_2"
all_plot_labels = []
try:
    with open(f"./{folder}/f_rws_perms_sost.pic", "rb") as r:
        sig_rws = pic.load(r) / (KEYROUND_WIDTH_B4 * (KEYROUND_WIDTH_B4 - 1) / 2)
    all_plot_labels.append("RWS perm")
except OSError:
    sig_rws = None
sig_masks_rws = []
for keyround_idx in range(KEYROUND_WIDTH_B4):
    for share_idx in range(NR_SHARES):
        try:
            with open(f"./{folder}/f_rws_masks_{keyround_idx}_{share_idx}_sost.pic", "rb") as r:
                sig_masks_rws.append(pic.load(r) / (len(KEY_ALPHABET) * (len(KEY_ALPHABET) - 1) / 2))
            all_plot_labels.append(f"RWS mask {keyround_idx} {share_idx}")
        except OSError:
            pass

try:
    with open(f"./{folder}/f_round_perms_sost.pic", "rb") as r:
        sig_round_perm = pic.load(r) / ((KEYROUND_WIDTH_B4 // BLOCK_WIDTH_B4) * ((KEYROUND_WIDTH_B4 // BLOCK_WIDTH_B4) - 1) / 2)
    all_plot_labels.append("Round perm")
except OSError:
    sig_round_perm = None
sig_copy_perms = []
sig_masks_perms = []
for round_idx in range(EARLIEST_ROUND, LATEST_ROUND):
    try:
        with open(f"./{folder}/f_copy_perms_{round_idx}_sost.pic", "rb") as r:
            sig_copy_perms.append(pic.load(r) / (BLOCK_WIDTH_B4 * (BLOCK_WIDTH_B4 - 1) / 2))
        all_plot_labels.append(f"Copy perm {round_idx}")
    except OSError:
        pass

for round_idx in range(EARLIEST_ROUND, LATEST_ROUND):
    for block_idx in range(BLOCK_WIDTH_B4):
        for share_idx in range(NR_SHARES):
            try:
                with open(f"./{folder}/f_round_masks_{round_idx}_{block_idx}_{share_idx}_sost.pic", "rb") as r:
                    sig_masks_perms.append(pic.load(r) / (len(KEY_ALPHABET) * (len(KEY_ALPHABET) - 1) / 2))
                all_plot_labels.append(f"Round mask {round_idx} {block_idx} {share_idx}")
            except OSError:
                pass


In [None]:
# Haar 3
trace_size = 25_000
for it in range(0, trace_size, trace_size // 4):
    everything = sig_masks_perms#([sig_rws] if sig_rws is not None else []) + sig_masks_rws + ([sig_round_perm] if sig_round_perm is not None else []) + sig_copy_perms + sig_masks_perms
    everything = [sost[it:it+trace_size // 4] for sost in everything]
    fig = px.line(y=everything)
    plot_labs = {f'wide_variable_{i}': l for i, l in enumerate(all_plot_labels)}
    fig.for_each_trace(lambda t: t.update(name = plot_labs[t.name], legendgroup = plot_labs[t.name], 
                                        hovertemplate = t.hovertemplate.replace(t.name, plot_labs[t.name])))
    fig.update_layout(yaxis_range=[-1,100])
    fig.show()

In [None]:
# Haar 2
#trace_size = 50_000
#for it in range(0, trace_size, trace_size // 4):
#    everything = ([sig_rws] if sig_rws is not None else []) + sig_masks_rws + ([sig_round_perm] if sig_round_perm is not None else []) + sig_copy_perms + sig_masks_perms
#    everything = [sost[it:it+trace_size // 4] for sost in everything]
#    fig = px.line(y=everything)
#    plot_labs = {f'wide_variable_{i}': l for i, l in enumerate(all_plot_labels)}
#    fig.for_each_trace(lambda t: t.update(name = plot_labs[t.name], legendgroup = plot_labs[t.name], 
#                                        hovertemplate = t.hovertemplate.replace(t.name, plot_labs[t.name])))
#    fig.update_layout(yaxis_range=[-1,50])
#    fig.show()

#everything = ([sig_rws] if sig_rws is not None else []) + sig_masks_rws + ([sig_round_perm] if sig_round_perm is not None else []) + sig_copy_perms + sig_masks_perms
everything = sig_masks_perms
fig = px.line(y=everything)
plot_labs = {f'wide_variable_{i}': l for i, l in enumerate(all_plot_labels)}
fig.for_each_trace(lambda t: t.update(name = plot_labs[t.name], legendgroup = plot_labs[t.name], 
                                    hovertemplate = t.hovertemplate.replace(t.name, plot_labs[t.name])))
fig.show()

In [None]:
# Original traces
trace_size = 200_000
for it in range(0, trace_size, trace_size // 4):
    everything = ([sig_rws] if sig_rws is not None else []) + sig_masks_rws + ([sig_round_perm] if sig_round_perm is not None else []) + sig_copy_perms + sig_masks_perms
    everything = [sost[it:it+trace_size // 4] for sost in everything]
    fig = px.line(y=everything)
    plot_labs = {f'wide_variable_{i}': l for i, l in enumerate(all_plot_labels)}
    fig.for_each_trace(lambda t: t.update(name = plot_labs[t.name], legendgroup = plot_labs[t.name], 
                                        hovertemplate = t.hovertemplate.replace(t.name, plot_labs[t.name])))
    fig.show()