In [1]:
import random

# Primes

In [2]:
SMALL_PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997]

In [3]:
# see https://inventwithpython.com/rabinMiller.py

def rewrite(num):
    s = num - 1
    t = 0
    while s % 2 == 0:
        s = s // 2
        t += 1
    return s, t

def rabin_miller(num, iterations=10):
    s, t = rewrite(num)
    for _ in range(iterations):
        a = random.randrange(2, num - 1)
        v = pow(a, s, num)
        if v != 1:
            i = 0
            while v != (num - 1):
                if i == t - 1:
                    return False
                else:
                    i = i + 1
                    v = pow(v, 2, num)
    return True

In [4]:
def is_prime(num):
    if (num < 2): return False
    for prime in SMALL_PRIMES:
        if num == prime: return True
        if num % prime == 0: return False
    return rabin_miller(num)

assert(not is_prime(1))
assert(    is_prime(2))
assert(    is_prime(3))
assert(not is_prime(4))
assert(    is_prime(5))
assert(not is_prime(6))
assert(    is_prime(7))
assert(not is_prime(8))
assert(not is_prime(9))

# Parameter generation

In [5]:
def is_power_of(x, base):
    p = 1
    while p < x:
        p = p * base
    return p == x

assert(not is_power_of(0, 2))
assert(    is_power_of(1, 2))
assert(    is_power_of(2, 2))
assert(not is_power_of(3, 2))
assert(    is_power_of(4, 2))
assert(not is_power_of(5, 2))
assert(not is_power_of(6, 2))
assert(not is_power_of(7, 2))
assert(    is_power_of(8, 2))
assert(not is_power_of(9, 2))

assert(not is_power_of(0, 3))
assert(    is_power_of(1, 3))
assert(not is_power_of(2, 3))
assert(    is_power_of(3, 3))
assert(not is_power_of(4, 3))
assert(not is_power_of(5, 3))
assert(not is_power_of(6, 3))
assert(not is_power_of(7, 3))
assert(not is_power_of(8, 3))
assert(    is_power_of(9, 3))

In [6]:
def sample_prime(bitsize):
    lower = 1 << (bitsize-1)
    upper = 1 << (bitsize)
    while True:
        candidate = random.randrange(lower, upper)
        if is_prime(candidate):
            return candidate

assert(is_prime(sample_prime(5)))
assert(is_prime(sample_prime(128)))

In [7]:
def remove_factor(x, factor):
    while x % factor == 0:
        x //= factor
    return x
    
assert(remove_factor(1, 2) == 1)
assert(remove_factor(2, 2) == 1)
assert(remove_factor(3, 2) == 3)
assert(remove_factor(4, 2) == 1)

def prime_factor(x):
    factors = []
    for prime in SMALL_PRIMES:
        if prime > x: break
        if x % prime == 0:
            factors.append(prime)
            x = remove_factor(x, prime)
    assert(x == 1) # fail if we were trying to factor a too large number
    return factors

assert(prime_factor(1) == [])
assert(prime_factor(2) == [2])
assert(prime_factor(3) == [3])
assert(prime_factor(4) == [2])
assert(prime_factor(5) == [5])
assert(prime_factor(6) == [2,3])
assert(prime_factor(7) == [7])
assert(prime_factor(8) == [2])
assert(prime_factor(9) == [3])

In [8]:
def find_prime(min_bitsize, order_divisor):
    while True:
        k1 = sample_prime(min_bitsize)
        for k2 in range(128):
            q = k1 * k2 * order_divisor + 1
            if is_prime(q):
                order_prime_factors  = [k1]
                order_prime_factors += prime_factor(k2)
                order_prime_factors += prime_factor(order_divisor)
                return q, order_prime_factors

def find_generator(q, order_prime_factors):
    order = q - 1
    for candidate in range(2, q):
        for factor in order_prime_factors:
            exponent = order // factor
            if pow(candidate, exponent, q) == 1:
                break
        else:
            return candidate
            
def find_prime_field(min_bitsize, order_divisor):
    q, order_prime_factors = find_prime(min_bitsize, order_divisor)
    g = find_generator(q, order_prime_factors)
    return q, g

assert(find_prime_field(2, 8*9) == (433, 5))

In [9]:
def generate_parameters(min_bitsize, order2, order3):
    assert(is_power_of(order2, 2))
    assert(is_power_of(order3, 3))
    
    order_divisor = order2 * order3
    q, g = find_prime_field(min_bitsize, order_divisor)
    assert(is_prime(q))
    
    order = q - 1
    assert(order % order2 == 0)
    assert(order % order3 == 0)
    omega2 = pow(g, order // order2, q)
    omega3 = pow(g, order // order3, q)
    
    return q, omega2, omega3
    
assert(generate_parameters(2, 8, 9) == (433, 354, 150))

### Either generate new parameters ...

In [10]:
ORDER2 = 512
ORDER3 = 729
Q, OMEGA2, OMEGA3 = generate_parameters(80, ORDER2, ORDER3)

print("Prime is %d" % Q)
print("%d-th principal root of unity is %d" % (ORDER2, OMEGA2))
print("%d-th principal root of unity is %d" % (ORDER3, OMEGA3))

Prime is 10931067715358518584417802183681
512-th principal root of unity is 1452790765625403579292312919844
729-th principal root of unity is 10486827228965628789878829027445


### .. or use a fixed choice ...

In [11]:
ORDER2 = 8
ORDER3 = 9

Q = 433
OMEGA2 = 354
OMEGA3 = 150

### ... but in both cases the following should hold

In [12]:
assert(is_prime(Q))

assert(is_power_of(ORDER2, 2))
assert(is_power_of(ORDER3, 3))

ORDER = Q - 1
assert(ORDER % ORDER2 == 0)
assert(ORDER % ORDER3 == 0)

assert(pow(OMEGA2, ORDER2, Q) == 1) # root of unity
assert(1 not in ( pow(OMEGA2, e, Q) for e in range(1, ORDER2) )) # principal 

assert(pow(OMEGA3, ORDER3, Q) == 1) # root of unity
assert(1 not in ( pow(OMEGA3, e, Q) for e in range(1, ORDER3) )) # principal

# Basic algorithms

In [13]:
# from http://www.ucl.ac.uk/~ucahcjm/combopt/ext_gcd_python_programs.pdf
def egcd_binary(a, b):
    u, v, s, t, r = 1, 0, 0, 1, 0
    while (a % 2 == 0) and (b % 2 == 0):
        a, b, r = a//2, b//2, r+1
    alpha, beta = a, b
    while (a % 2 == 0):
        a = a//2
        if (u % 2 == 0) and (v % 2 == 0):
            u, v = u//2, v//2
        else:
            u, v = (u + beta)//2, (v - alpha)//2
    while a != b:
        if (b % 2 == 0):
            b = b//2
            if (s % 2 == 0) and (t % 2 == 0):
                s, t = s//2, t//2
            else:
                s, t = (s + beta)//2, (t - alpha)//2
        elif b < a:
            a, b, u, v, s, t = b, a, s, t, u, v
        else:
            b, s, t = b - a, s - u, t - v
    return (2 ** r) * a, s, t


def inverse(a):
    _, b, _ = egcd_binary(a, Q)
    return b


for a in range(1, min(Q, 1000000)):
    assert(a * inverse(a) % Q == 1)

In [14]:
def eval_poly_at(coefs, point):
    return sum( coef * pow(point, degree, Q) % Q for degree, coef in enumerate(coefs) ) % Q

if Q > 17:
    assert(eval_poly_at([1,2,3], 0) == 1 % Q)
    assert(eval_poly_at([1,2,3], 1) == 6 % Q)
    assert(eval_poly_at([1,2,3], 2) == 17 % Q)

# Fast Fourier Transform (unoptimised)

In [15]:
# len(A_coeffs) must be a power of 2
def fft2_forward(A_coeffs, omega=OMEGA2):
    if len(A_coeffs) == 1:
        return A_coeffs

    # split A into B and C such that A(x) = B(x^2) + x C(x^2)
    B_coeffs = A_coeffs[0::2]
    C_coeffs = A_coeffs[1::2]
    
    # apply recursively
    omega_squared = pow(omega, 2, Q)
    B_values = fft2_forward(B_coeffs, omega_squared)
    C_values = fft2_forward(C_coeffs, omega_squared)
        
    # combine subresults
    A_values = [0] * len(A_coeffs)
    L_half = len(A_coeffs) // 2
    for i in range(0, L_half):
        
        j = i
        x = pow(omega, j, Q)
        A_values[j] = (B_values[i] + x * C_values[i]) % Q
        
        j = i + L_half
        x = pow(omega, j, Q)
        A_values[j] = (B_values[i] + x * C_values[i]) % Q
        
    return A_values

def fft2_backward(A_values):
    L_inv = inverse(len(A_values))
    A_coeffs = [ (a * L_inv) % Q for a in fft2_forward(A_values, inverse(OMEGA2)) ]
    return A_coeffs


coefs = [ x for x in range(ORDER2) ]
assert(len(coefs) == ORDER2)

values = fft2_forward(coefs)
assert(len(values) == ORDER2)
assert(values == [ eval_poly_at(coefs, pow(OMEGA2, degree, Q)) for degree in range(ORDER2) ])

coefs_recovered = fft2_backward(values)
assert(coefs == coefs_recovered)

In [16]:
# len(A_coeffs) must be a power of 3
def fft3_forward(A_coeffs, omega=OMEGA3):
    if len(A_coeffs) == 1:
        return A_coeffs

    # split A(x) into B(x), C(x), and D(x) such that A(x) = B(x^3) + x C(x^3) + x^2 D(x^3)
    B_coeffs = A_coeffs[0::3]
    C_coeffs = A_coeffs[1::3]
    D_coeffs = A_coeffs[2::3]
    
    # apply recursively
    omega_cubed = pow(omega, 3, Q)
    B_values = fft3_forward(B_coeffs, omega_cubed)
    C_values = fft3_forward(C_coeffs, omega_cubed)
    D_values = fft3_forward(D_coeffs, omega_cubed)
        
    # combine subresults
    A_values = [0] * len(A_coeffs)
    L_third = len(A_coeffs) // 3
    for i in range(L_third):
        
        j = i
        x = pow(omega, j, Q)
        xx = (x * x) % Q
        A_values[j] = (B_values[i] + x * C_values[i] + xx * D_values[i]) % Q
        
        j = i + L_third
        x = pow(omega, j, Q)
        xx = (x * x) % Q
        A_values[j] = (B_values[i] + x * C_values[i] + xx * D_values[i]) % Q
        
        j = i + L_third + L_third
        x = pow(omega, j, Q)
        xx = (x * x) % Q
        A_values[j] = (B_values[i] + x * C_values[i] + xx * D_values[i]) % Q

    return A_values

def fft3_backward(A_values):
    L_inv = inverse(len(A_values))
    A_coeffs = [ (a * L_inv) % Q for a in fft3_forward(A_values, inverse(OMEGA3)) ]
    return A_coeffs


coefs = [ x for x in range(ORDER3) ]
assert(len(coefs) == ORDER3)

values = fft3_forward(coefs)
assert(len(values) == ORDER3)
assert(values == [ eval_poly_at(coefs, pow(OMEGA3, degree, Q)) for degree in range(ORDER3) ])

coefs_recovered = fft3_backward(values)
assert(coefs == coefs_recovered)

# Fast Fourier Transform (optimised)

In [None]:
# len(aX) must be a power of 2
def fft2_forward(aX, omega=OMEGA2):
    if len(aX) == 1:
        return aX

    # split A(x) into B(x) and C(x) -- A(x) = B(x^2) + x C(x^2)
    bX = aX[0::2]
    cX = aX[1::2]
    
    # apply recursively
    omega_squared = pow(omega, 2, Q)
    B = fft2_forward(bX, omega_squared)
    C = fft2_forward(cX, omega_squared)
        
    # combine subresults
    A = [0] * len(aX)
    Nhalf = len(aX) >> 1
    point = 1
    for i in range(0, Nhalf):
        
        x = point
        A[i]         = (B[i] + x * C[i]) % Q
        A[i + Nhalf] = (B[i] - x * C[i]) % Q

        point = (point * omega) % Q
        
    return A

def fft2_backward(A):
    N_inv = inverse(len(A))
    return [ (a * N_inv) % Q for a in fft2_forward(A, inverse(OMEGA2)) ]


coefs = [1,2,3,4,5,6,7,8]
assert(len(coefs) == ORDER2)

values = fft2_forward(coefs)
assert(len(values) == ORDER2)
assert(values == [ eval_poly_at(coefs, pow(OMEGA2, degree, Q)) for degree in range(ORDER2) ])

coefs_recovered = fft2_backward(values)
assert(coefs == coefs_recovered)

In [None]:
# len(aX) must be a power of 3
def fft3_forward(aX, omega=OMEGA3):
    if len(aX) == 1:
        return aX

    # split A(x) into B(x), C(x), and D(x): A(x) = B(x^3) + x C(x^3) + x^2 D(x^3)
    bX = aX[0::3]
    cX = aX[1::3]
    dX = aX[2::3]
    
    # apply recursively
    omega_cubed = pow(omega, 3, Q)
    B = fft3_forward(bX, omega_cubed)
    C = fft3_forward(cX, omega_cubed)
    D = fft3_forward(dX, omega_cubed)
        
    # combine subresults
    A = [0] * len(aX)
    Nthird = len(aX) // 3
    omega_Nthird = pow(omega, Nthird, Q)
    point = 1
    for i in range(Nthird):
        
        x = point
        xx = (x * x) % Q
        A[i                  ] = (B[i] + x * C[i] + xx * D[i]) % Q
        
        x = x * omega_Nthird % Q
        xx = (x * x) % Q
        A[i + Nthird         ] = (B[i] + x * C[i] + xx * D[i]) % Q
        
        x = x * omega_Nthird % Q
        xx = (x * x) % Q
        A[i + Nthird + Nthird] = (B[i] + x * C[i] + xx * D[i]) % Q

        point = (point * omega) % Q
        
    return A

def fft3_backward(A):
    N_inv = inverse(len(A))
    return [ (a * N_inv) % Q for a in fft3_forward(A, inverse(OMEGA3)) ]


coefs = [1,2,3,4,5,6,7,8,9]
assert(len(coefs) == ORDER3)

values = fft3_forward(coefs)
assert(len(values) == ORDER3)
assert(values == [ eval_poly_at(coefs, pow(OMEGA3, degree, Q)) for degree in range(ORDER3) ])

coefs_recovered = fft3_backward(values)
assert(coefs == coefs_recovered)

# Packed Secret Sharing

In [17]:
N = ORDER3 - 1
T = N // 2
K = ORDER2 - T - 1

assert(ORDER2 == T + K + 1)
assert(ORDER3 == N + 1)

POINTS_SECRETS = set( pow(OMEGA2, e, Q) for e in range(1, ORDER2) ) # secrets, incl randomness
POINTS_SHARES  = set( pow(OMEGA3, e, Q) for e in range(1, ORDER3) )
assert(POINTS_SECRETS.intersection(POINTS_SHARES) == set([]))

print("Points used for secrets: %s" % sorted(POINTS_SECRETS))
print("Points used for shares:  %s" % sorted(POINTS_SHARES))

Points used for secrets: [79, 148, 179, 254, 285, 354, 432]
Points used for shares:  [27, 150, 153, 198, 234, 256, 296, 417]


In [19]:
def share(secrets):
    assert(len(secrets) == K)
    small_values = [0] + secrets + [random.randrange(Q) for _ in range(T)]
    small_coeffs = fft2_backward(small_values)
    large_coeffs = small_coeffs + [0] * (ORDER3-ORDER2)
    large_values = fft3_forward(large_coeffs)
    shares = large_values[1:]
    assert(len(small_values) == ORDER2)
    assert(len(large_coeffs) == ORDER3)
    assert(len(shares) == N)
    return shares

def reconstruct(shares):
    assert(len(shares) == N)
    large_values = [0] + shares
    large_coeffs = fft3_backward(large_values)
    small_coeffs = large_coeffs[:ORDER2]
    small_values = fft2_forward(small_coeffs)
    secrets = small_values[1:K+1]
    assert(len(large_values) == ORDER3)
    assert(large_coeffs[ORDER2:] == [0] * (ORDER3-ORDER2))
    assert(len(small_coeffs) == ORDER2)
    assert(len(secrets) == K)
    return secrets

secrets = [ s for s in range(1, K+1) ]
shares = share(secrets)
recovered_secrets = reconstruct(shares)
assert(recovered_secrets == secrets)