# A refinement of H.C. Williams' qth root algorithm

In [1]:
from random import *
import math
from sympy import *
import random

In [2]:
# Python3 program Miller-Rabin primality test


# This function is called for all k trials. 
# It returns false if n is composite and returns false if n is probably prime.
def miillerTest(d, n):

    a = 2 + random.randint(1, n - 4)

    x = pow(a, d, n)

    if (x == 1 or x == n - 1):
        return True

    while (d != n - 1):
        x = (x * x) % n
        d *= 2

        if (x == 1):
            return False
        if (x == n - 1):
            return True

    return False;

# It returns false if n is composite and returns true if n is probably prime. 
# k is an input parameter that determines accuracy level. Higher value of k indicates more accuracy.
def isPrime( n, k):

    if (n <= 1 or n == 4):
        return False
    if (n <= 3):
        return True

    d = n - 1
    while (d % 2 == 0):
        d //= 2
    
    for i in range(k):
        if (miillerTest(d, n) == False):
            return False

    return True


# Driver Code
k = 4 # Number of iterations


In [3]:
def random_p(q):
    #randomly choose p using Miller-Rabin primality test
    while 1:
        p = random.randint(pow(2,10),pow(2,11))
        if (isPrime(p, k) and (p%q==1)):
            print("10 bit random prime number p =",p)
            break
    return p

In [4]:
def compute_E1(a,b,c,q,p,theta):
    print("="*50)
    print("<Computing E1>")
    
    E1=1
    for i in range(0, q-1) :
        print("i = ", i)
        #if exponent is negative, for exmaple : -n, dealing it as (-1)*n
        #(some poly)^(-n) => ((some poly)^(-1))^(n) => (changed poly)^n
        if(pow(-1,q-i) < 0):
            #일단 (poly)^(-1)부터 계산해줌
            total=0
            for j in range(0,q) :
                total = total+pow(b,q-1-j)*pow(pow(c,i),j)*pow(theta,j)
                total = expand(total)
            t = expand(pow(a,p-2)*total)
            
            #다음 n승
            X1 = pow(t,math.comb(q-2,i))
        else:
            X1 = pow((b-pow(c,i)*theta),pow(-1,q-i)*math.comb(q-2,i))
            
        X1 = expand(X1)
        
        #If the exponent of theta is larger than q, substitute theta**q as pow(b,q)-a
        #일단 X1 항 하나하나 q승 이상으로 넘어가는 항들을 치환해줌
        subs_val = pow(b,q)-a
        
        for j in range((q-1)*abs(pow(-1,q-i)*math.comb(q-2,i)),q-1,-1):
            X1 = X1.subs(theta**j, subs_val*(theta**(j-q)))
            X1 = expand(X1)
        
        
        #다음에 X1들끼리 곱하면서 q승 이상으로 넘어가는 부분이 있다면 치환
        E1 = expand((E1) * (X1))
        for k in range(2*q-2, q-1, -1):
            E1 = E1.subs(theta**k, subs_val*(theta**(k-q)))
            E1 = expand(E1)
            
        temp_X1=[]
        temp_E1=[]
        for i in range(0,q):
            temp_X1.append(X1.coeff(theta,i)%p)
            temp_E1.append(E1.coeff(theta,i)%p)
        
            
        print("X1 : ", temp_X1)
        print("E1 : ", temp_E1)
    return E1

In [5]:
def modp_E1(E1,p,q,theta):
    #We have to change the coefficients of E1 to mod p coefficients.
    coef_E1 = []
    for i in range(0,q):
        coef_E1.append(E1.coeff(theta,i)%p)
        
    return coef_E1

In [6]:
def compute_E2(a,b,c,q,p,theta):
    print("="*50)
    print("Computing E2")
    
    E2=1
    for i in range(1, q) :
        print("i = ",i)
        
        #if exponent is negative, for exmaple : -n, dealing it as (-1)*n
        #(some poly)^(-n) => ((some poly)^(-1))^(n) => (changed poly)^n
        if (int((1-pow(-1,i)*math.comb(q-1,i))/q) < 0):
            total=0
            for j in range(0,q) :
                total = total+pow(b,q-1-j)*pow(pow(c,q-i-1),j)*pow(theta,j)
            Y2 = pow(a,p-2)*total
            #다음 n승
            Y2 = pow(Y2,abs(int((1-pow(-1,i)*math.comb(q-1,i))/q)))
        else:
            Y2 = pow((b-pow(c,q-i-1)*theta),int((1-pow(-1,i)*math.comb(q-1,i))/q))
            
        Y2 = expand(Y2)
            
        #If the exponent of theta is larger than q, substitute theta**q as pow(b,q)-a
        #일단 Y2 항 하나하나 q승 이상으로 넘어가는 항들을 치환해줌
        subs_val = pow(b,q)-a
        for j in range((q-1)*abs(int((1-pow(-1,i)*math.comb(q-1,i))/q)),q-1,-1):
            #print("j=",i,"(q승 처리된) Y2 value:", Y2)
            Y2 = Y2.subs(theta**j, subs_val*(theta**(j-q)))
            Y2 = expand(Y2)
            
            
        E2 = expand((E2) * (Y2))
        #다음에 Y2들끼리 곱하면서 q승 이상으로 넘어가는 부분이 있다면 치환
        for k in range(2*q-2, q-1, -1):
            E2 = E2.subs(theta**k, subs_val*(theta**(k-q)))
            
        temp_Y2=[]
        temp_E2=[]
        for i in range(0,q):
            temp_Y2.append(Y2.coeff(theta,i)%p)
            temp_E2.append(E2.coeff(theta,i)%p)
        
        
        print("Y2 : ", temp_Y2)
        print("E2 : ", temp_E2)
            
    return E2

In [7]:
def modp_E2(E2,p,q,theta):
    #We have to change the coefficients of E2 to mod p coefficients.
    coef_E2 = []
    for i in range(0,q):
        coef_E2.append(E2.coeff(theta,i)%p)
        
    return coef_E2   

In [8]:
def recurrence_ai(coef_E1, p,b,q,a):
    goal = bin(int((p-1)//q))
    l=len(goal)
    
    a_list=[]
    a_list.append([])
    for i in range(len(coef_E1)):
        a_list[0].append(coef_E1[i])
    
 
    index=0
    for u in range(1,l-2):
        a_list.append([])
        for i in range(0,q):
            temp1 = 0
            for k in range(0,i+1):
                temp1 = temp1+a_list[u-1][k]*a_list[u-1][i-k]
            temp2 = 0
            for k in range(i+1, q):
                temp2 = temp2+a_list[u-1][k]*a_list[u-1][q+i-k]
            a_list[u].append((temp1+(pow(b,q)-a)*temp2)%p)
        index=u
   
    #find index of 1 of binary(p-1/q) 
    pos=[]
    for i in range(2,l):
        if(goal[i]=='1'):
            pos.append(i)
    
    #print("pos is",pos)
    #print(len(pos)) 
    
    if(len(pos)>=2):
        a_list.append([])
        m = l-pos[0]-1
        n = l-pos[1]-1
        flag=1
        for w in range(0,len(pos)-1):
            index = index+1
            for i in range(0,q):
                temp1 = 0
                for k in range(0,i+1):
                    temp1 = temp1+a_list[m][k]*a_list[n][i-k]
                temp2 = 0
                for k in range(i+1, q):
                    temp2 = temp2+a_list[m][k]*a_list[n][q+i-k]
                a_list[index].append((temp1+(pow(b,q)-a)*temp2)%p)
            
            if(w!=len(pos)-2) :
                a_list.append([])
                m=index
                flag = flag+1
                n = l-pos[flag]-1
    
    return a_list
    

In [9]:
def refinement_H_C_Williams(p,q,a):
    phi = int((p-1)//q)

    #Step1 : check a is qth residue or not.
    #If a is NOT a qth residue, the congruence x^q = a (mod p) does not have any solutions.
    if (pow(a,int((p-1)//q),p) != 1) :
        print("There are no solutions. Algorithm terminates.")
        return False


    #Step2 : find b=1,2,3... until (b^q-a)^phi is not 0 and 1
    b=1
    val = pow((pow(b,q)-a), phi, p)
    while(1):
        if(val!=0 and val!=1) :
            c = val
            break
        else :
            b=b+1
            val = pow((pow(b,q)-a), phi, p)

    print("b is", b)    
    print("c is",c)

    #Step3 : field extension K=k[theta], where theta^q = pow(b,q)-a
    #Compute Xi, Yi and products of Xi's(=E1) and Yi's(E2).
    theta = Symbol('theta') 
    E1 = compute_E1(a,b,c,q,p, theta)
    E2 = compute_E2(a,b,c,q,p, theta)
    #print("E1 is", E1)
    #print("E2 is", E2)
    #modular computation on E1 and E2 and obtain a0,...,aq-1 and b0,...,bq-1
    coef_E1 = modp_E1(E1,p,q,theta)
    coef_E2 = modp_E2(E2,p,q,theta)
    print("="*50)
    print("coefficient of E1 :", coef_E1)
    print("coefficient of E2:", coef_E2)


    #Step4 : Use the recurrence relation in extension field K.
    #Goal : compute ai((p-1)/q)
    goal = bin(phi)
    m=len(goal)

    a_list = recurrence_ai(coef_E1, p,b,q,a)
    print("final a_list:")
    print(a_list)

    #Step5 : Output solution x of congruence x^q=a (mod p)
    x = a_list[len(a_list)-1][0]*coef_E2[0]
    temp=0
    for i in range(1,q):
        temp = temp + a_list[len(a_list)-1][i]*coef_E2[q-i]
    x = (x+(pow(b,q)-a)*temp)%p
    #print("x is",x)

    #Final output : q solutions
    qth_root = []
    for j in range(0,q):
        qth_root.append((pow(c,j)*x)%p)

    print("qth roots of the congruence x^q = a (mod p) : ", qth_root)
    return qth_root

## Test1

In [10]:
#Input : p,q satisfying p=1 (mod q), a : integer s.t NOT divisible by p
#x^q = a (mod p)
p=31
q=3
a=2

phi = int((p-1)//q)

#Step1 : check a is qth residue or not.
#If a is NOT a qth residue, the congruence x^q = a (mod p) does not have any solutions.
if (pow(a,int((p-1)//q),p) != 1) :
    print("There are no solutions. Algorithm terminates.")


#Step2 : find b=1,2,3... until (b^q-a)^phi is not 0 and 1
b=1
val = pow((pow(b,q)-a), phi, p)
while(1):
    if(val!=0 and val!=1) :
        c = val
        break
    else :
        b=b+1
        val = pow((pow(b,q)-a), phi, p)

print("b is", b)    
print("c is",c)

#Step3 : field extension K=k[theta], where theta^q = pow(b,q)-a
#Compute Xi, Yi and products of Xi's(=E1) and Yi's(E2).
theta = Symbol('theta') 
E1 = compute_E1(a,b,c,q,p, theta)
E2 = compute_E2(a,b,c,q,p, theta)
#print("E1 is", E1)
#print("E2 is", E2)
#modular computation on E1 and E2 and obtain a0,...,aq-1 and b0,...,bq-1
coef_E1 = modp_E1(E1,p,q,theta)
coef_E2 = modp_E2(E2,p,q,theta)
print("coefficient a :", coef_E1)
print("coefficient b:", coef_E2)


#Step4 : Use the recurrence relation in extension field K.
#Goal : compute ai((p-1)/q)
goal = bin(phi)
m=len(goal)
print(m)

a_list = recurrence_ai(coef_E1, p,b,q,a)
print("final a_list:", a_list)
print("len",len(a_list))

#Step5 : Output solution x of congruence x^q=a (mod p)
x = a_list[len(a_list)-1][0]*coef_E2[0]
temp=0
for i in range(1,q):
    temp = temp + a_list[len(a_list)-1][i]*coef_E2[q-i]
x = (x+(pow(b,q)-a)*temp)%p
print("x is",x)

#Final output : q solutions
qth_root = []
for j in range(0,q):
    qth_root.append((pow(c,j)*x)%p)

print("qth roots of the congruence x^q = a (mod p) : ", qth_root)

for i in range(len(qth_root)):
    print("check :",int((qth_root[i]**q)%p)==a)

b is 2
c is 25
<Computing E1>
i =  0
X1 :  [2, 1, 16]
E1 :  [2, 1, 16]
i =  1
X1 :  [2, 6, 0]
E1 :  [22, 14, 7]
Computing E2
i =  1
Y2 :  [2, 6, 0]
E2 :  [2, 6, 0]
i =  2
Y2 :  [1, 0, 0]
E2 :  [2, 6, 0]
coefficient a : [22, 14, 7]
coefficient b: [2, 6, 0]
6
final a_list: [[22, 14, 7], [17, 11, 8], [12, 14, 21], [14, 6, 18], [9, 4, 19]]
len 5
x is 20
qth roots of the congruence x^q = a (mod p) :  [20, 4, 7]
check : True
check : True
check : True


## Main

In [11]:
#main
#Input : p,q satisfying p=1 (mod q), a : integer s.t NOT divisible by p
#x^q = a (mod p)
q=7
p=random_p(q)
x=randrange(1,p)
a=(x**q)%p
#p=1451
#a=594
print("a is",a)
qth_root = refinement_H_C_Williams(p,q,a)
for i in range(q):
    print("check :",int((qth_root[i]**q)%p)==a)

10 bit random prime number p = 1667
a is 338
b is 1
c is 686
<Computing E1>
i =  0
X1 :  [725, 725, 725, 725, 725, 725, 725]
E1 :  [725, 725, 725, 725, 725, 725, 725]
i =  1
X1 :  [1, 1571, 19, 302, 1435, 1491, 0]
E1 :  [1593, 1497, 1516, 151, 1586, 1410, 1410]
i =  2
X1 :  [166, 307, 103, 420, 977, 1310, 334]
E1 :  [772, 1244, 1112, 684, 410, 1605, 332]
i =  3
X1 :  [433, 1527, 994, 632, 286, 496, 258]
E1 :  [1183, 1619, 1308, 1087, 1533, 1423, 1071]
i =  4
X1 :  [644, 1219, 1267, 245, 933, 389, 4]
E1 :  [253, 1207, 828, 1331, 159, 499, 1577]
i =  5
X1 :  [1, 1491, 0, 0, 0, 0, 0]
E1 :  [1574, 21, 105, 632, 950, 854, 437]
Computing E2
i =  1
Y2 :  [1, 1491, 0, 0, 0, 0, 0]
E2 :  [1, 1491, 0, 0, 0, 0, 0]
i =  2
Y2 :  [957, 971, 872, 1122, 827, 65, 94]
E2 :  [170, 906, 10, 1014, 61, 1209, 323]
i =  3
Y2 :  [1, 424, 469, 1165, 0, 0, 0]
E2 :  [1258, 862, 1067, 1426, 1546, 853, 847]
i =  4
Y2 :  [957, 862, 671, 1205, 373, 318, 512]
E2 :  [1251, 757, 311, 1392, 1262, 1557, 1641]
i =  5
Y2 :  