# CS295/395: Secure Distributed Computation
## In-Class Exercise, week of 9/26/2022

In [41]:
# For later this week:
from nacl.public import PrivateKey, Box, SealedBox

# PyNaCl is a library for (traditional) encryption
# It is easiest to install using: `conda install pynacl`
# It can also be installed using: `pip install pynacl`
# but the conda version is more likely to work cleanly.
# See documentation here: https://pynacl.readthedocs.io/en/latest/

In [42]:
# Imports and definitions
import numpy as np
from collections import defaultdict
import numpy as np
import galois
GF = galois.GF(2**13 - 1)

# Library for circuits
from dataclasses import dataclass

@dataclass
class Gate:
    type: str
    in1: int
    in2: int
    out: int

@dataclass
class Circuit:
    inputs: any
    outputs: any
    gates: any
        
def print_circuit(c):
    print('inputs:', c.inputs)
    print('outputs:', c.outputs)
    print('gates:')
    for g in c.gates:
        print('  ', g)

## Party Class and Shamir sharing

In [43]:
class Party:
    """A participant in a multiparty computation protocol."""
    def __init__(self):
        """Initialize the field size and dictionary to hold received messages."""
        self.input = None
        self.output = None
        self.received = defaultdict(list)
    
    def send(self, other, round, msg):
        """Simulate sending a message `msg` to another party `other` during round `round`"""
        other.received[round].append(msg)

    def get_view(self):
        """Returns the view of this party: its input, output, and received messages."""
        return (self.input, self.output, dict(self.received))

# Generate Shamir shares for secret v with threshold t and number of shares n
def shamir_share(v, t, n):
    coefficients = GF([GF.Random() for _ in range(t-1)] + [v])
    poly = galois.Poly(coefficients)
    shares = [(GF(x), poly(GF(x))) for x in range(1, n+1)]
    return shares

# Reconstruct the secret from at least t Shamir shares
def reconstruct(shares):
    xs = GF([s[0] for s in shares])
    ys = GF([s[1] for s in shares])
    poly = galois.lagrange_poly(xs, ys)
    #print(poly)
    secret = poly(0)
    
    return secret

## Question 1

Implement a function `sum_sq_circuit` that returns the sum and the squared sum of a list of numbers.

In [44]:
def sum_sq_circuit(n):

    inputs = [[i] for i in range(n)]
    input_wires = [i for i in range(n)]

    total = input_wires[0]

    w = n
    gates = []
    for i in input_wires[1:]:
        g = Gate("ADD", total, i, w)
        total = w
        w += 1
        gates.append(g)

    squared_sum_wire = w
    gates.append(Gate("MULT", total, total, squared_sum_wire))

    circuit = Circuit(inputs, [total, squared_sum_wire], gates)
    return circuit

print(sum_sq_circuit(6))

Circuit(inputs=[[0], [1], [2], [3], [4], [5]], outputs=[10, 11], gates=[Gate(type='ADD', in1=0, in2=1, out=6), Gate(type='ADD', in1=6, in2=2, out=7), Gate(type='ADD', in1=7, in2=3, out=8), Gate(type='ADD', in1=8, in2=4, out=9), Gate(type='ADD', in1=9, in2=5, out=10), Gate(type='MULT', in1=10, in2=10, out=11)])


## Question 2

Implement a function `eval_circuit` for evaluating circuits.

In [45]:
def eval_circuit(inputs, circuit):
    # this hsould output correct output values for circuit based on correct input values

    # create a dictuonary to store the values of the wires
    wire_vals = {}

    for part_input_values, party_wires in zip(inputs, circuit.inputs):
        for input_value, wire in zip(part_input_values, party_wires):
            wire_vals[wire] = GF(input_value)
    print(wire_vals)

    for gate in circuit.gates:
        v1 = wire_vals[gate.in1]
        v2 = wire_vals[gate.in2]

        if gate.type == "ADD":
            wire_vals[gate.out] = v1 + v2
        elif gate.type == "MULT":
            wire_vals[gate.out] = v1 * v2
        else:
            print("unknown gate type")
    
    return [wire_vals[wire] for wire in circuit.outputs]

    

In [46]:
# TEST CASE
# Example: sum of 0 to 6 = 15
circuit = sum_sq_circuit(6)
inputs = [[i + 10] for i in range(6)]
outputs = eval_circuit(inputs, circuit)
print(outputs)
# assert outputs == [GF(15), GF(225)]

{0: GF(10, order=8191), 1: GF(11, order=8191), 2: GF(12, order=8191), 3: GF(13, order=8191), 4: GF(14, order=8191), 5: GF(15, order=8191)}
[GF(75, order=8191), GF(5625, order=8191)]


## Question 3

Sketch the BGW protocol for evaluating an arithmetic or boolean circuit with $n$ parties.

- wire_vals will now map wire numbers to *one share* of a wire's value
- ADD will work normally
- MULT will require mult and then degree reduction
- Broadcast shares of output wires and reconstruct output values

# Approach 

each party maintinas a dictionary mapping wires

- Round 1 each party P_i generates shamir share of each of its secret inputs, sends one share to each party
- Round 2 each party P_i receives shares of its inputs from each party and initializes the wire_vals dict
- Round n
    - evaluarte the next gate in the circuit
    - if add, add the values of the two inputs and store in the output wire
    - if mult, multiply the values of the two inputs and store in the output wire, perform degree reduction
- #round n+1
    - each party P_i broadcasts its shares of output wires
- round n+2
    - each party P_i reconstructs the output wires from the shares it received

## Question 4

Implement the BGW protocol.

In [47]:
class BGWParty(Party):
    def round1(self, parties, circuit, my_inputs):
        self.parties = parties
        self.is_done = False
        self.circuit = circuit
        n = len(parties)
        t = int(n/2)

        # Round 1 (phase 1): Each party P_i create n Shamir shares of ech of its secret inputs ,
        # and sends one share to each other party
        my_id = parties.index(self)
        my_input_wires = circuit.inputs[my_id]
        # print(f"party num {my_id} my input wires {my_inputs}")

        # input shares will map each party to the shares of my inputs destined for that party
        input_shares = {p: {} for p in parties}

        for wire, value in zip(my_input_wires, my_inputs):
            shares = shamir_share(value, t, n)
            for p, s in zip(parties, shares):
                input_shares[p][wire] = s

        for p in parties:
            self.send(p, 1, input_shares[p])


    def round2(self, my_id):
        self.wire_vals = {}

        # Round 2 (phase 2): EAch party recieves one share for
        # each input wire and initializs the wire_vals dict
        received_shares = self.received[1]

        for received_dict in received_shares:
            for key, value in received_dict.items():
                self.wire_vals[key] = value

        self.phase = 3
        self.current_gate = 0
        self.need_degree_reduction = False

    def roundn(self, round_num):
        n = len(self.parties)
        t = int(n/2)

        if self.need_degree_reduction:
            # finish the degree reduction and
            # update wire_vals
            # - each party $P_i$ receives shares $h_j^i$
            h_j_is = self.received[round_num - 1]
            h_js_is_y = [s[1] for s in h_j_is]

            V_a = GF(np.vander(range(1, n+1), increasing=True))
            V_a_inv = np.linalg.inv(V_a)
            lambda_js = V_a_inv[0] # first row

            prods = [lambda_j * s for lambda_j, s in zip(lambda_js, h_js_is_y)]

            #
            g = self.circuit.gates[self.current_gate]
            self.wire_vals[g.out] = \
                (self.x_coord, GF(prods).sum())

            self.current_gate += 1
            self.need_degree_reduction = False

        if self.current_gate >= len(self.circuit.gates) and self.phase == 3:
            self.phase = 4

        if self.phase == 3:
            # Evaluate the next gate in the circuit
            # If it is an ADD gate, look up the shares of its input sin the dict and add them together,then update the dict to map its output to the resulting share
            # If it is a MULT gate then, look up the shares of its input sin the dict and multiply them together,then perform degree reduction
            g = self.circuit.gates[self.current_gate]


            x1, y1 = self.wire_vals[g.in1] # lookup the value of the first input
            x2, y2 = self.wire_vals[g.in2]
            assert x1 == x2

            if g.type == 'ADD':
                self.wire_vals[g.out] = (x1, y1 + y2)
                self.current_gate += 1

            elif g.type == 'MULT':
                mult_result = y1 * y2
                # remember this value
                # setup the degree reduction
                # next time I enter `roundn` function,
                # finish the degree reduction and
                # update wire_vals
                self.x_coord = x1
                self.need_degree_reduction = True

                h_i_js = shamir_share(mult_result, t, n)
                for party, share, in zip(self.parties, h_i_js):
                    self.send(party, round_num, share)


        elif self.phase == 4:
            # Round k: (phase 4) When all gates have been evaluated
            # Each party P_u broadcasts its shares of output wires
            output_wires = self.circuit.outputs
            output_shares = [self.wire_vals[w] for w in output_wires]

            for p in self.parties:
                self.send(p, round_num, output_shares)

            self.phase = 5

        elif self.phase == 5:
            # Round k+1: (phase 5) each party receives n shares of each output wire value, reconstructs each wire’s actual value, and outputs the values
            received_shares = self.received[round_num - 1]

            output_shares = [ [] for _ in self.circuit.outputs]
            # arrange the shares
            for shares in received_shares:
                # shares received from a single party p_i
                for j, wire_share in enumerate(shares):
                    # this is the share for wire j
                    output_shares[j].append(wire_share)

            # do the reconstruction
            output_vals = []
            for shares in output_shares:
                output_vals.append(reconstruct(shares))

            self.output = output_vals

            self.is_done = True


In [48]:
def run_bgw_protocol():
    NUM_PARTIES = 6
    n = NUM_PARTIES
    
    circuit = sum_sq_circuit(6)
    
    inputs = [[i] for i in range(6)]
    print('Inputs:', inputs)
    parties = [BGWParty() for _ in range(NUM_PARTIES)]
    
    for p, i in zip(parties, inputs):
        p.round1(parties, circuit, i)
    for p in parties:
        p.round2(parties)
    round_num = 3

    while not parties[0].is_done:
        for p in parties:
            p.roundn(round_num)       
        round_num += 1
        
    for p in parties:
        print('Output:', p.output)

    outputs = [p.output for p in parties]
    return outputs

In [49]:
# TEST CASE
outputs = run_bgw_protocol()
for o in outputs:
    assert o == [GF(15), GF(225)]

Inputs: [[0], [1], [2], [3], [4], [5]]
Output: [GF(15, order=8191), GF(225, order=8191)]
Output: [GF(15, order=8191), GF(225, order=8191)]
Output: [GF(15, order=8191), GF(225, order=8191)]
Output: [GF(15, order=8191), GF(225, order=8191)]
Output: [GF(15, order=8191), GF(225, order=8191)]
Output: [GF(15, order=8191), GF(225, order=8191)]


# GMW
- binary circuits
- two parties
- uses oblivious transfer
- uses additive secret sharing
  - Additive homomorphism
  - No multiplicative homomorphism


# OT
- S has two secrets
- R wants to select one to recieve
  - R does not learn the other secret
  - S does not learn which was selected

# Public Key Encryption
- Anyone can encrypt something with the public key
- Need secret key to decrypt
- is this like $f(x)$ is public and $f^{-1}(x)$ is secret

In OT, S encrypts both secrets with two public keys, but R only has one private key and S doesnt know which. R can decrypt only one.
R generates two key-pairs, sends the public keys to the sender, then R "forgets" one of the secret keys.

## Question 5

Describe the 1-out-of-2 *oblivious transfer* (OT) protocol. Reference Section 3.7 in Pragmatic MPC.

Ideal functionality:
- R has a secret selection bit b, S has two secrets X1 and X2
- R recieves x1 if b=0 else x2
- S receives nothing

## Question 6

Why is the oblivious transfer protocol secure against semi-honest adversaries? Why is it not secure against malicious adversaries?

YOUR ANSWER HERE

## Question 7

Implement 1-out-of-2 OT.

In [None]:
class OT_Sender(Party):
    # x1 and x2 are the secrets
    def round1(self, x1, x2, receiver):
        self.x1 = x1
        self.x2 = x2
        self.receiver = receiver

    def round2(self):
        # YOUR CODE HERE
        raise NotImplementedError()
    
    def round3(self):
        pass

class OT_Receiver(Party):
    def round1(self, b, sender):
        self.sender = sender
        self.b = b
        # YOUR CODE HERE
        raise NotImplementedError()
    
    def round2(self):
        pass
    
    def round3(self):
        # YOUR CODE HERE
        raise NotImplementedError()

In [None]:
# TEST CASE
GF_2 = galois.GF(2)

sender = OT_Sender()
receiver = OT_Receiver()

# Round 1
sender.round1(GF_2(0), GF_2(1), receiver)
receiver.round1(GF_2(1), sender)

# Round 2
sender.round2()
receiver.round2()

# Round 3
sender.round3()
output = receiver.round3()

print("Receiver's output:", output)
assert output == 1

## Question 8

Describe 1-out-of-4 OT.

YOUR ANSWER HERE

## Question 9

Describe the GMW protocol for evaluating a binary circuit.

YOUR ANSWER HERE