In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

In [11]:
import syft as sy

hook = sy.TorchHook(torch)
client = sy.VirtualWorker(hook, id="client")
server = sy.VirtualWorker(hook, id="server")
crypto_provider = sy.VirtualWorker(hook, id="crypto_provider")

In [12]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

In [13]:
model = Net()
model.load_state_dict(torch.load("server_trained_model.pt"))
model.eval()

Net(
  (fc1): Linear(in_features=784, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

In [14]:
model.fix_precision().share(client, server, crypto_provider=crypto_provider)

Net(
  (fc1): Linear(in_features=784, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

In [20]:
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "data",
        train=False,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=64,
    shuffle=True,
)

private_test_loader = []
for data, target in test_loader:
    private_test_loader.append(
        (
            data.fix_precision().share(client, server, crypto_provider=crypto_provider),
            target.fix_precision().share(
                client, server, crypto_provider=crypto_provider
            ),
        )
    )

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data\MNIST\raw\train-images-idx3-ubyte.gz


9920512it [00:02, 4492198.36it/s]                                                                                      


Extracting data\MNIST\raw\train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data\MNIST\raw\train-labels-idx1-ubyte.gz


32768it [00:00, 181516.16it/s]                                                                                         


Extracting data\MNIST\raw\train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data\MNIST\raw\t10k-images-idx3-ubyte.gz


1654784it [00:00, 2551579.84it/s]                                                                                      


Extracting data\MNIST\raw\t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data\MNIST\raw\t10k-labels-idx1-ubyte.gz


8192it [00:00, 62302.79it/s]                                                                                           


Extracting data\MNIST\raw\t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [21]:
def test(model, test_loader):
    """Test the model."""
    model.eval()
    n_correct_priv = 0
    n_total = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            pred = output.argmax(dim=1)
            n_correct_priv += pred.eq(target.view_as(pred)).sum()
            n_total += 64
            n_correct = n_correct_priv.copy().get().float_precision().long().item()
            print(
                "Test set: Accuracy: {}/{} ({:.0f}%)".format(
                    n_correct, n_total, 100.0 * n_correct / n_total
                )
            )

In [22]:
test(model, private_test_loader)

Test set: Accuracy: 62/64 (97%)
Test set: Accuracy: 123/128 (96%)
Test set: Accuracy: 186/192 (97%)
Test set: Accuracy: 249/256 (97%)
Test set: Accuracy: 313/320 (98%)
Test set: Accuracy: 376/384 (98%)
Test set: Accuracy: 439/448 (98%)
Test set: Accuracy: 503/512 (98%)
Test set: Accuracy: 564/576 (98%)
Test set: Accuracy: 624/640 (98%)
Test set: Accuracy: 688/704 (98%)
Test set: Accuracy: 751/768 (98%)
Test set: Accuracy: 812/832 (98%)
Test set: Accuracy: 872/896 (97%)
Test set: Accuracy: 936/960 (98%)
Test set: Accuracy: 997/1024 (97%)
Test set: Accuracy: 1061/1088 (98%)
Test set: Accuracy: 1124/1152 (98%)
Test set: Accuracy: 1186/1216 (98%)
Test set: Accuracy: 1250/1280 (98%)
Test set: Accuracy: 1312/1344 (98%)
Test set: Accuracy: 1372/1408 (97%)
Test set: Accuracy: 1435/1472 (97%)
Test set: Accuracy: 1498/1536 (98%)
Test set: Accuracy: 1561/1600 (98%)
Test set: Accuracy: 1624/1664 (98%)
Test set: Accuracy: 1686/1728 (98%)
Test set: Accuracy: 1749/1792 (98%)
Test set: Accuracy: 1813/