# Hyperspace Jump Drive

In [122]:
SCOPETYPE = 'CWNANO'
PLATFORM = 'CWNANO'

In [123]:
%run "../setup/Setup_Generic.ipynb"



INFO: Caught exception on reconnecting to target - attempting to reconnect to scope first.
INFO: This is a work-around when USB has died without Python knowing. Ignore errors above this line.
INFO: Found ChipWhispererüòç


In [124]:
cw.program_target(scope, prog, "hyperspaceJumpDrive-{}.hex".format(PLATFORM))

Detected known STMF32: STM32F04xxx
Extended erase (0x44), this can take ten seconds or more
Attempting to program 13919 bytes at 0x8000000
STM32F Programming flash...
STM32F Reading flash...
Verified flash OK, 13919 bytes


In [125]:
import time
import chipwhisperer as cw
import numpy as np
from sklearn.ensemble import ExtraTreesClassifier

# =========================================================
# CONFIG
# =========================================================
REPEATS = 10          # traces per mask to average
SAMPLES = 2200
N_BYTES = 12
MIN_SEP = 180
LOCAL_SNAP = 6        # pre-ML snap
ML_SEARCH = 3         # how far ML can look around each peak
FEAT_WIN = 4          # feature half-window for ML



# =========================================================
# CW SETUP
# =========================================================
scope = cw.scope()
target = cw.target(scope, cw.targets.SimpleSerial)
scope.default_setup()
scope.adc.samples = SAMPLES

def reset_target():
    scope.io.nrst = 'low'
    time.sleep(0.05)
    scope.io.nrst = 'high_z'
    time.sleep(0.05)

# =========================================================
# TARGET CMDS
# =========================================================
def arm_system(seq: bytes):
    target.simpleserial_write('a', seq)
    resp = target.simpleserial_read('r', 17)
    return resp

def invert_polarity(mask: int):
    scope.arm()
    target.simpleserial_write('p', bytes([mask]))
    if scope.capture():
        raise RuntimeError("capture timed out")
    _ = target.simpleserial_read('r', 1)
    trace = scope.get_last_trace()
    return np.ravel(trace)

# =========================================================
# ACQ
# =========================================================
def acquire_traces(repeats=REPEATS):
    traces = []
    for m in range(256):
        print(f"[+] capturing mask={m}")
        reps = []
        for _ in range(repeats):
            reps.append(invert_polarity(m))
        reps = np.vstack(reps)
        traces.append(reps.mean(axis=0))
    return np.vstack(traces)

# =========================================================
# DPA
# =========================================================
def compute_dpas(traces: np.ndarray):
    n_traces, n_samples = traces.shape
    dpas = np.zeros((8, n_samples), dtype=np.float32)
    for bit in range(8):
        ones, zeros = [], []
        key_guess = 1 << bit
        for pt in range(256):
            lbl = ((pt ^ key_guess) >> bit) & 1
            (ones if lbl else zeros).append(traces[pt])
        ones = np.vstack(ones)
        zeros = np.vstack(zeros)
        dpas[bit] = ones.mean(axis=0) - zeros.mean(axis=0)
    return dpas

# =========================================================
# PEAK FINDING (your logic)
# =========================================================
def periodic_peaks(energy, n_peaks, search_start=0, first_window=200,
                   acf_min=50, acf_max=400):
    n = len(energy)
    if search_start + first_window > n:
        return None, None

    first = search_start + int(np.argmax(energy[search_start:search_start+first_window]))

    corr = np.correlate(energy, energy, mode='full')
    corr = corr[corr.size // 2:]

    acf_max = min(acf_max, len(corr) - 1)
    if acf_min >= acf_max:
        return None, None

    lag = acf_min + int(np.argmax(corr[acf_min:acf_max]))
    period = lag

    peaks = [first + i * period for i in range(n_peaks)]
    peaks = [p for p in peaks if p < n]
    if len(peaks) < n_peaks:
        return None, None
    return peaks, period

def greedy_peaks(energy, n_peaks, min_sep):
    energy = energy.copy()
    peaks = []
    for _ in range(n_peaks):
        idx = int(np.argmax(energy))
        peaks.append(idx)
        left = max(0, idx - min_sep)
        right = min(len(energy), idx + min_sep)
        energy[left:right] = 0.0
    peaks.sort()
    return peaks

def snap_peaks_to_local_max(peaks, energy, window=6):
    snapped = []
    n = len(energy)
    for p in peaks:
        s = max(0, p - window)
        e = min(n, p + window + 1)
        loc = np.argmax(energy[s:e])
        snapped.append(s + loc)
    snapped.sort()
    return snapped

def find_12_peaks(dpas):
    energy = np.mean(np.abs(dpas), axis=0)
    peaks, _ = periodic_peaks(energy, N_BYTES)
    if peaks is None:
        peaks = greedy_peaks(energy, N_BYTES, MIN_SEP)
    peaks = snap_peaks_to_local_max(peaks, energy, window=LOCAL_SNAP)
    return peaks

# =========================================================
# ML REFINER (with guardrail)
# =========================================================
def _make_local_feature(dpas, idx, win=FEAT_WIN):
    energy = np.mean(np.abs(dpas), axis=0)
    T = dpas.shape[1]
    feats = []
    for off in range(-win, win+1):
        j = np.clip(idx + off, 0, T-1)
        feats.append(energy[j])
    for bit in range(8):
        arr = np.abs(dpas[bit])
        for off in range(-win, win+1):
            j = np.clip(idx + off, 0, T-1)
            feats.append(arr[j])
    return np.array(feats, dtype=np.float32)

def train_peak_refiner(dpas, coarse_peaks, neg_samples=1500):

    T = dpas.shape[1]
    X, y = [], []

    # positives: around each coarse peak
    for p in coarse_peaks:
        for off in (-1, 0, 1):
            q = p + off
            if 0 <= q < T:
                X.append(_make_local_feature(dpas, q))
                y.append(1)

    # negatives: random far from peaks
    mask = np.zeros(T, dtype=bool)
    for p in coarse_peaks:
        s = max(0, p - 10)
        e = min(T, p + 11)
        mask[s:e] = True

    rng = np.random.default_rng(1234)
    neg_idx = []
    tries = 0
    while len(neg_idx) < neg_samples and tries < neg_samples * 3:
        r = int(rng.integers(0, T))
        tries += 1
        if not mask[r]:
            neg_idx.append(r)

    for r in neg_idx:
        X.append(_make_local_feature(dpas, r))
        y.append(0)

    X = np.vstack(X)
    y = np.array(y, dtype=np.int32)

    clf = ExtraTreesClassifier(
        n_estimators=400,
        max_features='sqrt',
        n_jobs=-1,
        class_weight='balanced',
        random_state=99,
    )
    clf.fit(X, y)
    return clf

def ml_refine_peaks_safe(dpas, coarse_peaks, clf, search=ML_SEARCH):
    """
    For each coarse peak:
      - look in [p-search, p+search]
      - score with ML
      - BUT keep the candidate only if its *energy* >= coarse_energy
    So ML can only improve, never worsen.
    """
    if clf is None:
        return coarse_peaks

    energy = np.mean(np.abs(dpas), axis=0)
    T = dpas.shape[1]
    final = []

    for p in coarse_peaks:
        base_energy = energy[p]
        best_idx = p
        best_energy = base_energy
        best_ml = 0.0

        for off in range(-search, search+1):
            q = p + off
            if q < 0 or q >= T:
                continue
            feat = _make_local_feature(dpas, q)
            proba = clf.predict_proba([feat])[0][1]
            q_energy = energy[q]

            # guardrail: only accept if energy is not worse
            if q_energy >= best_energy:
                # break ties using ML score
                if (q_energy > best_energy) or (proba > best_ml):
                    best_idx = q
                    best_energy = q_energy
                    best_ml = proba

        final.append(best_idx)

    # sort and dedupe but keep count == N_BYTES
    final = sorted(final)
    if len(final) != len(coarse_peaks):
        return coarse_peaks
    return final

# =========================================================
# EXTRACTION (yours)
# =========================================================
def extract_bits_from_fixed_peaks(dpas, fixed_peaks, flip_sign=False, reverse_bits=True):
    bits = ["" for _ in range(len(fixed_peaks))]
    for dpa in dpas:
        lsbs = []
        for peak in fixed_peaks:
            v = dpa[peak]
            if flip_sign:
                lsb = '1' if v < 0 else '0'
            else:
                lsb = '1' if v > 0 else '0'
            lsbs.append(lsb)
        for i, lsb in enumerate(lsbs):
            bits[i] += lsb

    key_bytes = []
    for b in bits:
        if reverse_bits:
            b = b[::-1]
        key_bytes.append(int(b, 2))
    return bytes(key_bytes)

# =========================================================
# MAIN
# =========================================================
def main():
    reset_target()

    print("[*] acquiring traces ...")
    traces = acquire_traces()

    print("[*] computing DPAs ...")
    dpas = compute_dpas(traces)

    print("[*] locating 12 peaks (signal-based) ...")
    coarse_peaks = find_12_peaks(dpas)
    print("[+] coarse peaks:", coarse_peaks)

    print("[*] training ML peak refiner ...")
    clf = train_peak_refiner(dpas, coarse_peaks)

    if clf is None:
        print("[!] sklearn not available, using coarse peaks")
        final_peaks = coarse_peaks
    else:
        final_peaks = ml_refine_peaks_safe(dpas, coarse_peaks, clf)
    print("[+] final peaks:", final_peaks)

    print("[*] extracting key ...")
    secret = extract_bits_from_fixed_peaks(
        dpas,
        final_peaks,
        flip_sign=False,
        reverse_bits=True,
    )
    print("[+] recovered bytes:", list(secret))
    print("[+] recovered hex  :", secret.hex())

    print("[*] sending ...")
    resp = arm_system(secret)
    print("[+] response:", resp.decode(errors="ignore"))

if __name__ == "__main__":
    main()




[*] acquiring traces ...
[+] capturing mask=0
[+] capturing mask=1
[+] capturing mask=2
[+] capturing mask=3
[+] capturing mask=4
[+] capturing mask=5
[+] capturing mask=6
[+] capturing mask=7
[+] capturing mask=8
[+] capturing mask=9
[+] capturing mask=10
[+] capturing mask=11
[+] capturing mask=12
[+] capturing mask=13
[+] capturing mask=14
[+] capturing mask=15
[+] capturing mask=16
[+] capturing mask=17
[+] capturing mask=18
[+] capturing mask=19
[+] capturing mask=20
[+] capturing mask=21
[+] capturing mask=22
[+] capturing mask=23
[+] capturing mask=24
[+] capturing mask=25
[+] capturing mask=26
[+] capturing mask=27
[+] capturing mask=28
[+] capturing mask=29
[+] capturing mask=30
[+] capturing mask=31
[+] capturing mask=32
[+] capturing mask=33
[+] capturing mask=34
[+] capturing mask=35
[+] capturing mask=36
[+] capturing mask=37
[+] capturing mask=38
[+] capturing mask=39
[+] capturing mask=40
[+] capturing mask=41
[+] capturing mask=42
[+] capturing mask=43
[+] capturing mas