# Alchemist Infuser

In [124]:
SCOPETYPE = 'CWNANO'
PLATFORM = 'CWNANO'

In [125]:
%run "../setup/Setup_Generic.ipynb"



INFO: Caught exception on reconnecting to target - attempting to reconnect to scope first.
INFO: This is a work-around when USB has died without Python knowing. Ignore errors above this line.
INFO: Found ChipWhispererüòç


In [126]:
cw.program_target(scope, prog, "alchemistInfuser-{}.hex".format(PLATFORM))

Detected known STMF32: STM32F04xxx
Extended erase (0x44), this can take ten seconds or more
Attempting to program 6707 bytes at 0x8000000
STM32F Programming flash...
STM32F Reading flash...
Verified flash OK, 6707 bytes


In [127]:
import time
import os
import numpy as np
import chipwhisperer as cw
from sklearn.ensemble import ExtraTreesClassifier


# ======================================================
# CONFIG
# ======================================================
REPEATS_PER_PT = 50      
SAMPLES         = 4500
FIRST_SEARCH_END = 500     # where to look for the very first op
MIN_SEP_FIRST    = 150     # separation for 1st 8 ops
SNAP_W           = 6       # snap to local max
WIN_DPA          = 50      # like your start=-50..+50
NEG_RADIUS       = 60      # when we train ML, how far from peaks to avoid
TAIL_OFFSET      = 200     # start ML search for 2nd block after last 1st-block peak
ML_WIN           = 12      # window (per side) for ML features
N_SECOND         = 8


# ======================================================
# CW SETUP
# ======================================================
scope = cw.scope()
target = cw.target(scope, cw.targets.SimpleSerial)
scope.default_setup()
scope.adc.samples = SAMPLES

def reset_target():
    scope.io.nrst = 'low'
    time.sleep(0.05)
    scope.io.nrst = 'high_z'
    time.sleep(0.05)

def encrypt_8(pt: bytes):
    scope.arm()
    target.simpleserial_write('e', pt)
    if scope.capture():
        raise RuntimeError("capture timed out")
    _ = target.simpleserial_read('r', 8)
    return scope.get_last_trace()

def verify_key(key16: bytes):
    target.simpleserial_write('c', key16)
    resp = target.simpleserial_read('r', 17)
    return resp

# ======================================================
# 1) ALIGNING ACQUISITION
# ======================================================
def acquire_aligned_traces(repeats=REPEATS_PER_PT):
    """
    For pt = 0..255, we send [pt]*8.
    For each pt we capture many traces, align on the *first* XOR template,
    then average. Returns (256, SAMPLES)
    """
    print("[*] building initial template for alignment ...")
    # grab a few traces of pt=0 to get template of 1st XOR
    tmp_traces = []
    for _ in range(5):
        tr = encrypt_8(bytes([0]*8))
        tmp_traces.append(np.ravel(tr))
    tmp_traces = np.vstack(tmp_traces)
    tmpl = tmp_traces.mean(axis=0)
    # find first strong thing in template
    search_end = min(FIRST_SEARCH_END, len(tmpl))
    ref_idx = int(np.argmax(np.abs(tmpl[:search_end])))
    # small window around ref
    ref_w = 30
    ref_start = max(0, ref_idx - ref_w)
    ref_end   = min(len(tmpl), ref_idx + ref_w + 1)
    ref_seg   = tmpl[ref_start:ref_end]

    traces = []
    for pt in range(256):
        print(f"[+] pt={pt}")
        reps = []
        msg = bytes([pt]*8)
        for _ in range(repeats):
            tr = encrypt_8(msg)
            tr = np.ravel(tr)
            # align: search in small region near expected
            search_range = 200  # should be enough for jitter
            cand = tr[ref_start:ref_start+search_range+ref_seg.size]
            # cross-correlation
            corr = np.correlate(cand, ref_seg, mode="valid")
            off = int(np.argmax(corr))
            shift = off
            # shift trace so that first op at ref_idx
            aligned = np.roll(tr, ref_idx - (ref_start + shift))
            reps.append(aligned)
        reps = np.vstack(reps)
        traces.append(reps.mean(axis=0))
    return np.vstack(traces)  # (256, S)

# ======================================================
# 2) GLOBAL DPAS (8 traces, whole length)
# ======================================================
def compute_global_dpas(traces: np.ndarray):
    """
    traces: (256, S)
    We consider plaintext byte = index (0..255) for ALL 8 bytes.
    DPA model: like your get_dpas but without windows.
    returns dpas: (8, S)
    """
    S = traces.shape[1]
    dpas = np.zeros((8, S), dtype=np.float32)

    pts = np.arange(256, dtype=np.uint8)  # pt byte value
    for bit in range(8):
        mask = 1 << bit
        # lsb of (pt ^ mask) at position 'bit'
        lsbs = ((pts ^ mask) >> bit) & 1
        ones_idx = np.where(lsbs == 1)[0]
        zeros_idx = np.where(lsbs == 0)[0]
        ones = traces[ones_idx].mean(axis=0)
        zeros = traces[zeros_idx].mean(axis=0)
        dpas[bit] = ones - zeros
    return dpas

# ======================================================
# 3) FIRST 8 PEAKS (signal-based)
# ======================================================
def find_first_8_peaks(dpas: np.ndarray):
    energy = np.mean(np.abs(dpas), axis=0)
    e = energy.copy()
    peaks = []
    for _ in range(8):
        idx = int(np.argmax(e[:2000]))   # first block is early
        peaks.append(idx)
        left = max(0, idx - MIN_SEP_FIRST)
        right = min(len(e), idx + MIN_SEP_FIRST)
        e[left:right] = 0.0
    # snap
    snapped = []
    for p in peaks:
        s = max(0, p - SNAP_W)
        r = min(len(energy), p + SNAP_W + 1)
        sidx = s + int(np.argmax(energy[s:r]))
        snapped.append(sidx)
    snapped.sort()
    return snapped   # list of 8 ints

# ======================================================
# 4) ML detector trained on first 8 ops
# ======================================================
def make_feat(dpas, center, win=ML_WIN):
    S = dpas.shape[1]
    feats = []
    for b in range(8):
        s = center - win
        r = center + win + 1
        if s < 0: s = 0
        if r > S: r = S
        seg = np.abs(dpas[b, s:r])
        if seg.size < 2*win+1:
            seg = np.pad(seg, (0, 2*win+1 - seg.size))
        feats.append(seg)
    return np.concatenate(feats).astype(np.float32)

def train_op_detector(dpas, first_peaks):
    X = []
    y = []
    S = dpas.shape[1]
    # positives: around each first-peak
    for p in first_peaks:
        for off in range(-3, 4):
            c = p + off
            if 0 <= c < S:
                X.append(make_feat(dpas, c))
                y.append(1)
    # negatives: random away from peaks
    rng = np.random.default_rng(1337)
    bad = set()
    for p in first_peaks:
        for off in range(-NEG_RADIUS, NEG_RADIUS+1):
            bad.add(p+off)
    bad = {x for x in bad if 0 <= x < S}
    while len(X) < 2000:
        c = int(rng.integers(0, S))
        if c in bad:
            continue
        X.append(make_feat(dpas, c))
        y.append(0)
    X = np.vstack(X)
    y = np.array(y, dtype=np.int32)
    clf = ExtraTreesClassifier(
        n_estimators=400,
        max_features='sqrt',
        n_jobs=-1,
        class_weight='balanced',
        random_state=42,
    )
    clf.fit(X, y)
    return clf

# ======================================================
# 5) FIND SECOND 8 OPS WITH ML
# ======================================================
def find_second_8_peaks(dpas, first_peaks, clf=None):
    energy = np.mean(np.abs(dpas), axis=0)
    start = first_peaks[-1] + TAIL_OFFSET
    S = dpas.shape[1]

    scores = np.zeros(S, dtype=np.float32)
    if clf is not None:
        for c in range(start, S-ML_WIN-1):
            feat = make_feat(dpas, c)
            scores[c] = clf.predict_proba([feat])[0,1]
    else:
        # fallback: template match with last first-peak
        tmpl_center = first_peaks[-1]
        tmpl = energy[tmpl_center-ML_WIN:tmpl_center+ML_WIN+1]
        for c in range(start, S-ML_WIN-1):
            seg = energy[c-ML_WIN:c+ML_WIN+1]
            if seg.size != tmpl.size:
                continue
            scores[c] = np.dot(seg, tmpl)

    # now greedy-pick 8 peaks from scores
    sec_peaks = []
    scr = scores.copy()
    for _ in range(8):
        idx = int(np.argmax(scr))
        sec_peaks.append(idx)
        left = max(start, idx - MIN_SEP_FIRST)
        right = min(S, idx + MIN_SEP_FIRST)
        scr[left:right] = 0.0
    snapped = []
    for p in sec_peaks:
        s = max(0, p - SNAP_W)
        r = min(S, p + SNAP_W + 1)
        sidx = s + int(np.argmax(energy[s:r]))
        snapped.append(sidx)
    snapped.sort()
    return snapped

# ======================================================
# 6) BUILD BLOCK DPAS
# ======================================================
def build_block_dpas(traces, block_peaks, key_prefix=None, win=WIN_DPA):
    """
    traces: (256, S)
    block_peaks: list of 8 ints (centers for 8 bytes)
    key_prefix: None for first block; 8-byte list for second
    returns: list of 8 DPAs, each length 2*win
    order matches block_peaks
    """
    dpas = []
    for byte_idx, center in enumerate(block_peaks):
        start = max(0, center - win)
        end   = min(traces.shape[1], center + win)
        # for each bit
        for guess_bit in range(8):
            mask = 1 << guess_bit
            ones = []
            zeros = []
            for pt_val in range(256):
                inter = pt_val ^ mask
                if key_prefix is not None:
                    inter ^= key_prefix[byte_idx]
                lsb = (inter >> guess_bit) & 1
                seg = traces[pt_val, start:end]
                if lsb:
                    ones.append(seg)
                else:
                    zeros.append(seg)
            ones = np.vstack(ones).mean(axis=0)
            zeros = np.vstack(zeros).mean(axis=0)
            dpas.append(ones - zeros)
    return dpas  # len 8*8 = 64, each shape (~100,)

# ======================================================
# 7) EXTRACT BYTES FROM BLOCK DPAS
# ======================================================
def extract_block_key(block_dpas, win=WIN_DPA):
    """
    block_dpas: list of 64 arrays (8 bytes √ó 8 bits)
    we take the max-abs point of each (since we centered on real op)
    and read sign
    return: 8-byte list
    """
    key = []
    # group every 8 dpas -> 1 byte
    for byte_i in range(8):
        bits = ""
        for bit in range(8):
            dpa = block_dpas[byte_i*8 + bit]
            # take center
            center = len(dpa)//2
            val = dpa[center]
            b = '1' if val > 0 else '0'
            bits += b
        bits = bits[::-1]
        key.append(int(bits, 2))
    return key

# ======================================================
# MAIN
# ======================================================
def main():
    reset_target()
    print("[*] acquiring + aligning traces ...")
    traces = acquire_aligned_traces()

    print("[*] computing global DPAs ...")
    dpas = compute_global_dpas(traces)

    print("[*] finding first 8 ops ...")
    first_peaks = find_first_8_peaks(dpas)
    print("[+] first block peaks:", first_peaks)

    print("[*] training ML op detector on first block ...")
    clf = train_op_detector(dpas, first_peaks)
    if clf is None:
        print("[!] sklearn not available, will use template fallback.")

    print("[*] finding second 8 ops (ML/template) ...")
    second_peaks = find_second_8_peaks(dpas, first_peaks, clf)
    print("[+] second block peaks:", second_peaks)

    # --------------------------------------------------
    # FIRST BLOCK DPA -> key[0..7]
    # --------------------------------------------------
    print("[*] building first-block DPAs ...")
    block1_dpas = build_block_dpas(traces, first_peaks, key_prefix=None, win=WIN_DPA)
    key_first8 = extract_block_key(block1_dpas, win=WIN_DPA)
    print("[+] key first 8:", key_first8)

    # --------------------------------------------------
    # SECOND BLOCK DPA -> key[8..15]
    # --------------------------------------------------
    print("[*] building second-block DPAs ...")
    block2_dpas = build_block_dpas(traces, second_peaks, key_prefix=key_first8, win=WIN_DPA)
    key_last8 = extract_block_key(block2_dpas, win=WIN_DPA)
    print("[+] key last 8:", key_last8)

    full_key = bytes(key_first8 + key_last8)
    print("\n[+] recovered key bytes:", list(full_key))
    print("[+] recovered key hex  :", full_key.hex())

    print("[*] verifying on target ...")
    resp = verify_key(full_key)
    print("[+] device says:", resp.decode(errors="ignore"))

if __name__ == "__main__":
    main()




[*] acquiring + aligning traces ...
[*] building initial template for alignment ...
[+] pt=0
[+] pt=1
[+] pt=2
[+] pt=3
[+] pt=4
[+] pt=5
[+] pt=6
[+] pt=7
[+] pt=8
[+] pt=9
[+] pt=10
[+] pt=11
[+] pt=12
[+] pt=13
[+] pt=14
[+] pt=15
[+] pt=16
[+] pt=17
[+] pt=18
[+] pt=19
[+] pt=20
[+] pt=21
[+] pt=22
[+] pt=23
[+] pt=24
[+] pt=25
[+] pt=26
[+] pt=27
[+] pt=28
[+] pt=29
[+] pt=30
[+] pt=31
[+] pt=32
[+] pt=33
[+] pt=34
[+] pt=35
[+] pt=36
[+] pt=37
[+] pt=38
[+] pt=39
[+] pt=40
[+] pt=41
[+] pt=42
[+] pt=43
[+] pt=44
[+] pt=45
[+] pt=46
[+] pt=47
[+] pt=48
[+] pt=49
[+] pt=50
[+] pt=51
[+] pt=52
[+] pt=53
[+] pt=54
[+] pt=55
[+] pt=56
[+] pt=57
[+] pt=58
[+] pt=59
[+] pt=60
[+] pt=61
[+] pt=62
[+] pt=63
[+] pt=64
[+] pt=65
[+] pt=66
[+] pt=67
[+] pt=68
[+] pt=69
[+] pt=70
[+] pt=71
[+] pt=72
[+] pt=73
[+] pt=74
[+] pt=75
[+] pt=76
[+] pt=77
[+] pt=78
[+] pt=79
[+] pt=80
[+] pt=81
[+] pt=82
[+] pt=83
[+] pt=84
[+] pt=85
[+] pt=86
[+] pt=87
[+] pt=88
[+] pt=89
[+] pt=90
[+] pt=91
[+] pt

In [113]:
import struct
import time
import random
import glob
import os

import chipwhisperer as cw
import matplotlib.pyplot as plt
import numpy as np

scope = cw.scope()
target = cw.target(scope, cw.targets.SimpleSerial)
scope.default_setup()
scope.adc.samples = 4500


def reset_target():
    scope.io.nrst = 'low'
    time.sleep(0.05)
    scope.io.nrst = 'high_z'
    time.sleep(0.05)


# simpleserial_addcmd('c', 16, verify);
def verify(key):
    target.simpleserial_write('c', bytes(key))
    response = target.simpleserial_read('r', 17)
    return response

key = [78, 97, 74, 45, 85, 103, 82, 100, 50, 112, 88, 107, 56, 118, 53, 115]

response = verify(key)
print(response)



bytearray(b'a1c{Wh1teDragonT}')
