In [1]:
import string
import torch as th

In [44]:
import syft as sy
import torch as th
hook = sy.TorchHook(th)

from torch import nn, optim

W0711 22:08:51.714949 21564 secure_random.py:22] Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow (1.14.0). Fix this by compiling custom ops.
W0711 22:08:51.749859 21564 deprecation_wrapper.py:119] From e:\anaconda3\envs\pysyft\lib\site-packages\tf_encrypted\session.py:28: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.



In [45]:
bob = sy.VirtualWorker(hook, id="bob").add_worker(sy.local_worker)
alice = sy.VirtualWorker(hook, id="alice").add_worker(sy.local_worker)
secure_worker = sy.VirtualWorker(hook, id="secure_worker").add_worker(sy.local_worker)

In [2]:
char2index = {}
index2char = {}

In [3]:
for i, char in enumerate(' ' + string.ascii_lowercase + '0123456789' + string.punctuation):
    char2index[char] = i
    index2char[i] = char

In [55]:
def string2values(str_input, max_len=8):
    
    str_input = str_input[:max_len].lower()

    if(len(str_input) < max_len):
        str_input = str_input + "." * (max_len - len(str_input))
    
    values = list()
    for char in str_input:
        values.append(char2index[char])
    
    return th.tensor(values).long()

def values2string(input_values):
    s = ""
    for value in input_values:
        s += index2char[int(value)]
    return s

def one_hot(index, length):
    vect = th.zeros(length).long()
    vect[index] = 1
    return vect

def string2one_hot_matrix(str_input, max_len=8):
    
    str_input = str_input[:max_len].lower()
    
    if(len(str_input) < max_len):
        str_input = str_input + "." * (max_len - len(str_input))
        
    char_vectors = list()
    
    for char in str_input:
        char_v = one_hot(char2index[char], len(index2char)).unsqueeze(0)
        char_vectors.append(char_v)
    
    return th.cat(char_vectors, dim=0)

def strings_equal(str_a, str_b):
    
    vect = (str_a * str_b).sum(1)
    
    x = vect[0]

    for i in range(vect.shape[0] - 1):
        x = x * vect[i + 1]

    return x

In [56]:
class EncryptedDB():
    
    def __init__(self, *owners, max_key_len=8, max_val_len=8):
        
        self.max_key_len = 8
        self.max_val_len = 8
        
        self.keys = list()
        self.values = list()
        self.owners = owners
        
    def add_entry(self, key, value):
        key = string2one_hot_matrix(key)
        key = key.share(*self.owners)
        self.keys.append(key)
        
        value = string2values(value)
        value = value.share(*self.owners)
        self.values.append(value)
        
    def query(self, query_str):
    
        query_matrix = string2one_hot_matrix(query_str)
        query_matrix = query_matrix.share(*self.owners)
        
        key_matches = list()
        
        for key in self.keys:
            key_match = strings_equal(key, query_matrix)
            key_matches.append(key_match)

        result = self.values[0] * key_matches[0]

        for i in range(len(self.values) - 1):

            result += self.values[i+1] * key_matches[i+1]

        result = result.get()
        
        return values2string(result).replace(".", "")

In [57]:
db = EncryptedDB(bob, alice, secure_worker, max_val_len=256)

In [58]:
db.add_entry("key1", "value1")
db.add_entry("key2", "value2")
db.add_entry("key3", "value3")
db.add_entry("key4", "value4")

In [62]:
db.query("key2")

'value2'

In [60]:
db.values

[(Wrapper)>[AdditiveSharingTensor]
 	-> (Wrapper)>[PointerTensor | me:33724769132 -> bob:51508208133]
 	-> (Wrapper)>[PointerTensor | me:49456846578 -> alice:94653780467]
 	-> (Wrapper)>[PointerTensor | me:75410666437 -> secure_worker:74827746001]
 	*crypto provider: me*, (Wrapper)>[AdditiveSharingTensor]
 	-> (Wrapper)>[PointerTensor | me:79596864977 -> bob:30449414366]
 	-> (Wrapper)>[PointerTensor | me:48761183427 -> alice:13644266682]
 	-> (Wrapper)>[PointerTensor | me:6589754021 -> secure_worker:40661810868]
 	*crypto provider: me*, (Wrapper)>[AdditiveSharingTensor]
 	-> (Wrapper)>[PointerTensor | me:52259506507 -> bob:87446705114]
 	-> (Wrapper)>[PointerTensor | me:90713215157 -> alice:6398205814]
 	-> (Wrapper)>[PointerTensor | me:38927723653 -> secure_worker:32849796054]
 	*crypto provider: me*, (Wrapper)>[AdditiveSharingTensor]
 	-> (Wrapper)>[PointerTensor | me:74695980674 -> bob:81743105141]
 	-> (Wrapper)>[PointerTensor | me:75855056964 -> alice:62396218578]
 	-> (Wrapper)>