In [1]:
import functools
import itertools
import operator
import pickle

### Retrieve prime factorizations for each $n$, with $1\leq n\leq 10^7$, pre-computed using Pollard's Rho algorithm  

In [2]:
with open('..\\Computation Caches\\prime_factorizations_1_10000000.pkl', 'rb') as file:
    prime_factorization = pickle.load(file)

### Define functions for the Extended Euclidean algorithm, the Chinese Remainder Theorem, and the lifting of roots modulo $p^k$ to $p^{k+1}$ by Hensel's Lemma 

In [3]:
# Compute modular inverse of a (mod b) via extended_euclidean_algorithm(a, b)[1] % b
def extended_euclidean_algorithm(a, b):
    if a == 0:
        return b, 0, 1
    else:
        gcd, x, y = extended_euclidean_algorithm(b % a, a)
        return gcd, y - (b // a) * x, x

# Pg. 22 of http://www.personal.psu.edu/rcv4/CENT.pdf
def chinese_remainder_theorem(linear_congruences):
    # Solves the system of linear congruences x = c_i (mod m_i)
    M = functools.reduce(operator.mul, [m_i for c_i, m_i in linear_congruences])
    result = 0
    for c_i, m_i in linear_congruences:
        M_i = M // m_i
        GCD, x, y = extended_euclidean_algorithm(m_i, M_i)
        result += c_i * y * M_i
    return result % M, M

# Hensel's Lemma (solutions to polynomial(x) ≡ 0 (mod p^k)): https://github.com/p4-team/crypto-commons/blob/master/crypto_commons/rsa/rsa_commons.py
def lift(f, df, p, k, previous):
    result = []
    for lower_solution in previous:
        dfr = df(lower_solution)
        fr = f(lower_solution)
        if dfr % p != 0:
            t = (-(extended_euclidean_algorithm(dfr, p)[1]) * (fr // p ** (k - 1))) % p
            result.append(lower_solution + t * p ** (k - 1))
        if dfr % p == 0:
            if fr % p ** k == 0:
                for t in range(0, p):
                    result.append(lower_solution + t * p ** (k - 1))
    return result

def hensel_lifting(f, df, p, k, base_solution: int | list = None):
    """
    Calculate solutions to f(x) = 0 (mod p^k) for prime p, where f is a polynomial function in x
    :param f: function
    :param df: derivative of function
    :param p: prime number
    :param k: power of prime
    :param base_solution: solution to return for k = 1 (assumed to be [1, p - 1] if not supplied)
    :return: possible solutions to f(x) = 0 mod p^k
    """
    if base_solution is None:
        solution = [1, p - 1]
    elif type(base_solution) is list:
        solution = base_solution
    else:
        solution = [base_solution]
    for i in range(2, k + 1):
        solution = lift(f, df, p, i, solution)
    return list(set(solution))

### For our use case of the Hensel's Lemma, $f = x^2-x$ and $\partial f_x=2x-1$ because we are solving $a^2\equiv a\,(\text{mod }n)$

In [4]:
h, dh = lambda x: x * x - x, lambda x: 2 * x - 1

### Compute $M(n)$ for $n=\prod p_i^{\alpha_i}$ by first solving the congruences $\{x_i^2\equiv x_i\,(\text{mod }p_i^{\alpha_i}){\}$ individually using Hensel's Lemma, and then combining the resultant solutions for all $i$ via the Chinese Remainder Theorem to deduce possible values for $M(n)$ and choosing the maximum of them in $\mathbb{Z}_n$
- Note that for each $n$, the base solutions are $a\equiv 0,\,1$ since they satisfy $a^2\equiv a\,(\text{mod }n)$
- Algorithm can be further sped up by [realizing](https://math.stackexchange.com/a/1661780) that $M(p^k)=1$ for a prime $p$ and $k\in\mathbb{N}$ 

In [7]:
def M(n):
    # M(1) = 0 because 0^2 ≡ 0 ≡ 1 (mod 1)
    if n == 1:
        return 0
    hensel_lemma_solutions = []
    for p in prime_factorization[n]:
        roots = hensel_lifting(h, dh, p, prime_factorization[n][p], [0, 1])
        if len(roots) == 0:
            return 1
        hensel_lemma_solutions.append((p, prime_factorization[n][p], roots))
    congruence_solutions = [chinese_remainder_theorem([(congruence_system[i], (lambda p, k: p ** k)(*list(prime_factorization[n].items())[i])) for i in range(len(congruence_system))]) for congruence_system in itertools.product(*[roots for p, k, roots in hensel_lemma_solutions])]
    return max([x for x, M in congruence_solutions]) if len(congruence_solutions) != 0 else 1

N = 10 ** 7
s = 0
for n in range(1, N + 1):
    s += M(n)
    
    # Display progress of function execution for large values of N
    if n % (N / 100) == 0:
        print('Progress:', n // (N / 100), '%', end='\r', flush=True)
        
s

Progress: 100.0 %

39782849136421