# 安全模型分类实例

在此笔记本中，我们将使用到目前为止所学到的所有技术来执行神经网络训练（和预测），同时对模型和数据进行加密。

## 1 创建数据、设置syft

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import syft as sy

In [2]:
# Set everything up
hook = sy.TorchHook(torch) 

alice = sy.VirtualWorker(id="alice", hook=hook)
bob = sy.VirtualWorker(id="bob", hook=hook)
james = sy.VirtualWorker(id="james", hook=hook)

In [3]:
# A Toy Dataset
data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]])
target = torch.tensor([[0],[0],[1],[1.]])

# A Toy Model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 2)
        self.fc2 = nn.Linear(2, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x
model = Net()

## 2 加密模型和数据

In [4]:
# We encode everything
data = data.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)
target = target.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)
model = model.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)

In [5]:
print(data)

(Wrapper)>AutogradTensor>FixedPrecisionTensor>[AdditiveSharingTensor]
	-> [PointerTensor | me:37577169362 -> bob:13920372678]
	-> [PointerTensor | me:72102385192 -> alice:93469225639]
	*crypto provider: james*


## 3 训练

In [6]:
opt = optim.SGD(params=model.parameters(),lr=0.1).fix_precision()

for iter in range(20):
    # 1) erase previous gradients (if they exist)
    opt.zero_grad()

    # 2) make a prediction
    pred = model(data)

    # 3) calculate how much we missed
    loss = ((pred - target)**2).sum()

    # 4) figure out which weights caused us to miss
    loss.backward()

    # 5) change those weights
    opt.step()

    # 6) print our progress
    print(loss.get().float_precision())

tensor(1.0490)
tensor(0.9820)
tensor(0.9470)
tensor(0.8940)
tensor(0.8300)
tensor(0.7630)
tensor(0.6830)
tensor(0.5860)
tensor(0.5040)
tensor(0.3940)
tensor(0.2920)
tensor(0.2050)
tensor(0.1720)
tensor(0.1150)
tensor(0.0940)
tensor(0.0740)
tensor(0.0550)
tensor(0.0550)
tensor(0.0330)
tensor(0.0370)
