In [1]:
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import time
import copy
import numpy as np
import syft as sy
from syft.frameworks.torch.fl import utils
from syft.workers.websocket_client import WebsocketClientWorker

Initiating the training parameters

In [2]:
class Parser:
    def __init__(self):
        self.epochs = 500
        self.lr = 0.0001
        self.test_batch_size = 8
        self.batch_size = 8
        self.log_interval = 10
        self.seed = 1
    
args = Parser()
torch.manual_seed(args.seed)

<torch._C.Generator at 0x1b01c9c41d0>

In [3]:
#import pandas as pd
#df = pd.read_csv('diabetes.csv')
#df.to_pickle('diabetes.pkl')

Dataset Preprocessing

In [4]:
import pandas as pd
from sklearn.model_selection import train_test_split
df = pd.read_csv('diabetes.csv')

print(df.shape)
print(df.head(10))
X = df.iloc[:, :-1]
y = df.iloc[:, -1]
print(X.shape)
print(y.shape)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
x = torch.from_numpy(X_train.to_numpy()).float()
y = torch.from_numpy(y_train.to_numpy()).float()
x_test = torch.from_numpy(X_test.to_numpy()).float()
y_test = torch.from_numpy(y_test.to_numpy()).float()

(768, 9)
   Pregnancies  Glucose  BloodPressure  SkinThickness  Insulin   BMI  \
0            6      148             72             35        0  33.6   
1            1       85             66             29        0  26.6   
2            8      183             64              0        0  23.3   
3            1       89             66             23       94  28.1   
4            0      137             40             35      168  43.1   
5            5      116             74              0        0  25.6   
6            3       78             50             32       88  31.0   
7           10      115              0              0        0  35.3   
8            2      197             70             45      543  30.5   
9            8      125             96              0        0   0.0   

   DiabetesPedigreeFunction  Age  Outcome  
0                     0.627   50        1  
1                     0.351   31        0  
2                     0.672   32        1  
3                     

In [5]:
mean = x.mean(0, keepdim=True)
dev = x.std(0, keepdim=True)
mean[:, 3] = 0.
dev[:, 3] = 1.
x = (x - mean) / dev
x_test = (x_test - mean) / dev
train = TensorDataset(x, y)

test = TensorDataset(x_test, y_test)
train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test, batch_size=args.test_batch_size, shuffle=True)

Creating Neural Network with PyTorch

In [6]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(8, 20)
        self.fc2 = nn.Linear(20, 10)
        self.fc3 = nn.Linear(10, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Connecting the data with the remote mobile devices

In [7]:
hook = sy.TorchHook(torch)
bob_worker = sy.VirtualWorker(hook, id="bob")
alice_worker = sy.VirtualWorker(hook, id="alice")
# kwargs_websocket = {"host": "localhost", "hook": hook}
# alice = WebsocketClientWorker(id='alice', port=8779, **kwargs_websocket)
# bob = WebsocketClientWorker(id='bob', port=8778, **kwargs_websocket)
compute_nodes = [bob_worker, alice_worker]

In [8]:
remote_dataset = (list(), list())
train_distributed_dataset = []

for batch_idx, (data,target) in enumerate(train_loader):
    data = data.send(compute_nodes[batch_idx % len(compute_nodes)])
    target = target.send(compute_nodes[batch_idx % len(compute_nodes)])
    remote_dataset[batch_idx % len(compute_nodes)].append((data, target))

In [9]:
bobs_model = Net()
alices_model = Net()
bobs_optimizer = optim.Adam(bobs_model.parameters(), lr=args.lr)
alices_optimizer = optim.Adam(alices_model.parameters(), lr=args.lr)

In [10]:
models = [bobs_model, alices_model]
optimizers = [bobs_optimizer, alices_optimizer]

In [11]:
model = Net()
model

Net(
  (fc1): Linear(in_features=8, out_features=20, bias=True)
  (fc2): Linear(in_features=20, out_features=10, bias=True)
  (fc3): Linear(in_features=10, out_features=1, bias=True)
)

Training the Neural Network

In [12]:
def update(data, target, model, optimizer):
    model.send(data.location)
    optimizer.zero_grad()
    prediction = model(data)
    loss = F.mse_loss(prediction.view(-1), target)
    loss.backward()
    optimizer.step()
    return model

def train():
    for data_index in range(len(remote_dataset[0])-1):
        for remote_index in range(len(compute_nodes)):
            data, target = remote_dataset[remote_index][data_index]
            models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])
        for model in models:
            model.get()
        return utils.federated_avg({
            "bob": models[0],
            "alice": models[1]
        })

In [13]:
def test(federated_model):
    federated_model.eval()
    test_loss = 0
    for data, target in test_loader:
        output = federated_model(data)
        test_loss += F.mse_loss(output.view(-1), target, reduction='mean').item()
        predection = output.data.max(1, keepdim=True)[1]
        
    test_loss /= len(test_loader.dataset)
    print('Test set: Average loss: {:.4f}'.format(test_loss))

In [14]:
for epoch in range(args.epochs):
    start_time = time.time()
    print(f"Epoch Number {epoch + 1}")
    federated_model = train()
    model = federated_model
    test(federated_model)
    total_time = time.time() - start_time
    print('Communication time over the network', round(total_time, 2), 's\n')

Epoch Number 1
Test set: Average loss: 0.0444
Communication time over the network 0.11 s

Epoch Number 2
Test set: Average loss: 0.3187
Communication time over the network 0.08 s

Epoch Number 3
Test set: Average loss: 0.1295
Communication time over the network 0.1 s

Epoch Number 4
Test set: Average loss: 0.0330
Communication time over the network 0.07 s

Epoch Number 5
Test set: Average loss: 0.0552
Communication time over the network 0.07 s

Epoch Number 6
Test set: Average loss: 0.0440
Communication time over the network 0.07 s

Epoch Number 7
Test set: Average loss: 0.1890
Communication time over the network 0.08 s

Epoch Number 8
Test set: Average loss: 0.1047
Communication time over the network 0.1 s

Epoch Number 9
Test set: Average loss: 0.0533
Communication time over the network 0.1 s

Epoch Number 10
Test set: Average loss: 0.1869
Communication time over the network 0.1 s

Epoch Number 11
Test set: Average loss: 0.0398
Communication time over the network 0.1 s

Epoch Number 

Test set: Average loss: 0.0932
Communication time over the network 0.1 s

Epoch Number 93
Test set: Average loss: 0.0474
Communication time over the network 0.1 s

Epoch Number 94
Test set: Average loss: 0.1327
Communication time over the network 0.09 s

Epoch Number 95
Test set: Average loss: 0.0295
Communication time over the network 0.09 s

Epoch Number 96
Test set: Average loss: 0.1251
Communication time over the network 0.08 s

Epoch Number 97
Test set: Average loss: 0.1068
Communication time over the network 0.08 s

Epoch Number 98
Test set: Average loss: 0.0655
Communication time over the network 0.08 s

Epoch Number 99
Test set: Average loss: 0.0341
Communication time over the network 0.09 s

Epoch Number 100
Test set: Average loss: 0.0333
Communication time over the network 0.08 s

Epoch Number 101
Test set: Average loss: 0.1542
Communication time over the network 0.09 s

Epoch Number 102
Test set: Average loss: 0.0347
Communication time over the network 0.07 s

Epoch Number 1

Test set: Average loss: 0.0317
Communication time over the network 0.09 s

Epoch Number 185
Test set: Average loss: 0.0315
Communication time over the network 0.1 s

Epoch Number 186
Test set: Average loss: 0.0484
Communication time over the network 0.1 s

Epoch Number 187
Test set: Average loss: 0.0699
Communication time over the network 0.1 s

Epoch Number 188
Test set: Average loss: 0.0559
Communication time over the network 0.1 s

Epoch Number 189
Test set: Average loss: 0.1046
Communication time over the network 0.11 s

Epoch Number 190
Test set: Average loss: 0.1633
Communication time over the network 0.09 s

Epoch Number 191
Test set: Average loss: 0.1324
Communication time over the network 0.11 s

Epoch Number 192
Test set: Average loss: 0.0707
Communication time over the network 0.12 s

Epoch Number 193
Test set: Average loss: 0.0305
Communication time over the network 0.13 s

Epoch Number 194
Test set: Average loss: 0.0498
Communication time over the network 0.12 s

Epoch Num

Test set: Average loss: 0.3941
Communication time over the network 0.1 s

Epoch Number 276
Test set: Average loss: 0.0615
Communication time over the network 0.09 s

Epoch Number 277
Test set: Average loss: 0.1170
Communication time over the network 0.09 s

Epoch Number 278
Test set: Average loss: 0.0302
Communication time over the network 0.1 s

Epoch Number 279
Test set: Average loss: 0.0491
Communication time over the network 0.09 s

Epoch Number 280
Test set: Average loss: 0.0657
Communication time over the network 0.1 s

Epoch Number 281
Test set: Average loss: 0.0738
Communication time over the network 0.09 s

Epoch Number 282
Test set: Average loss: 0.0398
Communication time over the network 0.09 s

Epoch Number 283
Test set: Average loss: 0.1360
Communication time over the network 0.09 s

Epoch Number 284
Test set: Average loss: 0.0350
Communication time over the network 0.1 s

Epoch Number 285
Test set: Average loss: 0.0901
Communication time over the network 0.09 s

Epoch Num

Test set: Average loss: 0.1305
Communication time over the network 0.1 s

Epoch Number 366
Test set: Average loss: 0.1871
Communication time over the network 0.1 s

Epoch Number 367
Test set: Average loss: 0.0445
Communication time over the network 0.1 s

Epoch Number 368
Test set: Average loss: 0.0496
Communication time over the network 0.1 s

Epoch Number 369
Test set: Average loss: 0.1935
Communication time over the network 0.09 s

Epoch Number 370
Test set: Average loss: 0.0765
Communication time over the network 0.1 s

Epoch Number 371
Test set: Average loss: 0.0706
Communication time over the network 0.11 s

Epoch Number 372
Test set: Average loss: 0.0358
Communication time over the network 0.1 s

Epoch Number 373
Test set: Average loss: 0.1386
Communication time over the network 0.1 s

Epoch Number 374
Test set: Average loss: 0.0534
Communication time over the network 0.11 s

Epoch Number 375
Test set: Average loss: 0.0569
Communication time over the network 0.11 s

Epoch Number

Test set: Average loss: 0.2209
Communication time over the network 0.11 s

Epoch Number 456
Test set: Average loss: 0.0602
Communication time over the network 0.11 s

Epoch Number 457
Test set: Average loss: 0.0413
Communication time over the network 0.1 s

Epoch Number 458
Test set: Average loss: 0.0313
Communication time over the network 0.1 s

Epoch Number 459
Test set: Average loss: 0.1614
Communication time over the network 0.1 s

Epoch Number 460
Test set: Average loss: 0.1102
Communication time over the network 0.11 s

Epoch Number 461
Test set: Average loss: 0.1958
Communication time over the network 0.09 s

Epoch Number 462
Test set: Average loss: 0.0433
Communication time over the network 0.09 s

Epoch Number 463
Test set: Average loss: 0.0432
Communication time over the network 0.1 s

Epoch Number 464
Test set: Average loss: 0.4984
Communication time over the network 0.1 s

Epoch Number 465
Test set: Average loss: 0.0556
Communication time over the network 0.1 s

Epoch Numbe