In [1]:
from tinysmpc import VirtualMachine, PrivateScalar, SharedScalar
from dpf import NaiveDPF
import galois

# Optimised Private 2 Party State Machine

The protocol that will be described in this notebook is an optimised version of the [Naive Private State Machine](./naive_private_state_machine.ipynb). The optimisation centers around using a [distributed point function](https://en.wikipedia.org/wiki/Distributed_point_function) (DPF) to reduce the number of communication rounds.

Although we are using a DPF the same privacy properties are maintained. Therefore for this optimised protocol "privacy" is defined as:
- The **servers** do not learn the **input** of the client
- The **servers and client** do not know the **current state** of the machine (apart from the initial state)

## General Protocol Overview

### Preprocessing: Convert Regex to Arithmetic Circuit
1. Regex is converted into a [DFA](https://en.wikipedia.org/wiki/Deterministic_finite_automaton) (state machine)
2. DFA is converted into a polynomial equation over a finite field (or ring)
3. 2D polynomial describing the DFA is converted into a 1D polynomial for each input token
4. Polynomial converted into arithmetic circuit
5. Arithmetic circuit is given each server

<!-- TODO: add diagram -->

### Evaluation: Client Inputs Token
1. Client inputs a token (e.g. a character)
2. The token is converted into a distributed point function and keys are distributed to servers ($k_0$ shared with $server_0$ and $k_1$ shared with $server_1$)
3. Servers evaluate all arithmetic circuits using their own $DFP.Eval(k, i)$ as well as their share of the current state
4. If more input go to step 1 
5. If evaluation is complete combine the results from the servers to get the final state

<!-- TODO: add diagram -->

## STEP 3: Convert 2D Polynomial into 1D Polynomial for Each Input Token

$$f(x, y) = 2 + 1y + 10922x + 54602xy + 65520x^2 + 65518x^2y + 54601x^3 + 10921x^3y$$

$f_i(x)$ is the state transition function where $i$ is the input token. 

Essentially we choose the specific polynomial depending on the input token.

### State Transition for 'a'

For input 'b' we have the following state transition table:

| **STATE** | **NEXT STATE** |
| --------- | -------------- |
| 0         | 2              |
| 1         | 3              |
| 2         | 3              |
| 3         | 3              |

Converting into a polynomial we have:

$$f_0(x) = 2 + 10922x + 65520x^2 + 54601x^3$$

### State Transition for 'b'

For input 'b' we have the following state transition table:

| **STATE** | **NEXT STATE** |
| --------- | -------------- |
| 0         | 3              |
| 1         | 3              |
| 2         | 1              |
| 3         | 3              |

Converting into a polynomial we have:

$$f_1(x) = 3 + 3x + 65517x^2 + x^3$$


<!-- 
TODO: add

$$DPF.Eval(k_0, {i'}) * f_{i'}(x_0) + DPF.Eval(k_0, i) * f_i(x_0)$$
$$DPF.Eval(k_1, {i'}) * f_{i'}(x_1) + DPF.Eval(k_1, i) * f_i(x_1)$$

$$0 * f_{i'}(x) + 1 * f_i(x) = f_i(x)$$ -->


In [2]:
PRIME = 65521
GF = galois.GF(PRIME)


def tf_0(s: int):
    """
    Transition function for the regex example 'ab' over a Galois ring where p = 65521. 
    This function is only for the input 'a'.

    (STATE, 'a') -> NEXT STATE

    Args:
        s : The current state
    """
    return 2 + 10922*s + 65520*(s**2) + 54601*(s**3)


def tf_1(s: int):
    """
    Transition function for the regex example 'ab' over a Galois ring where p = 65521. 
    This function is only for the input 'b'.
    
    (STATE, 'b') -> NEXT STATE

    Args:
        s : The current state
    """
    return 3 + 3*s + 65517*(s**2) + (s**3)


def tf(s: int, t: int):
    """
    Transition function for the regex example 'ab' over a Galois ring where p = 65521. 
    This function is for both inputs 'a' and 'b'.

    (STATE, 'a') -> NEXT STATE
    (STATE, 'b') -> NEXT STATE

    Args:
        s : The current state
        t : The input token (0 for 'a' and 1 for 'b')
    """
    if t == 0: return tf_0(s)
    elif t == 1: return tf_1(s)


def token_to_int(token: str):
    """Converts a token to an integer so that it can be read by the state machine
    
    Args:
        token : The token to convert
    """
    if token == 'a':
        return 0
    elif token == 'b':
        return 1
    else:
        raise ValueError('Invalid token')


def eval_state_machine(token, shared_state, tf_list, PRIME, alice, bob, charlie):
    """Evaluates the state machine using the given transition function
    
    Args:
        token(str) : The token to evaluate
        shared_state : The current shared state of the state machine
        tf_list : List of state transition functions where the index is the input token
        PRIME : The prime used for the Galois ring
        alice : Server 0
        bob : Server 1
        charlie : The client sending the token
    """
    
    DPF = NaiveDPF(galois.GF(PRIME), len(tf_list))
    
    # ASSUMPTION: These keys are sent to the servers Alice and Bob
    k_0, k_1 = DPF.gen_keys(
        x=token_to_int(token), 
        y=GF(1),
    )
    
    shared_ns = PrivateScalar(0, alice).share([alice, bob], PRIME) # next state
    for i in range(len(tf_list)):
        # Select transition function from list of transition functions
        tf_i = tf_list[i]

        # Evaluate the transition function using the previous shared state
        # TODO: when using a ring of 2^n elements, this can be sped up by precomputing x, x^2, x^3, ..., x^n
        # TODO: and then just using the precomputed values
        shared_ns_i = tf_i(shared_state)

        # Evaluate the DPF for input token `i`
        eval_dpf_bob = PrivateScalar(int(DPF.eval_key(k_0, i)), bob)
        eval_dpf_alice = PrivateScalar(int(DPF.eval_key(k_1, i)), bob)

        # Share the DPF evaluation with the servers
        share_eval_dpf_alice = eval_dpf_alice.share([alice, bob], PRIME)
        share_eval_dpf_bob = eval_dpf_bob.share([alice, bob], PRIME)

        # Add the transition function evaluation to the next state
        # INFO: Due to the DPF evaluation, only the correct transition function will be added
        # From the perspective of both servers it is impossible to determine which transition function was evaluated
        shared_ns += (share_eval_dpf_alice * shared_ns_i) + (share_eval_dpf_bob * shared_ns_i)
    
    return shared_ns

In [3]:
alice = VirtualMachine('alice')     # server 0
bob = VirtualMachine('bob')         # server 1
charlie = VirtualMachine('charlie') # client

shared_state = PrivateScalar(0, alice).share([alice, bob], PRIME) # initial state

# Evaluate the state machine on the input 'aaa'
shared_state = eval_state_machine('a', shared_state, [tf_0, tf_1], PRIME, alice, bob, charlie)
shared_state = eval_state_machine('a', shared_state, [tf_0, tf_1], PRIME, alice, bob, charlie)
shared_state = eval_state_machine('a', shared_state, [tf_0, tf_1], PRIME, alice, bob, charlie)

print(f"Reconstructed state from shares: {shared_state.reconstruct(alice)}") # reconstruct state from shares

self.value - mul 56742
other_value - mul 10922
self.value - mul 8779
other_value - mul 10922
self.value 38506
other_value 2
self.value - mul 34655
other_value - mul -1
self.value - mul 33674
other_value - mul -1
self.value 56742
other_value 30866
self.value 8779
other_value 31847
self.value - mul 1946
other_value - mul -1
self.value - mul 51589
other_value - mul -1
self.value 56742
other_value 63575
self.value 8779
other_value 13932
self.value 22087
other_value 0
self.value 22087
other_value 40626
self.value 54796
other_value 0
self.value 54796
other_value 22711
self.value - mul 1946
other_value - mul 62713
self.value - mul 51589
other_value - mul 62713
self.value 2278
other_value 39396
self.value 18828
other_value 5019
self.value - mul 34655
other_value - mul 11986
self.value - mul 33674
other_value - mul 11986
self.value 41674
other_value 37211
self.value 23847
other_value 7204
self.value 13364
other_value 751678018


ValueError: GF(65521) scalars must be in `0 <= x < 65521`, not 751678018.

In [None]:
state = 0 # initial state

state = tf(state, token_to_int('a')) % PRIME
state = tf(state, token_to_int('a')) % PRIME
state = tf(state, token_to_int('a')) % PRIME

print(f"State from direct evaluation: {state}")
print("State reconstructed from shares matches state from direct evaluation:", shared_state.reconstruct(alice).value == state)

State from direct evaluation: 3
State reconstructed from shares matches state from direct evaluation: True
