In [1]:
import syft as sy
from syft.lib import python
sy.VERBOSE = False
# from syft.core.common.uid import UID
from syft.util import syrange

In [2]:
alice = sy.VirtualMachine(name="alice")
alice_client = alice.get_root_client()
#alice.root_verify_key = alice_client.verify_key  # inject 📡🔑 as 📍🗝
remote_python = alice_client.syft.lib.python

In [3]:
# original MNIST imports
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torch.optim as optim
# from torchvision import datasets, transforms
# from torch.optim.lr_scheduler import StepLR

# get imports from remote client to match
torch = alice_client.torch
torchvision = alice_client.torchvision
transforms = torchvision.transforms
datasets = torchvision.datasets
nn = torch.nn
F = torch.nn.functional
optim = torch.optim
StepLR = torch.optim.lr_scheduler.StepLR

In [4]:
# MODIFIED MNIST NET
# class Net(nn.Module):
class Net:
    modules = []
    training = False
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
        
        # add to modules list
        self.modules.append(self.conv1)
        self.modules.append(self.conv2)
        self.modules.append(self.dropout1)
        self.modules.append(self.dropout2)
        self.modules.append(self.fc1)
        self.modules.append(self.fc2)

    def train(self, mode: bool = True):
        self.training = mode
        for module in self.modules:
            module.train(mode)
        return self

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output
    
    # remote ListPointers of TensorPointers
#     def parameters(self, recurse: bool = True):
#         params_list = []
#         prev = None
#         for module in self.modules:
#             param_pointers = module.parameters()
#             # hack to work around remote generator
#             param_pointers = remote_python.List(param_pointers)
#             print("type", type(param_pointers))
#             if prev is None:
#                 prev = param_pointers
#             else:
#                 param_pointers.__add__(prev)
#                 prev = param_pointers

#         return prev

    def __call__(self, input):
        return self.forward(input)

    # local list of remote ListPointers of TensorPointers
    def parameters(self, recurse: bool = True):
        params_list = []
        for module in self.modules:
            param_pointers = module.parameters()
            # hack to work around remote generator
            param_pointers = remote_python.List(param_pointers)
            for pointer in syrange(param_pointers):
                params_list.append(pointer)

        return params_list

In [5]:
# Training settings
args = {
    "batch_size": 64,
    "test_batch_size": 1000,
    "epochs": 14,
    "lr": 1.0,
    "gamma": 0.7,
    "no_cuda": True,
    "dry_run": False,
    "seed": 42,
    "log_interval": 10,
    "save_model": False,
}

use_cuda = not args["no_cuda"] and torch.cuda.is_available()
torch.manual_seed(args["seed"])

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'batch_size': args["batch_size"]}
if use_cuda:
    kwargs.update(
        wrap_args_dict(
            {
                'num_workers': 1,
                'pin_memory': True,
                'shuffle': True
            },
        )
    )

In [6]:
# DATA
transform_1 = torchvision.transforms.ToTensor()  # we need this to conver to torch.Tensor

# Unable to compose currently, need to fix storing pointers
#transform_2 = torchvision.transforms.Normalize(0.1307, 0.3081)
#transform = torchvision.transforms.Compose([transform_1, transform_2])

dataset1 = datasets.MNIST('../data', train=True, download=True, transform=transform_1)
dataset2 = datasets.MNIST('../data', train=False, transform=transform_1)
train_loader = torch.utils.data.DataLoader(dataset1,**kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **kwargs)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100.1%

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


113.5%

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100.4%

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


180.4%

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [7]:
model = Net()
# model = model.to(device)

In [8]:
params = model.parameters()
print(params, type(params))

[<syft.proxy.torch.TensorPointer object at 0x7fb8c8d39950>, <syft.proxy.torch.TensorPointer object at 0x7fb8aa778390>, <syft.proxy.torch.TensorPointer object at 0x7fb8c8d29c10>, <syft.proxy.torch.TensorPointer object at 0x7fb8aa778190>, <syft.proxy.torch.TensorPointer object at 0x7fb8c8d32410>, <syft.proxy.torch.TensorPointer object at 0x7fb8b8824110>, <syft.proxy.torch.TensorPointer object at 0x7fb8c8d29a90>, <syft.proxy.torch.TensorPointer object at 0x7fb88802d510>] <class 'list'>


In [9]:
# need to fix
# optimizer = optim.Adadelta(params_list, lr=args["lr"])
# print(optimizer, type(optimizer))

In [10]:
# need to fix
# scheduler = StepLR(optimizer, step_size=1, gamma=args["gamma"])

In [11]:
# MODIFIED TRAIN
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, tensor_ptr in enumerate(train_loader):
        # destructure by using __getitem__
        data, target = tensor_ptr[0], tensor_ptr[1]
#         data, target = data.to(device), target.to(device)
#         optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
#         optimizer.step()
        if batch_idx % args["log_interval"] == 0:
            print('Train Epoch: {} {}'.format(epoch, batch_idx))
            if args["dry_run"]:
                break

In [12]:
# epoch = 0
# optimizer = None
# train(args, model, device, train_loader, optimizer, epoch)

In [13]:
for epoch in range(1, args["epochs"] + 1):
    optimizer = None
    train(args, model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    # scheduler.step()

Train Epoch: 1 0
Train Epoch: 1 10
Train Epoch: 1 20
Train Epoch: 1 30
Train Epoch: 1 40
Train Epoch: 1 50
Train Epoch: 1 60
Train Epoch: 1 70
Train Epoch: 1 80
Train Epoch: 1 90
Train Epoch: 1 100
Train Epoch: 1 110
Train Epoch: 1 120
Train Epoch: 1 130
Train Epoch: 1 140
Train Epoch: 1 150
Train Epoch: 1 160
Train Epoch: 1 170
Train Epoch: 1 180
Train Epoch: 1 190
Train Epoch: 1 200
Train Epoch: 1 210
Train Epoch: 1 220
Train Epoch: 1 230
Train Epoch: 1 240
Train Epoch: 1 250
Train Epoch: 1 260
Train Epoch: 1 270
Train Epoch: 1 280
Train Epoch: 1 290
Train Epoch: 1 300
Train Epoch: 1 310
Train Epoch: 1 320
Train Epoch: 1 330
Train Epoch: 1 340
Train Epoch: 1 350
Train Epoch: 1 360
Train Epoch: 1 370
Train Epoch: 1 380
Train Epoch: 1 390
Train Epoch: 1 400


KeyboardInterrupt: 

In [None]:
conv1 = nn.Conv2d(1, 32, 3, 1)
print(conv1, type(conv1))

In [None]:
# N X C, H, W
# make some data
g = torch.ones([32, 1, 24, 24])
print(g, type(g))

x = conv1(g)
x = F.relu(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = F.relu(x)
output = F.log_softmax(x, dim=1)

In [None]:
# get the output
print(output, type(output))

In [None]:
result = output.get()
print(result)

In [None]:
# ORIGINAL MNIST NET
# class Net(nn.Module):
#     def __init__(self) -> None:
#         super(Net, self).__init__()
#         self.conv1 = nn.Conv2d(1, 32, 3, 1)
#         self.conv2 = nn.Conv2d(32, 64, 3, 1)
#         self.dropout1 = nn.Dropout2d(0.25)
#         self.dropout2 = nn.Dropout2d(0.5)
#         self.fc1 = nn.Linear(9216, 128)
#         self.fc2 = nn.Linear(128, 10)

#     def forward(self, x):
#         x = self.conv1(x)
#         x = F.relu(x)
#         x = self.conv2(x)
#         x = F.relu(x)
#         x = F.max_pool2d(x, 2)
#         x = self.dropout1(x)
#         x = torch.flatten(x, 1)
#         x = self.fc1(x)
#         x = F.relu(x)
#         x = self.dropout2(x)
#         x = self.fc2(x)
#         output = F.log_softmax(x, dim=1)
#         return output

In [None]:
# ORIGINAL TRAIN
# def train(args, model, device, train_loader, optimizer, epoch):
#     model.train()
#     for batch_idx, (data, target) in enumerate(train_loader):
#         data, target = data.to(device), target.to(device)
#         optimizer.zero_grad()
#         output = model(data)
#         loss = F.nll_loss(output, target)
#         loss.backward()
#         optimizer.step()
#         if batch_idx % args.log_interval == 0:
#             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
#                 epoch, batch_idx * len(data), len(train_loader.dataset),
#                 100. * batch_idx / len(train_loader), loss.item()))
#             if args.dry_run:
#                 break

In [None]:
# ORIGINAL TEST
# def test(model, device, test_loader):
#     model.eval()
#     test_loss = 0
#     correct = 0
#     with torch.no_grad():
#         for data, target in test_loader:
#             data, target = data.to(device), target.to(device)
#             output = model(data)
#             test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
#             pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
#             correct += pred.eq(target.view_as(pred)).sum().item()

#     test_loss /= len(test_loader.dataset)

#     print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
#         test_loss, correct, len(test_loader.dataset),
#         100. * correct / len(test_loader.dataset)))