In [13]:
# Resources and tutorials:
# OpenMined Advanced tutorials: https://github.com/OpenMined/PySyft/tree/master/examples/tutorials/advanced/websockets-example-MNIST
# Andrew Task youtube lessons: https://www.youtube.com/watch?v=TWa6wFarCeI
# OpenMined Blog about FL: https://blog.openmined.org/upgrade-to-federated-learning-in-10-lines/
# OpenMined blog about setting FL and RNN with RPi: https://blog.openmined.org/federated-learning-of-a-rnn-on-raspberry-pis/
# Udacity and Facebook "Secure and Private AI challenge"

In [1]:
# %load_ext autoreload

# %autoreload 2

# load libraries





import sys
import syft as sy
#from syft.workers.virtual import VirtualWorker
from syft.workers import WebsocketClientWorker
import torch
from torchvision import datasets, transforms, models, utils
from syft.frameworks.torch.federated import utils

W0830 17:47:26.513465 4634990016 secure_random.py:26] Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was '/Users/jluissamper/.virtualenvs/pytorch/lib/python3.6/site-packages/tf_encrypted/operations/secure_random/secure_random_module_tf_1.14.0.so'
W0830 17:47:26.528890 4634990016 deprecation_wrapper.py:119] From /Users/jluissamper/.virtualenvs/pytorch/lib/python3.6/site-packages/tf_encrypted/session.py:26: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.



In [2]:
# import model to share and other client nn-related functionalities such as: next batch, train, get params...
import run_websocket_client as rwc

In [3]:
args = rwc.define_and_get_arguments(args=[])
use_cuda = args.cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
print(args)

Namespace(batch_size=64, cuda=False, epochs=2, federate_after_n_batches=50, lr=0.01, save_model=False, seed=1, test_batch_size=1000, use_virtual=False, verbose=False)


In [4]:
hook = sy.TorchHook(torch)

In [5]:
# websocket clients and workers instantiation. This step will fall if the websocket server workers are not running

kwargs_websocket_alice = {"host": "localhost", "hook": hook, "verbose": args.verbose}
alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket_alice)
print("alice set")
kwargs_websocket_bob = {"host": "192.168.1.36", "hook": hook, "verbose": args.verbose}
bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket_bob)
print("bob set")
kwargs_websocket_charlie = {"host": "192.168.1.35", "hook": hook, "verbose": args.verbose}
charlie = WebsocketClientWorker(id="Charlie", port=8779, **kwargs_websocket_charlie)
print("charlie set")
workers = [alice, bob, charlie]
print(workers)

alice set
bob set
charlie set
[<WebsocketClientWorker id:alice #objects local:0 #objects remote: 101>, <WebsocketClientWorker id:bob #objects local:0 #objects remote: 0>, <WebsocketClientWorker id:Charlie #objects local:0 #objects remote: 0>]


# Prepare and distribute the training data

In [7]:
# number of subprocesses to use for data loading
num_workers = 4
# how many samples per batch to load
batch_size = 1
# Images size to rescale
img_size = (512,512)
# percentage of training set to use as validation
valid_size = 0.2

data_dir = '~/Documents/SecureAndPrivateChallenge/sg-intro-ai-challenge/CNN - Eye Diseases/Data 15/'
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

In [None]:
eye_dataset = simpleImageLoader(csv_file='~/Documents/SecureAndPrivateChallenge/sg-intro-ai-challenge/CNN - Eye Diseases/labels/trainLabels15.csv',
                                    root_dir='~/Documents/SecureAndPrivateChallenge/sg-intro-ai-challenge/CNN - Eye Diseases/Data 15/train 15',
                                    transform = transform)  

In [None]:
eye_dataloader = torch.utils.data.DataLoader(eye_dataset, batch_size=batch_size,
                        shuffle=False, num_workers=num_workers)

In [6]:
#run this box only if the the next box gives pipeline error
torch.utils.data.DataLoader(
    datasets.MNIST(
        "../data/MNIST",
        train=True,download=True))



<torch.utils.data.dataloader.DataLoader at 0x107151ba8>

In [7]:
# Download the MNIST dataset and use federated dataloader

federated_train_loader = sy.FederatedDataLoader(
    datasets.MNIST(
        "../data/MNIST",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ).federate(tuple(workers)),
    batch_size=args.batch_size,
    shuffle=True,
    iter_per_worker=True
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../data/MNIST",
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=args.test_batch_size,
    shuffle=True
)



In [8]:
# instantiate the model, imported from run_websocket_client.py
# it is a 2 layers conv net
model = rwc.Net().to(device)
print(model)

Net(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)


In [9]:
import logging
import sys
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler(sys.stderr)
formatter = logging.Formatter("%(asctime)s %(levelname)s %(filename)s(l:%(lineno)d) - %(message)s")
handler.setFormatter(formatter)
logger.handlers = [handler]

In [10]:
# start the training

for epoch in range(1, args.epochs + 1):
    print("Starting epoch {}/{}".format(epoch, args.epochs))
    model = rwc.train(model, device, federated_train_loader, args.lr, args.federate_after_n_batches)
    rwc.test(model, device, test_loader)

Starting epoch 1/2


2019-08-30 17:51:06,338 DEBUG run_websocket_client.py(l:125) - Starting training round, batches [0, 50]
2019-08-30 17:56:23,219 DEBUG run_websocket_client.py(l:125) - Starting training round, batches [50, 100]
2019-08-30 18:01:47,530 DEBUG run_websocket_client.py(l:125) - Starting training round, batches [100, 150]
2019-08-30 18:09:20,597 DEBUG run_websocket_client.py(l:125) - Starting training round, batches [150, 200]
2019-08-30 18:18:46,845 DEBUG run_websocket_client.py(l:125) - Starting training round, batches [200, 250]
2019-08-30 18:28:02,075 DEBUG run_websocket_client.py(l:125) - Starting training round, batches [250, 300]


WebSocketConnectionClosedException: Connection is already closed.