In [None]:
params = 'SLH-DSA-SHAKE-256s'

In [None]:
import spxplus

slh = spxplus.SLH_DSA(params)
a = slh.a
d = slh.d
hp = slh.hp
n = slh.n
k = slh.k
lg_w = slh.lg_w
len1 = slh.len1
len2 = slh.len2
w = slh.w

In [None]:
# FIPS205 utility functions

from tools.cryptanalysis_worker import to_byte, to_int, base_2b, cksum
import math

In [None]:
from tools.cryptanalysis_worker import candidate_ratio, signable_messages, hash_complexity, cksum_chain

In [None]:
def test_signable():
    import math
    import itertools
    # define some test parameters
    lg_w = 2
    w = 2**lg_w
    len1 = 3
    len2 = math.floor(math.log2(len1*(w-1))/lg_w)+1
    
    msg_chains = [1,1,1]
    msg_cksum = cksum(msg_chains, w)
    cksum_chains = [0,1]
    
    # Brute-force: enumerate all messages and count valid forgeries. This is exponentiel in chain lengths
    triples = list(itertools.product(range(w), repeat=3))
    all_cksums = [(m, cksum_chain(m, w, len2, lg_w)) for m in triples]
    signable_cksums = [c for m, c in all_cksums if all(mi >= msg_chains[i] for i, mi in enumerate(m)) and all(ci >= cksum_chains[i] for i, ci in enumerate(c))]
    print("Signable cksums:", len(signable_cksums)/len(all_cksums))
    
    signable = signable_messages(msg_chains, cksum_chains, w, len1, len2)
    print(signable)
    signable = sum(val for (_, val) in signable)
    print(signable)
    total = w**len1
    print("Fraction of signable messages:", signable/total)
test_signable()

In [None]:
import concurrent.futures

# Add current directory to path so we can import the worker module
# sys.path.insert(0, os.path.dirname(os.path.abspath('.')))

# Import the worker functions from external module

n_experiments = 126*12*4
experiments = []
M = 4

# Prepare arguments for multiprocessing
args_list = [(i, M, lg_w, len1, len2, hp) for i in range(n_experiments)]

print(f"Running {n_experiments} experiments with multiprocessing...")

# Use ProcessPoolExecutor with the external worker function
with concurrent.futures.ProcessPoolExecutor() as executor:
    results = list(executor.map(candidate_ratio, args_list))
msg_chains = [r[0] for r in results]
cksum_chains = [r[1] for r in results]
ratios = [r[2] for r in results]
complexities = [math.log2(hash_complexity(hp, len1+len2, w) * 1.0/ratio) for ratio in ratios]

In [None]:
from statistics import fmean
from statistics import stdev, median
print(f"Average complexity: 2^{fmean(complexities)}")
print(f"Min: {min(complexities)}")
print(f"Max: {max(complexities)}")
print(f"Stddev: {stdev(complexities)}")
print(f"Median: {median(complexities)}")

In [None]:
from matplotlib import pyplot as plt


plt.figure(figsize=(10, 6))
plt.plot(complexities, marker='o', linestyle='-')
plt.title(f"Expected Attack Complexity (n={n_experiments}, M={M})")
plt.xlabel("Experiment Index")
plt.ylabel("Complexity (log2)")
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
import random


target_idx = random.randint(0, len(complexities)-1)
target_chains = msg_chains[target_idx] + cksum_chains[target_idx]
target_complexity = complexities[target_idx] - math.log2(hash_complexity(hp, len1+len2, w))
print("Expected attack complexity (log2):", target_complexity)
for i in range(math.ceil(2**(target_complexity))):
    candidate = [random.randint(0, w-1) for i in range(len1)]
    msg_cksum = cksum(candidate, w)
    msg_cksum <<= 4
    msg_cksum = to_byte(msg_cksum, (len2 * lg_w + 7) // 8)
    msg_cksum = base_2b(msg_cksum, lg_w, len2)
    chains = candidate + msg_cksum
    if all(c >= t for c, t in zip(chains, target_chains)):
        print(f"Found match after 2^{math.log2(i)} repetitions")
        break
    