In [1]:
import syft as sy
from syft.lib import python


In [2]:
# We need this for the DEMO purpose because at the training time
# we want to see the loss and for doing that (in a real-world scenario)
# we will have to do a request and then to get it approved by the data owner
# Since training might generate a lot of requests and we know the VM is locally
# we kind of approve those requests locally
def get_permission(obj):
    remote_obj = alice.store[obj.id_at_location]
    remote_obj.read_permissions[alice_client.verify_key] = obj.id_at_location

In [3]:
alice = sy.VirtualMachine(name="alice")
alice_client = alice.get_root_client()

# Q: Why is this needed?
#alice.root_verify_key = alice_client.verify_key  # inject 📡🔑 as 📍🗝
remote_python = alice_client.syft.lib.python

In [4]:
# original MNIST imports
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torch.optim as optim
# from torchvision import datasets, transforms
# from torch.optim.lr_scheduler import StepLR

# get imports from remote client to match
torch = alice_client.torch
torchvision = alice_client.torchvision
transforms = torchvision.transforms
datasets = torchvision.datasets
nn = torch.nn
F = torch.nn.functional
optim = torch.optim
StepLR = torch.optim.lr_scheduler.StepLR

In [5]:
# MODIFIED MNIST NET
# class Net(nn.Module):
class Net:
    modules = []
    training = False

    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

        # add to modules list
        self.modules.append(self.conv1)
        self.modules.append(self.conv2)
        self.modules.append(self.dropout1)
        self.modules.append(self.dropout2)
        self.modules.append(self.fc1)
        self.modules.append(self.fc2)

    def train(self, mode: bool = True):
        self.training = mode
        for module in self.modules:
            module.train(mode)
        return self

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

    def __call__(self, input):
        return self.forward(input)

    # local list of remote ListPointers of TensorPointers
    def parameters(self, recurse: bool = True):
        params_list = remote_python.List()
        for module in self.modules:
            param_pointers = module.parameters()
            params_list += param_pointers

        return params_list
    
    def cuda(self, device):
        for module in self.modules:
            module.cuda(device)
    
    def cpu(self):
        for module in self.modules:
            module.cpu()

In [6]:
# Training settings
args = {
    "batch_size": 64,
    "test_batch_size": 1000,
    "epochs": 14,
    "lr": 1.0,
    "gamma": 0.7,
    "no_cuda": True,
    "dry_run": False,
    "seed": 42,
    "log_interval": 10,
    "save_model": False,
}

# Q: Here shouldn't it be?
# ptr = torch.cuda.is_available()
# ptr.request()
# Wait for approve
#cuda_available = ptr.get()
ptr = torch.cuda.is_available()
get_permission(ptr)

use_cuda = not args["no_cuda"] and ptr.get()

torch.manual_seed(args["seed"])

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'batch_size': args["batch_size"]}
if use_cuda:
    kwargs.update(
        wrap_args_dict(
            {
                'num_workers': 1,
                'pin_memory': True,
                'shuffle': True
            },
        )
    )

In [7]:
# DATA
transform_1 = torchvision.transforms.ToTensor()  # we need this to conver to torch.Tensor
transform_2 = torchvision.transforms.Normalize(0.1307, 0.3081)

lst = remote_python.List()
lst.append(transform_1)
lst.append(transform_2)
transform = torchvision.transforms.Compose(lst)
dataset1 = datasets.MNIST('../data', train=True, download=True, transform=transform)
dataset2 = datasets.MNIST('../data', train=False, transform=transform_1)
train_loader = torch.utils.data.DataLoader(dataset1,**kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **kwargs)

In [8]:
model = Net()
# model = model.to(device)

In [9]:
out = model.cpu()

In [10]:
params = model.parameters()
print(params, type(params))

<syft.proxy.syft.lib.python.ListPointer object at 0x7fb1806a04f0> <class 'syft.proxy.syft.lib.python.ListPointer'>


In [11]:
optimizer = optim.Adadelta(params, lr=args["lr"])
print(optimizer, type(optimizer))

<syft.proxy.torch.optim.AdadeltaPointer object at 0x7fb166207f70> <class 'syft.proxy.torch.optim.AdadeltaPointer'>


In [12]:
scheduler = StepLR(optimizer, step_size=1, gamma=args["gamma"])
print(scheduler, type(scheduler))

<syft.proxy.torch.optim.lr_scheduler.StepLRPointer object at 0x7fb166207c10> <class 'syft.proxy.torch.optim.lr_scheduler.StepLRPointer'>


In [13]:
# MODIFIED TRAIN
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, tensor_ptr in enumerate(train_loader):
        data, target = tensor_ptr[0], tensor_ptr[1]
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        loss_item = loss.item()

        # Usually you will need to do lots of requests and the data owner
        # should approve them, but since the training will generate lots
        # of requests we "artificially" approve them all locally
        get_permission(loss_item)
        local_loss = loss_item.get()
        if batch_idx % args["log_interval"] == 0:
            print('Train Epoch: {} {} {:.4}'.format(epoch, batch_idx, local_loss))
            if args["dry_run"]:
                break

In [None]:
for epoch in range(1, args["epochs"] + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    scheduler.step()

In [None]:
conv1 = nn.Conv2d(1, 32, 3, 1)
print(conv1, type(conv1))

In [None]:
# N X C, H, W
# make some data
g = torch.ones([32, 1, 24, 24])
print(g, type(g))

x = conv1(g)
x = F.relu(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = F.relu(x)
output = F.log_softmax(x, dim=1)

In [None]:
# get the output
print(output, type(output))

In [None]:
result = output.get()
print(result)

In [None]:
# ORIGINAL MNIST NET
# class Net(nn.Module):
#     def __init__(self) -> None:
#         super(Net, self).__init__()
#         self.conv1 = nn.Conv2d(1, 32, 3, 1)
#         self.conv2 = nn.Conv2d(32, 64, 3, 1)
#         self.dropout1 = nn.Dropout2d(0.25)
#         self.dropout2 = nn.Dropout2d(0.5)
#         self.fc1 = nn.Linear(9216, 128)
#         self.fc2 = nn.Linear(128, 10)

#     def forward(self, x):
#         x = self.conv1(x)
#         x = F.relu(x)
#         x = self.conv2(x)
#         x = F.relu(x)
#         x = F.max_pool2d(x, 2)
#         x = self.dropout1(x)
#         x = torch.flatten(x, 1)
#         x = self.fc1(x)
#         x = F.relu(x)
#         x = self.dropout2(x)
#         x = self.fc2(x)
#         output = F.log_softmax(x, dim=1)
#         return output

In [None]:
# ORIGINAL TRAIN
# def train(args, model, device, train_loader, optimizer, epoch):
#     model.train()
#     for batch_idx, (data, target) in enumerate(train_loader):
#         data, target = data.to(device), target.to(device)
#         optimizer.zero_grad()
#         output = model(data)
#         loss = F.nll_loss(output, target)
#         loss.backward()
#         optimizer.step()
#         if batch_idx % args.log_interval == 0:
#             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
#                 epoch, batch_idx * len(data), len(train_loader.dataset),
#                 100. * batch_idx / len(train_loader), loss.item()))
#             if args.dry_run:
#                 break

In [None]:
# ORIGINAL TEST
# def test(model, device, test_loader):
#     model.eval()
#     test_loss = 0
#     correct = 0
#     with torch.no_grad():
#         for data, target in test_loader:
#             data, target = data.to(device), target.to(device)
#             output = model(data)
#             test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
#             pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
#             correct += pred.eq(target.view_as(pred)).sum().item()

#     test_loss /= len(test_loader.dataset)

#     print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
#         test_loss, correct, len(test_loader.dataset),
#         100. * correct / len(test_loader.dataset)))