# RPA-based reverse-engineering
This notebook showcases the RPA-based reverse-engineering technique for scalar multipliers.
 - [Exploration](#Exploration)
 - [Reverse-engineering](#Reverse-engineering)
   - [Oracle simulation](#Oracle-simulation)
   - [Method simulation](#Method-simulation)

In [None]:
from collections import Counter
from math import sqrt
import numpy as np
import holoviews as hv
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
from functools import partial, lru_cache
from scipy.stats import bernoulli
from concurrent.futures import ProcessPoolExecutor, as_completed

from IPython.display import HTML, display
from tqdm.auto import tqdm, trange
import tabulate
from anytree import LevelOrderGroupIter, RenderTree

from pyecsca.ec.model import ShortWeierstrassModel
from pyecsca.ec.coordinates import AffineCoordinateModel
from pyecsca.ec.curve import EllipticCurve
from pyecsca.ec.params import DomainParameters, get_params
from pyecsca.ec.formula import FormulaAction
from pyecsca.ec.point import Point
from pyecsca.ec.mod import Mod
from pyecsca.ec.mult import *
from pyecsca.misc.utils import silent
from pyecsca.sca.trace.sampling import downsample_average, downsample_max
from pyecsca.sca.trace.process import normalize, rolling_mean
from pyecsca.sca.trace.combine import average, subtract
from pyecsca.sca.trace.test import welch_ttest
from pyecsca.sca.attack.leakage_model import HammingWeight, NormalNoice
from pyecsca.ec.context import DefaultContext, local
from pyecsca.sca.re.rpa import MultipleContext, rpa_distinguish, RPA
from pyecsca.sca.trace import Trace
from pyecsca.sca.trace.plot import plot_trace, plot_traces

In [None]:
%matplotlib ipympl
hv.extension("bokeh")

In [None]:
model = ShortWeierstrassModel()
coordsaff = AffineCoordinateModel(model)
coords = model.coordinates["projective"]
add = coords.formulas["add-2007-bl"]  # The formulas are irrelevant for this method
dbl = coords.formulas["dbl-2007-bl"]
neg = coords.formulas["neg"]

# A 64-bit prime order curve for testing things out
p = 0xc50de883f0e7b167
a = Mod(0x4833d7aa73fa6694, p)
b = Mod(0xa6c44a61c5323f6a, p)
gx = Mod(0x5fd1f7d38d4f2333, p)
gy = Mod(0x21f43957d7e20ceb, p)
n = 0xc50de885003b80eb
h = 1

# A (0, y) RPA point on the above curve, in affine coords.
P0_aff = Point(coordsaff, x=Mod(0, p), y=Mod(0x1742befa24cd8a0d, p))

infty = Point(coords, X=Mod(0, p), Y=Mod(1, p), Z=Mod(0, p))
g = Point(coords, X=gx, Y=gy, Z=Mod(1, p))

curve = EllipticCurve(model, coords, p, infty, dict(a=a,b=b))
params = DomainParameters(curve, g, n, h)

## Exploration
First select a bunch of multipliers. We will be trying to distinguish among these.

In [None]:
multipliers = [
    LTRMultiplier(add, dbl, None, False, AccumulationOrder.PeqPR, True, True),
    LTRMultiplier(add, dbl, None, True, AccumulationOrder.PeqPR, True, True),
    RTLMultiplier(add, dbl, None, False, AccumulationOrder.PeqPR, True),
    RTLMultiplier(add, dbl, None, True, AccumulationOrder.PeqPR, False),
    SimpleLadderMultiplier(add, dbl, None, True, True),
    BinaryNAFMultiplier(add, dbl, neg, None, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True),
    BinaryNAFMultiplier(add, dbl, neg, None, ProcessingDirection.RTL, AccumulationOrder.PeqPR, True),
    WindowNAFMultiplier(add, dbl, neg, 3, None, AccumulationOrder.PeqPR, True, True),
    WindowNAFMultiplier(add, dbl, neg, 4, None, AccumulationOrder.PeqPR, True, True),
    WindowNAFMultiplier(add, dbl, neg, 5, None, AccumulationOrder.PeqPR, True, True),
    WindowBoothMultiplier(add, dbl, neg, 3, None, AccumulationOrder.PeqPR, True, True),
    WindowBoothMultiplier(add, dbl, neg, 4, None, AccumulationOrder.PeqPR, True, True),
    WindowBoothMultiplier(add, dbl, neg, 5, None, AccumulationOrder.PeqPR, True, True),
    SlidingWindowMultiplier(add, dbl, 3, None, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True),
    SlidingWindowMultiplier(add, dbl, 4, None, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True),
    SlidingWindowMultiplier(add, dbl, 5, None, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True),
    SlidingWindowMultiplier(add, dbl, 3, None, ProcessingDirection.RTL, AccumulationOrder.PeqPR, True),
    SlidingWindowMultiplier(add, dbl, 4, None, ProcessingDirection.RTL, AccumulationOrder.PeqPR, True),
    SlidingWindowMultiplier(add, dbl, 5, None, ProcessingDirection.RTL, AccumulationOrder.PeqPR, True),
    FixedWindowLTRMultiplier(add, dbl, 3, None, AccumulationOrder.PeqPR, True),
    FixedWindowLTRMultiplier(add, dbl, 4, None, AccumulationOrder.PeqPR, True),
    FixedWindowLTRMultiplier(add, dbl, 5, None, AccumulationOrder.PeqPR, True),
    FixedWindowLTRMultiplier(add, dbl, 8, None, AccumulationOrder.PeqPR, True),
    FixedWindowLTRMultiplier(add, dbl, 16, None, AccumulationOrder.PeqPR, True),
    FullPrecompMultiplier(add, dbl, None, True, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True, True),
    FullPrecompMultiplier(add, dbl, None, False, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True, True),
    BGMWMultiplier(add, dbl, 2, None, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True),
    BGMWMultiplier(add, dbl, 3, None, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True),
    BGMWMultiplier(add, dbl, 4, None, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True),
    BGMWMultiplier(add, dbl, 5, None, ProcessingDirection.LTR, AccumulationOrder.PeqPR, True),
    CombMultiplier(add, dbl, 2, None, AccumulationOrder.PeqPR, True),
    CombMultiplier(add, dbl, 3, None, AccumulationOrder.PeqPR, True),
    CombMultiplier(add, dbl, 4, None, AccumulationOrder.PeqPR, True),
    CombMultiplier(add, dbl, 5, None, AccumulationOrder.PeqPR, True)
]
print(len(multipliers))

Then select a random scalar and simulate computation using all of the multipliers, track the multiples, print the projective and affine results.

In [None]:
scalar = 0b1000000000000000000000000000000000000000000000000
scalar = 0b1111111111111111111111111111111111111111111111111
scalar = 0b1010101010101010101010101010101010101010101010101
scalar = 0b1111111111111111111111110000000000000000000000000
scalar = 123456789123456789
scarar = 8750920244948492046
# multiples is a mapping from a multiple (integer) to a set of scalar multipliers that compute said multiple when doing [scalar]P
multiples = {}

table = [["Multiplier", "multiples"]]

for mult in multipliers:
    with local(MultipleContext()) as ctx:
        mult.init(params, g)
        res = mult.multiply(scalar)
    for m in ctx.points.values():
        s = multiples.setdefault(m, set())
        s.add(mult)
    table.append([str(mult), str(list(ctx.points.values()))])

display(HTML(tabulate.tabulate(table, tablefmt="html", headers="firstrow")))

Pick a multiple `k` that is computed by some multiplier for the scalar,
invert it mod n, and do `[k^-1]P0` to obtain a point `P0_target`,
such that, `[k]P0_target = P0` and `P0` has a zero coordinate.

In [None]:
k = 108
kinv = Mod(k, n).inverse()
P0_target = curve.affine_multiply(P0_aff, int(kinv)).to_model(coords, curve)

print("Original P0", P0_aff)
print("P0_target  ", P0_target.to_affine())
print("Verify P0  ", curve.affine_multiply(P0_target.to_affine(), k))

Now go over the multipliers with P0_target and the original scalar as input.
Then look whether a zero coordinate point was computed.
Also look at whether the multiple "k" was computed. These two should be the same.

In [None]:
table = [["Multiplier", "zero present", "multiple computed"]]

for mult in multipliers:
    with local(MultipleContext()) as ctx:
        mult.init(params, P0_target)
        res = mult.multiply(scalar)
    zero = any(map(lambda P: P.X == 0 or P.Y == 0, ctx.points.keys()))
    multiple = k in ctx.points.values()
    table.append([str(mult), f"<b>{zero}</b>" if zero else zero, f"<b>{multiple}</b>" if multiple else multiple])

display(HTML(tabulate.tabulate(table, tablefmt="unsafehtml", headers="firstrow", colalign=("left", "center", "center"))))

Now lets look at the relation of multiples to multipliers.

In [None]:
table = [["Multiple", "Multipliers"]]
for multiple, mults in multiples.items():
    table.append([bin(multiple), [mult.__class__.__name__ for mult in mults]])

display(HTML(tabulate.tabulate(table, tablefmt="html", headers="firstrow")))

Note that all of the exploration so far was in a context of a fixed scalar. Even though for a given scalar some multipliers might be indistinguishable from the perspective of the multiples they compute, there may be other scalars that distinguish them.

## Reverse-engineering

### Oracle simulation
The `simulated_oracle` function simulates an RPA oracle that detect a zero coordinate point in the scalar multiplication.
This can be used by the `rpa_distinguish` function to distinguish the true scalar multiplier. The oracle is parametrized with the simulated multiplier index in the table of multipliers (it simulates this "real" multiplier). Furthermore, lets also examine a `noisy_oracle` (with a flip probability) and a `biased_oracle` (with asymmetric flip probability).

Note that the oracle has two additional parameters `measure_init` and `measure_multiply` which determine whether the oracle considers the zero coordinate point in scalar multiplier initialization (precomputation) and in scalar multiplier multiplication, respectively. This is important for scalar multipliers with precomputation as there one might be able to separate the precomputation and multiplication stages and obtain oracle answers on both separately.

In [None]:
def simulated_oracle(scalar, affine_point, simulate_mult_id=0, measure_init=True, measure_multiply=True, randomize=False):
    real_mult = multipliers[simulate_mult_id]
    point = affine_point.to_model(params.curve.coordinate_model, params.curve, randomized=randomize)
    
    # Simulate the multiplier init
    with local(MultipleContext()) as ctx:
        real_mult.init(params, point)
    init_points = set(ctx.parents.keys())
    init_parents = set(sum((ctx.parents[point] for point in init_points), []))
    # Did zero happen in some input point during the init?
    init_zero = any(map(lambda P: P.X == 0 or P.Y == 0, init_parents))
    
    # Simulate the multiplier multiply
    with local(ctx) as ctx:
        real_mult.multiply(scalar)
    all_points = set(ctx.parents.keys())
    multiply_parents = set(sum((ctx.parents[point] for point in all_points - init_points), []))
    # Did zero happen in some input point during the multiply?
    multiply_zero = any(map(lambda P: P.X == 0 or P.Y == 0, multiply_parents))
    real_result = (init_zero and measure_init) or (multiply_zero and measure_multiply)
    return real_result

def noisy_oracle(oracle, flip_proba=0):
    def noisy(*args, **kwargs):
        real_result = oracle(*args, **kwargs)
        change = bernoulli(flip_proba).rvs()
        return bool(real_result ^ change)
    return noisy

def biased_oracle(oracle, flip_0=0, flip_1=0):
    def biased(*args, **kwargs):
        real_result = oracle(*args, **kwargs)
        change = bernoulli(flip_1).rvs() if real_result else bernoulli(flip_0).rvs()
        return bool(real_result ^ change)
    return biased

We can see how the RPA-RE method distinguishes a given multiplier:

In [None]:
p256 = get_params("secg", "secp256r1", "projective")
res = rpa_distinguish(params, multipliers, simulated_oracle)

Let's see if the result is correct.

In [None]:
print(multipliers[0] in res)

We can also have a look at the distinguishing tree that the method builds for this set of multipliers.

In [None]:
re = RPA(set(multipliers))
with silent():
    re.build_tree(p256, tries=10)
print(re.tree.describe())

We can also look at the rough tree structure.

In [None]:
print(re.tree.render_basic())

#### What about (symmetric) noise?
Now we can examine how the method performs in the presence of noise and with various majority vote parameters. Note that the code below spawns several processes (`num_cores`) and saturates their CPU fully, so set this to something appropriate.

In [None]:
errs = (0, 0.1, 0.2, 0.3, 0.4, 0.5)
majs = (1, 3, 5, 7, 9, 11)
num_tries = 100

In [None]:
correct_tries = np.zeros((len(errs), len(majs)))
precise_tries = np.zeros((len(errs), len(majs)))
query_tries = np.zeros((len(errs), len(majs)))
total_tries = 0

In [None]:
num_cores = 30

def measure_mult(params, multipliers, simulated_oracle, i, mult, err, majority):
    correct = 0
    precise = 0
    calls = 0
    p = lru_cache(maxsize=2)(partial(simulated_oracle, simulate_mult_id=i))
    noisy = noisy_oracle(p, flip_proba=err)
    def oracle(scalar, affine_point):
        nonlocal calls
        calls += 1
        return noisy(scalar, affine_point)
    re = RPA(set(multipliers))
    re.build_tree(params, tries=10)
    for j in range(num_tries):
        res = re.run(oracle, majority=majority)
        if mult in res:
            correct += 1
        if len(res) == 1:
            precise += 1
    return correct, precise, calls

with silent(), ProcessPoolExecutor(max_workers=num_cores) as pool:
    futures = []
    args = []
    for i, mult in enumerate(multipliers):
        for err in errs:
            for majority in majs:
                a = (params, multipliers, simulated_oracle, i, mult, err, majority)
                futures.append(pool.submit(measure_mult, *a))
                args.append(a)
    results = [None for _ in futures]
    for future in tqdm(as_completed(futures), total=len(futures), smoothing=0):
        j = futures.index(future)
        a = args[j]
        results[j] = future.result()

Now we accumulate the results across the error rate and majority vote parameters.

In [None]:
for a, result in zip(args, results):
    i = errs.index(a[5])
    j = len(majs) - majs.index(a[6]) - 1
    correct_tries[i, j] += result[0]
    precise_tries[i, j] += result[1]
    query_tries[i, j] += result[2]
total_tries += num_tries

correct_rate = (correct_tries * 100) / (total_tries * len(multipliers))
precise_rate = (precise_tries * 100) / (total_tries * len(multipliers))
query_rate = query_tries / (total_tries * len(multipliers))

And save the results for later.

In [None]:
np.save("rpa_re_correct_rate", correct_rate)
np.save("rpa_re_precise_rate", precise_rate)
np.save("rpa_re_query_rate", query_rate)

We can plot several heatmaps:
 - One for the average number of queries to the oracle.
 - One for the success rate of the reverse-engineering.
 - One for the precision of the reverse-engineering.

In [None]:
fig, ax = plt.subplots()
im = ax.imshow(query_rate.T, cmap="plasma")
cbar_ax = fig.add_axes((0.85, 0.15, 0.04, 0.69))
cbar = fig.colorbar(im, cax=cbar_ax)
cbar.ax.set_ylabel("Oracle query rate", rotation=-90, va="bottom")

ax.set_xticks(np.arange(len(errs)), labels=errs)
ax.set_yticks(np.arange(len(majs)), labels=reversed(majs))
ax.set_xlabel("error probability")
ax.set_ylabel("majority vote")
for i in range(len(errs)):
    for j in range(len(majs)):
        text = ax.text(i, j, f"{query_rate[i, j]:.1f}",
                       ha="center", va="center", color="w" if i - j <= 2 else "black")
fig.savefig("rpa_re_query_rate.pdf", bbox_inches="tight")
plt.show()

In [None]:
fig, ax = plt.subplots()
im = ax.imshow(correct_rate.T, vmin=0, cmap="viridis")
cbar_ax = fig.add_axes((0.85, 0.15, 0.04, 0.69))
cbar = fig.colorbar(im, cax=cbar_ax)
cbar.ax.set_ylabel("Success rate", rotation=-90, va="bottom")
cbar.ax.axhline(100 / len(multipliers), color="red", linestyle="--")

ax.set_xticks(np.arange(len(errs)), labels=errs)
ax.set_yticks(np.arange(len(majs)), labels=reversed(majs))
ax.set_xlabel("error probability")
ax.set_ylabel("majority vote")
for i in range(len(errs)):
    for j in range(len(majs)):
        c_rate = correct_rate[i, j]
        text = ax.text(i, j, f"{c_rate:.1f}%",
                       ha="center", va="center", color="w" if c_rate < 80 else "black")
fig.savefig("rpa_re_success_rate.pdf", bbox_inches="tight")
plt.show()

In [None]:
fig, ax = plt.subplots()
im = ax.imshow(precise_rate.T, vmin=0, cmap="viridis")
cbar_ax = fig.add_axes((0.85, 0.15, 0.04, 0.69))
cbar = fig.colorbar(im, cax=cbar_ax)
cbar.ax.set_ylabel("Precision", rotation=-90, va="bottom")

ax.set_xticks(np.arange(len(errs)), labels=errs)
ax.set_yticks(np.arange(len(majs)), labels=reversed(majs))
ax.set_xlabel("error probability")
ax.set_ylabel("majority vote")
for i in range(len(errs)):
    for j in range(len(majs)):
        p_rate = precise_rate[i, j]
        text = ax.text(i, j, f"{p_rate:.1f}%",
                       ha="center", va="center", color="w" if p_rate < 80 else "black")
fig.savefig("rpa_re_precision.pdf", bbox_inches="tight")
plt.show()

Another way to look at these metrics is a scatter plot.

In [None]:
fig, ax = plt.subplots()
ax.grid()
for i, err in enumerate(errs):
    qrs = query_rate[i, :]
    crs = correct_rate[i, :]
    ax.scatter(qrs, crs, label=f"error = {err}")
ax.set_xlabel("oracle queries")
ax.set_ylabel("success rate")
ax.legend()
fig.savefig("rpa_re_scatter.pdf", bbox_inches="tight")
plt.show()

In [None]:
fig, ax = plt.subplots()
for i, err in enumerate(errs):
    crs = correct_rate[i, :]
    ax.plot(list(reversed(majs)), crs, label=f"error = {err}")
ax.set_xlabel("majority vote")
ax.set_ylabel("success rate")
ax.set_xticks(majs)
ax.legend()
fig.savefig("rpa_re_plot.pdf", bbox_inches="tight")
plt.show()

#### What about (asymmetric) noise?

In [None]:
correct_tries_b = np.zeros((len(errs), len(errs), len(majs)))
precise_tries_b = np.zeros((len(errs), len(errs), len(majs)))
query_tries_b =   np.zeros((len(errs), len(errs), len(majs)))
total_tries_b = 0

num_tries_b = 100

In [None]:
num_cores = 30

def measure_mult(params, multipliers, simulated_oracle, i, mult, err_0, err_1, majority):
    correct = 0
    precise = 0
    calls = 0
    p = lru_cache(maxsize=2)(partial(simulated_oracle, simulate_mult_id=i))
    biased = biased_oracle(p, flip_0=err_0, flip_1=err_1)
    def oracle(scalar, affine_point):
        nonlocal calls
        calls += 1
        return biased(scalar, affine_point)
    re = RPA(set(multipliers))
    re.build_tree(params, tries=10)
    for j in range(num_tries_b):
        res = re.run(oracle, majority=majority)
        if mult in res:
            correct += 1
        if len(res) == 1:
            precise += 1
    return correct, precise, calls

with silent(), ProcessPoolExecutor(max_workers=num_cores) as pool:
    futures = []
    args = []
    for i, mult in enumerate(multipliers):
        for err_0 in errs:
            for err_1 in errs:
                for majority in majs:
                    a = (params, multipliers, simulated_oracle, i, mult, err_0, err_1, majority)
                    futures.append(pool.submit(measure_mult, *a))
                    args.append(a)
    results = [None for _ in futures]
    for future in tqdm(as_completed(futures), total=len(futures), smoothing=0):
        j = futures.index(future)
        a = args[j]
        results[j] = future.result()

Now we accumulate the results across the error rate and majority vote parameters.

In [None]:
for a, result in zip(args, results):
    i = errs.index(a[5])
    j = errs.index(a[6])
    k = majs.index(a[7])
    correct_tries_b[i, j, k] += result[0]
    precise_tries_b[i, j, k] += result[1]
    query_tries_b[i, j, k] += result[2]
total_tries_b += num_tries_b

correct_rate_b = (correct_tries_b * 100) / (total_tries_b * len(multipliers))
precise_rate_b = (precise_tries_b * 100) / (total_tries_b * len(multipliers))
query_rate_b = query_tries_b / (total_tries_b * len(multipliers))

And save the results for later.

In [None]:
np.save("rpa_re_correct_rate_b", correct_rate_b)
np.save("rpa_re_precise_rate_b", precise_rate_b)
np.save("rpa_re_query_rate_b", query_rate_b)

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=3, sharex="col", sharey="row")
vmin = np.min(query_rate_b)
vmax = np.max(query_rate_b)

for row in range(2):
    for col in range(3):
        ax = axs[row, col]
        level = row * 3 + col
        im = ax.imshow(query_rate_b[::-1,:,level], cmap="plasma", vmin=vmin, vmax=vmax)
        ax.set_xticks(np.arange(len(errs)), labels=errs)
        ax.set_yticks(np.arange(len(errs)), labels=list(reversed(errs)))
        for i in range(len(errs)):
            for j in range(len(errs)):
                q_rate = f"{query_rate_b[i, len(errs) - j - 1, level]:.0f}"
                loc = f"{errs[i]} {errs[j]}"
                text = ax.text(i, j, q_rate, ha="center", va="center")
        ax.set_xlabel("$e_1$")
        ax.set_ylabel("$e_O$")
        ax.set_title(majs[level])
fig.set_size_inches((10,6))
fig.tight_layout(h_pad=1.5, rect=(0, 0, 0.9, 1))
cbar_ax = fig.add_axes((0.9, 0.10, 0.02, 0.84))
cbar = fig.colorbar(im, cax=cbar_ax)
cbar.ax.set_ylabel("Oracle query rate", rotation=-90, va="bottom")
fig.savefig("rpa_re_asymmetric_query_rate.pdf")
plt.show()

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=3, sharex="col", sharey="row")
for row in range(2):
    for col in range(3):
        ax = axs[row, col]
        level = row * 3 + col
        im = ax.imshow(correct_rate_b[::-1,:,level], cmap="viridis", vmin=0, vmax=100)
        ax.set_xticks(np.arange(len(errs)), labels=errs)
        ax.set_yticks(np.arange(len(errs)), labels=list(reversed(errs)))
        for i in range(len(errs)):
            for j in range(len(errs)):
                c = correct_rate_b[i, len(errs) - j - 1, level]
                c_rate = f"{c:.0f}%"
                loc = f"{errs[i]} {errs[j]}"
                text = ax.text(i, j, c_rate, ha="center", va="center", color="w" if c < 50 else "black")
        ax.set_xlabel("$e_1$")
        ax.set_ylabel("$e_O$")
        ax.set_title(majs[level])
fig.set_size_inches((10,6))
fig.tight_layout(h_pad=1.5, rect=(0, 0, 0.9, 1))
cbar_ax = fig.add_axes((0.9, 0.10, 0.02, 0.84))
cbar = fig.colorbar(im, cax=cbar_ax)
cbar.ax.set_ylabel("Success rate", rotation=-90, va="bottom")
cbar.ax.axhline(100 / len(multipliers), color="red", linestyle="--")
fig.savefig("rpa_re_asymmetric_success_rate.pdf")
plt.show()

In [None]:
fig, ax = plt.subplots()
ax.grid()
crs_accumulated = {}
for i, err_0 in enumerate(errs):
    for j, err_1 in enumerate(errs):
        crs = correct_rate_b[i, j, :]
        total_err = round(err_0 + err_1, 1)
        l = crs_accumulated.setdefault(total_err, [])
        l.append(crs)
        #ax.scatter(majs, crs, label=str(err_0 + err_1))
for total_err in crs_accumulated.keys():
    crs = np.mean(crs_accumulated[total_err], axis=0)
    ax.plot(majs, crs, label=f"total_error = {total_err}")
ax.set_xticks(majs)
ax.set_xlabel("majority")
ax.set_ylabel("success rate")
ax.legend(bbox_to_anchor=(1, 1.02))
fig.tight_layout()
plt.show()

### Method simulation

The `simulate_trace` function simulates a Hamming weight leakage trace of a given multiplier computing a scalar multiple.
This is used by the `simulated_rpa_trace` function that does the RPA attack on simulated traces and returns the differential
trace. This is in turn used to build the `simulated_rpa_oracle` which can be used by the `rpa_distinguish` function to perform
RPA-RE and distinguish the true scalar multiplier. The oracle is parametrized with the simulated multiplier index in the table of multipliers (it simulates this "real" multiplier).

In [None]:
def simulate_trace(mult, scalar, point):
    with local(DefaultContext()) as ctx:
        mult.init(params, point)
        mult.multiply(scalar)

    lm = HammingWeight()
    trace = []

    def callback(action):
        if isinstance(action, FormulaAction):
            for intermediate in action.op_results:
                leak = lm(intermediate.value)
                trace.append(leak)

    ctx.actions.walk(callback)
    return Trace(np.array(trace))

def simulated_rpa_trace(mult, scalar, affine_point, noise):
    target_point = affine_point.to_model(params.curve.coordinate_model, params.curve)
    random_point = params.curve.affine_random().to_model(params.curve.coordinate_model, params.curve)

    random_traces = [noise(simulate_trace(mult, scalar, random_point)) for _ in range(10)]
    target_traces = [noise(simulate_trace(mult, scalar, target_point)) for _ in range(500)]

    random_avg = average(*random_traces)
    target_avg = average(*target_traces)

    diff_trace = downsample_max(subtract(random_avg, target_avg), 25)
    return diff_trace

def simulated_rpa_oracle(scalar, affine_point, simulate_mult_id = 0):
    real_mult = multipliers[simulate_mult_id]
    noise = NormalNoice(0, 1)
    diff_trace = normalize(simulated_rpa_trace(real_mult, scalar, affine_point, noise))
    peaks, props = find_peaks(diff_trace.samples, height=4)
    return len(peaks) != 0

In [None]:
table = [["True multiplier", "Reversed", "Correct", "Remaining"]]
with TemporaryConfig() as cfg:
    cfg.log.enabled = False
    for i, mult in tqdm(enumerate(multipliers), total=len(multipliers)):
        res = rpa_distinguish(params, multipliers, partial(simulated_rpa_oracle, simulate_mult_id = i))
        table.append([mult, res, mult in res, len(res)])
display(HTML(tabulate.tabulate(table, tablefmt="html", headers="firstrow")))

Note that the oracle function above has several parameters, like noise amplitude, amount of traces simulated, amount of downsampling and peak finding height threshold. The cell below compares the differential RPA trace when the multiple is computed in the simulation vs when it is not.

In [None]:
diff_real = normalize(simulated_rpa_trace(multipliers[0], scalar, P0_target.to_affine(), NormalNoice(0, 1)))
diff_nothing = normalize(simulated_rpa_trace(multipliers[7], scalar, P0_target.to_affine(), NormalNoice(0, 1)))
plot_traces(diff_real, diff_nothing).opts(responsive=True, height=600)