# Muller's square root algorithm for q=1 (mod 4)

In [1]:
import timeit

In [25]:
#Find prime p which is 100 bits
#In order to find such p, I use Miller-Rabin primality test.
#In Muller's algorithm, we use p as 1 mod 4, so I add this condition in the end.

# Python3 program Miller-Rabin primality test
import random


# This function is called for all k trials. 
# It returns false if n is composite and returns false if n is probably prime.
def miillerTest(d, n):

    a = 2 + random.randint(1, n - 4)

    x = pow(a, d, n)

    if (x == 1 or x == n - 1):
        return True

    while (d != n - 1):
        x = (x * x) % n
        d *= 2

        if (x == 1):
            return False
        if (x == n - 1):
            return True

    return False;

# It returns false if n is composite and returns true if n is probably prime. 
# k is an input parameter that determines accuracy level. Higher value of k indicates more accuracy.
def isPrime( n, k):

    if (n <= 1 or n == 4):
        return False
    if (n <= 3):
        return True

    d = n - 1
    while (d % 2 == 0):
        d //= 2
    
    for i in range(k):
        if (miillerTest(d, n) == False):
            return False

    return True


# Driver Code
k = 4 # Number of iterations

while 1:
    p = random.randint(pow(2,100),pow(2,101))
    if (isPrime(p, k) and (p%4==1)):
        print("2^100 bit random prime number p =",p)
        break

2^100 bit random prime number p = 1791611784636012329615370947573


In [26]:
print(p%4)
n=625
print("Our sample p and n, respectively :",p,n)

1
Our sample p and n, respectively : 1791611784636012329615370947573 625


In [27]:
#Takes quadratic residue n and odd prime p=1 (mod 4)
#Returns both square roots of n modulo p as a pair (a,b)
#Returns () if no root

def muller(n,p):
    n %= p
    
    #Step1
    if(n==4):
        return (2,-2)
    if(n == 0 or n == 1):
        return (n,-n%p)
    
    #Step2
    phi=p-1
    temp = pow(n-4, int(phi//2), p)
    if (temp==-1) :
        t=1
    elif(temp==1) :
        t=1
        while (pow(n*t*t-4, int(phi//2), p)==1) :
            t = (t+1)%p #it stops when t is quadratic-nonresidue
    
    #Step3
    largep = (n*t*t-2)%p
    
    #Step4
    v = largep%p
    w = (largep*largep-2)%p
    m = bin(int(phi//4))
    l = len(m)
    for i in range(3,l-1) :
        if(m[i]=='1'):
            v = (v*w - largep)%p
            w = (w*w-2)%p
        else:
            w = (v*w-largep)%p
            v = (v*v-2)%p
            
    #Step5
    w1 = (v*w-largep)%p
    w2 = (v*v-2)%p
    if(m[l-1]=='1'):
        return (int(w1//t)%p, (-int(w1//t))%p)
    else:
        return (int(w2//t)%p, (-int(w2//t))%p)
    

print ("Roots of", n, "mod", p,":" + str(muller(n,p)))


Roots of 625 mod 1791611784636012329615370947573 :(25, 1791611784636012329615370947548)


In [31]:
t1 = timeit.timeit(str(muller(n,p)),setup='pass', number=1)
print("cpu time for executing Muller's algorithm :", t1)

cpu time for executing Muller's algorithm : 3.999998625658918e-07
