# EPA-based reverse engineering

In [None]:
import io
import tabulate
import secrets
from tqdm.notebook import tqdm, trange
from itertools import product
from IPython.display import HTML, display
from sympy.ntheory import factorint
from sympy.ntheory.modular import crt

from pyecsca.ec.model import ShortWeierstrassModel
from pyecsca.ec.coordinates import AffineCoordinateModel
from pyecsca.ec.curve import EllipticCurve
from pyecsca.ec.params import DomainParameters, load_params_ectester
from pyecsca.ec.mod import Mod, miller_rabin, gcd
from pyecsca.ec.point import Point, InfinityPoint
from pyecsca.ec.error import NonInvertibleError
from pyecsca.ec.mult import LTRMultiplier, AccumulationOrder
from pyecsca.ec.formula.fake import FakeAdditionFormula, FakeDoublingFormula, FakePoint
from pyecsca.sca.re.rpa import MultipleContext
from pyecsca.sca.re.zvp import unroll_formula_expr
from pyecsca.ec.context import local
from pyecsca.ec.error import UnsatisfiedAssumptionError

In [None]:
curves = [
    # phi(p)/p =
    # 0.8286039438617044
    "dfb2da5e1b7bd7bb098cb975966293ed,d9c4372806e8131b18d0036e8f832749,bcae41be8e808acdc04bb769dead91e2,0e2f983c0f852bef381f567448f0d488,1599bba77ed1cb8dec41555098958492,10fcabd48fffc71e6300d44acc236157d,0001",
    # 0.633325621696952
    "cca6f6718a06cad7094962b2a35f067d,67aa9464eb493fbb7b509d29381b9a9d,cafc69aa517b654a6a608644996cc8d1,4c092beb06cc00751eec39675f680cb8,82800378a47dd6f26ff6a50f69e4c4e6,18a22d20b6de3ff6bdc49329c21163f77,0001",
    # 0.8508806646440022 
    "b3755d654bad73114e4191e9f5f36af9,9fe4f88cfbacba71f4b767ace8580c74,4610526fdcfbd69aed453ac2ee6efeef,542d8e0bbafe40dae36f25cbc350add6,68a65f5a5dc304bfd0d8fe963c250206,118a34a1ea295e78b3a3c960b6f680ee1,0001",
    
    # 0.9845701775215489 (has a = 0 for a subcurve)
    "de1406450d5d7e91d81907956019c0c1,5fbe46b9f1086011e18f5d823c6110ce,a859c36ceeadb39c7a978f7b1b0563ee,1cba89c3f099c29401ecf3fe1806e822,345d7282a0114070be91f95fe3db1faa,0fcd24d24e57a40547814b6766b9ea735,0001",
    # 0.980582605794486  (does not have a = 0 for any subcurve)
    "cab298b495875d4ab2c8ee3eb03016a7,a7c4f56f286d9eae44424c85c8b2fcb9,5e8c439d939273fdcb5503acbda7d3f8,816c9f865c831223067a88046bf00d75,972ce29ed18d5d73f15cef31187659be,0b0e97ff8c3e72e7ae75eb3f5e759fe03,0001",
    
    # 0.9547100843537808
    "f1a8a441b6d0e9600e33ccf16f9b8291,b3f55185bd6a63528e3d560c6a7b729a,c2fee2d65350e870eda0ac5e2b96b810,29b3e793822fad03a3c2ebca3cf62c12,b937d5389b6c5d0212d0f53e26843092,1153442389f9e1da8dd130bc93c6ef42b,0001",
    # 0.7214369438844093
    "a4dfa4b6b065c40b45980474266c9fbb,2c3486e725755b44a7c119473c5b9c64,329078ab070fc18edc6ce53047e00a39,9f6209be91b66943d9e8e0b61c4aae4e,05271c9ac628351b9add9e1be69a9fa4,0cef2e52ffe86ebc6dd323912ac7d9a87,0001",
    # 0.4716485170445178 (160-bit)
    "db49063db56b7783fa01dd62077c5a88dfa28009,aee572fdd4790bcd4729bb3b612b52a573df46e9,dab9e68366a593ca1df9cb2f20890a578729d6ef,d4a3aaf43bdb25be7c308b69ae54f639e6e32e8c,7b6c82140bb427ac6e2a64507f60775949b2c8ce,34a9fbe62b272f930b2e5027780a32300feb0dd8f,0001"
]

In [None]:
model = ShortWeierstrassModel()
affine = AffineCoordinateModel(model)
which = "projective"
coords = model.coordinates[which]

params = load_params_ectester(io.BytesIO(curves[4].encode()), which)
curve = params.curve
p = params.curve.prime
g = params.generator
n = params.order

In [None]:
adds = list(filter(lambda formula: formula.name.startswith("add"), coords.formulas.values()))
dbls = list(filter(lambda formula: formula.name.startswith("dbl"), coords.formulas.values()))
formula_pairs = list(product(adds, dbls))

fake_add = FakeAdditionFormula(params.curve.coordinate_model)
fake_dbl = FakeDoublingFormula(params.curve.coordinate_model)
fake_mult = LTRMultiplier(fadd, fdbl, None, False, AccumulationOrder.PeqPR, True, True)
fake_mult.init(params, FakePoint(params.curve.coordinate_model))

In [None]:
def random_scalar(n):
    """Generate a random scalar mod n."""
    return secrets.randbelow(n)

def random_scalar_trivial(n):
    """Generate a random scalar with trivial gcd mod n."""
    scalar = secrets.randbelow(n)
    while gcd(scalar, n) != 1:
        scalar = secrets.randbelow(n)
    return scalar

def random_scalar_fully_trivial(n, mult):
    """Generate a random scalar with trivial gcd mod n, and also ensure that the given mult computes only multiples with trivial gcd mod n."""
    scalar = random_scalar_trivial(n)
    while True:
        with local(MultipleContext()) as ctx:
            mult.multiply(scalar)
        if all(map(lambda x: gcd(x, n) == 1, ctx.points.values())):
            return scalar
        scalar = random_scalar_trivial(n)

def fixed_point(params):
    return params.generator

def random_point(splitted, top, randomized=False):
    results = {}
    for factor, params in splitted.items():
        results[factor] = params.curve.affine_random()
    factors = list(results.keys())
    xs = list(map(lambda factor: int(results[factor].x), factors))
    ys = list(map(lambda factor: int(results[factor].y), factors))
    res_x = Mod(int(crt(factors, xs)[0]), top.curve.prime)
    res_y = Mod(int(crt(factors, ys)[0]), top.curve.prime)
    res = Point(affine, x=res_x, y=res_y)
    return res.to_model(top.curve.coordinate_model, top.curve, randomized=randomized)

In [None]:
def project_down(point, subcurve):
    return Point(subcurve.coordinate_model, **{name: Mod(int(value), subcurve.prime) for name, value in point.coords.items()})

def lift_up(point, topcurve):
    return Point(topcurve.coordinate_model, **{name: Mod(int(value), topcurve.prime) for name, value in point.coords.items()})

def split_params(params):
    factors = factorint(params.curve.prime)
    if set(factors.values()) != {1}:
        raise ValueError("Not squarefree")
    results = {}
    # Construct the curves
    for factor in sorted(factors.keys()):
        p_i = factor
        parameters_i = {name: Mod(int(value), p_i) for name, value in params.curve.parameters.items()}
        curve_i = EllipticCurve(params.curve.model, params.curve.coordinate_model, p_i, params.curve.neutral, parameters_i)
        generator_i = project_down(params.generator, curve_i)
        params_i = DomainParameters(curve_i, generator_i, 0, 1)
        results[factor] = params_i
    # Now map the orders to the curves
    orders = list(factorint(params.order).keys())
    orders.sort()
    for factor_i, params_i in results.items():
        for order in orders:
            try:
                params_i.curve.affine_multiply(params_i.generator.to_affine(), order)
            except NonInvertibleError:
                params_i.order = order
                orders.remove(order)
                break
    return results

def split_scalarmult(splitted, top, point, scalar):
    results = {}
    for factor, params in splitted.items():
        order = params.order
        projected = project_down(point, params.curve)
        partial_scalar = scalar % order
        if partial_scalar == 0:
            result = InfinityPoint(params.curve.coordinate_model)
        else:
            result = params.curve.affine_multiply(projected.to_affine(), partial_scalar)
        results[factor] = result
    if any(map(lambda point: isinstance(point, InfinityPoint), results.values())):
        return InfinityPoint(top.curve.coordinate_model)
    factors = list(results.keys())
    xs = list(map(lambda factor: int(results[factor].x), factors))
    ys = list(map(lambda factor: int(results[factor].y), factors))
    res_x = Mod(int(crt(factors, xs)[0]), top.curve.prime)
    res_y = Mod(int(crt(factors, ys)[0]), top.curve.prime)
    return Point(affine, x=res_x, y=res_y)

In [None]:
split = split_params(params)
scalars = [random_scalar_trivial(n) for _ in trange(50, desc="Generate scalars")]
points = [random_point(split, params, randomized=True) for _ in trange(50, desc="Generate points")]
results = []
chains = []


gcds = []
fgcds = []
for scalar, point in tqdm(zip(scalars, points), desc="Precomp", total=len(scalars)):
    try:
        result = split_scalarmult(split, params, point, scalar)
    except NonInvertibleError:
        result = None
    results.append(result)
    with local(MultipleContext()) as ctx:
        fake_mult.multiply(scalar)
    chains.append(list(ctx.points.values()))
    scalar_trivial_gcd = gcd(scalar, n) == 1
    all_subscalars_trivial_gcd = all(map(lambda x: gcd(x, n) == 1, ctx.points.values()))
    gcds.append(scalar_trivial_gcd)
    fgcds.append(all_subscalars_trivial_gcd)


table = [["Pair", "scalars with trivial gcd", "scalars with all multiples with trivial gcds", "scalars with invertible final zs", "scalars with all multiples's zs invertible", "scalars with correct result"]]
pair_table = [[None for _ in dbls] for _ in adds]
for pair in tqdm(formula_pairs):
    mult = LTRMultiplier(*pair, None, False, AccumulationOrder.PeqPR, True, True)
    inv = []
    correct = []
    zs = []
    for scalar, point, result in tqdm(zip(scalars, points, results), leave=None, total=len(scalars)):
        mult.init(params, point)
        with local(MultipleContext()) as ctx:
            res = mult.multiply(scalar)
        
        all_submultiples_invertible_z = all(map(lambda x: gcd(int(x.Z), p) == 1, ctx.points.keys()))
        result_invertible_z = False
        result_correct = False
        try:
            res_aff = res.to_affine()
            result_invertible_z = True
            if res_aff == result:
                result_correct = True
        except NonInvertibleError as e:
            pass
        zs.append(all_submultiples_invertible_z)
        inv.append(result_invertible_z)
        correct.append(result_correct)
    pair_table[adds.index(pair[0])][dbls.index(pair[1])] = sum(inv)
    for i in inv:
        print("x" if i else ".", end="")
    print()
    table.append([f"{pair[0].name}, {pair[1].name}", sum(gcds), sum(fgcds), sum(inv), sum(zs), sum(correct)])
for pl, add in zip(pair_table, adds):
    pl.insert(0, add.name)
pair_table.insert(0, [None] + [dbl.name for dbl in dbls])
display(HTML(tabulate.tabulate(table, tablefmt="html", headers="firstrow")))
display(HTML(tabulate.tabulate(pair_table, tablefmt="html", headers="firstrow")))

In [None]:
def simulate_epa_oracle(affine_params, affine_point, scalar):
    real_coords = model.coordinates["projective"]
    real_add = real_coords.formulas["add-2007-bl"]
    real_dbl = real_coords.formulas["dbl-2007-bl"]
    real_mult = LTRMultiplier(real_add, real_dbl, None, False, AccumulationOrder.PeqPR, True, True)
    params = affine_params.to_coords(real_coords)
    point = affine_point.to_model(real_coords, params.curve)
    real_mult.init(params, point)
    res = real_mult.multiply(scalar)
    try:
        res.to_affine()
        return True
    except NonInvertibleError as e:
        return False

def epa_distinguish(oracle, mult_factory):
    affine_params = load_params_ectester(io.BytesIO(curves[3].encode()), "affine")
    model = affine_params.curve.model
    scalars = [int(Mod.random(affine_params.order)) for _ in range(100)]
    responses = [oracle(affine_params,  affine_params.generator, scalar) for scalar in scalars]
    candidates = set()
    print("Got responses")
    total = 0
    for coords in model.coordinates.values():
        adds = list(filter(lambda formula: formula.name.startswith("add"), coords.formulas.values()))
        dbls = list(filter(lambda formula: formula.name.startswith("dbl"), coords.formulas.values()))
        formula_pairs = list(product(adds, dbls))
        total += len(formula_pairs)
        try:
            params = affine_params.to_coords(coords)
        except UnsatisfiedAssumptionError:
            print(f"Skipping {coords.name}, does not fit")
            continue

        for pair in formula_pairs:
            mult = mult_factory(*pair)
            mult.init(params, params.generator)
            abort = False
            print(f"Trying {coords.name} {pair[0].name} {pair[1].name}", end="") 
            for scalar, target in zip(scalars, responses):
                res = mult.multiply(scalar)
                try:
                    res.to_affine()
                    if not target:
                        # not this one
                        abort = True
                        break
                except NonInvertibleError as e:
                    if target:
                        # not this one
                        abort = True
                        break
            if abort:
                print(" not")
                continue
            else:
                print(" candidate")
                candidates.add((coords.name, pair[0].name, pair[1].name))
    print(f"Got {len(candidates)} out of {total} total")
    return candidates
                

In [None]:
c = epa_distinguish(simulate_epa_oracle, lambda add,dbl:LTRMultiplier(add, dbl, None, False, AccumulationOrder.PeqPR, True, True))

In [None]:
split = split_params(params)
split

In [None]:
for add in adds:
    r = None
    for iv in unroll_formula_expr(add):
        if iv[0] == "Z3":
            r = iv[1]
    print(add, r)
    print("---")