Skip to content

Commit

Permalink
Move to function based Mod dispatch.
Browse files Browse the repository at this point in the history
  • Loading branch information
J08nY committed Jul 15, 2024
1 parent ba894fe commit 06e005a
Show file tree
Hide file tree
Showing 29 changed files with 251 additions and 256 deletions.
20 changes: 10 additions & 10 deletions pyecsca/ec/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from pyecsca.ec.coordinates import CoordinateModel, AffineCoordinateModel
from pyecsca.ec.error import raise_unsatisified_assumption
from pyecsca.ec.mod import Mod
from pyecsca.ec.mod import Mod, mod
from pyecsca.ec.model import CurveModel
from pyecsca.ec.point import Point, InfinityPoint

Expand All @@ -37,12 +37,12 @@ class EllipticCurve:
>>> from pyecsca.ec.coordinates import AffineCoordinateModel
>>> affine = AffineCoordinateModel(curve.model)
>>> points_P = sorted(curve.affine_lift_x(Mod(5, curve.prime)), key=lambda p: int(p.y))
>>> points_P = sorted(curve.affine_lift_x(mod(5, curve.prime)), key=lambda p: int(p.y))
>>> points_P # doctest: +NORMALIZE_WHITESPACE
[Point([x=5, y=31468013646237722594854082025316614106172411895747863909393730389177298123724] in shortw/affine),
Point([x=5, y=84324075564118526167843364924090959423913731519542450286139900919689799730227] in shortw/affine)]
>>> P = points_P[0]
>>> Q = Point(affine, x=Mod(106156966968002564385990772707119429362097710917623193504777452220576981858057, curve.prime), y=Mod(89283496902772247016522581906930535517715184283144143693965440110672128480043, curve.prime))
>>> Q = Point(affine, x=mod(106156966968002564385990772707119429362097710917623193504777452220576981858057, curve.prime), y=mod(89283496902772247016522581906930535517715184283144143693965440110672128480043, curve.prime))
>>> curve.affine_add(P, Q)
Point([x=110884201872336783252492544257507655322265785208411447156687491781308462893723, y=17851997459724035659875545393642578516937407971293368958749928013979790074156] in shortw/affine)
>>> curve.affine_multiply(P, 10)
Expand Down Expand Up @@ -94,7 +94,7 @@ def __init__(
if value.n != prime:
raise ValueError(f"Parameter {name} has wrong modulus.")
else:
value = Mod(value, prime)
value = mod(value, prime)
self.parameters[name] = value
self.neutral = neutral
self.__validate_coord_assumptions()
Expand Down Expand Up @@ -147,9 +147,9 @@ def _execute_base_formulas(self, formulas: List[Module], *points: Point) -> Poin
for line in formulas:
exec(compile(line, "", mode="exec"), None, locls) # exec is OK here, skipcq: PYL-W0122
if not isinstance(locls["x"], Mod):
locls["x"] = Mod(locls["x"], self.prime)
locls["x"] = mod(locls["x"], self.prime)
if not isinstance(locls["y"], Mod):
locls["y"] = Mod(locls["y"], self.prime)
locls["y"] = mod(locls["y"], self.prime)
return Point(AffineCoordinateModel(self.model), x=locls["x"], y=locls["y"])

def affine_add(self, one: Point, other: Point) -> Point:
Expand Down Expand Up @@ -234,9 +234,9 @@ def affine_neutral(self) -> Optional[Point]:
for line in self.model.base_neutral:
exec(compile(line, "", mode="exec"), None, locls) # exec is OK here, skipcq: PYL-W0122
if not isinstance(locls["x"], Mod):
locls["x"] = Mod(locls["x"], self.prime)
locls["x"] = mod(locls["x"], self.prime)
if not isinstance(locls["y"], Mod):
locls["y"] = Mod(locls["y"], self.prime)
locls["y"] = mod(locls["y"], self.prime)
return Point(AffineCoordinateModel(self.model), x=locls["x"], y=locls["y"])

@property
Expand Down Expand Up @@ -314,15 +314,15 @@ def decode_point(self, encoded: bytes) -> Point:
raise ValueError("Encoded point has bad length")
coords = {}
for var in sorted(self.coordinate_model.variables):
coords[var] = Mod(int.from_bytes(data[:coord_len], "big"), self.prime)
coords[var] = mod(int.from_bytes(data[:coord_len], "big"), self.prime)
data = data[coord_len:]
return Point(self.coordinate_model, **coords)
elif encoded[0] in (0x02, 0x03):
if isinstance(self.coordinate_model, AffineCoordinateModel):
data = encoded[1:]
if len(data) != coord_len:
raise ValueError("Encoded point has bad length")
x = Mod(int.from_bytes(data, "big"), self.prime)
x = mod(int.from_bytes(data, "big"), self.prime)
loc = {**self.parameters, "x": x}
rhs = eval(compile(self.model.ysquared, "", mode="eval"), loc) # eval is OK here, skipcq: PYL-W0123
if not rhs.is_residue():
Expand Down
8 changes: 4 additions & 4 deletions pyecsca/ec/divpoly.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import networkx as nx

from pyecsca.ec.curve import EllipticCurve
from pyecsca.ec.mod import Mod
from pyecsca.ec.mod import Mod, mod
from pyecsca.ec.model import ShortWeierstrassModel

has_pari = False
Expand Down Expand Up @@ -90,9 +90,9 @@ def a_invariants(curve: EllipticCurve) -> Tuple[Mod, ...]:
:return: A tuple of 5 a-invariants (a1, a2, a3, a4, a6).
"""
if isinstance(curve.model, ShortWeierstrassModel):
a1 = Mod(0, curve.prime)
a2 = Mod(0, curve.prime)
a3 = Mod(0, curve.prime)
a1 = mod(0, curve.prime)
a2 = mod(0, curve.prime)
a3 = mod(0, curve.prime)
a4 = curve.parameters["a"]
a6 = curve.parameters["b"]
return a1, a2, a3, a4, a6
Expand Down
8 changes: 4 additions & 4 deletions pyecsca/ec/formula/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pyecsca.ec.context import ResultAction
from pyecsca.ec import context
from pyecsca.ec.error import UnsatisfiedAssumptionError, raise_unsatisified_assumption
from pyecsca.ec.mod import Mod, SymbolicMod
from pyecsca.ec.mod import Mod, SymbolicMod, mod
from pyecsca.ec.op import CodeOp, OpType
from pyecsca.misc.cfg import getconfig
from pyecsca.misc.cache import sympify
Expand Down Expand Up @@ -191,7 +191,7 @@ def __validate_assumption_simple(self, lhs, rhs, field, params):
domain = FF(field)
numerator, denominator = expr.as_numer_denom()
val = int(domain.from_sympy(numerator) / domain.from_sympy(denominator))
params[lhs] = Mod(val, field)
params[lhs] = mod(val, field)
_assumption_cache[cache_key] = params[lhs]
return True

Expand All @@ -216,7 +216,7 @@ def __validate_assumption_generic(self, lhs, rhs, field, params, assumption_stri
poly = Poly(numerator, symbols(param), domain=domain)
roots = poly.ground_roots()
for root in roots:
params[param] = Mod(int(domain.from_sympy(root)), field)
params[param] = mod(int(domain.from_sympy(root)), field)
return
raise UnsatisfiedAssumptionError(
f"Unsatisfied assumption in the formula ({assumption_string}).\n"
Expand Down Expand Up @@ -274,7 +274,7 @@ def __call__(self, field: int, *points: Any, **params: Mod) -> Tuple[Any, ...]:
f"Bad stuff happened in op {op}, floats will pollute the results."
)
if not isinstance(op_result, Mod):
op_result = Mod(op_result, field)
op_result = mod(op_result, field)
if context.current is not None:
action.add_operation(op, op_result)
params[op.result] = op_result
Expand Down
6 changes: 3 additions & 3 deletions pyecsca/ec/formula/switch_sign.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pyecsca.ec.formula.base import Formula
from pyecsca.ec.formula.graph import FormulaGraph, ConstantNode, CodeOpNode, CodeFormula
from pyecsca.ec.point import Point
from pyecsca.ec.mod import Mod
from pyecsca.ec.mod import Mod, mod


@public
Expand Down Expand Up @@ -59,7 +59,7 @@ def sign_test(output_signs: Dict[str, int], coordinate_model: Any):
out_var = out[: out.index(ind)]
if not out_var.isalpha():
continue
point_dict[out_var] = Mod(sign, p)
point_dict[out_var] = mod(sign, p)
point = Point(coordinate_model, **point_dict)
try:
apoint = point.to_affine()
Expand All @@ -68,7 +68,7 @@ def sign_test(output_signs: Dict[str, int], coordinate_model: Any):
if scale is None:
raise BadSignSwitch
apoint = scale(p, point)[0]
if set(apoint.coords.values()) != {Mod(1, p)}:
if set(apoint.coords.values()) != {mod(1, p)}:
raise BadSignSwitch


Expand Down
44 changes: 23 additions & 21 deletions pyecsca/ec/mod/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import random
import secrets
from abc import ABC
from functools import lru_cache, wraps

from public import public
Expand Down Expand Up @@ -123,7 +124,7 @@ def __repr__(self):


@public
class Mod:
class Mod(ABC):
"""
An element x of ℤₙ.
Expand All @@ -134,8 +135,8 @@ class Mod:
Has all the usual special methods that upcast integers automatically:
>>> a = Mod(3, 5)
>>> b = Mod(2, 5)
>>> a = mod(3, 5)
>>> b = mod(2, 5)
>>> a + b
0
>>> a * 2
Expand Down Expand Up @@ -163,21 +164,9 @@ class Mod:
n: Any
__slots__ = ("x", "n")

def __new__(cls, *args, **kwargs) -> "Mod":
if cls != Mod:
return cls.__new__(cls, *args, **kwargs)
if not _mod_classes:
raise ValueError("Cannot find any working Mod class.")
selected_class = getconfig().ec.mod_implementation
if selected_class not in _mod_classes:
# Fallback to something
for fallback in _mod_order:
if fallback in _mod_classes:
selected_class = fallback
break
return _mod_classes[selected_class].__new__(
_mod_classes[selected_class], *args, **kwargs
)
def __init__(self, x, n):
self.x = x
self.n = n

@_check
def __add__(self, other) -> "Mod":
Expand Down Expand Up @@ -269,7 +258,7 @@ def random(cls, n: int) -> "Mod":
:return: The random :py:class:`Mod`.
"""
with RandomModAction(n) as action:
return action.exit(cls(secrets.randbelow(n), n))
return action.exit(mod(secrets.randbelow(n), n))

def __pow__(self, n) -> "Mod":
return NotImplemented
Expand All @@ -288,8 +277,7 @@ def __new__(cls, *args, **kwargs):
return object.__new__(cls)

def __init__(self):
self.x = None
self.n = None
super().__init__(None, None)

def __add__(self, other):
return NotImplemented
Expand Down Expand Up @@ -359,3 +347,17 @@ def __hash__(self):

def __pow__(self, n):
return NotImplemented


@public
def mod(x, n) -> Mod:
if not _mod_classes:
raise ValueError("Cannot find any working Mod class.")
selected_class = getconfig().ec.mod_implementation
if selected_class not in _mod_classes:
# Fallback to something
for fallback in _mod_order:
if fallback in _mod_classes:
selected_class = fallback
break
return _mod_classes[selected_class](x, n)
3 changes: 0 additions & 3 deletions pyecsca/ec/mod/flint.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ class FlintMod(Mod):
_ctx: flint.fmpz_mod_ctx
__slots__ = ("x", "_ctx")

def __new__(cls, *args, **kwargs):
return object.__new__(cls)

def __init__(
self,
x: Union[int, flint.fmpz_mod],
Expand Down
3 changes: 0 additions & 3 deletions pyecsca/ec/mod/gmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ class GMPMod(Mod):
n: gmpy2.mpz
__slots__ = ("x", "n")

def __new__(cls, *args, **kwargs):
return object.__new__(cls)

def __init__(
self,
x: Union[int, gmpy2.mpz],
Expand Down
5 changes: 1 addition & 4 deletions pyecsca/ec/mod/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ class RawMod(Mod):
n: int
__slots__ = ("x", "n")

def __new__(cls, *args, **kwargs):
return object.__new__(cls)

def __init__(self, x: int, n: int):
self.x = x % n
self.n = n
Expand Down Expand Up @@ -115,6 +112,6 @@ def __pow__(self, n) -> "RawMod":
return RawMod(pow(self.x, n, self.n), self.n)


from pyecsca.ec.mod.base import _mod_classes
from pyecsca.ec.mod.base import _mod_classes # noqa

_mod_classes["python"] = RawMod
5 changes: 1 addition & 4 deletions pyecsca/ec/mod/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ class SymbolicMod(Mod):
n: int
__slots__ = ("x", "n")

def __new__(cls, *args, **kwargs):
return object.__new__(cls)

def __init__(self, x: Expr, n: int):
self.x = x
self.n = n
Expand Down Expand Up @@ -117,6 +114,6 @@ def __pow__(self, n) -> "SymbolicMod":
return self.__class__(pow(self.x, n), self.n)


from pyecsca.ec.mod.base import _mod_classes
from pyecsca.ec.mod.base import _mod_classes # noqa

_mod_classes["symbolic"] = SymbolicMod
14 changes: 9 additions & 5 deletions pyecsca/ec/mod/test.pyx
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import cython
cdef class Test:
def __init__(self):
print("here")

cdef class SubTest(Test):
def __init__(self):
print("sub init")

@cython.cclass
class Test:
def __new__(cls, *args, **kwargs):
pass
cdef class OtherTest(Test):
def __init__(self):
print("other init")
12 changes: 6 additions & 6 deletions pyecsca/ec/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pyecsca.ec.coordinates import AffineCoordinateModel, CoordinateModel
from pyecsca.ec.curve import EllipticCurve
from pyecsca.ec.error import raise_unsatisified_assumption
from pyecsca.ec.mod import Mod
from pyecsca.ec.mod import Mod, mod
from pyecsca.ec.model import (
CurveModel,
ShortWeierstrassModel,
Expand Down Expand Up @@ -189,7 +189,7 @@ def _create_params(curve, coords, infty):
else:
raise ValueError("Unknown curve model.")
params = {
name: Mod(int(curve["params"][name]["raw"], 16), field) for name in param_names
name: mod(int(curve["params"][name]["raw"], 16), field) for name in param_names
}

# Check coordinate model name and assumptions
Expand Down Expand Up @@ -233,7 +233,7 @@ def _create_params(curve, coords, infty):
poly = Poly(numerator, symbols(param), domain=k)
roots = poly.ground_roots()
for root in roots:
params[param] = Mod(int(k.from_sympy(root)), field)
params[param] = mod(int(k.from_sympy(root)), field)
break
else:
raise_unsatisified_assumption(
Expand All @@ -258,16 +258,16 @@ def _create_params(curve, coords, infty):
)
value = ilocals[coordinate]
if isinstance(value, int):
value = Mod(value, field)
value = mod(value, field)
infinity_coords[coordinate] = value
infinity = Point(coord_model, **infinity_coords)
elliptic_curve = EllipticCurve(model, coord_model, field, infinity, params) # type: ignore[arg-type]
if "generator" not in curve:
raise ValueError("Cannot construct curve, missing generator.")
affine = Point(
AffineCoordinateModel(model),
x=Mod(int(curve["generator"]["x"]["raw"], 16), field),
y=Mod(int(curve["generator"]["y"]["raw"], 16), field),
x=mod(int(curve["generator"]["x"]["raw"], 16), field),
y=mod(int(curve["generator"]["y"]["raw"], 16), field),
)
if not isinstance(coord_model, AffineCoordinateModel):
generator = affine.to_model(coord_model, elliptic_curve)
Expand Down
4 changes: 2 additions & 2 deletions pyecsca/ec/point.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pyecsca.ec.context import ResultAction
from pyecsca.ec.coordinates import AffineCoordinateModel, CoordinateModel
from pyecsca.ec.mod import Mod, Undefined
from pyecsca.ec.mod import Mod, Undefined, mod
from pyecsca.ec.op import CodeOp


Expand Down Expand Up @@ -131,7 +131,7 @@ def to_model(
for var in coordinate_model.variables:
if var in locls:
result[var] = (
Mod(locls[var], curve.prime)
mod(locls[var], curve.prime)
if not isinstance(locls[var], Mod)
else locls[var]
)
Expand Down
Loading

0 comments on commit 06e005a

Please sign in to comment.