In [1]:
#Part 11 - Secure Deep Learning Classification

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

import syft as sy


Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was 'C:\ProgramData\Anaconda3\lib\site-packages\tf_encrypted/operations/secure_random/secure_random_module_tf_1.15.2.so'





In [2]:
epochs = 10
n_test_batches = 200

In [3]:
hook = sy.TorchHook(torch) 
client = sy.VirtualWorker(hook, id="client")
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
crypto_provider = sy.VirtualWorker(hook, id="crypto_provider")

In [4]:
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 50
        self.epochs = epochs
        self.lr = 0.001
        self.log_interval = 100

In [5]:
args = Arguments()


In [6]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.batch_size, shuffle=True)

In [7]:
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.test_batch_size, shuffle=True)

private_test_loader = []
for data, target in test_loader:
    private_test_loader.append((
        data.fix_precision().share(alice, bob, crypto_provider=crypto_provider),
        target.fix_precision().share(alice, bob, crypto_provider=crypto_provider)
    ))

In [8]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

In [9]:
def train(args, model, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        output = F.log_softmax(output, dim=1)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size, len(train_loader) * args.batch_size,
                100. * batch_idx / len(train_loader), loss.item()))

In [10]:
model = Net()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

for epoch in range(1, args.epochs + 1):
    train(args, model, train_loader, optimizer, epoch)



In [11]:
def test(args, model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            output = F.log_softmax(output, dim=1)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [12]:
test(args, model, test_loader)


Test set: Average loss: 0.0843, Accuracy: 9820/10000 (98%)



In [15]:
model.fix_precision().share(alice, bob, crypto_provider=crypto_provider)

Net(
  (fc1): Linear(in_features=784, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

In [16]:
def test(args, model, test_loader):
    model.eval()
    n_correct_priv = 0
    n_total = 0
    with torch.no_grad():
        for data, target in test_loader[:n_test_batches]:
            output = model(data)
            pred = output.argmax(dim=1) 
            n_correct_priv += pred.eq(target.view_as(pred)).sum()
            n_total += args.test_batch_size
# This 'test' function performs the encrypted evaluation. The model weights, the data inputs, the prediction and the target used for scoring are all encrypted!

# However as you can observe, the syntax is very similar to normal PyTorch testing! Nice!

# The only thing we decrypt from the server side is the final score at the end of our 200 items batches to verify predictions were on average good.      
            n_correct = n_correct_priv.copy().get().float_precision().long().item()
    
            print('Test set: Accuracy: {}/{} ({:.0f}%)'.format(
                n_correct, n_total,
                100. * n_correct / n_total))

In [17]:
test(args, model, private_test_loader)

Test set: Accuracy: 50/50 (100%)
Test set: Accuracy: 100/100 (100%)
Test set: Accuracy: 149/150 (99%)
Test set: Accuracy: 198/200 (99%)
Test set: Accuracy: 248/250 (99%)
Test set: Accuracy: 298/300 (99%)
Test set: Accuracy: 347/350 (99%)
Test set: Accuracy: 396/400 (99%)
Test set: Accuracy: 443/450 (98%)
Test set: Accuracy: 491/500 (98%)
Test set: Accuracy: 541/550 (98%)
Test set: Accuracy: 591/600 (98%)
Test set: Accuracy: 640/650 (98%)
Test set: Accuracy: 690/700 (99%)
Test set: Accuracy: 738/750 (98%)
Test set: Accuracy: 786/800 (98%)
Test set: Accuracy: 836/850 (98%)
Test set: Accuracy: 886/900 (98%)
Test set: Accuracy: 935/950 (98%)
Test set: Accuracy: 981/1000 (98%)
Test set: Accuracy: 1031/1050 (98%)
Test set: Accuracy: 1081/1100 (98%)
Test set: Accuracy: 1130/1150 (98%)
Test set: Accuracy: 1178/1200 (98%)
Test set: Accuracy: 1227/1250 (98%)
Test set: Accuracy: 1277/1300 (98%)
Test set: Accuracy: 1325/1350 (98%)
Test set: Accuracy: 1374/1400 (98%)
Test set: Accuracy: 1424/1450 (