In [1]:
import syft as sy
from syft.lib import python
sy.VERBOSE = False



In [2]:
alice = sy.VirtualMachine(name="alice")
alice_client = alice.get_root_client()
alice.root_verify_key = alice_client.verify_key  # inject 📡🔑 as 📍🗝

In [3]:
# original MNIST imports
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torch.optim as optim
# from torchvision import datasets, transforms
# from torch.optim.lr_scheduler import StepLR

# get imports from remote client to match
torch = alice_client.torch
torchvision = alice_client.torchvision
transforms = torchvision.transforms
datasets = torchvision.datasets
nn = torch.nn
F = torch.nn.functional

In [4]:
# helper to create and send a primitive
def make(thing):
    prim = python.primitive_factory.PrimitiveFactory.generate_primitive(value=thing)
    ptr = prim.send(alice_client)
    return ptr

In [5]:
conv1 = nn.Conv2d(make(1), make(32), make(3), make(1))
print(conv1, type(conv1))

# TODO: make modules callable
# conv1()

<syft.proxy.torch.nn.Conv2dPointer object at 0x13dc2bb90> <class 'syft.proxy.torch.nn.Conv2dPointer'>


In [6]:
# make a tiny fake image tensor
import torch as local_torch
g = local_torch.Tensor([[[1, 1, 1] * 3] * 3])
g = g.send(alice_client)

In [7]:
# test all the Functional methods
x = F.relu(g)
x = F.relu(x)
x = F.max_pool2d(x, make(2))
x = torch.flatten(x, make(1))
x = F.relu(x)
output = F.log_softmax(x, dim=make(1))

In [8]:
# get the output
print(output, type(output))

result = output.get()
print(result)

<syft.proxy.torch.TensorPointer object at 0x13dc2b0d0> <class 'syft.proxy.torch.TensorPointer'>
tensor([[-1.3863, -1.3863, -1.3863, -1.3863]])


In [9]:
# ORIGINAL MNIST NET
# class Net(nn.Module):
#     def __init__(self) -> None:
#         super(Net, self).__init__()
#         self.conv1 = nn.Conv2d(1, 32, 3, 1)
#         self.conv2 = nn.Conv2d(32, 64, 3, 1)
#         self.dropout1 = nn.Dropout2d(0.25)
#         self.dropout2 = nn.Dropout2d(0.5)
#         self.fc1 = nn.Linear(9216, 128)
#         self.fc2 = nn.Linear(128, 10)

#     def forward(self, x):
#         x = self.conv1(x)
#         x = F.relu(x)
#         x = self.conv2(x)
#         x = F.relu(x)
#         x = F.max_pool2d(x, 2)
#         x = self.dropout1(x)
#         x = torch.flatten(x, 1)
#         x = self.fc1(x)
#         x = F.relu(x)
#         x = self.dropout2(x)
#         x = self.fc2(x)
#         output = F.log_softmax(x, dim=1)
#         return output

# define a Network using make
# class Net(nn.Module):
class Net:
    def __init__(self) -> None:
        self.conv1 = nn.Conv2d(make(1), make(32), make(3), make(1))
        self.conv2 = nn.Conv2d(make(32), make(64), make(3), make(1))
        self.dropout1 = nn.Dropout2d(make(0.25))
        self.dropout2 = nn.Dropout2d(make(0.5))
        self.fc1 = nn.Linear(make(9216), make(128))
        self.fc2 = nn.Linear(make(128), make(10))

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, make(2))
        x = self.dropout1(x)
        x = torch.flatten(x, make(1))
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=make(1))
        return output

In [10]:
net = Net()

In [11]:
# net.forward(g)
# fix Module callable

In [12]:
# ORIGINAL TRAIN
# def train(args, model, device, train_loader, optimizer, epoch):
#     model.train()
#     for batch_idx, (data, target) in enumerate(train_loader):
#         data, target = data.to(device), target.to(device)
#         optimizer.zero_grad()
#         output = model(data)
#         loss = F.nll_loss(output, target)
#         loss.backward()
#         optimizer.step()
#         if batch_idx % args.log_interval == 0:
#             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
#                 epoch, batch_idx * len(data), len(train_loader.dataset),
#                 100. * batch_idx / len(train_loader), loss.item()))
#             if args.dry_run:
#                 break

In [13]:
# ORIGINAL TEST
# def test(model, device, test_loader):
#     model.eval()
#     test_loss = 0
#     correct = 0
#     with torch.no_grad():
#         for data, target in test_loader:
#             data, target = data.to(device), target.to(device)
#             output = model(data)
#             test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
#             pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
#             correct += pred.eq(target.view_as(pred)).sum().item()

#     test_loss /= len(test_loader.dataset)

#     print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
#         test_loss, correct, len(test_loader.dataset),
#         100. * correct / len(test_loader.dataset)))

In [14]:
# fix issue with 
transform_1 = torchvision.transforms.ToTensor()
transform_2 = torchvision.transforms.Normalize(make(0.1307), make(0.3081))

# fix cant store pointer? what to do?
# transform = torchvision.transforms.Compose(make([transform_1, transform_2]))

In [15]:
# utility to convert config to pointers
# todo replace with dict
def wrap_args_dict(args: dict) -> dict:
    wrapped_args = {}
    for k, v in args.items():
        if sy.lib.python.primitive_factory.isprimitive(v):
            wrapped_args[k] = make(v)
    return wrapped_args

def main():
    # Training settings
    plain_args = {"batch_size":64, "test_batch_size":1000,
            "epochs":14, "lr":1.0, "gamma":0.7,
            "no_cuda":True, "dry_run":False, "seed":42,
            "log_interval":10, "save_model":False}
    
    args = wrap_args_dict(plain_args)

    use_cuda = not args["no_cuda"] and torch.cuda.is_available()
    device_type = sy.lib.python.String("cuda" if use_cuda else "cpu")
    device_type_ptr = device_type.send(alice_client)
    torch.manual_seed(args["seed"])

    device = torch.device(device_type_ptr)

    kwargs = {'batch_size': args["batch_size"]}
    if use_cuda:
        kwargs.update(
            wrap_args_dict(
                {
                    'num_workers': 1,
                    'pin_memory': True,
                    'shuffle': True
                },
            )
        )
        
    transform_1 = torchvision.transforms.ToTensor()
    transform_2 = torchvision.transforms.Normalize(make(0.1307), make(0.3081))

    # fix cant store pointer? what to do?
    # transform = torchvision.transforms.Compose(make([transform_1, transform_2]))

    dataset1 = datasets.MNIST(make('../data'), train=make(True), download=make(True))
                       #transform=transform)
    dataset2 = datasets.MNIST(make('../data'), train=make(False))
#                        transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1,**kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **kwargs)

    #model = Net().to(device)
    model = Net()
#     optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

#     scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
#     for epoch in range(1, args["epochs"] + 1):
#         train(args, model, device, train_loader, optimizer, epoch)
#         test(model, device, test_loader)
#         scheduler.step()

#     if args["save_model"]:
#         torch.save(model.state_dict(), "mnist_cnn.pt")

In [16]:
main()