# CS295/395: Secure Distributed Computation
## In-Class Exercise, week of 9/19/2022

## References

- Overview of the BGW protocol: [Pragmatic MPC, Section 3.3](https://securecomputation.org/docs/pragmaticmpc.pdf)
- Vandermonde Matrices for polynomial evaluation: [Asharov & Lindell, 2011, Section 3.3, Definition 3.6](https://eprint.iacr.org/2011/136.pdf)
- Formal protocol description (GRR protocol): [Lindell & Nof, 2017, Appendix B.3 (Protocol B.3)](https://eprint.iacr.org/2017/816.pdf)

In [1]:
# Imports and definitions
import numpy as np
from collections import defaultdict
import numpy as np
import galois
GF = galois.GF(2 ** 13 - 1)

# Library for binary circuits
from collections import namedtuple
AddGate = namedtuple('AddGate', ['in1', 'in2'])
MultGate = namedtuple('MultGate', ['in1', 'in2'])
#Gate = namedtuple('Gate', ['type', 'in1', 'in2', 'out'])

from dataclasses import dataclass

@dataclass
class Gate:
    type: str
    in1: int
    in2: int
    out: int

@dataclass
class Circuit:
    inputs: any
    outputs: any
    gates: any
        
def print_circuit(c):
    print('inputs:', c.inputs)
    print('outputs:', c.outputs)
    print('gates:')
    for g in c.gates:
        print('  ', g)

In [2]:
class Party:
    """A participant in a multiparty computation protocol."""
    def __init__(self):
        """Initialize the field size and dictionary to hold received messages."""
        self.input = None
        self.output = None
        self.received = defaultdict(list)
    
    def send(self, other, round, msg):
        """Simulate sending a message `msg` to another party `other` during round `round`"""
        other.received[round].append(msg)

    def get_view(self):
        """Returns the view of this party: its input, output, and received messages."""
        return (self.input, self.output, dict(self.received))

# Generate Shamir shares for secret v with threshold t and number of shares n
def shamir_share(v, t, n):
    coefficients = GF([GF.Random() for _ in range(t-1)] + [v])
    poly = galois.Poly(coefficients)
    shares = [(GF(x), poly(GF(x))) for x in range(1, n+1)]
    return shares

# Reconstruct the secret from at least t Shamir shares
def reconstruct(shares):
    xs = GF([s[0] for s in shares])
    ys = GF([s[1] for s in shares])
    poly = galois.lagrange_poly(xs, ys)
    #print(poly)
    secret = poly(0)
    
    return secret

In [3]:
n = 5
V_a = GF(np.vander(range(1,n+1), increasing=True))
V_a_inv = np.linalg.inv(V_a)
lambda_js = V_a_inv[0]
lambda_js

GF([   5, 8181,   10, 8186,    1], order=8191)

## Question 1

Describe a protocol to multiply two input numbers. The input numbers will be secret-shared according to a $(t,n)$ Shamir secret sharing scheme before the protocol starts, and each party will receive one share of both numbers. Each party should output *one share of the product*, using a $(t, n)$ Shamir secret sharing scheme (i.e. the threshold for the output should be the same as the threshold for the input).

\begin{equation*}
\textbf{Functionality: Multiply Two Numbers}\\
\fbox{$\mathcal{F}(a, b) = a \cdot b$}
\end{equation*}




Do a single multiplecation of shamir shares and then deg reduction

- R1
  - Each party $P_i$ recieves shares $a_i$, $b_i$ as input
  - Let $s_i = a_i * b_i$ - threshold for $s_i$ will be less than or equal to $2t$ where t is the initial threshold
    - $s_i$ is exactly $q(\alpha_i)$
  - $P_i$ computes $h_i^1 ... h_i^n$ = `share`($s_i$, t, n)
  - $P_i$ sends share $h_i^j$ to party $j$
- R2
  - Each party $P_i$ recieves the shares $h_j^i$ (yes the sub and superscripts are flipped)
  - $P_i$ computes $\sum_j (h_j^i * \lambda_j)$ and ouptuts this value as its own share of the origional product with threshold $t$

## Question 2

Implement your protocol from question 1.

In [4]:
class MultTwoParty(Party):
    def round1(self, parties, a_shr, b_shr, t):
        self.input = (a_shr, b_shr)
        self.parties = parties
        n = len(parties)
        assert t <= n/2
        
        # - Each party $P_i$ recieves shares $a_i$, $b_i$ as input
        # - Let $s_i = a_i * b_i$ - threshold for $s_i$ will be less than or equal to $2t$ where t is the initial threshold
        #     - $s_i$ is exactly $q(\alpha_i)$

        a_x, a_y = a_shr
        b_x, b_y = b_shr

        # they better have the same x coord
        assert a_x == b_x

        s_i = a_y * b_y #q(x_i) higher degree than we'd like (degree 2t at most)
        self.x_coord = a_x # save this for round 2

        # - $P_i$ computes $h_i^1 ... h_i^n$ = `share`($s_i$, t, n)
        h_i_js = shamir_share(s_i, t, n)
        # - $P_i$ sends share $h_i^j$ to party $j$
        for party, share in zip(self.parties, h_i_js):
            self.send(party, 1, share)

    def round2(self):
        n = len(self.parties)
        
        # - Each party $P_i$ recieves the shares $h_j^i$ (yes the sub and superscripts are flipped)
        h_j_is = self.received[1]
        h_j_is_y = [s[1] for s in h_j_is]

        # $P_i$ computes $\sum_j (h_j^i * \lambda_j)$ and ouptuts this value as its own share of the origional product with threshold $t$
        V_a = GF(np.vander(range(1,n+1), increasing=True))
        V_a_inv = np.linalg.inv(V_a)
        lambda_js = V_a_inv[0]

        prods = [h_j_is_y[i] * lambda_js[i] for i in range(n)]

        self.output = self.x_coord, GF(prods).sum()
        return self.output

        

In [5]:
NUM_PARTIES = 6
# (t, n)-Shamir scheme
n = NUM_PARTIES
t = 3

shares1 = shamir_share(5, t, n)
shares2 = shamir_share(6, t, n)

parties = [MultTwoParty() for _ in range(NUM_PARTIES)]

for p,s1,s2 in zip(parties, shares1, shares2):
    p.round1(parties, s1, s2, t)
for p in parties:
    p.round2()
for p in parties:
    print(p.get_view())

output_shares = [p.output for p in parties]
print('Reconstruction, with all shares:', reconstruct(output_shares))
print('Reconstruction, with 3 shares:', reconstruct(output_shares[:3]))
print('Reconstruction, with 2 shares:', reconstruct(output_shares[:2]))

assert reconstruct(output_shares) == 30
assert reconstruct(output_shares[:3]) == 30
assert reconstruct(output_shares[:2]) != 30

(((GF(1, order=8191), GF(6623, order=8191)), (GF(1, order=8191), GF(4358, order=8191))), (GF(1, order=8191), GF(5317, order=8191)), {1: [(GF(1, order=8191), GF(5901, order=8191)), (GF(1, order=8191), GF(3428, order=8191)), (GF(1, order=8191), GF(2461, order=8191)), (GF(1, order=8191), GF(4508, order=8191)), (GF(1, order=8191), GF(3758, order=8191)), (GF(1, order=8191), GF(7390, order=8191))]})
(((GF(2, order=8191), GF(7863, order=8191)), (GF(2, order=8191), GF(7059, order=8191))), (GF(2, order=8191), GF(2032, order=8191)), {1: [(GF(2, order=8191), GF(4117, order=8191)), (GF(2, order=8191), GF(8132, order=8191)), (GF(2, order=8191), GF(1330, order=8191)), (GF(2, order=8191), GF(7483, order=8191)), (GF(2, order=8191), GF(5591, order=8191)), (GF(2, order=8191), GF(4220, order=8191))]})
(((GF(3, order=8191), GF(3725, order=8191)), (GF(3, order=8191), GF(8109, order=8191))), (GF(3, order=8191), GF(6557, order=8191)), {1: [(GF(3, order=8191), GF(789, order=8191)), (GF(3, order=8191), GF(431,

## Question 3

Write a function `sum_circuit` that builds an arithmetic circuit for summing up a set of `n` inputs.

In [10]:
def sum_circuit(n):
    input_wires = range(n)
    total = input_wires[0]

    w = n
    gates = []
    for i in input_wires[1:]:
        g = Gate("ADD", total, i, w)
        total = w
        w += 1
        gates.append(g)
    circuit = Circuit(input_wires, [total], gates)
    return circuit

In [11]:
print_circuit(sum_circuit(6))

inputs: [[0], [1], [2], [3], [4], [5]]
outputs: [10]
gates:
   Gate(type='ADD', in1=[0], in2=[1], out=6)
   Gate(type='ADD', in1=6, in2=[2], out=7)
   Gate(type='ADD', in1=7, in2=[3], out=8)
   Gate(type='ADD', in1=8, in2=[4], out=9)
   Gate(type='ADD', in1=9, in2=[5], out=10)
   Gate(type='MULT', in1=10, in2=10, out=11)


In [30]:
# TEST CASE

assert sum_circuit(2) == \
  Circuit(inputs=[0, 1], outputs=[3], gates=[Gate(type='ADD', in1=0, in2=1, out=3)])

AssertionError: 

In [31]:
import urllib.request
adder_url = "https://homes.esat.kuleuven.be/~nsmart/MPC/adder64.txt"
adder_txt = urllib.request.urlopen(adder_url).read().decode("utf-8")

In [32]:
# Parse a circuit from a Bristol-Fashion specification
def parse_circuit(bristol_fashion_text):
    lines = [l.strip() for l in bristol_fashion_text.split('\n') if l != '']
    total_wires = int(lines[0].split(' ')[1])
    inputs = lines[1]
    outputs = lines[2]
    gates_txt = lines[3:]
    gates = []
    
    # parse the gates
    for g_txt in gates_txt:
        sp = g_txt.split(' ')
        gate_type = sp[-1]
        if gate_type in ['XOR', 'AND']:
            _, _, in1, in2, out, typ = g_txt.split(' ')
        elif gate_type == 'INV':
            _, _, in1, out, typ = g_txt.split(' ')
            in2 = -1
        else:
            raise RuntimeError('unknown gate type:', gate_type)
        gates.append(Gate(typ, int(in1), int(in2), int(out)))
    
    ins = inputs.split(' ')
    num_inputs = int(ins[0])
    
    # generate the bundles of input wires
    w = 0
    input_bundle_sizes = [int(x) for x in inputs.split(' ')[1:]]
    inputs = []
    for bundle_size in ins[1:]:
        inputs.append(list(range(w, w+int(bundle_size))))
        w += int(bundle_size)

    # generate the bundles of output wires
    output_bundle_sizes = [int(x) for x in outputs.split(' ')[1:]]
    total_output_wires = sum(output_bundle_sizes)
    w = total_wires - total_output_wires
    outputs = []
    for bundle_size in output_bundle_sizes:
        outputs.append(list(range(w, w+int(bundle_size))))
        w += int(bundle_size)
    
    return Circuit(inputs, outputs, gates)

In [33]:
print_circuit(parse_circuit(adder_txt))

inputs: [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]]
outputs: [[440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503]]
gates:
   Gate(type='XOR', in1=63, in2=127, out=376)
   Gate(type='XOR', in1=62, in2=126, out=375)
   Gate(type='XOR', in1=61

## Question 4

Implement a function `eval_circuit` for evaluating circuits.

In [None]:
def int_to_bitstring(i, n):
    return [int(x) for x in list(reversed('{0:0b}'.format(i).zfill(n)))]

def bitstring_to_int(bs):
    return sum([x*(2**i) for i, x in enumerate(bs)])

def eval_circuit(inputs, circuit):
    # YOUR CODE HERE
    raise NotImplementedError()

In [None]:
# TEST CASE
# Example: 5 + 6 = 11
circuit = parse_circuit(adder_txt)
inputs = [int_to_bitstring(5, 64), int_to_bitstring(6, 64)]
outputs = eval_circuit(inputs, parse_circuit(adder_txt))
assert [bitstring_to_int(b) for b in outputs] == [11]

In [None]:
sha256_url = "https://homes.esat.kuleuven.be/~nsmart/MPC/sha256.txt"
sha256_txt = urllib.request.urlopen(sha256_url).read().decode("utf-8")

sha256_circuit = parse_circuit(sha256_txt)

In [None]:
# Example: SHA256 hash of a bunch of 1s
test_inputs = [[1 for x in y] for y in sha256_circuit.inputs]
outputs = eval_circuit(test_inputs, sha256_circuit)
bitstring_to_int(outputs[0])

## Question 5

Sketch the BGW protocol for evaluating an arithmetic or boolean circuit with $n$ parties.

YOUR ANSWER HERE