In [3]:
# %load_ext autoreload

# %autoreload 2

# load libraries

import sys
import syft as sy
from syft.workers.virtual import VirtualWorker
from syft.workers import WebsocketClientWorker
from syft import FederatedDataset, FederatedDataLoader, BaseDataset
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms, models, utils
from syft.frameworks.torch.federated import utils

In [4]:
# import model to share and other client nn-related functionalities such as: next batch, train, get params...
import run_websocket_client as rwc

In [6]:
args = rwc.define_and_get_arguments(args=[])
use_cuda = args.cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
print(args)

Namespace(batch_size=64, cuda=False, epochs=2, federate_after_n_batches=50, lr=0.01, save_model=False, seed=1, test_batch_size=1000, use_virtual=False, verbose=False)


In [7]:
# websocket clients and workers instantiation. This step will fall if the websocket server workers are not running
hook = sy.TorchHook(torch)

kwargs_websocket = {"host": "localhost", "hook": hook, "verbose": args.verbose}
alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)
bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket)
charlie = WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket)

workers = [alice, bob, charlie]
print(workers)

W0818 18:12:42.679344 4705670592 hook.py:98] Torch was already hooked... skipping hooking process


[<WebsocketClientWorker id:alice #objects local:0 #objects remote: 0>, <WebsocketClientWorker id:bob #objects local:0 #objects remote: 0>, <WebsocketClientWorker id:charlie #objects local:0 #objects remote: 0>]


# Prepare and distribute the training data

In [None]:
# number of subprocesses to use for data loading
num_workers = 4
# how many samples per batch to load
batch_size = 1
# Images size to rescale
img_size = (512,512)
# percentage of training set to use as validation
valid_size = 0.2

data_dir = '~/Documents/SecureAndPrivateChallenge/sg-intro-ai-challenge/CNN - Eye Diseases/Data 15/'
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

In [None]:
eye_dataset = simpleImageLoader(csv_file='~/Documents/SecureAndPrivateChallenge/sg-intro-ai-challenge/CNN - Eye Diseases/labels/trainLabels15.csv',
                                    root_dir='~/Documents/SecureAndPrivateChallenge/sg-intro-ai-challenge/CNN - Eye Diseases/Data 15/train 15',
                                    transform = transform)  

In [None]:
eye_dataloader = torch.utils.data.DataLoader(eye_dataset, batch_size=batch_size,
                        shuffle=False, num_workers=num_workers)

In [10]:
#run this box only if the the next box gives pipeline error
torch.utils.data.DataLoader(
    datasets.MNIST(
        "../data/MNIST",
        train=True,download=True))



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz


100.1%

Extracting ../data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz


113.5%

Extracting ../data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz


100.4%

Extracting ../data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz


180.4%

Extracting ../data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!


<torch.utils.data.dataloader.DataLoader at 0x143bc8780>

In [11]:
# Download the MNIST dataset and use federated dataloader

federated_train_loader = sy.FederatedDataLoader(
    datasets.MNIST(
        "../data/MNIST",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ).federate(tuple(workers)),
    batch_size=args.batch_size,
    shuffle=True,
    iter_per_worker=True
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../data/MNIST",
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=args.test_batch_size,
    shuffle=True
)



In [12]:
# instantiate the model, imported from run_websocket_client.py
# it is a 2 layers conv net
model = rwc.Net().to(device)
print(model)

Net(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)


In [13]:
import logging
import sys
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler(sys.stderr)
formatter = logging.Formatter("%(asctime)s %(levelname)s %(filename)s(l:%(lineno)d) - %(message)s")
handler.setFormatter(formatter)
logger.handlers = [handler]

In [14]:
# start the training

for epoch in range(1, args.epochs + 1):
    print("Starting epoch {}/{}".format(epoch, args.epochs))
    model = rwc.train(model, device, federated_train_loader, args.lr, args.federate_after_n_batches)
    rwc.test(model, device, test_loader)

Starting epoch 1/2


2019-08-18 18:16:32,669 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [0, 50]
2019-08-18 18:17:05,735 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [50, 100]
2019-08-18 18:17:38,089 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [100, 150]
2019-08-18 18:18:10,274 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [150, 200]
2019-08-18 18:18:42,446 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [200, 250]
2019-08-18 18:19:15,038 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [250, 300]
2019-08-18 18:19:38,354 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [300, 350]
2019-08-18 18:19:49,831 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [350, 400]
2019-08-18 18:19:49,842 DEBUG run_websocket_client.py(l:136) - At least one worker ran out of data, stopping.
2019-08-18 18:19:53,920 DEBUG run_webs

Starting epoch 2/2


2019-08-18 18:20:05,853 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [0, 50]
2019-08-18 18:20:39,716 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [50, 100]
2019-08-18 18:21:12,556 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [100, 150]
2019-08-18 18:21:44,935 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [150, 200]
2019-08-18 18:22:17,318 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [200, 250]
2019-08-18 18:22:49,987 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [250, 300]
2019-08-18 18:23:13,548 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [300, 350]
2019-08-18 18:23:25,161 DEBUG run_websocket_client.py(l:123) - Starting training round, batches [350, 400]
2019-08-18 18:23:25,172 DEBUG run_websocket_client.py(l:136) - At least one worker ran out of data, stopping.
2019-08-18 18:23:29,217 DEBUG run_webs