# Setup
The first block is mandatory, the second and the third can be run on demand.

In [None]:
BUILDPLATFORM = 'stm32f4discovery'
BUILDPLATFORM = 'cw308t-stm32f415'
#BUILDPLATFORM = 'cw308t-stm32f3'
lvl = "l1"
m4f = True # implementation switch, either use m4f or mupq implementation
# for m4f only levels 1 and 3 are implemented

%run "Setup_Disco_or_CWLITE.ipynb"

In [None]:
# this script cleans before building, so if one already built the firmware or is not sure about it, run it via terminal without `make clean`
%run "Build_Firmware.ipynb"

The target `mps2-an386` is not implemented in such a comfortable way as the other targets.

First the target has to be started with `qemu-system-arm -M mps2-an386 -nographic -serial pty -semihosting -kernel <path_to_firmware_file>`, where the firmware path could be, e.g. `elf/mupq_crypto_kem_bikel11_opt_fi.elf`, if called from the pqm4 folder. This creates a pseudo terminal, bound to a file, which can be used as serial interface. This is usually located in `/dev/pts/`, the exact name is prompted and the variable `i` in the next block has to be set accordingly.
If `i` is set the next block can be used to setup the Communication_Target object to comfortably interact with the target.

In [None]:
import serial
import sys

new_path = '../scripts/'

if new_path not in sys.path:
    sys.path.append(new_path)
import target_com as com

lvl = 'l11'
i = ?

target = serial.Serial(f"/dev/pts/{i}", 38400)
t_com = com.Communication_Target(target, lvl)

def reboot_flush():
    # first use `system_reset` in qemu console
    target.read_all()

---

If firmware needs to be flashed run next block, otherwise skip it.

In [None]:
import kat_bike as kat
ff = f"{'' if m4f else 'mupq_'}crypto_kem_bike{lvl}_{'m4f' if m4f else 'opt'}_fi"
%run "Flash_Disco_CWLITE.ipynb"

reboot_flush()

---
# Communication with the target

Next blocks can be used to trigger specific operations or read data from target's memory.

In [None]:
reboot_flush()

In [None]:
t_com.reset_prng()

In [None]:
t_com.keygen_async()
t_com.check_done()

In [None]:
for _ in range(20):
	t_com.encaps()
	t_com.decaps()
	print(t_com.c_ss().hex())

In [None]:
t_com.encaps()

In [None]:
t_com.decaps()

In [None]:
print(t_com.r_ss().hex())
print(t_com.r_ss_dec().hex())
print(t_com.r_pk().hex())
print(t_com.r_sk().hex())
print(t_com.r_sk_mupq().hex())

---
## Generate correct keys

Use key data from a KAT to verify communication and computation of weight index list which is required by mupq implementation.

In [None]:
import kat_bike as kat

ref_kat = kat.read_rsp(lvl)

In [None]:
import bike_key as conv

# chose which KAT to use
cnt = 5
pk = ref_kat[cnt].pk
sk = ref_kat[cnt].sk
ct = ref_kat[cnt].ct
ss = ref_kat[cnt].ss

key = conv.BIKE_key(sk, lvl)
key.pk = pk

# write keys and ciphertext to board
t_com.w_pk(pk)
t_com.w_ct(ct)
t_com.w_sk(key.mupq_key)

# test decapsulation
t_com.decaps()
if t_com.r_ss_dec() == ss:
	print("decapsulation on board returned same shared secret as reference implementation\nseems like everything is setup correctly")

---
# Capture Power Traces

in the following we are going to capture power traces for the different subroutines of the kem. This works only for `CWLITE` and not for `discovery board`, because `discovery board` has hardware rng and no prng in pqm4 implemented, while `CWLITE` has only software prng and no hardware rng.

In [None]:
import numpy as np
import time
import math

# if the triggers are set/unset within a loop it might make sense to use the async routines
routines = [t_com.keygen, t_com.encaps, t_com.decaps_async]
traces = dict()
avg_traces = dict()
Segments = list()
cycles = dict()

# set trace length to maximum
scope.adc.samples = 24400

###########
# following values can be adjusted
###########
scope.adc.decimate = 50 # adjust trace resolution
Traces = 20 # amount of separate traces per routine
seed = 0 # prng will be regenerated seed times


# determine capture values
for routine in routines:
    traces[routine] = list()
    cycles[routine] = list()
    scope.arm()
    routine()
    scope.capture()
    if "async" in routine.__name__:
        t_com.check_done()
    
    # calculate segments to capture for routine
    # every segment is of length scope.adc.samples
    s = math.ceil(scope.adc.trig_count*1.1 /scope.adc.samples/scope.adc.decimate)
    Segments.append(s)

    print("for routine {}".format(routine.__name__))
    print("which takes about {} clock cycles".format(scope.adc.trig_count))
    print("we will take {} segments for {} traces taking every {}th datum".format(s, Traces, scope.adc.decimate))

print()

# not interested in ...
# to avoid capturing traces for a specific routine 
# one can set the maximum segments to capture for this routine to 0
#Segments[0] = 0
#Segments[1] = 0
#Segments[2] = 0

for s in range(np.max(Segments)):
    print("start capturing segment {}".format(s))
    # set cycles to wait after trigger event before capturing trace
    scope.adc.offset = scope.adc.samples * s
    #reset pseudo random number generator
    t_com.reset_prng()
    for i in range(seed):
        t_com.regen_prng()
    
    for t in range(Traces):
        # to get different ct and keys for encaps and decaps 
        # we take samples from consecutive runs of every routine
        for routine, s_end in zip(routines, Segments):
            # we do not need to run decaps when we already captured enough traces
            # because no other routine depends on its output
            if routine.__name__ == "decaps" and s > s_end: continue

            scope.arm()
            if not routine() and not "async" in routine.__name__: 
                print("out of sync")
                exit()
            ret = scope.capture()
            if "async" in routine.__name__:
                t_com.check_done()
            
            if s == 0: 
                # apend new trace
                traces[routine].append(scope.get_last_trace())
            elif s < s_end:
                # extend existing trace but discard idle samples
                traces[routine][t] = np.concatenate((traces[routine][t], scope.get_last_trace()), axis=0)
            
            # save the number of cycles the trigger was high
            cycles[routine].append(scope.adc.trig_count)

To show the cycles the trigger was high during capturing one can run the following cell.

In [None]:
for r in routines:
	print("for {} the trigger was high for ... cycles".format(r.__name__))
	print(" {} maximum".format(np.max(cycles[r])))
	print(" {} average".format(np.average(cycles[r])))
	print(" {} minimum".format(np.min(cycles[r])))
	

## Save and Load Traces
The next block will store the currently captured power traces

In [None]:
import shelve
data = shelve.open("../traces/power", writeback = True)

try:
	for r in routines:
		key = "raw_reenc.sparse_{}_{}".format(r.__name__, scope.adc.decimate)
		data[key] = traces[r]
finally:
	data.close()

and here we can load them again.

The first block will show the keys available in the `shelve`, the second block's intention is to load one of the raw traces (several traces per key) and calculate the average and the standard deviation of those traces and make them ready for plotting in the `trace` list.

In [None]:
# show saved keys
import shelve

data = shelve.open("../traces/power")
for k in data.keys():
	print(k)
data.close()

In [None]:
import shelve
import numpy as np

alg = "keygen"

trace = list()
data = shelve.open("../traces/power")
try:
	# choose a key from the previous cell's output
	key = "raw_{}_50".format(alg)
	trace.append(data[key])
finally:
	data.close()

trace.append(np.average(trace[0], axis=0))
trace.append(np.std(trace[0], axis=0))
trace.pop(0)

## Visualization
This will visualize the traces

In [None]:
# if one wants to process the last captured traces run this block
# if the traces just were loaded the variables are already set correct so skip this block and use the next one.

# set the index of the routine you are interested in
r = routines[0]

trace = list()
trace.append(np.average(traces[r], axis=0)[:14000])
trace.append(np.std(traces[r], axis=0)[:14000])


In [None]:
import matplotlib.pylab as plt

plt.figure()
for t in trace:
	plt.plot(t)
plt.xlabel("Clock Cycles")
plt.ylabel("Voltage")
plt.show()

# Faulty Keys

Before we try to intercept at key generation with fault injection to generate faulty keys, we first calculate faulty keys on the host and see if and how the target board processes them.

For every secret key there is only one public key, so if we alter a secret key we have to calculate the corresponding public key and can't just use the same public key.
$$ sk = (h_0, h_1, \sigma ) $$
$$ pk = h_1 \cdot h_0^{-1} $$

## Verify Host Public Key Generation
The next code block takes a key pair from a KAT file, calculates the public key from the given secret key and compares both public keys.

In [None]:
if "new_path" not in dir():
    global new_path
    new_path = '../scripts/'

if new_path not in sys.path:
    sys.path.append(new_path)

import bike_key as bk
import kat_bike as kat

# chose key pair of KAT file and BIKE level
k = 60
lvl = "l11"

rsp = kat.read_rsp(lvl)
key = bk.BIKE_key(rsp[k].sk, lvl)

print(rsp[k].pk.hex())
print(key.pk.hex())
if key.pk == rsp[k].pk:
    print("The calculated key matches the one from the file!\nSUCCESS")
else:
    print("The calculated key does NOT match the one from the file!\nERROR")

## Use host generated (Faulty) Keys
To determine which kind of faults take which effect we can generate faulty keys of various kinds on the host and pass them to the target. The aim is to achieve a DFR of about 50%. With the following blocks we can determine the DFR for different weights and kinds of faulty keys.

In [None]:
# init target
BUILDPLATFORM = 'cw308t-stm32f3'
BUILDPLATFORM = 'stm32f4discovery'
lvl = "l11"
%run "Setup_Disco_or_CWLITE.ipynb"

In [None]:
# example faulty key generation and enc/dec on target
if "new_path" not in dir():
    global new_path
    new_path = '../scripts/'

if new_path not in sys.path:
    sys.path.append(new_path)
import bike_key as bk

fm = bk.FaultMode(bk.FK_Kind.TWO, bk.PK_Kind.SK, bk.WL_Kind.MISMATCH, bk.Fault.SK)
key = bk.faulty_key_fm(t_com.lvl.d*0.5, fm, lvl)

t_com.w_sk(key.mupq_key)
t_com.w_pk(key.pk)

t_com.encaps()
t_com.decaps()
print(t_com.c_ss().hex())

In [None]:
# trigger enc/dec and read result
for _ in range(5):
	t_com.encaps()	
	t_com.decaps()
	print(t_com.c_ss().hex())

In [None]:
# check if key analysis works fine
if "new_path" not in dir():
    global new_path
    new_path = '../scripts/'

if new_path not in sys.path:
    sys.path.append(new_path)
import bike_key as bk

for i,fm in enumerate(bk.get_valid_faultmodes()):
	for d in [int(t_com.lvl.d*0.6), int(t_com.lvl.d*1.4)]:
		key = bk.faulty_key_fm(d, fm, lvl)
		f = bk.analyze_key(key.mupq_key, t_com.lvl)
		check = f[0] == fm
		if not check:
			print(f"in round {i},\nfor weight {d} and \n{fm}\n\n{f[0]}\n{f[1]}, {f[2]} and {f[3]} was recognized\n\n")

print(f"if there was no output, every of the {len(bk.get_valid_faultmodes())} different fault modes was successfully recognized.")

With the previous code blocks one can check if the base methods for the following steps work fine.

### Decoding Failure Rate DFR
The following block allows to run through all valid FaultModes (or a subset of them), generate faulty keys according to them in a range of weights and trigger multiple encapsulations and decapsulations. If it was successful or not will be stored. The blocks afterwards can save the results via shelve, calculate the DFR from the results and plot the DFR for every FaultMode separately.

In [None]:
# capture enc/dec results for multiple fault modes and weights
import bike_key as bk
keys = 12
runs = 30

D = t_com.lvl.d
dmin, dmax = int(D * 0.45), int(D * 2)

fms = bk.get_valid_faultmodes() #wl_kind=[bk.WL_Kind.MISMATCH])
# wlk = [bk.WL_Kind.MULTI] #, bk.WL_Kind.UNSET, bk.WL_Kind.MISMATCH]
# pkk = [bk.PK_Kind.SK]
# fms = [bk.FaultMode(bk.FK_Kind.TWOa,pk,wl,bk.Fault.SK) for pk in pkk for wl in wlk]

results = dict()
for i, fm in enumerate(fms):
	print(f"start with fault mode {i+1} of {len(fms)}")
	results[fm] = dict()
	for d in range(dmin, dmax):
		results[fm][d] = list()
		for _ in range(keys):
			key = bk.faulty_key_fm(d, fm.new(), lvl)
			t_com.w_sk(key.mupq_key)
			t_com.w_pk(key.pk)
			for _ in range(runs):
				t_com.encaps()	
				t_com.decaps()
				results[fm][d].append(t_com.c_ss())

In [None]:
# save enc/dec data
import shelve
data = shelve.open("../traces/power", writeback = True)

try:
	for i,r in enumerate(results):
		key = f"faultm{i}"
		data[key] = results[r]
finally:
	data.close()

In [None]:
# load dec/enc data
import shelve
data = shelve.open("../traces/power")

results = dict()
fms = bk.get_valid_faultmodes()
try:
	for i, fm in enumerate(fms):
		results[fm] = data[f"faultm{i}"]
finally:
	data.close()
	

In [None]:
#compute decoding failure rate
dfr = dict()
for i,fm in enumerate(results):
	dfr[i] = list()
	for d_list in results[fm].values():
		dfr[i].append(1-(d_list.count(b'\x00') /len(d_list)))

In [None]:
for i,fm in enumerate(fms):
	dfr_tuples = [(j+dmin, d) for j,d in enumerate(dfr[i])]
	print(fm)
	for j,d in dfr_tuples: print(j, d)

In [None]:
# plot dfr and print fault mode
import matplotlib.pylab as plt
# D = t_com.lvl.d
# dmin, dmax = int(D * 0.3), int(D * 1.8)

for i,fm in enumerate(fms):
	plt.xlim(dmin,dmax)
	plt.xlabel("Row Weight D")
	plt.ylabel("Decoding Failure Rate DFR")
	plt.plot([0 for _ in range(dmin)] + dfr[i])
	print(fm)
	plt.savefig(f"{fm.SK},{fm.PK},{fm.WK},{fm.Fault}.png")
	plt.show()

In [None]:
dfr

In [None]:
target.dis()
scope.dis()