In [9]:
import math

In [10]:
def windowed_approximate_modular_exponentiation(
        g: int,
        Q_e: int,
        N: int,
        P: list[int],
        f: int,
        ne: int
) -> int:
    """Computes an approximate modular exponentiation.

    Args:
        g: The base of the exponentiation.
        Q_e: The exponent.
        N: The modulus.
        P: Small primes for the residue arithmetic.
        f: Kept bits during approximate accumulation.

    Returns:
        An approximation of pow(g, Q_e, N).

        The modular deviation of the approximation is
        at most 3 * Q_e.bit_length() / 2**f .
    """
    L = math.prod(P)
    print(f"L has {L.bit_length()} bits")
    ell = max(p.bit_length() for p in P)
    m = Q_e.bit_length()
    t = N.bit_length() - f
    assert L == math.lcm(*P) and L >= N ** m and (L % N) < (N >> f)
    we = m//ne
    print(f"bit length is {m}")
    Q_total = 1
    Q_inter_lst = []
    Q_ref_lst = []
    for l in range(ne):
        Q_inter = 0
        for p in P:
            Q_residue = 1
            for k in range(we):
                precomputed = pow(g, 1 << (k + l*we), N) % p

                # controlled inplace modular multiplication:
                if Q_e & (1 << k + l*we):
                    Q_residue *= precomputed
                    Q_residue %= p

            u = (L // p) * pow(L // p, -1, p)
            for k in range(ell):
                precomputed = (((u << k) % L % N) >> t) % (N >> t)

                # controlled inplace modular addition:
                if Q_residue & (1 << k):
                    Q_inter += precomputed
                    Q_inter %= N >> t
        if l == 0:
            Q_ref = Q_e - ((Q_e >> we)<< we)
        else:
            Q_ref = (Q_e >> we)<<we
        Q_ref = pow(g, Q_ref, N)
        print(Q_ref - (Q_inter<<t))
        error = (Q_ref - (Q_inter<<t))
        deviation = min(error % N, -error % N)
        print(f"loop {l} has relative error {deviation/N}")
        Q_inter_lst.append(Q_inter<<t)
        Q_ref_lst.append(Q_ref)
        Q_total *= Q_inter
        Q_total %= N >> t
    return Q_total << t, Q_inter_lst, Q_ref_lst


In [11]:
def approximate_modular_exponentiation(
        g: int,
        Q_e: int,
        N: int,
        P: list[int],
        f: int,
) -> int:
    """Computes an approximate modular exponentiation.

    Args:
        g: The base of the exponentiation.
        Q_e: The exponent.
        N: The modulus.
        P: Small primes for the residue arithmetic.
        f: Kept bits during approximate accumulation.

    Returns:
        An approximation of pow(g, Q_e, N).

        The modular deviation of the approximation is
        at most 3 * Q_e.bit_length() / 2**f .
    """
    L = math.prod(P)
    print(f"L has {L.bit_length()} bits")
    ell = max(p.bit_length() for p in P)
    m = Q_e.bit_length()
    t = N.bit_length() - f
    assert L == math.lcm(*P) and L >= N ** m and (L % N) < (N >> f)

    Q_total = 0
    for p in P:
        Q_residue = 1
        for k in range(Q_e.bit_length()):
            precomputed = pow(g, 1 << k, N) % p

            # controlled inplace modular multiplication:
            if Q_e & (1 << k):
                Q_residue *= precomputed
                Q_residue %= p

        u = (L // p) * pow(L // p, -1, p)

        Q_inter = 0
        for k in range(ell):
            precomputed = (((u << k) % L % N) >> t) % (N >> t)
            # controlled inplace modular addition:
            if Q_residue & (1 << k):
                Q_inter += precomputed
                Q_inter %= N >> t
                # Q_total += precomputed
                # Q_total %= N >> t
        Q_total += Q_inter
        Q_ref = (pow(g, Q_e, p) * u) % L % N
        print((pow(g, Q_e, p) - Q_residue)/Q_residue, "test")
        Q_inter = Q_inter << t
        error = (Q_ref - (Q_inter))
        deviation = min(error % N, -error % N)
        print(f"loop {1} has relative error {deviation/N}")
    return Q_total << t

In [12]:


def test_approximate_modular_exponentiation():
    import random
    import sympy

    # RSA100 challenge number
    N = int(
        "15226050279225333605356183781326374297180681149613"
        "80688657908494580122963258952897654000350692006139"
    )
    g = random.randrange(2, N)
    Q_e = 1
    while (Q_e.bit_length()%2 != 0):
        Q_e = random.randrange(2**100)  # small exponent for quicker test
    print(Q_e.bit_length(), "hi")
    P = [
        *sympy.primerange(239382, 2**18),
        131101,
        131111,
        131113,
        131129,
        131143,
        131149,
        131947,
        182341,
        239333,
        239347,
    ]
    print(f"there are {len(P)} primes")
    f = 24
    ne = 2
    result, Q_inter_lst, Q_ref_lst = windowed_approximate_modular_exponentiation(g, Q_e, N, P, f, ne)
    # print(f"there are {len(P)} primes")
    # result = approximate_modular_exponentiation(g, Q_e, N, P, f)
    error = result - pow(g, Q_e, N)
    deviation = min(error % N, -error % N) / N
    ell = max(p.bit_length() for p in P)
    print(deviation)
    # assert deviation <= 3 * ne * len(P) * ell / 2**f
    return Q_inter_lst, Q_ref_lst,Q_e, g

In [13]:
Q_inter_lst, Q_ref_lst,Q_e, g = test_approximate_modular_exponentiation()

100 hi
there are 1841 primes
L has 33011 bits
bit length is 100
-268798923431390192759053929913641667364162336602478543691904807991984640949471241367899338823566
loop 0 has relative error 0.00017653883870207872
-274322875630384817868136695889517148871761437660305653665878772346962784740412588626457108841940
loop 1 has relative error 0.0001801668000562663
0.10402339177372584


In [14]:
N = int(
        "15226050279225333605356183781326374297180681149613"
        "80688657908494580122963258952897654000350692006139"
    )
print(Q_inter_lst)
print(Q_ref_lst)

[372992825038614360924945030024136292006428058397214651354400919376292359202348810441283882308337664, 1273192854074318725591891926078233604327570684345273884379761948966576333132393770399743476626358272]
[372724026115182970732185976094222650339063896060612172810709014568300374561399339199915982969514098, 1272918531198688340774023789382344087178698922907613578726096070194229370347653357811117019517516332]


In [15]:
print((Q_inter_lst[0]*Q_inter_lst[1] -Q_ref_lst[0]*Q_ref_lst[1])/N)

2.9192048299132614e+95


In [16]:
print((Q_ref_lst[0]*Q_ref_lst[1])%N - pow(g, Q_e, N))
print((Q_inter_lst[0]*Q_inter_lst[1])%N - pow(g, Q_e, N))
print(pow(g, Q_e, N)-(Q_inter_lst[0]*Q_ref_lst[1])%N
      )

print(((Q_ref_lst[0] - Q_inter_lst[0]) * Q_ref_lst[1])/N)

0
588156653704867734532369500377418533534142523234631447175943462571380343590336953085394866271397561
-19529156813001142660180776559716230085016030968857857051690494206545106656780838784849080937002112
-2.247195592601722e+95
