In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [2]:
import syft as sy
hook = sy.TorchHook(torch) 
client = sy.VirtualWorker(hook, id="client")
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
crypto_provider = sy.VirtualWorker(hook, id="crypto_provider") 

In [3]:
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 50
        self.epochs = 10
        self.lr = 0.001
        self.log_interval = 100

args = Arguments()

In [4]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.batch_size, shuffle=True)

In [5]:
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.test_batch_size, shuffle=True)

private_test_loader = []
for data, target in test_loader:
    private_test_loader.append((
        data.fix_precision().share(alice, bob, crypto_provider=crypto_provider),
        target.fix_precision().share(alice, bob, crypto_provider=crypto_provider)
    ))

In [6]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

In [7]:
model = Net()

PATH = './trained_model.pickle'
try:
    model.load_state_dict(torch.load(PATH))
    print('Model loaded from path', PATH, '!')
except FileNotFoundError:
    
    def train(args, model, train_loader, optimizer, epoch):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data)
            output = F.log_softmax(output, dim=1)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * args.batch_size, len(train_loader) * args.batch_size,
                    100. * batch_idx / len(train_loader), loss.item()))
            
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    for epoch in range(1, args.epochs + 1):
        train(args, model, train_loader, optimizer, epoch)

    def test(args, model, test_loader):
        model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                output = model(data)
                output = F.log_softmax(output, dim=1)
                test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
                pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)

        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))
        
    test(args, model, test_loader)

    torch.save(model.state_dict(), PATH)
    
model.eval()

Model loaded from path ./trained_model.pickle !


Net(
  (fc1): Linear(in_features=784, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

In [20]:
class PlanNet(sy.Plan):
    def __init__(self):
        super(PlanNet, self).__init__()
        self.fc1 = nn.Linear(784, 500)
        self.fc2 = nn.Linear(500, 10)
        
        self.state += ['fc1', 'fc2']
        

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

In [23]:
model = PlanNet()
print(model.fc1.bias[:5])
model.load_state_dict(torch.load(PATH))
print(model.fc1.bias[:5])

tensor([ 0.0117, -0.0003, -0.0307, -0.0185,  0.0088], grad_fn=<SliceBackward>)
tensor([-0.0339, -0.0175, -0.0268,  0.0149, -0.0039], grad_fn=<SliceBackward>)


In [None]:
# Make james own the Plan
model.send(james)

model.fix_precision().share(alice, bob, crypto_provider=crypto_provider)

# Fetch plan
fetched_plan = alice.fetch_plan(sent_plan.id)

hook.local_worker.clear_objects()

get_plan = sent_plan.get()


In [24]:
model.fix_precision().share(alice, bob, crypto_provider=crypto_provider)

<PlanNet PlanNet id:19738549030 owner:me>

In [25]:
def test(args, model, test_loader):
    model.eval()
    n_correct_priv = 0
    n_total = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            pred = output.argmax(dim=1) 
            n_correct_priv += pred.eq(target.view_as(pred)).sum()
            n_total += args.test_batch_size
# This 'test' function performs the encrypted evaluation. The model weights, the data inputs, the prediction and the target used for scoring are all encrypted!

# However as you can observe, the syntax is very similar to normal PyTorch testing! Nice!

# The only thing we decrypt from the server side is the final score at the end of our 200 items batches to verify predictions were on average good.      
            n_correct = n_correct_priv.copy().get().float_precision().long().item()
    
            print('Test set: Accuracy: {}/{} ({:.0f}%)'.format(
                n_correct, n_total,
                100. * n_correct / n_total))


In [26]:
test(args, model, private_test_loader)

Test set: Accuracy: 6/50 (12%)
Test set: Accuracy: 8/100 (8%)
Test set: Accuracy: 20/150 (13%)
Test set: Accuracy: 27/200 (14%)
Test set: Accuracy: 33/250 (13%)
Test set: Accuracy: 41/300 (14%)
Test set: Accuracy: 46/350 (13%)
Test set: Accuracy: 54/400 (14%)
Test set: Accuracy: 58/450 (13%)
Test set: Accuracy: 70/500 (14%)
Test set: Accuracy: 77/550 (14%)
Test set: Accuracy: 84/600 (14%)
Test set: Accuracy: 89/650 (14%)


KeyboardInterrupt: 