In [64]:
import copy
from math import gcd

In [65]:
def trial_division(n):
    d = 2
    while d <= n**0.5:
        if n % d == 0:
            return False
        d += 1
    return True

In [66]:
def prime_factors(n):
    i = 2
    result = []
    while i * i <= n:
        count = 0
        while n % i == 0:
            count += 1
            result.append((i, count))
            n = n // i
        i = i + 1
    if n > 1:
        result.append((n,1))
    return result

In [67]:
def v(q, t):
    ans = 0
    while(t % q == 0):
        ans +=1
        t//= q
    return ans

In [68]:
def e(t):
    s = 1
    q_list = []
    for q in range(2, t+2):
        if t%(q-1) == 0 and trial_division(q):
            s *= q ** (1+v(q,t))
            q_list.append(q)
    return 2*s, q_list

In [69]:
class JacobiSum:
    def __init__(self, p, k, q):
        self.p = p
        self.k = k
        self.q = q
        self.m = (p-1)*p**(k-1)
        self.pk = p**k
        self.coef = [0]*self.m
        

    def one(self):
        self.coef[0] = 1
        for i in range(1,self.m):
            self.coef[i] = 0
        return self
    

    def mul(self, jac):
        m = self.m
        pk = self.pk
        j_ret=JacobiSum(self.p, self.k, self.q)
        for i in range(m):
            for j in range(m):
                if (i+j)% pk < m:
                    j_ret.coef[(i+j)% pk] += self.coef[i] * jac.coef[j]
                else:
                    r = (i+j) % pk - self.p ** (self.k-1)                    
                    while r>=0:
                        j_ret.coef[r] -= self.coef[i] * jac.coef[j]
                        r-= self.p ** (self.k-1)
        return j_ret


    def __mul__(self, right):
        if type(right) is int:
            j_ret=JacobiSum(self.p, self.k, self.q)
            for i in range(self.m):
                j_ret.coef[i] = self.coef[i] * right
            return j_ret
        else:
            return self.mul(right)
        
    
    def modpow(self, x, n):
        j_ret=JacobiSum(self.p, self.k, self.q)
        j_ret.coef[0]=1
        j_a = copy.deepcopy(self)
        while x>0:
            if x%2==1:
                j_ret = (j_ret * j_a).mod(n)
            j_a = j_a*j_a
            j_a.mod(n)
            x //= 2
        return j_ret
    
    
    def mod(self, n):
        for i in range(self.m):
            self.coef[i] %= n
        return self
    

    def sigma(self, x):
        m = self.m
        pk = self.pk
        j_ret=JacobiSum(self.p, self.k, self.q)
        for i in range(m):
            if (i*x) % pk < m:
                j_ret.coef[(i*x) % pk] += self.coef[i]
            else:
                r = (i*x) % pk - self.p ** (self.k-1)                    
                while r>=0:
                    j_ret.coef[r] -= self.coef[i]
                    r-= self.p ** (self.k-1)
        return j_ret
    
                
    def sigma_inv(self, x):
        m = self.m
        pk = self.pk
        j_ret=JacobiSum(self.p, self.k, self.q)
        for i in range(pk):
            if i<m:
                if (i*x)%pk < m:
                    j_ret.coef[i] += self.coef[(i*x)%pk]
            else:
                r = i - self.p ** (self.k-1)
                while r>=0:
                    if (i*x)%pk < m:
                        j_ret.coef[r] -= self.coef[(i*x)%pk]
                    r-= self.p ** (self.k-1)
        return j_ret
    

    def is_root_of_unity(self, N):
        m = self.m
        p = self.p
        k = self.k
        one = 0
        for i in range(m):
            if self.coef[i]==1:
                one += 1
                h = i
            elif self.coef[i] == 0:
                continue
            elif (self.coef[i] - (-1)) %N != 0:
                return False, None
        if one == 1:
            return True, h
        for i in range(m):
            if self.coef[i]!=0:
                break
        r = i % (p**(k-1))
        for i in range(m):
            if i % (p**(k-1)) == r:
                if (self.coef[i] - (-1))%N != 0:
                    return False, None
            else:
                if self.coef[i] !=0:
                    return False, None
        return True, (p-1)*p**(k-1)+ r

In [70]:
def smallest_primitive_root(q):
    for r in range(2, q):
        s = set({})
        m = 1
        for i in range(1, q):
            m = (m*r) % q
            s.add(m)
        if len(s) == q-1:
            return r
    return None

In [71]:
def calc_f(q):
    g = smallest_primitive_root(q)
    m = {}
    for x in range(1,q-1):
        m[pow(g,x,q)] = x
    f = {}
    for x in range(1,q-1):
        f[x] = m[ (1-pow(g,x,q))%q ]
    return f

In [72]:
def calc_J_ab(p, k, q, a, b):
    j_ret = JacobiSum(p,k,q)
    f = calc_f(q)
    for x in range(1,q-1):
        pk = p**k
        if (a*x+b*f[x]) % pk < j_ret.m:
            j_ret.coef[(a*x+b*f[x]) % pk] += 1
        else:
            r = (a*x+b*f[x]) % pk - p**(k-1)
            while r>=0:
                j_ret.coef[r] -= 1
                r-= p**(k-1)
    return j_ret


def calc_J(p, k, q):
    return calc_J_ab(p, k, q, 1, 1)
       

def calc_J3(p, k, q):
    j2q = calc_J(p, k, q)
    j21 = calc_J_ab(p, k, q, 2, 1)
    j_ret = j2q * j21
    return j_ret

def calc_J2(p, k, q):
    j31 = calc_J_ab(2, 3, q, 3, 1)
    j_conv = JacobiSum(p, k, q)
    for i in range(j31.m):
        j_conv.coef[i*(p**k)//8] = j31.coef[i]
    j_ret = j_conv * j_conv
    return j_ret

In [73]:
def APRtest_step4a(p, k, q, N):    
    J = calc_J(p, k, q)
    s1 = JacobiSum(p,k,q).one()
    for x in range(p**k):
        if x % p == 0:
            continue
        t = J.sigma_inv(x)
        t = t.modpow(x, N)
        s1 = s1 * t
        s1.mod(N)

    r = N % (p**k)
    s2 = s1.modpow(N//(p**k), N)
    J_alpha = JacobiSum(p,k,q).one()
    for x in range(p**k):
        if x % p == 0:
            continue
        t = J.sigma_inv(x)
        t = t.modpow((r*x)//(p**k), N)
        J_alpha = J_alpha * t
        J_alpha.mod(N)

    S = (s2 * J_alpha).mod(N)
    exist, h = S.is_root_of_unity(N)

    if not exist:
        return False, None
    else:
        if h%p!=0:
            l_p = 1
        else:
            l_p = 0
        return True, l_p

In [74]:
def APRtest_step4b(p, k, q, N):
    J = calc_J3(p, k, q)
    s1 = JacobiSum(p,k,q).one()
    for x in range(p**k):
        if x % 8 not in [1,3]:
            continue
        t = J.sigma_inv(x)
        t = t.modpow(x, N)
        s1 = s1 * t
        s1.mod(N)

    r = N % (p**k)
    s2 = s1.modpow(N//(p**k), N)

    J_alpha = JacobiSum(p,k,q).one()
    for x in range(p**k):
        if x % 8 not in [1,3]:
            continue
        t = J.sigma_inv(x)
        t = t.modpow((r*x)//(p**k), N)
        J_alpha = J_alpha * t
        J_alpha.mod(N)

    if N%8 in [1,3]:
        S = (s2 * J_alpha ).mod(N)
    else:
        J2_delta = calc_J2(p,k,q)
        S = (s2 * J_alpha * J2_delta).mod(N)

    exist, h = S.is_root_of_unity(N)

    if not exist:
        return False, None
    else:
        if h%p!=0 and (pow(q,(N-1)//2,N) + 1)%N==0:
            l_p = 1
        else:
            l_p = 0
        return True, l_p

In [75]:
def APRtest_step4c(p, k, q, N):
    J2q = calc_J(p, k, q)
    s1 = (J2q * J2q * q).mod(N)
    s2 = s1.modpow(N//4, N)

    if N%4 == 1:
        S = s2
    elif N%4 == 3:
        S = (s2 * J2q * J2q).mod(N)

    exist, h = S.is_root_of_unity(N)

    if not exist:
        return False, None
    else:
        if h%p!=0 and (pow(q,(N-1)//2,N) + 1)%N==0:
            l_p = 1
        else:
            l_p = 0
        return True, l_p

In [76]:
def APRtest_step4d(p, k, q, N):
    S2q = pow(-q, (N-1)//2, N)
    if (S2q-1)%N != 0 and (S2q+1)%N != 0:
        return False, None
    else:
        if (S2q + 1)%N == 0 and (N-1)%4==0:
            l_p=1
        else:
            l_p=0
        return True, l_p

In [77]:
def APRtest_step4(p, k, q, N):
    if p>=3:
        result, l_p = APRtest_step4a(p, k, q, N)
    elif p==2 and k>=3:
        result, l_p = APRtest_step4b(p, k, q, N)
    elif p==2 and k==2:
        result, l_p = APRtest_step4c(p, k, q, N)
    elif p==2 and k==1:
        result, l_p = APRtest_step4d(p, k, q, N)
    return result, l_p

In [78]:
def APRtest(N):
    t_list = [2, 12, 60, 180, 840, 1260, 1680, 2520, 5040, 15120, 55440, 110880, 720720, 1441440, 4324320, 24504480, 73513440]
    if N==1:
        return False
    if N==2 or N==3:
        return True
    for t in t_list:
        et, q_list = e(t)
        if N < et*et:
            break
    else:
        return False
    
    # Step 1
    g = gcd(t*et, N)
    if g != 1:
        return False

    # Step 2
    l = {}
    fac_t = prime_factors(t)
    for p, k in fac_t:
        if p>=3 and pow(N, p-1, p*p)!=1:
            l[p] = 1
        else:
            l[p] = 0
            
    # Step 3 & Step 4
    for q in q_list:
        if q == 2:
            continue
        fac = prime_factors(q-1)
        for p,k in fac:

            # Step 4
            result, l_p = APRtest_step4(p, k, q, N)
            if not result:
                return False
            elif l_p==1:
                l[p] = 1

    # Step 5
    for p, value in l.items():
        if value==0:
            count = 0
            i = 1
            found = False
            while count < 30:
                q = p*i+1
                if N%q != 0 and trial_division(q) and (q not in q_list):
                    count += 1
                    k = v(p, q-1)
                    # Step 4
                    result, l_p = APRtest_step4(p, k, q, N)
                    if not result:
                        return False
                    elif l_p == 1:
                        found = True
                        break
                i += 1
            if not found:
                return False

    # Step 6
    r = 1
    for t in range(t-1):
        r = (r*N) % et
        if r!=1 and r!= N and N % r == 0:
            return False
    return True