In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

In [2]:
import syft as sy
hook = sy.TorchHook(torch) 
client = sy.VirtualWorker(hook, id="client")
server = sy.VirtualWorker(hook, id="server")
crypto_provider = sy.VirtualWorker(hook, id="crypto_provider")

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

In [4]:
model = Net()
model.load_state_dict(torch.load("server_trained_model.pt"))
model.eval()

Net(
  (fc1): Linear(in_features=784, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

In [5]:
model.fix_precision().share(client, server, crypto_provider=crypto_provider)

Net(
  (fc1): Linear(in_features=784, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

In [6]:
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=64, shuffle=True)

private_test_loader = []
for data, target in test_loader:
    private_test_loader.append((
        data.fix_precision().share(client, server, crypto_provider=crypto_provider),
        target.fix_precision().share(client, server, crypto_provider=crypto_provider)
    ))

In [9]:
def test(model, test_loader):
    model.eval()
    n_correct_priv = 0
    n_total = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            pred = output.argmax(dim=1) 
            n_correct_priv += pred.eq(target.view_as(pred)).sum()
            n_total += 64
            n_correct = n_correct_priv.copy().get().float_precision().long().item()
            print('Test set: Accuracy: {}/{} ({:.0f}%)'.format(
                n_correct, n_total,
                100. * n_correct / n_total))

In [10]:
test(model, private_test_loader)

Test set: Accuracy: 63/64 (98%)
Test set: Accuracy: 123/128 (96%)
Test set: Accuracy: 185/192 (96%)
Test set: Accuracy: 248/256 (97%)
Test set: Accuracy: 310/320 (97%)
Test set: Accuracy: 373/384 (97%)
Test set: Accuracy: 433/448 (97%)
Test set: Accuracy: 496/512 (97%)
Test set: Accuracy: 558/576 (97%)
Test set: Accuracy: 620/640 (97%)
Test set: Accuracy: 684/704 (97%)
Test set: Accuracy: 748/768 (97%)
Test set: Accuracy: 808/832 (97%)
Test set: Accuracy: 871/896 (97%)
Test set: Accuracy: 935/960 (97%)
Test set: Accuracy: 998/1024 (97%)
Test set: Accuracy: 1060/1088 (97%)
Test set: Accuracy: 1123/1152 (97%)
Test set: Accuracy: 1182/1216 (97%)
Test set: Accuracy: 1244/1280 (97%)
Test set: Accuracy: 1306/1344 (97%)
Test set: Accuracy: 1370/1408 (97%)
Test set: Accuracy: 1433/1472 (97%)
Test set: Accuracy: 1496/1536 (97%)
Test set: Accuracy: 1558/1600 (97%)
Test set: Accuracy: 1620/1664 (97%)
Test set: Accuracy: 1683/1728 (97%)
Test set: Accuracy: 1746/1792 (97%)
Test set: Accuracy: 1809/