<a href="https://colab.research.google.com/github/MKrupauskas/colab/blob/master/federated-learning-aggregation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [28]:
!pip install syft



In [29]:
import syft as sy
import torch as th
from torch import nn, optim

hook = sy.TorchHook(th)

W0818 11:23:47.730978 140031779673984 hook.py:98] Torch was already hooked... skipping hooking process


In [0]:
bob = sy.VirtualWorker(hook, id = "bob")
alice = sy.VirtualWorker(hook, id = "alice")

secure_worker = sy.VirtualWorker(hook, id = "secure_worker")

In [31]:
bob.add_workers([alice, secure_worker])
alice.add_workers([bob, secure_worker])
secure_worker.add_workers([alice, bob])

W0818 11:23:47.765182 140031779673984 base.py:646] Worker alice already exists. Replacing old worker which could cause                     unexpected behavior
W0818 11:23:47.768253 140031779673984 base.py:646] Worker secure_worker already exists. Replacing old worker which could cause                     unexpected behavior
W0818 11:23:47.769680 140031779673984 base.py:646] Worker bob already exists. Replacing old worker which could cause                     unexpected behavior
W0818 11:23:47.771459 140031779673984 base.py:646] Worker secure_worker already exists. Replacing old worker which could cause                     unexpected behavior
W0818 11:23:47.773218 140031779673984 base.py:646] Worker alice already exists. Replacing old worker which could cause                     unexpected behavior
W0818 11:23:47.774308 140031779673984 base.py:646] Worker bob already exists. Replacing old worker which could cause                     unexpected behavior


<VirtualWorker id:secure_worker #objects:0>

In [0]:
data = th.tensor([[1., 1], [0, 1,], [1, 0], [0, 0]], requires_grad = True)
target = th.tensor([[1.], [1], [0], [0]], requires_grad = True)

In [0]:
bobs_data = data[0 : 2].send(bob)
bobs_target = target[0 : 2].send(bob)

alices_data = data[2:].send(alice)
alices_target = target[2:].send(alice)

In [0]:
model = nn.Linear(2, 1)

In [0]:
bobs_model = model.copy().send(bob)
alices_model = model.copy().send(alice)

In [0]:
bobs_optimizer = optim.SGD(params = bobs_model.parameters(), lr = 0.1)
alices_optimizer = optim.SGD(params = alices_model.parameters(), lr = 0.1)

In [37]:
bobs_optimizer.zero_grad()

bobs_prediction = bobs_model(bobs_data)

bobs_loss = ((bobs_prediction - bobs_target) ** 2).sum()

bobs_loss.backward()

bobs_optimizer.step()

bobs_loss = bobs_loss.get().data

bobs_loss

tensor(0.5376)

In [38]:
alices_optimizer.zero_grad()

alices_prediction = alices_model(alices_data)

alices_loss = ((alices_prediction - alices_target) ** 2).sum()

alices_loss.backward()

alices_optimizer.step()

alices_loss = alices_loss.get().data

alices_loss

tensor(0.2122)

In [0]:
alices_model.move(secure_worker)
bobs_model.move(secure_worker)

In [40]:
model.weight.data.set_(((alices_model.weight.data + bobs_model.weight.data) / 2).get())
model.bias.data.set_(((alices_model.bias.data + bobs_model.bias.data) / 2).get())

secure_worker.clear_objects()

RuntimeError: ignored

In [0]:
import random

Q = 23740629843760239486723

def encrypt(x = 5, n_shares = 3):
  shares = list()

  for i in range(n_shares - 1):
    shares.append(random.randint(0, Q))

  final_share = Q - (sum(shares) % Q) + x

  shares.append(final_share)

  return tuple(shares)

In [42]:
encrypt(5, 10)

(2757908212675146006271,
 15963086721551369973671,
 22562357258713406599980,
 9627748376416124569161,
 10455108573637732144252,
 18255004447222748708162,
 1128125268672474577226,
 4318212741151429249323,
 15532160845696815487177,
 18103436773063950118397)

In [0]:
def decrypt(shares):
  return sum(shares) % Q

In [44]:
decrypt(encrypt())

5

In [0]:
def add(a, b):
  c = list()

  assert(len(a) == len(b))

  for i in range(len(a)):
    c.append((a[i] + b[i]) % Q)

  return tuple(c)

In [46]:
decrypt(add(encrypt(5), encrypt(2)))

7

For floating point numbers

In [0]:
BASE = 10
PRECISION = 4

def encode(x_dec):
  return int(x_dec * (BASE ** PRECISION)) % Q

def decode(x_fp):
  return (x_fp if x_fp <= Q / 2 else x_fp - Q) / BASE ** PRECISION

In [0]:
bob = bob.clear_objects()
alice = alice.clear_objects()
secure_worker = secure_worker.clear_objects()

In [49]:
x = th.tensor([1,2,3,4,5])

x = x.share(bob, alice, secure_worker)

y = x + x

y.get()

tensor([ 2,  4,  6,  8, 10])

In [50]:
x = th.tensor([0.1,0.2,0.3,0.4,0.5])

x = x.fix_prec()

x

(Wrapper)>FixedPrecisionTensor>tensor([100, 200, 300, 400, 500])

In [0]:
def sub(a, b):
  c = list()

  assert(len(a) == len(b))

  for i in range(len(a)):
    c.append((a[i] - b[i]) % Q)

  return tuple(c)

In [0]:
def imul(a, scalar):
  c = list()

  for i in range(len(a)):
    c.append((a[i] * scalar) % Q)

  return tuple(c)

In [53]:
decode(decrypt(imul(encrypt(encode(5.5)), 2)))

11.0

Secure NPC

In [60]:
bob = sy.VirtualWorker(hook, id = "bob").add_worker(sy.local_worker)
alice = sy.VirtualWorker(hook, id = "alice").add_worker(sy.local_worker)
secure_worker = sy.VirtualWorker(hook, id = "secure_worker").add_worker(sy.local_worker)

W0818 11:28:04.954257 140031779673984 base.py:646] Worker me already exists. Replacing old worker which could cause                     unexpected behavior
W0818 11:28:04.956845 140031779673984 base.py:646] Worker me already exists. Replacing old worker which could cause                     unexpected behavior
W0818 11:28:04.966185 140031779673984 base.py:646] Worker me already exists. Replacing old worker which could cause                     unexpected behavior


In [0]:
x = th.tensor([1,2,3,4])
y = th.tensor([2,-1,1,0])

In [0]:
x = x.share(bob, alice, crypto_provider=secure_worker)
y = y.share(bob, alice, crypto_provider=secure_worker)

In [66]:
z = x * y

z.get()

tensor([ 2, -2,  3,  0])

In [69]:
z = x == y

z.get()

tensor([0, 0, 0, 0])

Encrypted database

In [0]:
import string

index_to_char = {}
char_to_index = {}

In [0]:
for index, char in enumerate(' '  + string.ascii_lowercase + '0123456789' + string.punctuation):
  index_to_char[index] = char
  char_to_index[char] = index

In [0]:
def string_to_values(str_input = "hello", max_len = 8):
  str_input = str_input[:max_len].lower()

  if (len(str_input)):
    str_input += "." * (max_len - len(str_input))

  values = list()

  for char in str_input:
    values.append(char_to_index[char])

  return th.tensor(values).long()

In [0]:
def one_hot(index, length):
  vector = th.zeros(length).long()
  vector[index] = 1

  return vector;

In [0]:
def string_to_one_hot_matrix(str_input = "hello", max_len = 8):
  str_input = str_input[:max_len].lower()

  if (len(str_input)):
    str_input += "." * (max_len - len(str_input))

  values = list()

  for char in str_input:
    values.append(one_hot(char_to_index[char], len(char_to_index)).unsqueeze(0))

  return th.cat(values, dim = 0)

In [0]:
def strings_equal(a, b):
  return (a * b).sum() == len(a)

In [0]:
def values_to_string(input_values):
  string = ""

  for value in input_values:
    string += index_to_char[int(value)]

  return string

In [0]:
class EncryptedDB():
  def __init__(self, *owners, max_key_len = 8, max_val_len = 8):
    self.max_key_len = max_key_len
    self.max_val_len = max_val_len

    self.owners = owners

    self.keys = list()
    self.values = list()

  def add_entry(self, key = "keys1", value = "value1"):
    key = string_to_one_hot_matrix(key, self.max_key_len)
    key = key.share(*self.owners)

    self.keys.append(key)

    values = string_to_values(value, self.max_val_len)
    values = values.share(*self.owners)

    self.values.append(values)

  def query(self, query_str = "keys1"):
    query_matrix = string_to_one_hot_matrix(query_str, self.max_key_len)

    query_matrix = query_matrix.share(*self.owners)

    key_matches = list()

    for key in self.keys:
      key_match = strings_equal(key, query_matrix)

      key_matches.append(key_match)

    result = self.values[0] * key_matches[0]

    for i in range(len(self.values)):
      result += self.values[i] * key_matches[i]

    result = result.get()

    return values_to_string(result).replace(".", "")

In [161]:
db = EncryptedDB(bob, alice, secure_worker)

db.add_entry()

db.query()

TypeError: ignored