In [1]:
#Part 12: Train an Encrypted NN on Encrypted Data

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import syft as sy

Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was 'C:\ProgramData\Anaconda3\lib\site-packages\tf_encrypted/operations/secure_random/secure_random_module_tf_1.15.2.so'





In [2]:
#Step 1: Create Workers and Toy Data

# Set everything up
hook = sy.TorchHook(torch) 

alice = sy.VirtualWorker(id="alice", hook=hook)
bob = sy.VirtualWorker(id="bob", hook=hook)
james = sy.VirtualWorker(id="james", hook=hook)

In [3]:
# A Toy Dataset
data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]])
target = torch.tensor([[0],[0],[1],[1.]])

In [4]:
# A Toy Model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 2)
        self.fc2 = nn.Linear(2, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x


In [5]:
model = Net()

In [6]:
#Step 2: Encrypt the Model and Data¶

# We encode everything
data = data.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)
target = target.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)
model = model.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)

print(data)

(Wrapper)>AutogradTensor>FixedPrecisionTensor>[AdditiveSharingTensor]
	-> [PointerTensor | me:20670923291 -> bob:45670312454]
	-> [PointerTensor | me:29277605191 -> alice:58678621064]
	*crypto provider: james*


In [7]:
#Step 3: Train

opt = optim.SGD(params=model.parameters(),lr=0.1).fix_precision()

for iter in range(20):
    # 1) erase previous gradients (if they exist)
    opt.zero_grad()

    # 2) make a prediction
    pred = model(data)

    # 3) calculate how much we missed
    loss = ((pred - target)**2).sum()

    # 4) figure out which weights caused us to miss
    loss.backward()

    # 5) change those weights
    opt.step()

    # 6) print our progress
    print(loss.get().float_precision())

tensor(1.4130)
tensor(0.9320)
tensor(1.2300)
tensor(0.7050)
tensor(0.8820)
tensor(0.5460)
tensor(0.6980)
tensor(0.4620)
tensor(0.6290)
tensor(0.4360)
tensor(0.6200)
tensor(0.3290)
tensor(0.4290)
tensor(0.2290)
tensor(0.2800)
tensor(0.1540)
tensor(0.1820)
tensor(0.1040)
tensor(0.1120)
tensor(0.0670)
