In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import syft as sy

In [2]:
# Set everything up
hook = sy.TorchHook(torch) 

alice = sy.VirtualWorker(id="alice", hook=hook)
bob = sy.VirtualWorker(id="bob", hook=hook)
james = sy.VirtualWorker(id="james", hook=hook)

In [3]:
# A Toy Dataset
data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]])
target = torch.tensor([[0],[0],[1],[1.]])

# A Toy Model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 2)
        self.fc2 = nn.Linear(2, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x
model = Net()

In [4]:
# We encode everything
data = data.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)
target = target.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)
model = model.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)

In [5]:
print(data)

(Wrapper)>AutogradTensor>FixedPrecisionTensor>[AdditiveSharingTensor]
	-> [PointerTensor | me:3628573935 -> bob:91600533220]
	-> [PointerTensor | me:79760290746 -> alice:90859236382]
	*crypto provider: james*


In [6]:
print(target)

(Wrapper)>AutogradTensor>FixedPrecisionTensor>[AdditiveSharingTensor]
	-> [PointerTensor | me:13164392956 -> bob:97964324076]
	-> [PointerTensor | me:45736577317 -> alice:79156148178]
	*crypto provider: james*


In [7]:
print(model)

Net(
  (fc1): Linear(in_features=2, out_features=2, bias=True)
  (fc2): Linear(in_features=2, out_features=1, bias=True)
)
---
locations: [<VirtualWorker id:bob #objects:6>, <VirtualWorker id:alice #objects:6>]


In [8]:
opt = optim.SGD(params=model.parameters(),lr=0.1).fix_precision()

for iter in range(20):
    # 1) erase previous gradients (if they exist)
    opt.zero_grad()

    # 2) make a prediction
    pred = model(data)

    # 3) calculate how much we missed
    loss = ((pred - target)**2).sum()

    # 4) figure out which weights caused us to miss
    loss.backward()

    # 5) change those weights
    opt.step()

    # 6) print our progress
    print(loss.get().float_precision())

tensor(1.1470)
tensor(0.8940)
tensor(0.8340)
tensor(0.7760)
tensor(0.7110)
tensor(0.6360)
tensor(0.5560)
tensor(0.4710)
tensor(0.3850)
tensor(0.3040)
tensor(0.2310)
tensor(0.1700)
tensor(0.1220)
tensor(0.0840)
tensor(0.0590)
tensor(0.0400)
tensor(0.0260)
tensor(0.0180)
tensor(0.0130)
tensor(0.0090)
