# CS295/395: Secure Distributed Computation
## Homework 5

## Definitions

In [2]:
# Imports and definitions
import numpy as np
from collections import defaultdict
import urllib.request
import galois
from nacl.public import PrivateKey, Box, SealedBox

GF_2 = galois.GF(2) # we work in the binary field this week!

# Library for circuits
from dataclasses import dataclass

@dataclass
class Gate:
    type: str
    in1: int
    in2: int
    out: int

@dataclass
class Circuit:
    inputs: any
    outputs: any
    gates: any

class Party:
    """A participant in a multiparty computation protocol."""
    def __init__(self):
        """Initialize the field size and dictionary to hold received messages."""
        self.input = None
        self.output = None
        self.received = defaultdict(list)
    
    def send(self, other, round, msg):
        """Simulate sending a message `msg` to another party `other` during round `round`"""
        other.received[round].append(msg)

    def get_view(self):
        """Returns the view of this party: its input, output, and received messages."""
        return (self.input, self.output, dict(self.received))

# Parsing Circuits

In [3]:
import urllib.request
adder_url = "https://homes.esat.kuleuven.be/~nsmart/MPC/adder64.txt"
adder_txt = urllib.request.urlopen(adder_url).read().decode("utf-8")
sha256_url = "https://homes.esat.kuleuven.be/~nsmart/MPC/sha256.txt"
sha256_txt = urllib.request.urlopen(sha256_url).read().decode("utf-8")

In [4]:
# Parse a circuit from a Bristol-Fashion specification
def parse_circuit(bristol_fashion_text):
    lines = [l.strip() for l in bristol_fashion_text.split('\n') if l != '']
    total_wires = int(lines[0].split(' ')[1])
    inputs = lines[1]
    outputs = lines[2]
    gates_txt = lines[3:]
    gates = []
    
    # parse the gates
    for g_txt in gates_txt:
        sp = g_txt.split(' ')
        gate_type = sp[-1]
        if gate_type in ['XOR', 'AND']:
            _, _, in1, in2, out, typ = g_txt.split(' ')
        elif gate_type == 'INV':
            _, _, in1, out, typ = g_txt.split(' ')
            in2 = -1
        else:
            raise RuntimeError('unknown gate type:', gate_type)
        gates.append(Gate(typ, int(in1), int(in2), int(out)))
    
    ins = inputs.split(' ')
    num_inputs = int(ins[0])
    
    # generate the bundles of input wires
    w = 0
    input_bundle_sizes = [int(x) for x in inputs.split(' ')[1:]]
    inputs = []
    for bundle_size in ins[1:]:
        inputs.append(list(range(w, w+int(bundle_size))))
        w += int(bundle_size)

    # generate the bundles of output wires
    output_bundle_sizes = [int(x) for x in outputs.split(' ')[1:]]
    total_output_wires = sum(output_bundle_sizes)
    w = total_wires - total_output_wires
    outputs = []
    for bundle_size in output_bundle_sizes:
        outputs.append(list(range(w, w+int(bundle_size))))
        w += int(bundle_size)
    
    return Circuit(inputs, outputs, gates)

def int_to_bitstring(i, n):
    return [int(x) for x in list(reversed('{0:0b}'.format(i).zfill(n)))]

def bitstring_to_int(bs):
    return sum([int(x)*(2**i) for i, x in enumerate(bs)])

In [5]:
adder = parse_circuit(adder_txt)
sha256 = parse_circuit(sha256_txt)

In [6]:
# Compute the value of an AND gate, using all additive shares of its inputs
def S(s1_i, s1_j, s2_i, s2_j):
    return (s1_i + s2_i) * (s1_j + s2_j)

# Generate the truth table describing P2's share of an AND gate's output
def T_G(r, s1_i, s1_j):
    combinations = GF_2([(0,0), (0,1), (1,0), (1,1)])
    output_table = []
    for s2_i, s2_j in combinations:
        s2_k = r + S(s1_i, s1_j, s2_i, s2_j)
        output_table.append(s2_k)
    return output_table

# Question 1

Implement the GMW protocol.

Reference the following exercise questions:
- The definition of 1-out-of-4 Oblivious Transfer (OT) from the 10/03/2022 exercise
- The definition of the BGW protocol from the 9/26/2022 exercise
- The definition of circuit evaluation from the 9/26/2022 exercise

In [105]:
def make_additive_shares(n, x):
    """Create additive secret shares of the input, so that the shares look random but add up to the input.
    n: number of shares to create
    x: input value"""
    first_shares = GF_2.Random(n-1)
    last_share = x - first_shares.sum()
    # return first_shares, GF_2([last_share])
    return GF_2(list(first_shares) + [last_share])

class GMW_P1(Party):
    def __init__(self):
        super().__init__()
        self.is_done = False
        self.phase = 1
        self.ot_phase = 1
        self.wire_vals = {-1: None}
        self.current_gate = 0
        

    # this method implements all the rounds of the GMW protocol
    # keep track of the current phase and round
    def roundn(self, round_num, circuit, inputs, p2):
        # Each party secret shares its input bits to the other party
        # (same as BGW but with additive) and adds its own shares to the wire_vals
        # dictionary (create two shares and put one in dict and send other to other party)
        inputs = GF_2(inputs)

        if self.phase == 1:
            self.phase = 2
            
            my_input_wires = circuit.inputs[0]

            # secret share all my inputs, save one share in wire_vals, send other to p2
            for wire, value in zip(my_input_wires, inputs):
                s1, s2 = make_additive_shares(2, value)
                self.wire_vals[wire] = s1
                self.send(p2, 1, (wire, s2))
    
            pass
        elif self.phase == 2:
            self.phase = 3
            
            for wire, share in self.received[round_num-1]:
                self.wire_vals[wire] = share
                
            pass
        elif self.phase == 3:
            # EVAL SECTION
            g = circuit.gates[self.current_gate]

            if self.ot_phase == 1:
                # IF IT IS XOR OR INV STAY IN THIS PHASE
                if g.type == 'XOR':
                    self.wire_vals[g.out] = self.wire_vals[g.in1] + self.wire_vals[g.in2]
                    self.current_gate += 1
                    pass
                elif g.type == 'INV':
                    self.wire_vals[g.out] = self.wire_vals[g.in1] + GF_2(1)
                    self.current_gate += 1
                elif g.type == 'AND':
                    self.ot_phase = 2
                    # Round 1 code from AND protocol goes here
                # GENERATE PUBLIC KEYS
                pass
            elif self.ot_phase == 2:
                self.ot_phase = 3
                [pks] = self.received[round_num-1]

                s1_i = self.wire_vals[g.in1]
                s1_j = self.wire_vals[g.in2]
       
                # P1 generates a random output share r = s1_k
                # P1 calls T_G to get the truth table, using s1_i, s1_j, and r as inputs
                r = GF_2.Random()

                self.wire_vals[g.out] = r
                truth_table = T_G(r, s1_i, s1_j)
                encrypted_truth_table = []
                for pk, table_element in zip(pks, truth_table):
                    table_element_b = int(table_element).to_bytes(1, 'little')
                    enc = SealedBox(pk).encrypt(table_element_b)
                    encrypted_truth_table.append(enc)
            
                self.send(p2, round_num, encrypted_truth_table)
                pass
            elif self.ot_phase == 3:
                self.ot_phase = 1
                # Round 3 code from AND protocol goes here
                # DECRYPT TT
                self.current_gate += 1
                pass
            # move to next phase evnatually

            if self.current_gate == len(circuit.gates):
                self.phase = 4
 
            pass
        elif self.phase == 4:
            # parties broadcast their shares of the output wire values

            self.phase = 5

            self.send(p2, round_num, self.wire_vals)
            pass
        elif self.phase == 5:
            self.is_done = True
            
            [other_wire_vals] = self.received[round_num-1]

            self.output = []
            for wire in circuit.outputs[0]:
                self.output.append(self.wire_vals[wire] + other_wire_vals[wire])

            return self.output
            
                

class GMW_P2(Party):
    def __init__(self):
        super().__init__()
        self.is_done = False
        self.phase = 1
        self.ot_phase = 1
        self.wire_vals = {-1: None}
        self.current_gate = 0

    def roundn(self, round_num, circuit, inputs, p1):
        inputs = GF_2(inputs)

        if self.phase == 1:
            self.phase = 2
            
            my_input_wires = circuit.inputs[1]

            # secret share all my inputs, save one share in wire_vals, send other to p2
            for wire, value in zip(my_input_wires, inputs):
                s1, s2 = make_additive_shares(2, value)
                self.wire_vals[wire] = s1
                self.send(p1, 1, (wire, s2))
    
            pass
        elif self.phase == 2:
            self.phase = 3
            
            for wire, share in self.received[1]:
                self.wire_vals[wire] = share

        elif self.phase == 3:
            g = circuit.gates[self.current_gate]

            # EVAL SECTION
            if self.ot_phase == 1:
                # IF IT IS XOR OR INV STAY IN THIS PHASE
                if g.type == 'XOR':
                    self.wire_vals[g.out] = self.wire_vals[g.in1] + self.wire_vals[g.in2]
                    self.current_gate += 1
                    pass
                elif g.type == 'INV':
                    self.wire_vals[g.out] = self.wire_vals[g.in1] + GF_2(1)
                    self.current_gate += 1
                elif g.type == 'AND':
                
                    self.ot_phase = 2

                    s2_i = self.wire_vals[g.in1]
                    s2_j = self.wire_vals[g.in2]

                    keypair1 = PrivateKey.generate() # keep this one
                    keypair2 = PrivateKey.generate() # throw this one away after this round
                    keypair3 = PrivateKey.generate() # throw this one away after this round
                    keypair4 = PrivateKey.generate() # throw this one away after this round

                    self.saved_key = keypair1
                
                    if s2_i == 0 and s2_j == 0:
                        self.send(p1, round_num, (keypair1.public_key,
                                            keypair2.public_key,
                                            keypair3.public_key,
                                            keypair4.public_key))
                    elif s2_i == 0 and s2_j == 1:
                        self.send(p1, round_num, (keypair2.public_key,
                                            keypair1.public_key,
                                            keypair3.public_key,
                                            keypair4.public_key))
                    elif s2_i == 1 and s2_j == 0:
                        self.send(p1, round_num, (keypair3.public_key,
                                            keypair2.public_key,
                                            keypair1.public_key,
                                            keypair4.public_key))
                    elif s2_i == 1 and s2_j == 1:
                        self.send(p1, round_num, (keypair4.public_key,
                                            keypair2.public_key,
                                            keypair3.public_key,
                                            keypair1.public_key))
                
            elif self.ot_phase == 2:
                self.ot_phase = 3
                pass
            elif self.ot_phase == 3:
                self.ot_phase = 1

                s2_i = self.wire_vals[g.in1]
                s2_j = self.wire_vals[g.in2]
                
                [(c1, c2, c3, c4)] = self.received[round_num-1]
    
                if s2_i == 0 and s2_j == 0:
                    plaintext = SealedBox(self.saved_key).decrypt(c1)
                elif s2_i == 0 and s2_j == 1:
                    plaintext = SealedBox(self.saved_key).decrypt(c2)
                elif s2_i == 1 and s2_j == 0:
                    plaintext = SealedBox(self.saved_key).decrypt(c3)
                elif s2_i == 1 and s2_j == 1:
                    plaintext = SealedBox(self.saved_key).decrypt(c4)
            
                share = GF_2(int.from_bytes(plaintext, 'little'))
                self.wire_vals[g.out] = share
                        # move to next phase evnatually
                self.current_gate += 1
            pass


            if self.current_gate == len(circuit.gates):
                self.phase = 4

        elif self.phase == 4:
            # parties broadcast their shares of the output wire values

            self.phase = 5

            self.send(p1, round_num, self.wire_vals)
            pass
        elif self.phase == 5:
            self.is_done = True
            
            [other_wire_vals] = self.received[round_num-1]

            self.output = []
            for wire in circuit.outputs[0]:
                self.output.append(self.wire_vals[wire] + other_wire_vals[wire])

            return self.output
              

In [106]:
# Driver function for the protocol
def run_gmw(circuit, p1_input, p1_bitwidth, p2_input, p2_bitwidth):
    p1_inputs = int_to_bitstring(p1_input, p1_bitwidth)
    p2_inputs = int_to_bitstring(p2_input, p2_bitwidth)

    p1 = GMW_P1()
    p2 = GMW_P2()

    round_num = 1
    while not p1.is_done and not p2.is_done:
        p1.roundn(round_num, circuit, p1_inputs, p2)
        p2.roundn(round_num, circuit, p2_inputs, p1)
        round_num += 1
#     print('P1 output:', GF_2(p1.output))
#     print('P2 output:', GF_2(p2.output))

#     print('P1 output (int):', bitstring_to_int(p1.output))
#     print('P2 output (int):', bitstring_to_int(p2.output))
    
    return bitstring_to_int(p1.output), bitstring_to_int(p2.output)

In [107]:
## ADDER TEST CASE
for _ in range(10):
    n1 = np.random.randint(0, 1000)
    n2 = np.random.randint(0, 1000)
    
    o1, o2 = run_gmw(adder, n1, 64, n2, 64)
    assert o1 == o2 == n1 + n2, f'Mismatch! Inputs {n1}, {n2}, outputs {o1}, {o2}'

In [108]:
### SHA256 TEST CASE
### Warning: takes about a minute to run

o1, o2 = run_gmw(sha256, 1, 512, 2, 256)
assert o1 == o2 == 62635937818952219496566001010706647480343244544051980721954351996715678910351