In [1]:
import os
import struct
import socket
import pickle
import time
import sys
import copy
import numpy as np

import h5py
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

from threading import Thread
from threading import Lock

In [2]:
users = 2
user_ram = [8,4]

In [3]:
rounds = 10
local_epoch = 1
epochs = 10

## CUDA Creation

In [4]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = "cpu"
print(device)

cpu


In [5]:
# def hyperparameter_federated(userlist):
#     clientsoclist = [0]*userlist
    
#     start_time = 0
#     weight_count = 0
    
#     global_weights = copy.deepcopy(ecg_net.state_dict())
    
#     datasetsize = [0]*userlist
#     weights_list = [0]*userlist

#     return clientsoclist,start_time,weight_count,global_weights,datasetsize,weights_list
# ## if user > 3 then put one user in split user ram = [10,8,4]

lock = Lock()

In [6]:
total_sendsize_list = []
total_receivesize_list = []

client_sendsize_list = [[] for i in range(users)]
client_receivesize_list = [[] for i in range(users)]

train_sendsize_list = [] 
train_receivesize_list = []

In [7]:
# def hyperparameter_split():
#     criterion = nn.CrossEntropyLoss()
#     lr = 0.001
#     optimizer_server = Adam(ecg_net_splitserver.parameters(), lr=lr)
    
#     clientsoclist = []
#     train_total_batch = []
#     val_acc = []
#     client_weights = copy.deepcopy(ecg_net_splitclient.state_dict())
    
#     train_acc = []
#     val_acc = []
#     return criterion,lr,optimizer_server,clientsoclist,train_total_batch, val_acc,client_weights,train_acc

## Data Loading

In [8]:
root_path = 'models/'

In [9]:
# Used only in split learning
class ECG(Dataset):
    def __init__(self, train=True):
        if train:
            with h5py.File(os.path.join(root_path, 'ecg_data', 'train_ecg.hdf5'), 'r') as hdf:
                self.x = hdf['x_train'][:]
                self.y = hdf['y_train'][:]
        else:
            with h5py.File(os.path.join(root_path, 'ecg_data', 'test_ecg.hdf5'), 'r') as hdf:
                self.x = hdf['x_test'][:]
                self.y = hdf['y_test'][:]
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return torch.tensor(self.x[idx], dtype=torch.float), torch.tensor(self.y[idx])

In [10]:
batch_size = 32

In [11]:
train_dataset = ECG(train=True)
test_dataset = ECG(train=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

## Construct 1D-CNN ECG classification model

In [12]:
class EcgConv1d_Federated(nn.Module):
    def __init__(self):
        super(EcgConv1d_Federated, self).__init__()        
        self.conv1 = nn.Conv1d(1, 16, 7)  # 124 x 16        
        self.relu1 = nn.LeakyReLU()
        self.pool1 = nn.MaxPool1d(2)  # 62 x 16
        self.conv2 = nn.Conv1d(16, 16, 5)  # 58 x 16
        self.relu2 = nn.LeakyReLU()        
        self.conv3 = nn.Conv1d(16, 16, 5)  # 54 x 16
        self.relu3 = nn.LeakyReLU()        
        self.conv4 = nn.Conv1d(16, 16, 5)  # 50 x 16
        self.relu4 = nn.LeakyReLU()
        self.pool4 = nn.MaxPool1d(2)  # 25 x 16
        self.linear5 = nn.Linear(25 * 16, 128)
        self.relu5 = nn.LeakyReLU()        
        self.linear6 = nn.Linear(128, 5)
        self.softmax6 = nn.Softmax(dim=1)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)        
        x = self.conv2(x)
        x = self.relu2(x)        
        x = self.conv3(x)
        x = self.relu3(x)        
        x = self.conv4(x)
        x = self.relu4(x)
        x = self.pool4(x)
        x = x.view(-1, 25 * 16)
        x = self.linear5(x)
        x = self.relu5(x)        
        x = self.linear6(x)
        x = self.softmax6(x)
        return x        

In [13]:
ecg_net_Federated = EcgConv1d_Federated()
ecg_net_Federated.to('cpu')

EcgConv1d_Federated(
  (conv1): Conv1d(1, 16, kernel_size=(7,), stride=(1,))
  (relu1): LeakyReLU(negative_slope=0.01)
  (pool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(16, 16, kernel_size=(5,), stride=(1,))
  (relu2): LeakyReLU(negative_slope=0.01)
  (conv3): Conv1d(16, 16, kernel_size=(5,), stride=(1,))
  (relu3): LeakyReLU(negative_slope=0.01)
  (conv4): Conv1d(16, 16, kernel_size=(5,), stride=(1,))
  (relu4): LeakyReLU(negative_slope=0.01)
  (pool4): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (linear5): Linear(in_features=400, out_features=128, bias=True)
  (relu5): LeakyReLU(negative_slope=0.01)
  (linear6): Linear(in_features=128, out_features=5, bias=True)
  (softmax6): Softmax(dim=1)
)

In [14]:
class EcgConv1d_SplitServer(nn.Module):
    def __init__(self):
        super(EcgConv1d_SplitServer, self).__init__()
#         self.conv1 = nn.Conv1d(1, 16, 7, padding=3)  # 128 x 16
#         self.relu1 = nn.LeakyReLU()
#         self.pool1 = nn.MaxPool1d(2)  # 64 x 16
#         self.conv2 = nn.Conv1d(16, 16, 5, padding=2)  # 64 x 16
#         self.relu2 = nn.LeakyReLU()
        self.conv3 = nn.Conv1d(16, 16, 5, padding=2)  # 64 x 16
        self.relu3 = nn.LeakyReLU()
        self.conv4 = nn.Conv1d(16, 16, 5, padding=2)  # 64 x 16
        self.relu4 = nn.LeakyReLU()
        self.conv5 = nn.Conv1d(16, 16, 5, padding=2)  # 64 x 16
        self.relu5 = nn.LeakyReLU()
        self.pool5 = nn.MaxPool1d(2)  # 32 x 16
        self.linear6 = nn.Linear(32 * 16, 128)
        self.relu6 = nn.LeakyReLU()
        self.linear7 = nn.Linear(128, 5)
        self.softmax7 = nn.Softmax(dim=1)
    
    def forward(self, x):
#         x = self.conv1(x)
#         x = self.relu1(x)
#         x = self.pool1(x)
#         x = self.conv2(x)
#         x = self.relu2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.conv4(x)
        x = self.relu4(x)
        x = self.conv5(x)
        x = self.relu5(x)
        x = self.pool5(x)
        x = x.view(-1, 32 * 16)
        x = self.linear6(x)
        x = self.relu6(x)
        x = self.linear7(x)
        x = self.softmax7(x)
        return x       


In [15]:
ecg_net_splitserver = EcgConv1d_SplitServer().to(device)
print(ecg_net_splitserver)

EcgConv1d_SplitServer(
  (conv3): Conv1d(16, 16, kernel_size=(5,), stride=(1,), padding=(2,))
  (relu3): LeakyReLU(negative_slope=0.01)
  (conv4): Conv1d(16, 16, kernel_size=(5,), stride=(1,), padding=(2,))
  (relu4): LeakyReLU(negative_slope=0.01)
  (conv5): Conv1d(16, 16, kernel_size=(5,), stride=(1,), padding=(2,))
  (relu5): LeakyReLU(negative_slope=0.01)
  (pool5): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (linear6): Linear(in_features=512, out_features=128, bias=True)
  (relu6): LeakyReLU(negative_slope=0.01)
  (linear7): Linear(in_features=128, out_features=5, bias=True)
  (softmax7): Softmax(dim=1)
)


In [16]:
class EcgConv1d_SplitClient(nn.Module):
    def __init__(self):
        super(EcgConv1d_SplitClient, self).__init__()
        self.conv1 = nn.Conv1d(1, 16, 7, padding=3)  # 128 x 16
        self.relu1 = nn.LeakyReLU()
        self.pool1 = nn.MaxPool1d(2)  # 64 x 16
        self.conv2 = nn.Conv1d(16, 16, 5, padding=2)  # 64 x 16
        self.relu2 = nn.LeakyReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        return x           


In [17]:
ecg_net_splitclient = EcgConv1d_SplitClient().to('cpu')
print(ecg_net_splitclient)

EcgConv1d_SplitClient(
  (conv1): Conv1d(1, 16, kernel_size=(7,), stride=(1,), padding=(3,))
  (relu1): LeakyReLU(negative_slope=0.01)
  (pool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(16, 16, kernel_size=(5,), stride=(1,), padding=(2,))
  (relu2): LeakyReLU(negative_slope=0.01)
)


In [18]:
# criterion = nn.CrossEntropyLoss()
# lr = 0.001
# optimizer_splitserver = Adam(ecg_net_splitserver.parameters(), lr=lr)
# optimizer_Federated = Adam(ecg_net_Federated.parameters(), lr=lr)


## Socket Initialization

In [19]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = pickle.dumps(msg)
    l_send = len(msg)
    msg = struct.pack('>I', l_send) + msg
    sock.sendall(msg)
    return l_send

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    msg =  recvall(sock, msglen)
    msg = pickle.loads(msg)
    return msg, msglen

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

In [20]:
def mergeweights(dict1,dict2):
    for i in dict2.keys():
        dict1[i]=dict2[i]
    return dict1

In [21]:
criterion = nn.CrossEntropyLoss()
lr = 0.001
optimizer_server = Adam(ecg_net_splitserver.parameters(), lr=lr)

# clientsoclist = []
train_total_batch = []
val_acc = []
client_weights_split = copy.deepcopy(ecg_net_splitclient.state_dict())

train_acc = []
val_acc = []

# Split Learning Part

In [22]:
total_weights = {}

# def train_split(split_users,criterion,lr,optimizer_server,clientsoclist,train_total_batch, val_acc,client_weights,train_acc):
def train_split(userid, num_users, conn,client_weights):

    # criterion,lr,optimizer_server,clientsoclist,train_total_batch, val_acc,client_weights,train_acc = hyperparameter_split()
    # conn, addr = s.accept()
    # print('Connected with', addr)
    # clientsoclist.append(conn)    # append client socket on list
    i = userid
    datasize = send_msg(conn, epochs)    #send epoch
    total_sendsize_list.append(datasize)
    client_sendsize_list[i].append(datasize)

    total_batch, datasize = recv_msg(conn)    # get total_batch of train dataset
    total_receivesize_list.append(datasize)
    client_receivesize_list[i].append(datasize)

    train_total_batch.append(total_batch)    # append on list
    start_time = time.time()    # store start time
    split_users = 1
    print("Timer start!")


    global total_weights
    for e in range(epochs):
    
        # train client 0
        client_weights_out = client_weights
        for user in range(split_users):
    
            datasize = send_msg(clientsoclist_split[user], client_weights)
            total_sendsize_list.append(datasize)
            client_sendsize_list[user].append(datasize)
            train_sendsize_list.append(datasize)
    
            for i in tqdm(range(train_total_batch[user]), ncols=100, desc='Epoch {} Client{} '.format(e+1, user)):
                optimizer_server.zero_grad()  # initialize all gradients to zero
    
                msg, datasize = recv_msg(clientsoclist_split[user])  # receive client message from socket
                total_receivesize_list.append(datasize)
                client_receivesize_list[user].append(datasize)
                train_receivesize_list.append(datasize)
    
                client_output_cpu = msg['client_output']  # client output tensor
                label = msg['label']  # label
    
                client_output = client_output_cpu.to(device)
                # print("Client output for user",user,"output:",client_output_cpu[0])
                label = label.clone().detach().long().to(device)
    
                output = ecg_net_splitserver(client_output)  # forward propagation
                # print("Server output for user",user,"output:",output[0])
                loss = criterion(output, label)  # calculates cross-entropy loss
    
                # loss_values.append(loss.item())
                # accuracy = calculate_accuracy(output, label)
                # accuracy_values.append(accuracy)
                
                # cpu_percent, memory_percent = get_cpu_memory_usage()
                # cpu_values.append(cpu_percent)
                
                # memory_values.append(memory_percent)
                
                # elapsed_time = time.time() - start_time
                # time_values.append(elapsed_time)
                
                loss.backward()  # backward propagation
                msg = client_output_cpu.grad.clone().detach()
    
                datasize = send_msg(clientsoclist_split[user], msg)
                total_sendsize_list.append(datasize)
                client_sendsize_list[user].append(datasize)
                train_sendsize_list.append(datasize)
                
                optimizer_server.step()
                
            client_weights, datasize = recv_msg(clientsoclist_split[user])
            total_receivesize_list.append(datasize)
            client_receivesize_list[user].append(datasize)
            train_receivesize_list.append(datasize)

            client_weights_out = client_weights
            client_weights_copy = client_weights

            server_weights = ecg_net_splitclient.state_dict()
            split_weights = mergeweights(server_weights,client_weights_copy)

        
        ecg_net_splitclient.load_state_dict(client_weights_out)
        ecg_net_splitclient.to(device)
        ecg_net_splitclient.eval()
    
    
        # train acc
        with torch.no_grad():
            corr_num = 0
            total_num = 0
            train_loss = 0.0
            for j, trn in enumerate(train_loader):
                trn_x, trn_label = trn
                trn_x = trn_x.to(device)
                trn_label = trn_label.to(device)
    
                trn_output = ecg_net_splitclient(trn_x)
                # print("Client side output for user ",user," out:",trn_output[0])
                trn_output = ecg_net_splitserver(trn_output)
                # print("Server side output for user ",user," out:",trn_output[0])
                trn_label = trn_label.long()
                loss = criterion(trn_output, trn_label)
                train_loss += loss.item()
                model_label = trn_output.argmax(dim=1)
                corr = trn_label[trn_label == model_label].size(0)
                corr_num += corr
                total_num += trn_label.size(0)
    
            train_accuracy = corr_num / total_num * 100
            r_train_loss = train_loss / len(train_loader)
            print("train_acc: {:.2f}%, train_loss: {:.4f}".format(train_accuracy, r_train_loss))
            train_acc.append(train_accuracy)
        
        # test acc
        with torch.no_grad():
            corr_num = 0
            total_num = 0
            val_loss = 0.0
            for j, val in enumerate(test_loader):
                val_x, val_label = val
                val_x = val_x.to(device)
                val_label = val_label.to(device)
    
                val_output = ecg_net_splitclient(val_x)
                val_output = ecg_net_splitserver(val_output)
            
                val_label = val_label.long()
                loss = criterion(val_output, val_label)
                val_loss += loss.item()
                model_label = val_output.argmax(dim=1)
                corr = val_label[val_label == model_label].size(0)
                corr_num += corr
                total_num += val_label.size(0)
            test_accuracy = corr_num / total_num * 100
            test_loss = val_loss / len(test_loader)
            print("test_acc: {:.2f}%, test_loss: {:.4f}".format(test_accuracy, test_loss))
            val_acc.append(test_accuracy)
    
    print('train is done')
    
    end_time = time.time()  # store end time
    print("TrainingTime: {} sec".format(end_time - start_time))
    
    # Let's quickly save our trained model:
    PATH = './ecg_sp_server.pth'
    torch.save(ecg_net_splitserver.state_dict(), PATH)

    # if total_weights.isempty()!= False:
    if total_weights:
        return total_weights

    return -1
        

    


# Federated Learning Part

In [23]:
import copy

def average_weights(w, datasize):
    """
    Returns the average of the weights.
    """
        
    for i, data in enumerate(datasize):
        for key in w[i].keys():
            w[i][key] *= float(data)
    
    w_avg = copy.deepcopy(w[0])
    
    

# when client use only one kinds of device

    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], float(sum(datasize)))

# when client use various devices (cpu, gpu) you need to use it instead
#
#     for key, val in w_avg.items():
#         common_device = val.device
#         break
#     for key in w_avg.keys():
#         for i in range(1, len(w)):
#             if common_device == 'cpu':
#                 w_avg[key] += w[i][key].cpu()
#             else:
#                 w_avg[key] += w[i][key].cuda()
#         w_avg[key] = torch.div(w_avg[key], float(sum(datasize)))

    return w_avg

## Thread

In [24]:
#divide users into split and federated
# clientsoclist_federated = [0] #*users
clientsoclist_federated = []
clientsoclist_split = [0]  #*users

start_time = 0
weight_count = 0

global_weights = copy.deepcopy(ecg_net_Federated.state_dict())

datasetsize = [0]  #*users
weights_list = [0]  #*users

## if user > 3 then put one user in split user ram = [10,8,4]

lock = Lock()

In [25]:
def run_thread_federated(recieve_federated,train_split,num_user):#,user_ram):
    global clientsoclist_federated
    global clientsoclist_split
    global start_time
    
    thrs = []
    for i in range(num_user):
        # if user[i] has ram <4 then break out of loop and send for split learning if not continue
        # if user_ram[i] <= 5:
        #     pass
            
        conn, addr = s.accept()
        print('Connected with', addr)
        # append client socket on list
        usercount=1
        args1 = (0,usercount,conn,client_weights_split)
        args2 = (0, usercount, conn)
        ## 
        if i==1:               #Split
            clientsoclist_split[0] = conn  #i
            thread = Thread(target=train_split, args=args1)
            thrs.append(thread)
            thread.start()
            print("Training Split")
        else:                  #Federated
            # clientsoclist_federated[0] = conn  #
            clientsoclist_federated.append(conn)
            thread = Thread(target=recieve_federated, args=args2)
            thrs.append(thread)
            thread.start()
            print("Training Federated")

    # new_weight_list = mergedict(weight_list_federated,total_weights)
    # global_weights = average_weights(new_weight_list,datasetsize_federated)
                                     
    print("Timer start!")
    start_time = time.time()    # store start time
    for thread in thrs:
        thread.join()
    end_time = time.time()  # store end time
    print("TrainingTime: {} sec".format(end_time - start_time))

In [26]:
def receive_federated(userid, num_users, conn): #thread for receive clients
    global weight_count
    global datasetsize

    msg = {
        'rounds': rounds,
        'client_id': userid,
        'local_epoch': local_epoch, # " Ram " : user
    }
 
    datasize = send_msg(conn, msg)    #send epoch # 
    total_sendsize_list.append(datasize)
    client_sendsize_list[userid].append(datasize)

    train_dataset_size, datasize = recv_msg(conn)    # get total_batch of train dataset
    total_receivesize_list.append(datasize)
    client_receivesize_list[userid].append(datasize)
    
    
    with lock:
        datasetsize[userid] = train_dataset_size
        weight_count += 1
    
    train_federated(userid, train_dataset_size, num_users, conn)

In [27]:


# global weights_list_federated
# global global_weights_federated
# global weight_count_federated
# global ecg_net_federated
# global val_acc_federated

In [28]:


def train_federated(userid, train_dataset_size, num_users, client_conn):
    global weights_list
    global global_weights
    global weight_count
    global ecg_net
    global val_acc
    
    for r in range(rounds):
        with lock:
            if weight_count == num_users:
                for i, conn in enumerate(clientsoclist_federated):
                    datasize = send_msg(conn, global_weights)
                    total_sendsize_list.append(datasize)
                    client_sendsize_list[i].append(datasize)
                    train_sendsize_list.append(datasize)
                    weight_count = 0

        client_weights, datasize = recv_msg(client_conn)
        total_receivesize_list.append(datasize)
        client_receivesize_list[userid].append(datasize)
        train_receivesize_list.append(datasize)

        weights_list[userid] = client_weights
        print("User" + str(userid) + "'s Round " + str(r + 1) +  " is done")
        with lock:
            weight_count += 1
            if weight_count == num_users:
                #average
                global_weights = average_weights(weights_list, datasetsize)
                
# with lock:
#     weight_count += 1
#     if weight_count == num_users:
#         #average
#         global_weights = average_weights(weights_list, datasetsize)



In [29]:
host = socket.gethostbyname(socket.gethostname())
port = 10081
print(host)

172.70.103.230


In [30]:
s = socket.socket()
s.bind((host, port))
s.listen(5)

In [None]:
run_thread_federated(receive_federated, train_split, users)

In [None]:
user, user_ram = recv_msg(s)

In [None]:
def runmodel(users,user_ram):
    federated_users = []
    split_users = []
    for i in range(len(users)):
        if user_ram[i] <= 5:
            split_users.append(user[i])
        else :
            federated_users.append(user[i])
    return federated_users,split_users

In [None]:
federated_users, split_users = runmodel(users,user_ram)

In [None]:
clientsoclist_federated,start_time_federated,weight_count_federated,global_weights_federated,datasetsize_federated,weights_list_federated = hyperparameter_federated(federated_users)                                              
criterion_split,lr_split,optimizer_server_split,clientsoclist_split,train_total_batch_split, val_acc_split,client_weights_split,train_acc_split = hyperparameter_split(split_users)

In [None]:


# for i in range(split_users):
#     conn, addr = s.accept()
#     print('Connected with', addr)
#     clientsoclist.append(conn)    # append client socket on list

#     datasize = send_msg(conn, epochs)    #send epoch
#     total_sendsize_list.append(datasize)
#     client_sendsize_list[i].append(datasize)

#     total_batch, datasize = recv_msg(conn)    # get total_batch of train dataset
#     total_receivesize_list.append(datasize)
#     client_receivesize_list[i].append(datasize)

#     train_total_batch.append(total_batch)    # append on list

In [None]:
run_thread_federated(receive, federated_users)

In [None]:
end_time = time.time()  # store end time
print("TrainingTime: {} sec".format(end_time - start_time))

In [None]:
ecg_net_Federated.load_state_dict(global_weights)


In [None]:
# train_split(criterion_split,lr_split,optimizer_server_split,clientsoclist_split,train_total_batch_split, val_acc_split,client_weights_split,train_acc_split)

In [None]:
# def train(split_user,federated_user,epochs):
#     federated_weights = run_thread_Federated()
#     split_weights = train_split()
    
#     average_weight(