In [1]:
epochs = 1
# We don't use the whole dataset for efficiency purpose, but feel free to increase these numbers
n_train_items = 64
n_test_items = 64

In [2]:
import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

import time

In [26]:
class Arguments():
    def __init__(self):
        self.batch_size = 16
        self.test_batch_size = 16
        self.epochs = epochs
        self.lr = 0.02
        self.seed = 1
        self.log_interval = 1 # Log info at each batch
        self.precision_fractional = 4

args = Arguments()

_ = torch.manual_seed(args.seed)

In [4]:
import syft as sy  # import the Pysyft library
hook = sy.TorchHook(torch)  # hook PyTorch to add extra functionalities like Federated and Encrypted Learning

# simulation functions
def connect_to_workers(n_workers):
    return [
        sy.VirtualWorker(hook, id=f"worker{i+1}")
        for i in range(n_workers)
    ]
def connect_to_crypto_provider():
    return sy.VirtualWorker(hook, id="crypto_provider")

workers = connect_to_workers(n_workers=2)
crypto_provider = connect_to_crypto_provider()

In [5]:
t = th.randint(1024, (4, 1, 8, 8))
x = t.fix_precision(precision_fractional=args.precision_fractional).share(*workers, crypto_provider=crypto_provider, protocol="fss", requires_grad=True)
print(x)
y = F.max_pool2d(x, 2)
z = y.sum()
print(z)
z.backward()

(Wrapper)>AutogradTensor>FixedPrecisionTensor>[AdditiveSharingTensor]
	-> [PointerTensor | me:85658272934 -> worker1:19456649238]
	-> [PointerTensor | me:68430559625 -> worker2:61686137092]
	*crypto provider: crypto_provider*
(Wrapper)>AutogradTensor>FixedPrecisionTensor>[AdditiveSharingTensor]
	-> [PointerTensor | me:67035910710 -> worker1:93236807433]
	-> [PointerTensor | me:92520734188 -> worker2:35722020429]
	*crypto provider: crypto_provider*
Backward MAXPOOL2D


In [6]:
t = t.float()
t.requires_grad = True
F.max_pool2d(t, 2).sum().backward()
t.grad

tensor([[[[0., 0., 1., 0., 0., 0., 0., 0.],
          [0., 1., 0., 0., 1., 0., 1., 0.],
          [0., 0., 0., 0., 0., 0., 1., 0.],
          [1., 0., 0., 1., 1., 0., 0., 0.],
          [0., 0., 1., 0., 0., 0., 1., 0.],
          [0., 1., 0., 0., 0., 1., 0., 0.],
          [0., 0., 0., 0., 0., 1., 1., 0.],
          [0., 1., 1., 0., 0., 0., 0., 0.]]],


        [[[1., 0., 0., 0., 1., 0., 1., 0.],
          [0., 0., 1., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 1., 1., 0., 1., 0., 1., 0.],
          [0., 0., 0., 0., 1., 0., 0., 0.],
          [1., 0., 0., 1., 0., 0., 1., 0.],
          [0., 1., 0., 0., 0., 0., 0., 1.],
          [0., 0., 1., 0., 0., 1., 0., 0.]]],


        [[[1., 0., 0., 1., 1., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 1., 0.],
          [0., 0., 0., 0., 0., 1., 0., 1.],
          [1., 0., 1., 0., 0., 0., 0., 0.],
          [0., 1., 0., 1., 1., 0., 1., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 1., 1., 0

In [7]:
x.grad.decrypt()

tensor([[[[0., 0., 1., 0., 0., 0., 0., 0.],
          [0., 1., 0., 0., 1., 0., 1., 0.],
          [0., 0., 0., 0., 0., 0., 1., 0.],
          [1., 0., 0., 1., 1., 0., 0., 0.],
          [0., 0., 1., 0., 0., 0., 1., 0.],
          [0., 1., 0., 0., 0., 1., 0., 0.],
          [0., 0., 0., 0., 0., 1., 1., 0.],
          [0., 1., 1., 0., 0., 0., 0., 0.]]],


        [[[1., 0., 0., 0., 1., 0., 1., 0.],
          [0., 0., 1., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 1., 1., 0., 1., 0., 1., 0.],
          [0., 0., 0., 0., 1., 0., 0., 0.],
          [1., 0., 0., 1., 0., 0., 1., 0.],
          [0., 1., 0., 0., 0., 0., 0., 1.],
          [0., 0., 1., 0., 0., 1., 0., 0.]]],


        [[[1., 0., 0., 1., 1., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 1., 0.],
          [0., 0., 0., 0., 0., 1., 0., 1.],
          [1., 0., 1., 0., 0., 0., 0., 0.],
          [0., 1., 0., 1., 1., 0., 1., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 1., 1., 0

In [8]:
def get_private_data_loaders(precision_fractional, workers, crypto_provider):
    
    def one_hot_of(index_tensor):
        """
        Transform to one hot tensor
        
        Example:
            [0, 3, 9]
            =>
            [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
             [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]
            
        """
        onehot_tensor = torch.zeros(*index_tensor.shape, 10) # 10 classes for MNIST
        onehot_tensor = onehot_tensor.scatter(1, index_tensor.view(-1, 1), 1)
        return onehot_tensor
        
    def secret_share(tensor):
        """
        Transform to fixed precision and secret share a tensor
        """
        return (
            tensor
            .fix_precision(precision_fractional=precision_fractional)
            .share(*workers, crypto_provider=crypto_provider, protocol="fss", requires_grad=True)
        )
    
    transformation = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True, transform=transformation),
        batch_size=args.batch_size
    )
    
    private_train_loader = [
        (secret_share(data), secret_share(one_hot_of(target)))
        for i, (data, target) in enumerate(train_loader)
        if i < n_train_items / args.batch_size
    ]
    
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, download=True, transform=transformation),
        batch_size=args.test_batch_size
    )
    
    private_test_loader = [
        (secret_share(data), secret_share(target.float()))
        for i, (data, target) in enumerate(test_loader)
        if i < n_test_items / args.test_batch_size
    ]
    
    return private_train_loader, private_test_loader
    
    
private_train_loader, private_test_loader = get_private_data_loaders(
    precision_fractional=args.precision_fractional,
    workers=workers,
    crypto_provider=crypto_provider
)

In [9]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=0)
        self.conv2 = nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=0)
        self.fc1 = nn.Linear(256, 100)
        self.fc2 = nn.Linear(100, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = F.relu(x)  ## inverted!
        x = self.conv2(x)
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = F.relu(x)  ## inverted!
        x = x.reshape(-1, 256) # for some weird reason .view doesn't for after the 1st batch
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        return x

In [27]:
import numpy as np

## TEST Backprop
batch_size = 2

torch.manual_seed(42)
model = Net()
model.train()
data = ((th.randint(1024, (batch_size, 1, 28, 28)).float() - 512) / 512)
target = th.randint(2, (batch_size, 10)).float()

output = model(data.float())
loss = ((output - target)**2).sum()/batch_size
loss.backward()

print(np.round(model.fc2.weight.grad[5][10: 26], 3))
print(np.round(model.fc1.weight.grad[5][8: 26], 3))
print(np.round(model.conv2.weight.grad[0][3], 3))
print(np.round(model.conv1.weight.grad[0], 3))

torch.manual_seed(42)
model = Net()
model = model.fix_precision(precision_fractional=args.precision_fractional).share(*workers, crypto_provider=crypto_provider, protocol="fss", requires_grad=True)
data = data.fix_precision(precision_fractional=args.precision_fractional).share(*workers, crypto_provider=crypto_provider, protocol="fss", requires_grad=True)
target = target.fix_precision(precision_fractional=args.precision_fractional).share(*workers, crypto_provider=crypto_provider, protocol="fss", requires_grad=True)
output = model(data)
loss = ((output - target)**2).sum().refresh()/batch_size
loss.backward()
print(np.round(model.fc2.weight.grad.get().float_prec()[5][10: 26], 3))
print(np.round(model.fc1.weight.grad.get().float_prec()[5][8: 26], 3))
print(np.round(model.conv2.weight.grad.get().float_prec()[0][3], 3))
print(np.round(model.conv1.weight.grad.get().float_prec()[0], 3))


tensor([-0.0860, -0.0510,  0.0000,  0.0000,  0.0000, -0.3610, -0.0210,  0.0000,
        -0.2650, -0.0960,  0.0000,  0.0000,  0.0000, -0.4180, -0.1110,  0.0000])
tensor([0.0290, 0.0400, 0.0160, 0.0100, 0.0670, 0.0180, 0.0250, 0.0460, 0.0440,
        0.0560, 0.0400, 0.0480, 0.0510, 0.0430, 0.0520, 0.0640, 0.0740, 0.0720])
tensor([[ 0.0590,  0.0720, -0.0060,  0.1200,  0.0140],
        [ 0.0820,  0.0990, -0.0240, -0.0070,  0.0590],
        [-0.0120,  0.0500,  0.0680,  0.0620,  0.0360],
        [ 0.0260, -0.0030,  0.1120,  0.0550,  0.0090],
        [ 0.0500,  0.0470, -0.0360,  0.0510,  0.0030]])
tensor([[[-0.0610, -0.0430,  0.1020, -0.0320, -0.0380],
         [ 0.0900,  0.0540,  0.0120,  0.0560,  0.0360],
         [ 0.0220, -0.0530, -0.0560,  0.0020, -0.0070],
         [-0.0530, -0.0590, -0.0120,  0.0150, -0.0630],
         [-0.0410,  0.0810,  0.0260,  0.0170, -0.0110]]])
Backward RELU
Backward RELU
Backward RELU
Backward MAXPOOL2D
Backward CONV2D
Backward RELU
Backward MAXPOOL2D
Backward C

In [10]:
def train(args, model, private_train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(private_train_loader): # <-- now it is a private dataset
        start_time = time.time()
        
        optimizer.zero_grad()
        
        output = model(data)
        
        # loss = F.nll_loss(output, target)  <-- not possible here
        batch_size = output.shape[0]
        loss = ((output - target)**2).sum().refresh()/batch_size
        
        loss.backward()
        
        optimizer.step()

        if batch_idx % args.log_interval == 0:
            loss = loss.get().float_precision()
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tTime: {:.3f}s'.format(
                epoch, batch_idx * args.batch_size, len(private_train_loader) * args.batch_size,
                100. * batch_idx / len(private_train_loader), loss.item(), time.time() - start_time))
            

In [11]:
def test(args, model, private_test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in private_test_loader:
            start_time = time.time()
            
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target.view_as(pred)).sum()

    correct = correct.get().float_precision()
    print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct.item(), len(private_test_loader)* args.test_batch_size,
        100. * correct.item() / (len(private_test_loader) * args.test_batch_size)))

In [12]:
model = Net()
model = model.fix_precision(precision_fractional=args.precision_fractional).share(*workers, crypto_provider=crypto_provider, protocol="fss", requires_grad=True)

optimizer = optim.SGD(model.parameters(), lr=args.lr)
optimizer = optimizer.fix_precision(precision_fractional=args.precision_fractional) 

for epoch in range(1, args.epochs + 1):
    train(args, model, private_train_loader, optimizer, epoch)
    test(args, model, private_test_loader)

Backward RELU
Backward RELU
Backward RELU
Backward MAXPOOL2D
Backward CONV2D
Backward RELU
Backward MAXPOOL2D
Backward CONV2D
Backward RELU
Backward RELU
Backward RELU
Backward MAXPOOL2D
Backward CONV2D
Backward RELU
Backward MAXPOOL2D
Backward CONV2D
Backward RELU
Backward RELU
Backward RELU
Backward MAXPOOL2D
Backward CONV2D
Backward RELU
Backward MAXPOOL2D
Backward CONV2D
Backward RELU
Backward RELU
Backward RELU
Backward MAXPOOL2D
Backward CONV2D
Backward RELU
Backward MAXPOOL2D
Backward CONV2D
Backward RELU
Backward RELU
Backward RELU
Backward MAXPOOL2D
Backward CONV2D
Backward RELU
Backward MAXPOOL2D
Backward CONV2D
Backward RELU
Backward RELU
Backward RELU
Backward MAXPOOL2D
Backward CONV2D
Backward RELU
Backward MAXPOOL2D
Backward CONV2D
Backward RELU
Backward RELU
Backward RELU
Backward MAXPOOL2D
Backward CONV2D
Backward RELU
Backward MAXPOOL2D
Backward CONV2D
Backward RELU
Backward RELU
Backward RELU
Backward MAXPOOL2D
Backward CONV2D
Backward RELU
Backward MAXPOOL2D
Backward