There are two parties G (garbler) and E (evaluator)

main idea:
 1. G constructs a garbled table for each gate in the circuit
 2. G sends the garbled tables to E
 3. E evaluates the whole circuit using garbled tables

How to build a garbled table for a gate:
 1. Buid a truth table for a gate
 2. Fill in the wire labels for the gate's inputs and outputs, using the label corresponding to the appropriate values in the truth table
 3. Fill in the garbled table by encrypting the wire label for the output in each row using the key = (label for input 1, label for input 2)
 4. Shuffle the table and send only last column

How to evaluate a gate in the circuit:
 1. E knows the current labels for the two input wires (but doesnt know the value)
 2. E tries to decrypt each of the rows of the corresponding garbled table using the key = (enc(l1), enc(l2)). Only one of the rows will successfully decrypt. the result will be enc(l3), the current label of the gates output wire
 3. E uses enc(l3) to continue evaluating the circuit

How does the evaluator learn current labels of circuit's input wires?
 - Inputs known to G: G just sends the current label to E
 - Inputs known to E: use oblivious transfer
   - Sender (G) inputs two possible labels for the input wire
   - receiver (E) inputs value of the input wire as a select bit
   - E receives current label for the input wire, G learns nothing

How do we get values of output wires using the current labels of the output wires?
 1. E sends labels of circuit output wires to G
 2. G looks up values corresponding to those labels, and sends them to E

What kind of encryption do we use?
 1. Must be randomized (can use keys more than once) symmetric key encryption
 2. Must be able to encrypt things using a combination of two keys
   - This implementation decrypts in a specific order using two symmetric keys, so care is taken in the encryption and process.

In [12]:
# Imports and definitions
from collections import defaultdict
import numpy as np
import galois

import nacl.secret
import nacl.utils

# For oblivious Transfer
from nacl.public import PrivateKey, Box, SealedBox

import random # for shuffling lists

import urllib.request

GF_2 = galois.GF(2) # for working with circuits

# Library for circuits
from dataclasses import dataclass

@dataclass
class Gate:
    type: str
    in1: int
    in2: int
    out: int

@dataclass
class Circuit:
    inputs: any
    outputs: any
    gates: any
        
def print_circuit(c):
    print('inputs:', c.inputs)
    print('outputs:', c.outputs)
    print('gates:')
    for g in c.gates:
        print('  ', g)

In [13]:
import urllib.request
adder_url = "https://homes.esat.kuleuven.be/~nsmart/MPC/adder64.txt"
adder_txt = urllib.request.urlopen(adder_url).read().decode("utf-8")

In [14]:
# Parse a circuit from a Bristol-Fashion specification
def parse_circuit(bristol_fashion_text):
    lines = [l.strip() for l in bristol_fashion_text.split('\n') if l != '']
    total_wires = int(lines[0].split(' ')[1])
    inputs = lines[1]
    outputs = lines[2]
    gates_txt = lines[3:]
    gates = []
    
    # parse the gates
    for g_txt in gates_txt:
        sp = g_txt.split(' ')
        gate_type = sp[-1]
        if gate_type in ['XOR', 'AND']:
            _, _, in1, in2, out, typ = g_txt.split(' ')
        elif gate_type == 'INV':
            _, _, in1, out, typ = g_txt.split(' ')
            in2 = -1
        else:
            raise RuntimeError('unknown gate type:', gate_type)
        gates.append(Gate(typ, int(in1), int(in2), int(out)))
    
    ins = inputs.split(' ')
    num_inputs = int(ins[0])
    
    # generate the bundles of input wires
    w = 0
    input_bundle_sizes = [int(x) for x in inputs.split(' ')[1:]]
    inputs = []
    for bundle_size in ins[1:]:
        inputs.append(list(range(w, w+int(bundle_size))))
        w += int(bundle_size)

    # generate the bundles of output wires
    output_bundle_sizes = [int(x) for x in outputs.split(' ')[1:]]
    total_output_wires = sum(output_bundle_sizes)
    w = total_wires - total_output_wires
    outputs = []
    for bundle_size in output_bundle_sizes:
        outputs.append(list(range(w, w+int(bundle_size))))
        w += int(bundle_size)
    
    return Circuit(inputs, outputs, gates)

def int_to_bitstring(i, n):
    return [int(x) for x in list(reversed('{0:0b}'.format(i).zfill(n)))]

def bitstring_to_int(bs):
    return sum([int(x)*(2**i) for i, x in enumerate(bs)])

In [15]:
adder = parse_circuit(adder_txt)

In [16]:
def eval_circuit(inputs, circuit):
    wire_vals = {}
    # first set values on input wires
    for input_bundle, input_bundle_vals in zip(circuit.inputs, inputs):
        for wire, val in zip(input_bundle, input_bundle_vals):
            wire_vals[wire] = val
    for g in circuit.gates:
        x1 = wire_vals[g.in1]
        x2 = wire_vals[g.in2]
        if g.type == 'XOR':
            wire_vals[g.out] = x1 ^ x2 # python bitwise
        elif g.type == 'AND':
            wire_vals[g.out] = x1 & x2 # another bitwise computation
        else:
            raise Exception('unknown gate', g)
    # get the outputs
    outputs = []
    for output_bundle in circuit.outputs:
        out_bundle = [wire_vals[x] for x in output_bundle]
        outputs.append(out_bundle)
    return outputs

In [17]:
class Party:
    """A participant in a multiparty computation protocol."""
    def __init__(self):
        """Initialize the field size and dictionary to hold received messages."""
        self.input = None
        self.output = None
        self.received = defaultdict(list)
    
    def send(self, other, round, msg):
        """Simulate sending a message `msg` to another party `other` during round `round`"""
        other.received[round].append(msg)

    def get_view(self):
        """Returns the view of this party: its input, output, and received messages."""
        return (self.input, self.output, dict(self.received))

In [18]:
# creates a truth table for use in Yao's garbled circuits protocol

# takes in a gate dataclass and wire_keys dictionary
def create_truth_table(gate, wire_keys):
    combinations = [(0, 0), (0, 1), (1, 0), (1, 1)]
    
    table = []
    for (wi, wj) in combinations:
        if gate.type == 'XOR':
            # print(gate.in1, gate.in2, wi, wj)
            box1 = nacl.secret.SecretBox(wire_keys[gate.in1][wi])
            box2 = nacl.secret.SecretBox(wire_keys[gate.in2][wj])

            # encrypt the output value's key using keys from input1 and input2 keys
            encrypted_out_key = box2.encrypt(box1.encrypt(wire_keys[gate.out][wi ^ wj]))
            table.append(encrypted_out_key)

        elif gate.type == 'AND':

            box1 = nacl.secret.SecretBox(wire_keys[gate.in1][wi])
            box2 = nacl.secret.SecretBox(wire_keys[gate.in2][wj])

            # encrypt the output value's key using keys from input1 and input2 keys
            encrypted_out_key = box2.encrypt(box1.encrypt(wire_keys[gate.out][wi & wj]))
            table.append(encrypted_out_key)

    random.shuffle(table) # shuffle table so no knowledge is leaked as to what the wires are
    return table

In [19]:
class Garbler(Party):
    def __init__(self):
        super().__init__()
    
    def round1(self, evaluator, circuit, my_inputs):
        
        self.evaluator = evaluator
        self.circuit = circuit
        self.my_inputs = my_inputs
        self.wire_keys = {-1: None}
        
        # G first creates 2 labels (encryption keys) for each wire
        total_wires = self.circuit.outputs[-1][-1] # Last wire number
        
        for wire_num in range(total_wires + 1):
            
            bit0key = nacl.utils.random(nacl.secret.SecretBox.KEY_SIZE)
            bit1key = nacl.utils.random(nacl.secret.SecretBox.KEY_SIZE)
            
            self.wire_keys[wire_num] = (bit0key, bit1key)
        
        # For each gate, create a truth table for each gate, 
        # send the last column shuffled
        for gate in self.circuit.gates:
            G = create_truth_table(gate, self.wire_keys)
            self.send(evaluator, 1, G)
    
    def round2(self):
        # Garbler prepares the dictionary of input values for sending to evaluator
        # for the wires it has as it's inputs
        wire_vals = {-1: None}
        
        for wire_num, wire_value in zip(self.circuit.inputs[0], self.my_inputs):
            # each entry in the wire_keys is a tuple (KEY0, KEY1)
            # so look up the wire_num and grab the key at position wire_value
            # which is 0 or 1
            wire_vals[wire_num] = self.wire_keys[wire_num][wire_value]
            
        # garbler receives keys for the wire values of evaluator
        keys = self.received[1][0] # received[1] is a list of everything received so i need to index to 0 to get dict
        
        # keys should just contain a dictionary
        # {wire_num: [public_key for 0, public_key for 1]}
        
        encrypted_keys = {-1: None}
        
        # for each secret_key we will encrypt using the appropriate public_key
        # sending these to the evaluator who can then use them for evaluation
        
        # print(self.circuit.inputs[1])
        
        for wire_num in self.circuit.inputs[1]:
            
            pub_keys = keys[wire_num]
            encrypted_bit0key = SealedBox(pub_keys[0]).encrypt(self.wire_keys[wire_num][0])
            encrypted_bit1key = SealedBox(pub_keys[1]).encrypt(self.wire_keys[wire_num][1])
            encrypted_keys[wire_num] = (encrypted_bit0key, encrypted_bit1key)
        
        # key_bundle = [wire_vals, encrypted_keys]
        self.send(self.evaluator, 2, wire_vals)
        self.send(self.evaluator, 2, encrypted_keys)
        
    def round3(self):
        pass

    def round4(self):
        # loop through the outputs of circuit sent by evaluator
        self.output = []

        for wire_num, encrypted_bit in zip(self.circuit.outputs[0], self.received[3]):
            # get the keys
            keys = self.wire_keys[wire_num]

            # keys should have an entry for 0 and an entry for 1
            bit = keys.index(encrypted_bit)

            self.output.append(bit)

        # this parties output is broadcast
            

class Evaluator(Party):
    def __init__(self):
        super().__init__()
        
    def round1(self, garbler, circuit, my_inputs):
        self.garbler = garbler
        self.circuit = circuit
        self.my_inputs = my_inputs
        
        # send the garbler public keys for all inputs
        wire_keys_OT = {-1: None} # make sure not to save all of these
        self.actual_keys = {-1: None}
        
        for wire_value, wire_num in zip(my_inputs, self.circuit.inputs[1]):
            # my_input should be in terms of bits
            key_pairs = [PrivateKey.generate() for _ in range(2)]
            actual_key_pair = key_pairs[wire_value] # store the correct keypair for Evaluators inputs
            
            wire_keys_OT[wire_num] = [k.public_key for k in key_pairs]
            
            # E only saves keys to decrypt the correct wire value
            self.actual_keys[wire_num] = actual_key_pair
        
        # send the oblivious transfer keys to Evaluator for setting 
        # up the initial wire vals
        self.send(garbler, 1, wire_keys_OT)
        
    def round2(self):
        # nothing for evaluator to do in round 2 as it's waiting for inputs from round 1
        pass
    
    def round3(self):
        # Evaluator receives both starting conditions for garblers 
        # input wires and encrypted versions of its own
        # start a dictionary to keep track of the labels of each wire
        
        self.wire_labels = {-1: None}
        
        # print(self.received[2])
        
        garbler_labels, encrypted_evaluator_labels = self.received[2]
        
        # go through and set the wire_labels for the garblers inputs
        
        for wire_num in self.circuit.inputs[0]:
            self.wire_labels[wire_num] = garbler_labels[wire_num]
        
        # go through and decrypt the wire_label for evaluator's 
        # (this party) input wires
        
        for wire_num, wire_value in zip(self.circuit.inputs[1], self.my_inputs):
            
            # print(self.actual_keys[wire_num])
            
            pk = self.actual_keys[wire_num]
            # pk = key_pair.private_key
            
            encrypted_label = encrypted_evaluator_labels[wire_num][wire_value]
            plaintext = SealedBox(pk).decrypt(encrypted_label)
            self.wire_labels[wire_num] = plaintext   
        
        # With all the wire labels initiated we can start moving through and evaluating each gate
        # self.received[1] should be the truth tables for all the gates
        
        encrypted_gates = self.received[1]
        
        # enc_gate looks like [key1, key2, key3, key4] whichever one decrypts from both inputs to the gate is the correct output key
        # yo decrypt using input2 then input1
    
        for enc_gate, gate in zip(encrypted_gates, self.circuit.gates):
            # in_1 and in_2 should be encrypted keys 
            in_1 = self.wire_labels[gate.in1]
            in_2 = self.wire_labels[gate.in2]
            
            box1 = nacl.secret.SecretBox(in_1)
            box2 = nacl.secret.SecretBox(in_2)
            
            # find the correct row to decrypt
            for enc_row in enc_gate:
                try:
                    out = box1.decrypt(box2.decrypt(enc_row))
                except:
                    pass
            
            # out should only be set as the correct output key value
            self.wire_labels[gate.out] = out
        
        # Evaluator has set all the outputs of all gates and circuit can be computed
        for output_wire in self.circuit.outputs[0]:
            self.send(self.garbler, 3, self.wire_labels[output_wire])
        
    def round4(self):
        pass

In [20]:
# Driver function for the protocol

def run_yao(circuit, p1_input, p1_bitwidth, p2_input, p2_bitwidth):
    p1_inputs = int_to_bitstring(p1_input, p1_bitwidth)
    p2_inputs = int_to_bitstring(p2_input, p2_bitwidth)
    
    # print(p1_inputs)
    # print(p2_inputs)
    
    p1 = Garbler()
    p2 = Evaluator()
    
    ## run round1
    p1.round1(p2, circuit, p1_inputs)
    p2.round1(p1, circuit, p2_inputs)
    
    ## run round2
    p1.round2()
    p2.round2()
    
    ## run round3
    p1.round3()
    p2.round3()
    
    ## run round4
    p1.round4()
    p2.round4()
    
    return bitstring_to_int(p1.output)

In [21]:
## ADDER TEST CASE
n1 = np.random.randint(0, 1000)
n2 = np.random.randint(0, 1000)
    
output = run_yao(adder, n1, 64, n2, 64)
assert output == (n1+n2)