# Summary

We present an algorithm that takes as input an arbitrarily long sequence of positive integers $a_1,a_2,\ldots,a_\ell$ and a positive integer $m$ and computes
$$a_1^{a_2^{\cdot^{\cdot^{a_\ell}}}}\bmod m$$
efficiently (that is, without computing the value of the nested exponent).

# Notation

For convenience, we define an operator $E$ as a shorthand for nested exponentiation.

---

> **Definition.** Given a tuple of $\ell$ positive integers $(a_1,a_2,\ldots,a_\ell)$, define the operator $E$ recursively as follows.
$$E(a_1,a_2,\ldots,a_\ell)=\begin{cases}1&\ell=0\\a_1^{E(a_2,\ldots,a_\ell)}&\ell\gt0\end{cases}$$
We call $a_1$ the **base** and $E(a_2,\ldots,a_\ell)$ the **exponent** of $E(a_1,a_2,\ldots,a_\ell)$.

---

We are interested in computing $E(a_1,a_2,\ldots,a_\ell)\bmod m$.

# Preliminaries

We first install the prerequisites for `gmpy2`.

In [1]:
!python -V

Python 3.9.1+


In [2]:
!apt install libgmp-dev libmpfr-dev libmpc-dev &> /dev/null # for gmpy2

In [3]:
!pip install -U gmpy2 sympy mod-nest-exp

Requirement already up-to-date: gmpy2 in ./env/lib/python3.9/site-packages (2.0.8)
Requirement already up-to-date: modular-towers in ./env/lib/python3.9/site-packages (0.1.4)


# The algorithm

## `pow_lt`

`pow_lt` takes as input a sequence of positive integers $e_1,e_2,\ldots,e_\ell$ and a positive number $k$ and returns `True` iff $E(e_1,e_2,\ldots,e_\ell)\lt k$.

In [4]:
from decimal import Decimal
from math import ceil

def pow_lt(seq, k):
    if not len(seq): # if len(seq) == 0
        return 1 < k

    def _pow_lt(seq, k):
        if len(seq) == 1 or seq[0] == 1:
            return seq[0] < k
        if seq[1]*(seq[0].bit_length()-1) >= ceil(k).bit_length():
            return False
        l = Decimal(k).ln()/Decimal(seq[0]).ln() # high precision logarithm
        return _pow_lt(seq[1:], l) if l > 1 else False

    return _pow_lt(seq, k)

## `pow_list`

`pow_list` takes as input a sequence of numbers $e_1,e_2,\ldots,e_\ell$ and returns the value of $E(e_1,e_2,\ldots,e_\ell)$.

In [5]:
def pow_list(seq):
    l = len(seq)
    if not l: # if len(seq) == 0
        return 1
    elif l == 1: # if len(seq) == 1
        return seq[0]

    def _pow_list(seq):
        if seq[0] == 1:
            return 1
        if len(seq) == 2:
            return seq[0]**seq[1]
        return seq[0]**_pow_list(seq[1:])
    
    return _pow_list(seq)

## The main function

`mod_tower` takes as input a sequence of positive integers $a_1,a_2,\ldots,a_\ell$ and a positive integer $m$ and returns $E(a_1,a_2,\ldots,a_\ell)\bmod m$.

In [6]:
from gmpy2 import gcd, powmod, gcdext as ext_gcd
from sympy.ntheory import totient

def mod_nest_exp(seq, m):
    if m == 1: # 1 divides every integer
        return 0
    l = len(seq)
    if not l: # if len(seq) == 0
        return 1%m
    elif l == 1: # if len(seq) == 1
        return seq[0]%m

    def _mod_nest_exp(seq, m):
        if m == 1: # 1 divides every integer
            return 0
        if len(seq) == 2: # recursive base case
            return powmod(seq[0], seq[1], m)
        
        b, e = seq[0], seq[1:] # base and exponent
        g = gcd(b, m)
        if g == 1:
            return powmod(b, _mod_nest_exp(e, totient(m)), m)
        
        n, k = m//g, 1
        g_ = gcd(g, n)
        while g_ > 1:
            n //= g_
            k += 1
            g_ = gcd(g, n)
        h = m//n
        _, x, y = ext_gcd(n, h)
        return (h*(y%n)*powmod(b, _mod_nest_exp(e, totient(n)), n)+
                n*(x%h)*(powmod(b, pow_list(e), h) if pow_lt(e, k) else 0))%m
    
    return _mod_nest_exp(seq, m)

In [7]:
from mod_nest_exp.core.tests import test_core

test_core(
    list_lengths=(10, 100, 1000),
    bit_lengths=(16, 128, 1024),
    mod_bit_lengths=(16, 32, 64)
)

Running mod_tower on sequences of l pseudorandom b-bit positive integers over a B-bit modulus (1000 runs per table entry)
                            sequence length l                   
                  10               100               1000      
          ----------------- ----------------- -----------------
  B     b     mean    stdev     mean    stdev     mean    stdev
-----------------------------------------------------------------
       16 |   0.25     0.16     0.24     0.15     0.31     0.17  
 16   128 |   0.19     0.11     0.21     0.12     0.33     0.21  
     1024 |   0.24     0.16     0.27     0.17     0.39     0.38  
-----------------------------------------------------------------
       16 |   0.75     0.59     0.72     1.02     0.85     0.62  
 32   128 |   0.75     0.61     0.78     0.67     0.83     0.56  
     1024 |   0.80     0.70     0.90     0.95     0.93     0.61  
-----------------------------------------------------------------
       16 |  19.25    82.73