In [4]:
import sympy
from utils import *
import math
import random

def setup(n,t):
    p_dash = get_germain_prime(13) # TODO: change hardcoded value
    q_dash = get_germain_prime(18)

    p = 2*p_dash + 1
    q = 2*q_dash + 1
    N = p*q
    phi = 4*p_dash*q_dash

    n_fact = math.factorial(n)

    # randomly sample a prime s from the range [n+1, min(p_dash, q_dash)-1]
    s = sympy.randprime(n+1, min(p_dash, q_dash)-1)

    try:
        assert(s > n)
        assert(s < min(p_dash, q_dash))
        assert(phi % s != 0)
    except AssertionError:
        print("s is not valid")
        return None
    
    v = sympy.mod_inverse(n_fact*s, phi//4)
    a_coeff = [sympy.randprime(1, N) for _ in range(t-1)] # this has length t-1

    # generate_sk_from_polynomial
    sk = {i: v + sum([a_coeff[j-1]*i**j for j in range(1, len(a_coeff)+1)]) for i in range(1, n+1)}
    return N, phi, s, sk

def gen(n_fact, N):
    seed = random.randint(1, N)
    n_fact_sq = n_fact**2
    x_0 = pow(seed, n_fact_sq, N)
    return x_0

def eval(x_curr, sk_i, N):
    x_next_i = pow(x_curr, sk_i, N)
    return x_next_i

def combine(x_next_array, selected_indices, N, n, n_fact):
    n_fact_times_L_0 = {i: lagrange_basis_polynomial(i, 0, selected_indices, n_fact) for i in range(1, n+1)}
    x_next = 1
    for i in selected_indices:
        x_next *= pow(x_next_array[i], n_fact_times_L_0[i], N)
    x_next %= N
    return x_next


In [5]:
######## Test run ########
n = 10
t = 8
N, phi, s, sk = setup(n,t)
n_fact = math.factorial(n)
x_0 = gen(n_fact, N)

x_1_array = {i: eval(x_0, sk[i], N) for i in range(1, n+1)}

# create a set of t random nodes in range 1 to n
selected_indices = random.sample(range(1, n+1), t) # TODO: received from different nodes

x_1 = combine(x_1_array, selected_indices, N, n, n_fact)

assert x_0 == pow(x_1, s, N)

The Germain prime is: 8243
The Germain prime is: 262193


In [6]:
x_curr = x_1
for i in range(1, 10):
    x_next_array = {i: eval(x_curr, sk[i], N) for i in range(1, n+1)}
    selected_indices = random.sample(range(1, n+1), t) 
    x_next = combine(x_next_array, selected_indices, N, n, n_fact)
    assert x_curr == pow(x_next, s, N)
    print(f"Passed iteration: {i}")
    x_curr = x_next

Passed iteration: 1
Passed iteration: 2
Passed iteration: 3
Passed iteration: 4
Passed iteration: 5
Passed iteration: 6
Passed iteration: 7
Passed iteration: 8
Passed iteration: 9
