<a href="https://colab.research.google.com/github/anirbansen3027/SecureAndPrivateAI/blob/master/9_encrypted_database.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install syft

Collecting syft
[?25l  Downloading https://files.pythonhosted.org/packages/38/2e/16bdefc78eb089e1efa9704c33b8f76f035a30dc935bedd7cbb22f6dabaa/syft-0.1.21a1-py3-none-any.whl (219kB)
[K     |█▌                              | 10kB 20.2MB/s eta 0:00:01[K     |███                             | 20kB 2.2MB/s eta 0:00:01[K     |████▌                           | 30kB 3.2MB/s eta 0:00:01[K     |██████                          | 40kB 2.1MB/s eta 0:00:01[K     |███████▌                        | 51kB 2.6MB/s eta 0:00:01[K     |█████████                       | 61kB 3.1MB/s eta 0:00:01[K     |██████████▍                     | 71kB 3.6MB/s eta 0:00:01[K     |████████████                    | 81kB 4.1MB/s eta 0:00:01[K     |█████████████▍                  | 92kB 4.6MB/s eta 0:00:01[K     |███████████████                 | 102kB 3.5MB/s eta 0:00:01[K     |████████████████▍               | 112kB 3.5MB/s eta 0:00:01[K     |█████████████████▉              | 122kB 3.5MB/s eta 0:00:

In [0]:
import string
import torch as th
import syft as sy
hook = sy.TorchHook(th)

W0803 18:31:47.594170 140601560504192 secure_random.py:26] Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was '/usr/local/lib/python3.6/dist-packages/tf_encrypted/operations/secure_random/secure_random_module_tf_1.14.0.so'
W0803 18:31:47.610126 140601560504192 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/tf_encrypted/session.py:26: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.



In [0]:
bob = sy.VirtualWorker(hook, id="bob").add_worker(sy.local_worker)
alice = sy.VirtualWorker(hook, id="alice").add_worker(sy.local_worker)
secure_worker = sy.VirtualWorker(hook, id="secure_worker").add_worker(sy.local_worker)

In [0]:
char2index = {}
index2char = {}

for i,char in enumerate(' '+string.ascii_lowercase+'0123456789'+string.punctuation):
  char2index[char] = i
  index2char[i] = char

In [0]:
def string2values(str_input,max_length=8):
  str_input = str_input[:max_length].lower()
  if(len(str_input)<max_length):
    str_input = str_input + "." * (max_length-len(str_input))
  values = list()
  for char in str_input:
    values.append(char2index[char])
  return th.tensor(values).long()

def one_hot(index,length):
  vect = th.zeros(length).long()
  vect[index]=1
  return vect

def string2one_hot_matrix(str_input,max_length=8):
  str_input = str_input[:max_length].lower()
  if(len(str_input)<max_length):
    str_input = str_input + "." * (max_length-len(str_input))
  values = list()
  for char in str_input:
    values.append(one_hot(char2index[char],len(char2index)).unsqueeze(0))
  return th.cat(values,dim=0) #0th dimension

def values2string(input_values):
  s = ""
  for value in input_values:
    s+=index2char[int(value)]
  return s

def string_equal(str_a,str_b):
  vect = (str_a * str_b).sum(1)
  x = vect[0]
  for i in range(vect.shape[0]-1):
    x = x * vect[i+1]
  return x

In [0]:
class EncryptedDB():
  def __init__(self,*owners,max_key_len=8, max_value_len=8):
    self.max_key_len = max_key_len
    self.max_value_len = max_value_len
    self.keys = list()
    self.values = list()
    self.owners = owners
    
  def add_entry(self,key,value):
    key = string2one_hot_matrix(key)
    key = key.share(*self.owners)
    self.keys.append(key)
    
    value = string2values(value)
    value = value.share(*self.owners)
    self.values.append(value)
    
  def query(self,query_str):
      query_matrix = string2one_hot_matrix(query_str)
      query_matrix = query_matrix.share(*self.owners)
      key_matches = list()
      for key in self.keys:
        key_match = string_equal(key,query_matrix)
        key_matches.append(key_match)
      result = self.values[0]*key_matches[0]
      for i in range(len(self.values)-1):
        result +=self.values[i+1] * key_matches[i+1]
      result = result.get()
      return values2string(result).replace('.','')

In [0]:
db = EncryptedDB(bob,alice,secure_worker)

db.add_entry("key1","value1")
db.add_entry("key2","value2")
db.add_entry("key3","value3")
db.add_entry("key4","value4")

db.query("key1")

'value1'