In [None]:
# # variables which will be set to default values if not set
# m4f           [False] m4f or opt implementation switch
# adv           [False] use advanced trigger_high
# fisher        [False] target fisher_yates sampling or set_bits
# loop          [0] target loop {0,1,2}
# lp_fct        [0] loop factor: determines highes external offset
# sample_size   [2] how many samples per glitch setting
# step          [lp_cy] step through glitch settings, defaults to lp_cy
# repeat        [1] how many glitches to inject consecutively
# trig_src      ["ext_continuous"] or "ext_single
# spot          choose one width/offset spot, default iterate through all 10 spots

# # optional variables
# base_off -> e_off_low and e_off_up will be calculated
# w and o:                  choose given width/offset
# e_off_low and e_off_up:   can be set if base_off is not set

In [None]:
lvl = "l1"
if "m4f" not in dir():
    global m4f
    m4f = False # implementation switch, either use m4f or mupq/opt implementation
if "adv" not in dir():
    global adv
    adv = False # use firmware with advanced trigger high routine
if "fisher" not in dir():
    global fisher
    fisher = False
if "loop" not in dir():
    global loop
    loop = 0 # values {0,1,2}
if "thr" not in dir():
    global thr
    thr = 5
if "max_dev" not in dir():
    global max_dev
    max_dev = 0
if "ff" not in dir():
    global ff
    ff = f"{'' if m4f else 'mupq_'}crypto_kem_bike{lvl}_{'fisher_yates' if fisher else 'set_bits'}_{'adv_' if adv else ''}{loop}_{'m4f' if m4f else 'opt'}_fi"
print(ff)

global lp_cy
if fisher:
    lp_cy = 130 if loop==0 else 53 if loop==1 else 28 if loop==2 else 0
else:
 lp_cy = 63 if loop==0 else 61 if loop==1 else 38 if loop==2 else 0
if adv: lp_cy = lp_cy + 6


import sys
new_path = '../scripts/'
if new_path not in sys.path:
    sys.path.append(new_path)
    
import kat_bike as kat
import bike_key as bk
import numpy as np

bikelvl = kat.get_lvl(lvl)

In [None]:
keys = list()
import shelve
with shelve.open("../traces/key_dump") as data:
    keys = data[ff]
del(ff)

In [None]:
def intlists(mupq_key, lvl):
    byteorder = "little"
    key_tuple = bk._parse_mupq_key(mupq_key, lvl)
    int_wlists = bk._wlists_to_ilists(key_tuple[0])
    sk = key_tuple[1]

    int_h0 = int.from_bytes(sk[:lvl.r_bytes], byteorder)
    int_h1 = int.from_bytes(sk[lvl.r_bytes:], byteorder)
    
    int_h0_list = list()
    int_h1_list = list()
    for i,h0 in enumerate(bin(int_h0)[::-1]):
        if h0=='1': int_h0_list.append(i)
    for i,h1 in enumerate(bin(int_h1)[::-1]):
        if h1=='1': int_h1_list.append(i)

    return (int_wlists, (int_h0_list, int_h1_list))

def show_weights(fault_dict, key, amount=0, single=False, max_dev=0):
    if len(fault_dict[key]) < amount: return
    idx_h0 = list()
    idx_h1 = list()
    idx_w0 = list()
    idx_w1 = list()
    for k in fault_dict[key]:
        idx_lists = intlists(k[1].mupq_key, kat.get_lvl(lvl))
        idx_w0.append(idx_lists[0][0])
        idx_w1.append(idx_lists[0][1])
        idx_h0.append(idx_lists[1][0])
        idx_h1.append(idx_lists[1][1])
    w_h0 = [len(l) for l in idx_h0]
    w_h1 = [len(l) for l in idx_h1]
    w_w0 = [len(set(l)) for l in idx_w0]
    w_w1 = [len(set(l)) for l in idx_w1]
    h0_dev = np.std(w_h0)
    h0_mean = np.mean(w_h0)
    h1_dev = np.std(w_h1)
    h1_mean = np.mean(w_h1)

    weight_ok = h0_mean < bikelvl.d +1 and h0_mean > bikelvl.d -1
    weight_ok &= h1_mean < bikelvl.d +1 and h1_mean > bikelvl.d -1
    weight_ok &= h0_dev < 1 and h1_dev < 1
    if weight_ok: return

    if (max_dev<h0_dev or max_dev<h1_dev) and not single and not max_dev==0:
        tmp = dict()
        for k in fault_dict[key]:
            eo = k[0][2]
            if eo in tmp.keys():
                tmp[eo].append(k)
            else:
                tmp[eo] = [k]
        for k in tmp.keys():
            show_weights(tmp, k, amount=0, single=True, max_dev=max_dev)
    else:
        print(f"for base external offset {key}{ f' = {key%lp_cy} base offset' if single else '' } there are {len(idx_h0)} captured keys")
        print(f"Hamming weight of h0:\n mean {h0_mean}\n standard deviation {h0_dev}")
        print(f" max {np.max(w_h0)}\n min {np.min(w_h0)}")
        print(f"Hamming weight of h1:\n mean {h1_mean}\n standard deviation {h1_dev}")
        print(f" max {np.max(w_h1)}\n min {np.min(w_h1)}")
        if fisher:
            print(f"Hamming weight of w0:\n mean {np.mean(w_w0)}\n standard deviation {np.std(w_w0)}")
            print(f" max {np.max(w_w0)}\n min {np.min(w_w0)}")
            print(f"Hamming weight of w1:\n mean {np.mean(w_w1)}\n standard deviation {np.std(w_w1)}")
            print(f" max {np.max(w_w1)}\n min {np.min(w_w1)}")
        print()

In [None]:
e_off2keys = dict()
for k in keys:
    bo = k[0][2]%lp_cy
    if bo in e_off2keys.keys():
        e_off2keys[bo].append(k)
    else:
        e_off2keys[bo] = [k]

if not "base_off" in dir():
    tmp = e_off2keys.copy()

    for k in tmp:
        if len(e_off2keys.get(k)) < thr:
            e_off2keys.pop(k)
    del(tmp)
else:
    e_off2keys = {base_off: e_off2keys.get(base_off)}

for k in e_off2keys:
    show_weights(e_off2keys, k, max_dev=max_dev)

In [None]:

for lis in e_off2keys.values():
    eoffs = dict()
    for tup in lis:
        eo = tup[0][2]
        if eo in eoffs.keys():
            eoffs[eo] += 1
        else:
            eoffs[eo] = 1
    print(eoffs)

In [None]:
# for s,k in keys:
#     l0,l1 = k.wlists_as_int
#     l0.sort()
#     l1.sort()
#     print(l0)
#     print(l1)

#     for i in range(len(l0)-1): 
#         if l0[i] == l0[i+1]:
#             print(l0[i])
#         if l1[i] == l1[i+1]:
#             print(l1[i])

In [None]:
k_diff = list()

def show_double(keys, threshold=5):
    for i,k in enumerate(keys):
        cnt = 0
        for j in k[1].wlists_as_int[0]:
            if j < bikelvl.d: cnt += 1
        for j in k[1].wlists_as_int[1]:
            if j < bikelvl.d: cnt += 1
        interest = len(set(k[1].wlists_as_int[0])) != bikelvl.d
        interest |= len(set(k[1].wlists_as_int[1])) != bikelvl.d
        interest |= int.from_bytes(k[1].sk, 'little').bit_count() != bikelvl.d *2
        if cnt < threshold and not interest: continue# or cnt > 69: continue

        print(f"\nkey {i}, eo {k[0][2]%lp_cy}, {k[0]}")
        clusters(k[1])
        print(f"{cnt}\t bits set lower than D={bikelvl.d}")
        diff = bk.emph_difference(k[1], bikelvl, loud=True)
        print("wl0:")
        for i in k[1].wlists_as_int[0]:
            if i < bikelvl.d: print(i)
        print("wl1:")
        for i in k[1].wlists_as_int[1]:
            if i < bikelvl.d: print(i)

def clusters(key):
    h0_ints,h1_ints = bk.__get_sk_ilist(key)
    w0_ints,w1_ints = key.wlists_as_int
    max = [[0,0]]
    for l in [h0_ints,h1_ints,w0_ints,w1_ints]:
        tmp = bk.find_cluster(l, key.level.d, 5)
        max = tmp if tmp[0][1] > max[0][1] else max
    cluster = max[0][1] > 8
    if cluster:
        print(f"Clusters within weightlists and binary vector: {max}")

for sel in e_off2keys.values():
    # for _,key in sel:
    show_double(sel)

In [None]:
# # export into file
# with open("/home/till/double_indicies.txt", 'a') as file:
#     for keys in key_dump[-1:]:
#         for i,k in enumerate(keys):
#             cnt = 0
#             for j in k.wlists_as_int[0]:
#                 if j < t_com.lvl.d: cnt += 1
#             for j in k.wlists_as_int[1]:
#                 if j < t_com.lvl.d: cnt += 1
#             if cnt < 5: continue# or cnt > 69: continue
#             file.write(f"{k.mupq_key.hex()}\n").