In [1]:
%matplotlib inline
import sys
import numpy as np
import pandas as pd
import random
import galois  

Q = 7883 # large prime (toy large, in reality its much larger 256 bits large)
GF = galois.GF(Q**1)

def get_multiplicative_inverse(x):
    '''Get multiplicative inverse using extended euclidean GCD
    '''
    x = x % Q
    _, inverse, _ = galois.egcd(x, Q)
    return inverse

# Additive secret sharing of an integer

In [2]:
def additive_share(secret, N):
    '''Split in an integer (secret) in Finite fields into N additive shares
    '''
    shares  = [ random.randrange(Q) for _ in range(N-1) ]
    shares += [ (secret - sum(shares)) % Q ]
    return shares


def additive_reconstruct(shares):
    return sum(shares) % Q


In [3]:
shares = additive_share(65, 5)
print('5 Shares for secret 65: ', shares)
print('Reconstructing secret from shares by adding them up, sum = ', additive_reconstruct(shares))


5 Shares for secret 65:  [1564, 6859, 6566, 1657, 7068]
Reconstructing secret from shares by adding them up, sum =  65


# Additive secret sharing of vector

In [4]:
def packed_additive_share(secrets, N):
    '''
    Vector version of additive secret sharing
    '''
    shares = [[] for _ in secrets]
    for idx, secret in enumerate(secrets):
        shares[idx] = additive_share(secret, N)

    return np.array(shares).T.tolist()


def packed_additive_reconstruct(shares):
    '''row wise sum the arrays
    '''
    return [int(i) % Q for i in np.array(shares).sum(axis=0)]

In [5]:
print('Splitting a vector {} into secrets 4 secrets\n'.format([11, 2, 1, 112]))
shares = packed_additive_share([11, 2, 1, 112], 4)
for idx, share in enumerate(shares):
    print('\tSecret {}\t{}'.format(idx+1, share))
print('\nReconstruction', packed_additive_reconstruct(shares))




Splitting a vector [11, 2, 1, 112] into secrets 4 secrets

	Secret 1	[3752, 3806, 326, 3371]
	Secret 2	[7215, 3629, 6303, 2959]
	Secret 3	[5670, 3359, 2602, 1891]
	Secret 4	[7023, 4974, 6536, 7657]

Reconstruction [11, 2, 1, 112]


# Polynomial stuff

In [6]:
def get_coeffiecients(points):
    '''
    Given N points in a finite field, find the coefficients of the polynomial
    of degree N-1 that pass through all the points
    
    We solve it using Ax = b, 
    where A is the Vandermonde Matrix and x is the the coefficients
    b is the polynomial evaluated at N points. 
    
    A vandermonde matrix is always invertible if there are N distinct x_i's. 
    Proof: Look up that the formulae for determinant for the matrix, it will be clear.
    '''
    degree = len(points)
    A = []
    b = []
    for x, p_x in points:
        temp = []
        for deg in range(degree):
            temp.append(x**deg % Q)
        b.append(p_x)
        A.append(temp)
            
    A = GF(A)
    X = np.linalg.inv(A) @ GF(b)
    return X

def evaluate_at_point(coefs, point):
    '''Evaluate a polynomial using Horners Rule
    needing n addition and n multiplication operations 
    for degree n polynomial.
    
    coefs : coefficients that define a polynomial
    point: point at which we wish to evaluate polynomial
    '''
    result = 0
    for coef in reversed(coefs):
        result = (coef + point * result) % Q
    return result

def interpolate_at_point(points_values, point):
    '''Given n+1 points for a degree n polynomial,
    use Lagrange interpolation to interpolate polynomial at 
    a new point
    
    point_values: n+1 tuples of <x, p(x)>
    point: point at which we wish to interpolate polynomials
    '''
    points, values = zip(*points_values)
    constants = lagrange_constants_for_point(points, point)
    return sum( ci * vi for ci, vi in zip(constants, values) ) % Q


def lagrange_constants_for_point(points, point):

    inverse = lambda x : pow(x, -1, Q)
    constants = [0] * len(points)
    for i in range(len(points)):
        xi = points[i]
        num = 1
        denum = 1
        for j in range(len(points)):
            if j != i:
                xj = points[j]
                num = (num * (xj - point)) % Q
                denum = (denum * (xj - xi)) % Q
        constants[i] = (num * inverse(denum)) % Q
    return constants

# A clients vote

A one hot encoded vector with a 1 in the index of candidate i. Indexing is from 0.

In [7]:
def encode(x, M):
    '''This is the function voters use to send the electoral system their votes
    
    The checks here are trivially deteched by the server. If they send an 
    input of incorrect size, we do not need any crypto to catch that.
    
    x: is candidate they voted for
    M: number of candidates
    value=1 (legal value) 
    '''
    assert x < M
    assert x >=0
    assert isinstance(x, int)
    
    vote = [1 if i == x else 0 for i in range(M)]
    
    return vote


In [8]:
M = 10 # number of candidates
encode(2, M)

[0, 0, 1, 0, 0, 0, 0, 0, 0, 0]

# The Proof

In [9]:
def honest_client_proof_generation(encoded_vote, M):

    '''
    encoded_vote: Their one hot encoded vector
    M: Num candidates
    '''

    # CLINET SENDS THE BEAVE TRIPLES
    a, b = random.randrange(Q), random.randrange(Q)
    c = (a*b) % Q

    #-------------------------------------------------------------
    # This check needs no further proof apart from input shares.
    print('Checking if all input bits sum to 1')    
    input_shares = packed_additive_share(encoded_vote, K)
    server_outs = {}
    for server_idx in range(K):
        server_outs[server_idx] = sum(input_shares[server_idx]) % Q
    
    if sum([v for _, v in server_outs.items()]) % Q != 1:
        print('SUM TEST failed')
        return False
    print('SUM TEST PASSED\n')
    #-------------------------------------------------------------
    
    #-------------------------------------------------------------
    print('Checking if all input bits are 0 or 1')    
    def valid0_circuit_for_server(left_share, right_share, h_i, a_i, b_i, c_i, e, d):
    
        '''What a single server sees. Since this circuit is just one mult, all polynomials are just 
        constants.

        left_share: Share of left input to the mult circuit
        right_share: Share of right input to the mult circuit
        h_i : share of the true output of the mult gate
        
        Parameters for Beavers protocol
        a_i : share of a
        b_i : share of b
        c_i : share of c
        d   : sum(left_share - a) over all servers
        e   : sum(right_share - b) over all servers        
        '''

        hat_f = left_share
        hat_g = right_share
        
        # Use the protocol to get a share \hat{fg}_i from a share \hat{f}_i and \hat{g}_i
        hat_fg = d*e*get_multiplicative_inverse(K) + d*b_i + e*a_i + c_i

        return h_i, (hat_fg % Q)

    # Keeps track of the output of the valid circuit for each candidate index
    candidate_outputs = [] 
    for candidate_idx in range(M):
        ds = []
        es = []        
        a_shares = additive_share(a, K)
        b_shares = additive_share(b, K)
        c_shares = additive_share(c, K)        
        
        left_input_shares = additive_share(encoded_vote[candidate_idx], K)
        right_input_shares = additive_share((encoded_vote[candidate_idx] - 1) %Q, K)
        
        # The servers gossip together to share d_i and e_i 
        # to construct d and e
        for server_idx in range(K):
            hat_f = left_input_shares[server_idx]
            hat_g = (right_input_shares[server_idx]) %Q
            ds.append((hat_f - a_shares[server_idx]))
            es.append((hat_g - b_shares[server_idx]))
        
        d = sum(ds) % Q
        e = sum(es) % Q
        # Sanity check to ssee d and e are what they are supposed to be 
        assert (encoded_vote[candidate_idx] - a) %Q == d
        assert (encoded_vote[candidate_idx] - 1 - b) %Q == e
        
        # The actual output of all mult gates for all index_x
        proof_for_candidate_x = (encoded_vote[candidate_idx]*(encoded_vote[candidate_idx] - 1))%Q
        # Create shares for this output
        h_shares = additive_share(proof_for_candidate_x, K)
        server_outs = {}
        for server_idx in range(K):
            # h_i : output gate of each server
            # hat_fg_i : aggregate polynomial
            h_i, hat_fg_i = valid0_circuit_for_server(left_input_shares[server_idx],  # Left input
                                                      right_input_shares[server_idx], # Right input
                                                      h_shares[server_idx], # The help
                                                      a_shares[server_idx],
                                                      b_shares[server_idx],
                                                      c_shares[server_idx],
                                                      e, 
                                                      d)
            
            server_outs[server_idx] = {'sigma_i': hat_fg_i - h_i, 'output_i': h_i}
    
        df = pd.DataFrame(server_outs)
        if df.sum(axis=1)['output_i'] % Q != 0:
            print("Servers output claims candidate {}'s vote was illegal".format(candidate_idx))
            return False
        if df.sum(axis=1)['sigma_i'] % Q != 0:
            print("Servers output claims candidate {}'s proof to show it is 1 or 0 was fudged by client".format(server_idx))            
            return False
        
        candidate_outputs.append(1)
    print('Bit check passed')
    
    print('Severs claim input is clean')
    return True
    
    

In [10]:
# Global parameters
M = 7
K = 4


In [11]:
# Legal votes
for x in range(0, M):    
    encoded_vote = encode(x, M)
    print(encoded_vote)
    honest_client_proof_generation(encoded_vote, M)
    print('='*60)

[1, 0, 0, 0, 0, 0, 0]
Checking if all input bits sum to 1
SUM TEST PASSED

Checking if all input bits are 0 or 1
Bit check passed
Severs claim input is clean
[0, 1, 0, 0, 0, 0, 0]
Checking if all input bits sum to 1
SUM TEST PASSED

Checking if all input bits are 0 or 1
Bit check passed
Severs claim input is clean
[0, 0, 1, 0, 0, 0, 0]
Checking if all input bits sum to 1
SUM TEST PASSED

Checking if all input bits are 0 or 1
Bit check passed
Severs claim input is clean
[0, 0, 0, 1, 0, 0, 0]
Checking if all input bits sum to 1
SUM TEST PASSED

Checking if all input bits are 0 or 1
Bit check passed
Severs claim input is clean
[0, 0, 0, 0, 1, 0, 0]
Checking if all input bits sum to 1
SUM TEST PASSED

Checking if all input bits are 0 or 1
Bit check passed
Severs claim input is clean
[0, 0, 0, 0, 0, 1, 0]
Checking if all input bits sum to 1
SUM TEST PASSED

Checking if all input bits are 0 or 1
Bit check passed
Severs claim input is clean
[0, 0, 0, 0, 0, 0, 1]
Checking if all input bits sum

In [12]:
# Ballot stuffing : voting for 2 candidates
x = 1
encoded_vote = encode(x, M)
encoded_vote[0] = 1
print(encoded_vote)
print('vote acceppted: {}'.format(honest_client_proof_generation(encoded_vote, M)))


[1, 1, 0, 0, 0, 0, 0]
Checking if all input bits sum to 1
SUM TEST failed
vote acceppted: False


In [13]:
# Ballot stuffing : voting for 1 candidates but giving him value Q-1 votes and another guy 2 votes
x = 0
encoded_vote = encode(x, M)
encoded_vote[0] = Q-1
encoded_vote[1] = 2
print(encoded_vote)
print('vote acceppted: {}'.format(honest_client_proof_generation(encoded_vote, M)))


[7882, 2, 0, 0, 0, 0, 0]
Checking if all input bits sum to 1
SUM TEST PASSED

Checking if all input bits are 0 or 1
Servers output claims candidate 0's vote was illegal
vote acceppted: False


# Bad proofs

We perturb the servers true proof and show that they get caught with extremeley high probablity

In [14]:
def dishonest_client_proof_generation(encoded_vote, M, debug=True):

    '''
    encoded_vote: Their one hot encoded vector
    M: Num candidates
    '''

    # CLINET SENDS THE BEAVE TRIPLES
    a, b = random.randrange(Q), random.randrange(Q)
    c = (a*b) % Q

    #-------------------------------------------------------------
    # This check needs no further proof apart from input shares.
    input_shares = packed_additive_share(encoded_vote, K)
    server_outs = {}
    for server_idx in range(K):
        server_outs[server_idx] = sum(input_shares[server_idx]) % Q
    
    if sum([v for _, v in server_outs.items()]) % Q != 1:
        return False
    
    #-------------------------------------------------------------
    
    #-------------------------------------------------------------
    def valid0_circuit_for_server(left_share, right_share, h_i, a_i, b_i, c_i, e, d):
    
        '''What a single server sees. Since this circuit is just one mult, all polynomials are just 
        constants.

        left_share: Share of left input to the mult circuit
        right_share: Share of right input to the mult circuit
        h_i : share of the true output of the mult gate
        
        Parameters for Beavers protocol
        a_i : share of a
        b_i : share of b
        c_i : share of c
        d   : sum(left_share - a) over all servers
        e   : sum(right_share - b) over all servers        
        '''

        hat_f = left_share
        hat_g = right_share
        
        # Use the protocol to get a share \hat{fg}_i from a share \hat{f}_i and \hat{g}_i
        hat_fg = d*e*get_multiplicative_inverse(K) + d*b_i + e*a_i + c_i

        return h_i, (hat_fg % Q)

    # Keeps track of the output of the valid circuit for each candidate index
    candidate_outputs = [] 
    for candidate_idx in range(M):
        ds = []
        es = []        
        a_shares = additive_share(a, K)
        b_shares = additive_share(b, K)
        c_shares = additive_share(c, K)        
        
        left_input_shares = additive_share(encoded_vote[candidate_idx], K)
        right_input_shares = additive_share((encoded_vote[candidate_idx] - 1) %Q, K)
        
        # The servers gossip together to share d_i and e_i 
        # to construct d and e
        for server_idx in range(K):
            hat_f = left_input_shares[server_idx]
            hat_g = (right_input_shares[server_idx]) %Q
            ds.append((hat_f - a_shares[server_idx]))
            es.append((hat_g - b_shares[server_idx]))
        
        d = sum(ds) % Q
        e = sum(es) % Q
        # Sanity check to ssee d and e are what they are supposed to be 
        assert (encoded_vote[candidate_idx] - a) %Q == d
        assert (encoded_vote[candidate_idx] - 1 - b) %Q == e
        
        # The actual output of all mult gates for all index_x
        proof_for_candidate_x = (encoded_vote[candidate_idx]*(encoded_vote[candidate_idx] - 1))%Q
        # Create shares for this output
        h_shares = additive_share(proof_for_candidate_x, K)
        
        #-----------------------IMPORTANT----------------------------
        # FUDGING PROOF
        h_shares = [(x + random.randrange(Q) % Q) for x in h_shares]
        #-----------------------IMPORTANT----------------------------
                
        server_outs = {}
        for server_idx in range(K):
            # h_i : output gate of each server
            # hat_fg_i : aggregate polynomial
            h_i, hat_fg_i = valid0_circuit_for_server(left_input_shares[server_idx],  # Left input
                                                      right_input_shares[server_idx], # Right input
                                                      h_shares[server_idx], # The help
                                                      a_shares[server_idx],
                                                      b_shares[server_idx],
                                                      c_shares[server_idx],
                                                      e, 
                                                      d)
            
            server_outs[server_idx] = {'sigma_i': hat_fg_i - h_i, 'output_i': h_i}
    
        df = pd.DataFrame(server_outs)
        if df.sum(axis=1)['output_i'] % Q != 0:
            return False
        if df.sum(axis=1)['sigma_i'] % Q != 0:
            return False
        
        candidate_outputs.append(1)
    return True
    
    

In [16]:
# Legal votes
NUM_TRIALS = 5000
outputs = []
for trial in range(NUM_TRIALS):    
    sys.stdout.write("\r")
    sys.stdout.write("{}/{}".format(trial+1, NUM_TRIALS))    
    encoded_vote = encode(random.randrange(M-1), M) #
    outputs.append(dishonest_client_proof_generation(encoded_vote, M))

print()
from collections import Counter
print("Pass/Test fraction: ", {k: v/NUM_TRIALS for k,v in Counter(outputs).items()})
    

5000/5000
Pass/Test fraction:  {False: 1.0}
