In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [0]:
!pip install syft

# Create Hook

In [0]:
import syft as sy

hook = sy.TorchHook(torch)

# Create worker

In [0]:
client = sy.VirtualWorker(hook, id='client')
bob = sy.VirtualWorker(hook, id='bob')
alice = sy.VirtualWorker(hook, id='alice')
crypto_provider = sy.VirtualWorker(hook, id='crypto_provider')

# Hyperparameters

In [0]:
batch_size = 128
test_batch_size = 1024
epochs = 10
lr = 0.001
log_interval = 100

# Data loading

In [0]:
transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))])

train_data = datasets.MNIST('data', train=True, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

In [0]:
transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))])

test_data = datasets.MNIST('data', train=False, download=True, transform=transform)

test_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

# Share data with two workers

In [0]:
private_test_loader = []
for data, target in test_loader:
    private_test_loader.append((
        data.fix_prec().share(alice, bob, crypto_provider=crypto_provider),
        target.fix_prec().share(alice, bob, crypto_provider=crypto_provider)
    ))

# Network

In [0]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

# Train

In [0]:
def train(model, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        output = F.log_softmax(output, dim=1)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]'.format(
                epoch, batch_idx * batch_size, len(train_loader) * batch_size,
                100. * batch_idx / len(train_loader)))

In [31]:
model = Net()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

for epoch in range(1, epochs + 1):
    train(model, train_loader, optimizer, epoch)



# Now model reday to serve
send the model to virtual worker

In [32]:
model.fix_precision().share(alice, bob, crypto_provider=crypto_provider)

Net(
  (fc1): Linear(in_features=784, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

# Checek (only decrypt servide side)

In [0]:
def test(model, test_loader):
    model.eval()
    n_correct_priv = 0
    n_total = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            pred = output.argmax(dim=1) 
            n_correct_priv += pred.eq(target.view_as(pred)).sum()
            n_total += test_batch_size
            
            n_correct = n_correct_priv.copy().get().float_precision().long().item()
    
            print('Test set: Accuracy: {}/{} ({:.0f}%)'.format(
                n_correct, n_total,
                100. * n_correct / n_total))

In [34]:
test(model, private_test_loader)

Test set: Accuracy: 128/1024 (12%)
Test set: Accuracy: 255/2048 (12%)
Test set: Accuracy: 382/3072 (12%)
Test set: Accuracy: 510/4096 (12%)
Test set: Accuracy: 638/5120 (12%)
Test set: Accuracy: 766/6144 (12%)
Test set: Accuracy: 893/7168 (12%)
Test set: Accuracy: 1020/8192 (12%)
Test set: Accuracy: 1148/9216 (12%)
Test set: Accuracy: 1275/10240 (12%)
Test set: Accuracy: 1402/11264 (12%)
Test set: Accuracy: 1530/12288 (12%)
Test set: Accuracy: 1658/13312 (12%)
Test set: Accuracy: 1786/14336 (12%)
Test set: Accuracy: 1914/15360 (12%)
Test set: Accuracy: 2041/16384 (12%)
Test set: Accuracy: 2169/17408 (12%)
Test set: Accuracy: 2297/18432 (12%)
Test set: Accuracy: 2425/19456 (12%)
Test set: Accuracy: 2553/20480 (12%)
Test set: Accuracy: 2680/21504 (12%)
Test set: Accuracy: 2806/22528 (12%)
Test set: Accuracy: 2934/23552 (12%)
Test set: Accuracy: 3062/24576 (12%)
Test set: Accuracy: 3190/25600 (12%)
Test set: Accuracy: 3318/26624 (12%)
Test set: Accuracy: 3446/27648 (12%)
Test set: Accurac

KeyboardInterrupt: ignored