# ECG Hybrid 1D-CNN Server Side
This code is the server part of ECG Hybrid 1D-CNN model for **multi** client and a server.

## Setting variables

In [1]:
rounds = 400
local_epoch = 1
users = 2 # number of users

## Import required packages

In [2]:
import os
import socket
import struct
import pickle
from threading import Thread
from threading import Lock
import time
import sys


import h5py
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

## Set CUDA

In [3]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

torch.manual_seed(777)
if device =="cuda:0":
    torch.cuda.manual_seed_all(777)

In [4]:
root_path = '../../models/'

## datasets

In [5]:
class ECG(Dataset):
    def __init__(self, train=True):
        if train:
            with h5py.File(os.path.join(root_path, 'ecg_data', 'train_ecg.hdf5'), 'r') as hdf:
                self.x = hdf['x_train'][:]
                self.y = hdf['y_train'][:]
        else:
            with h5py.File(os.path.join(root_path, 'ecg_data', 'test_ecg.hdf5'), 'r') as hdf:
                self.x = hdf['x_test'][:]
                self.y = hdf['y_test'][:]
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return torch.tensor(self.x[idx], dtype=torch.float), torch.tensor(self.y[idx])

### Set batch size

In [6]:
batch_size = 32

## Make train and test dataset batch generator

In [7]:
train_dataset = ECG(train=True)
test_dataset = ECG(train=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [8]:
server_models = [0] * users

## Define ECG server model
Server side has **1 convolutional layer** and **2 fully connected layers**.


In [9]:
class EcgServer(nn.Module):
    def __init__(self):
        super(EcgServer, self).__init__()
#         self.conv1 = nn.Conv1d(1, 16, 7, padding=3)  # 128 x 16
#         self.relu1 = nn.LeakyReLU()
#         self.pool1 = nn.MaxPool1d(2)  # 64 x 16
#         self.conv2 = nn.Conv1d(16, 16, 5, padding=2)  # 64 x 16
#         self.relu2 = nn.LeakyReLU()
        self.conv3 = nn.Conv1d(16, 16, 5, padding=2)  # 64 x 16
        self.relu3 = nn.LeakyReLU()
        self.conv4 = nn.Conv1d(16, 16, 5, padding=2)  # 64 x 16
        self.relu4 = nn.LeakyReLU()
        self.pool4 = nn.MaxPool1d(2)  # 32 x 16
        self.linear5 = nn.Linear(32 * 16, 128)
        self.relu5 = nn.LeakyReLU()
        self.linear6 = nn.Linear(128, 5)
        self.softmax6 = nn.Softmax(dim=1)
    
    def forward(self, x):
#         x = self.conv1(x)
#         x = self.relu1(x)
#         x = self.pool1(x)
#         x = self.conv2(x)
#         x = self.relu2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.conv4(x)
        x = self.relu4(x)
        x = self.pool4(x)
        x = x.view(-1, 32 * 16)
        x = self.linear5(x)
        x = self.relu5(x)
        x = self.linear6(x)
        x = self.softmax6(x)
        return x   
    
class Ecgnet(nn.Module):
    def __init__(self):
        super(Ecgnet, self).__init__()
        self.conv1 = nn.Conv1d(1, 16, 7, padding=3)  # 128 x 16
        self.relu1 = nn.LeakyReLU()
        self.pool1 = nn.MaxPool1d(2)  # 64 x 16
        self.conv2 = nn.Conv1d(16, 16, 5, padding=2)  # 64 x 16
        self.relu2 = nn.LeakyReLU()
        self.conv3 = nn.Conv1d(16, 16, 5, padding=2)  # 64 x 16
        self.relu3 = nn.LeakyReLU()
        self.conv4 = nn.Conv1d(16, 16, 5, padding=2)  # 64 x 16
        self.relu4 = nn.LeakyReLU()
        self.pool4 = nn.MaxPool1d(2)  # 32 x 16
        self.linear5 = nn.Linear(32 * 16, 128)
        self.relu5 = nn.LeakyReLU()
        self.linear6 = nn.Linear(128, 5)
        self.softmax6 = nn.Softmax(dim=1)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.conv4(x)
        x = self.relu4(x)
        x = self.pool4(x)
        x = x.view(-1, 32 * 16)
        x = self.linear5(x)
        x = self.relu5(x)
        x = self.linear6(x)
        x = self.softmax6(x)
        return x 

In [10]:
for i in range(users):
    server_models[i] = EcgServer().to(device)
ecg_net = Ecgnet().to(device)
ecg_server = EcgServer().to(device)

In [11]:
# from torchsummary import summary

# print('ECG 1D CNN server')
# summary(ecg_server, (16, 65))

## client

In [12]:
class Ecgclient(nn.Module):
    def __init__(self):
        super(Ecgclient, self).__init__()
        self.conv1 = nn.Conv1d(1, 16, 7, padding=3)  # 128 x 16
        self.relu1 = nn.LeakyReLU()
        self.pool1 = nn.MaxPool1d(2)  # 64 x 16
        self.conv2 = nn.Conv1d(16, 16, 5, padding=2)  # 64 x 16
        self.relu2 = nn.LeakyReLU()
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        return x   

In [13]:
ecg_client = Ecgclient().to(device)
print(ecg_client)

Ecgclient(
  (conv1): Conv1d(1, 16, kernel_size=(7,), stride=(1,), padding=(3,))
  (relu1): LeakyReLU(negative_slope=0.01)
  (pool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(16, 16, kernel_size=(5,), stride=(1,), padding=(2,))
  (relu2): LeakyReLU(negative_slope=0.01)
)


In [14]:
# from torchsummary import summary

# print('ECG 1D CNN client')
# summary(ecg_client, (1, 130))

### Set other hyperparameters in the model
Hyperparameters here should be same with the client side.

In [15]:
import copy

criterion = nn.CrossEntropyLoss()
lr = 0.001
optimizer_server_list = []
for i in range(users):
    optimizer_server_list.append(Adam(server_models[i].parameters(), lr=lr))

datasetsize = [0]*users
clientsoclist = [0] * users

client_weights = [0] * users
server_weights = [0] * users

weight_count = 0
global_c_weights = copy.deepcopy(ecg_client.state_dict())
global_s_weights = copy.deepcopy(server_models[0].state_dict())

# for _ in range(users):
#     client_weights.append(c_weights)
#     server_weights.append(s_weights)

start_time = 0
lock = Lock()
    
###########################################################################

############################################################################

In [16]:
total_sendsize_list = []
total_receivesize_list = []

client_sendsize_list = [[] for i in range(users)]
client_receivesize_list = [[] for i in range(users)]

train_sendsize_list = [] 
train_receivesize_list = []

## Socket initialization
### Set host address and port number

### Required socket functions

In [17]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = pickle.dumps(msg)
    l_send = len(msg)
    msg = struct.pack('>I', l_send) + msg
    sock.sendall(msg)
    return l_send

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    msg =  recvall(sock, msglen)
    msg = pickle.loads(msg)
    return msg, msglen

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data


In [18]:
def average_weights(w, datasize):
    """
    Returns the average of the weights.
    """
        
    for i, data in enumerate(datasize):
        for key in w[i].keys():
            w[i][key] *= float(data)
    
    w_avg = copy.deepcopy(w[0])
    
    

# when client use only one kinds of device

    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], float(sum(datasize)))

# when client use various devices (cpu, gpu) you need to use it instead
#
#     for key, val in w_avg.items():
#         common_device = val.device
#         break
#     for key in w_avg.keys():
#         for i in range(1, len(w)):
#             if common_device == 'cpu':
#                 w_avg[key] += w[i][key].cpu()
#             else:
#                 w_avg[key] += w[i][key].cuda()
#         w_avg[key] = torch.div(w_avg[key], float(sum(datasize)))

    return w_avg

In [19]:
def run_thread(func, num_user):
    global clientsoclist
    global start_time
    thrs = []
    for i in range(num_user):
        conn, addr = s.accept()
        print('Conntected with', addr)
        # append client socket on list
        clientsoclist[i] = conn
        args = (i, num_user, conn)
        thread = Thread(target=func, args=args)
        thrs.append(thread)
        thread.start()
    print("timmer start!")
    start_time = time.time()    # store start time
    for thread in thrs:
        thread.join()
    end_time = time.time()  # store end time
    print("TrainingTime: {} sec".format(end_time - start_time))

## Receive client

In [20]:
def receive(userid, num_users, conn): #thread for receive clients
    global weight_count
    global datasetsize

    msg = {
        'rounds': rounds,
        'local_epoch': local_epoch,
        'client_id': userid
    }

    datasize = send_msg(conn, msg)    #send epoch
    total_sendsize_list.append(datasize)
    client_sendsize_list[userid].append(datasize)

    
    msg, datasize = recv_msg(conn)            # get total_batch of train dataset
    total_batch = msg['total_batch']
    train_dataset_size = msg['train_dataset_size']
    total_receivesize_list.append(datasize)
    client_receivesize_list[userid].append(datasize)
    with lock:
        datasetsize[userid] = train_dataset_size
        weight_count += 1
    
    train(userid, total_batch, num_users, conn)

In [21]:
train_acc = []
test_acc = []

## traning

In [22]:
def train(userid, total_batch, num_users, client_conn):
    global client_weights
    global weight_count
    global global_c_weights
    global global_s_weights
    
    for r in range(rounds):
        
        with lock:
            if weight_count == num_users:
                for i, conn in enumerate(clientsoclist):
                    datasize = send_msg(conn, global_c_weights)
                    total_sendsize_list.append(datasize)
                    client_sendsize_list[i].append(datasize)
                    train_sendsize_list.append(datasize)
                    weight_count = 0
                    server_models[i].load_state_dict(global_s_weights)
                    server_models[i].eval()
                
        for l in range(local_epoch):
            
                        
            for i in range(total_batch):
                optimizer_server_list[userid].zero_grad()  # initialize all gradients to zero

                msg, datasize = recv_msg(client_conn)  # receive client message from socket
                total_receivesize_list.append(datasize)
                client_receivesize_list[userid].append(datasize)
                train_receivesize_list.append(datasize)

                client_output_cpu = msg['client_output']  # client output tensor
                label = msg['label']  # label

                client_output = client_output_cpu.to(device)
                label = label.clone().detach().long().to(device)

                output = server_models[userid](client_output)  # forward propagation
                loss = criterion(output, label)  # calculates cross-entropy loss
                loss.backward()  # backward propagation
                msg = client_output_cpu.grad.clone().detach()

                datasize = send_msg(client_conn, msg)
                total_sendsize_list.append(datasize)
                client_sendsize_list[userid].append(datasize)
                train_sendsize_list.append(datasize)

                optimizer_server_list[userid].step()
                
            
        c_weights, datasize = recv_msg(client_conn)
        total_receivesize_list.append(datasize)
        client_receivesize_list[userid].append(datasize)
        train_receivesize_list.append(datasize)
        with lock:
            client_weights[userid] = c_weights
            server_weights[userid] = copy.deepcopy(server_models[userid].state_dict())
            weight_count += 1
            if weight_count == num_users:
                #average
                global_c_weights = average_weights(client_weights, datasetsize)
                global_s_weights = average_weights(server_weights, datasetsize)
                
                # acc
                

#                 ecg_client.load_state_dict(global_c_weights)
#                 ecg_client.to(device)

#                 ecg_server.load_state_dict(global_s_weights)
#                 ecg_server.to(device)
                
#                 # train acc
#                 with torch.no_grad():
#                     corr_num = 0
#                     total_num = 0
#                     train_loss = 0.0
#                     for j, trn in enumerate(train_loader):
#                         trn_x, trn_label = trn
#                         trn_x = trn_x.to(device)

#                         trn_label = trn_label.clone().detach().long().to(device)


#                         #trn_output = ecg_net(trn_x)
#                         trn_output = ecg_client(trn_x)
#                         trn_output = ecg_server(trn_output)

#                         loss = criterion(trn_output, trn_label)

#                         train_loss += loss.item()
#                         model_label = trn_output.argmax(dim=1)
#                         corr = trn_label[trn_label == model_label].size(0)
#                         corr_num += corr
#                         total_num += trn_label.size(0)
#                     train_accuracy = corr_num / total_num * 100
#                     r_train_loss = train_loss / len(train_loader)
#                     print("Round{}'s train_acc: {:.2f}%, train_loss: {:.4f}".format(r, train_accuracy, r_train_loss))
#                     train_acc.append(train_accuracy)
#                 # test acc
#                 with torch.no_grad():
#                     corr_num = 0
#                     total_num = 0
#                     val_loss = 0.0
#                     for j, val in enumerate(test_loader):
#                         val_x, val_label = val
#                         val_x = val_x.to(device)
#                         val_label = val_label.to(device)

#                         #val_output = ecg_net(val_x)
#                         val_output = ecg_client(val_x)
#                         val_output = ecg_server(val_output)

#                         val_label = val_label.long()
#                         loss = criterion(val_output, val_label)
#                         val_loss += loss.item()
#                         model_label = val_output.argmax(dim=1)
#                         corr = val_label[val_label == model_label].size(0)
#                         corr_num += corr
#                         total_num += val_label.size(0)
#                     test_accuracy = corr_num / total_num * 100
#                     test_loss = val_loss / len(test_loader)
#                     print("Round{}'s test_acc: {:.2f}%, test_loss: {:.4f}".format(r, test_accuracy, test_loss))
#                     test_acc.append(test_accuracy)
                
        print("round {}'s user {} is done".format(r, userid))      
            
    print('{} is complite'.format(userid))
            
                
        

In [23]:
host = socket.gethostbyname(socket.gethostname())
port = 10080
print(host)

192.168.83.1


### Open the server socket

In [24]:
s = socket.socket()
try:
    s.bind((host, port))
    print('Success to connect')
except:
    print('Fail to connect')
    
s.listen(5)

Success to connect


In [25]:
run_thread(receive, users)

Conntected with ('192.168.83.1', 5148)
Conntected with ('192.168.83.1', 5149)
timmer start!




round 0's user 1 is done
Round0's train_acc: 61.93%, train_loss: 1.2947
Round0's test_acc: 62.97%, test_loss: 1.2893
round 0's user 0 is done
round 1's user 1 is done
Round1's train_acc: 78.55%, train_loss: 1.1207
Round1's test_acc: 79.63%, test_loss: 1.1125
round 1's user 0 is done
round 2's user 1 is done
Round2's train_acc: 85.65%, train_loss: 1.0488
Round2's test_acc: 86.51%, test_loss: 1.0410
round 2's user 0 is done
round 3's user 1 is done
Round3's train_acc: 86.27%, train_loss: 1.0430
Round3's test_acc: 86.89%, test_loss: 1.0375
round 3's user 0 is done
round 4's user 1 is done
Round4's train_acc: 85.21%, train_loss: 1.0543
Round4's test_acc: 86.01%, test_loss: 1.0476
round 4's user 0 is done
round 5's user 0 is done
Round5's train_acc: 87.13%, train_loss: 1.0325
Round5's test_acc: 88.03%, test_loss: 1.0252
round 5's user 1 is done
round 6's user 1 is done
Round6's train_acc: 87.38%, train_loss: 1.0313
Round6's test_acc: 87.73%, test_loss: 1.0267
round 6's user 0 is done
round 

Round56's test_acc: 90.18%, test_loss: 1.0013
round 56's user 1 is done
round 57's user 1 is done
Round57's train_acc: 89.97%, train_loss: 1.0042
Round57's test_acc: 90.25%, test_loss: 1.0010
round 57's user 0 is done
round 58's user 1 is done
Round58's train_acc: 89.88%, train_loss: 1.0041
Round58's test_acc: 90.18%, test_loss: 1.0010
round 58's user 0 is done
round 59's user 0 is done
Round59's train_acc: 90.06%, train_loss: 1.0018
Round59's test_acc: 90.37%, test_loss: 0.9992
round 59's user 1 is done
round 60's user 1 is done
Round60's train_acc: 89.94%, train_loss: 1.0039
Round60's test_acc: 90.25%, test_loss: 1.0007
round 60's user 0 is done
round 61's user 1 is done
Round61's train_acc: 90.18%, train_loss: 1.0006
Round61's test_acc: 90.39%, test_loss: 0.9984
round 61's user 0 is done
round 62's user 1 is done
Round62's train_acc: 89.92%, train_loss: 1.0044
Round62's test_acc: 90.20%, test_loss: 1.0014
round 62's user 0 is done
round 63's user 1 is done
Round63's train_acc: 90.04

Round112's test_acc: 90.71%, test_loss: 0.9945
round 112's user 1 is done
round 113's user 0 is done
Round113's train_acc: 90.54%, train_loss: 0.9958
Round113's test_acc: 90.70%, test_loss: 0.9943
round 113's user 1 is done
round 114's user 0 is done
Round114's train_acc: 90.61%, train_loss: 0.9953
Round114's test_acc: 90.77%, test_loss: 0.9939
round 114's user 1 is done
round 115's user 0 is done
Round115's train_acc: 90.60%, train_loss: 0.9953
Round115's test_acc: 90.76%, test_loss: 0.9939
round 115's user 1 is done
round 116's user 0 is done
Round116's train_acc: 90.62%, train_loss: 0.9951
Round116's test_acc: 90.74%, test_loss: 0.9940
round 116's user 1 is done
round 117's user 0 is done
Round117's train_acc: 90.71%, train_loss: 0.9943
Round117's test_acc: 90.88%, test_loss: 0.9928
round 117's user 1 is done
round 118's user 1 is done
Round118's train_acc: 90.59%, train_loss: 0.9951
Round118's test_acc: 90.63%, test_loss: 0.9950
round 118's user 0 is done
round 119's user 1 is done

Round167's train_acc: 91.67%, train_loss: 0.9843
Round167's test_acc: 91.62%, test_loss: 0.9848
round 167's user 0 is done
round 168's user 1 is done
Round168's train_acc: 91.70%, train_loss: 0.9839
Round168's test_acc: 91.56%, test_loss: 0.9855
round 168's user 0 is done
round 169's user 1 is done
Round169's train_acc: 91.74%, train_loss: 0.9833
Round169's test_acc: 91.63%, test_loss: 0.9845
round 169's user 0 is done
round 170's user 0 is done
Round170's train_acc: 91.66%, train_loss: 0.9845
Round170's test_acc: 91.53%, test_loss: 0.9857
round 170's user 1 is done
round 171's user 1 is done
Round171's train_acc: 91.57%, train_loss: 0.9853
Round171's test_acc: 91.60%, test_loss: 0.9852
round 171's user 0 is done
round 172's user 1 is done
Round172's train_acc: 92.00%, train_loss: 0.9811
Round172's test_acc: 91.82%, test_loss: 0.9828
round 172's user 0 is done
round 173's user 1 is done
Round173's train_acc: 91.85%, train_loss: 0.9824
Round173's test_acc: 91.75%, test_loss: 0.9836
roun

round 222's user 1 is done
Round222's train_acc: 93.18%, train_loss: 0.9689
Round222's test_acc: 92.73%, test_loss: 0.9736
round 222's user 0 is done
round 223's user 1 is done
Round223's train_acc: 93.33%, train_loss: 0.9675
Round223's test_acc: 92.87%, test_loss: 0.9718
round 223's user 0 is done
round 224's user 0 is done
Round224's train_acc: 93.39%, train_loss: 0.9670
Round224's test_acc: 92.78%, test_loss: 0.9729
round 224's user 1 is done
round 225's user 1 is done
Round225's train_acc: 93.22%, train_loss: 0.9686
Round225's test_acc: 93.01%, test_loss: 0.9710
round 225's user 0 is done
round 226's user 1 is done
Round226's train_acc: 93.34%, train_loss: 0.9672
Round226's test_acc: 93.03%, test_loss: 0.9706
round 226's user 0 is done
round 227's user 1 is done
Round227's train_acc: 93.40%, train_loss: 0.9667
Round227's test_acc: 93.04%, test_loss: 0.9709
round 227's user 0 is done
round 228's user 1 is done
Round228's train_acc: 93.42%, train_loss: 0.9667
Round228's test_acc: 93.

round 277's user 0 is done
Round277's train_acc: 98.84%, train_loss: 0.9165
Round277's test_acc: 98.17%, test_loss: 0.9232
round 277's user 1 is done
round 278's user 0 is done
Round278's train_acc: 98.81%, train_loss: 0.9166
Round278's test_acc: 98.19%, test_loss: 0.9229
round 278's user 1 is done
round 279's user 1 is done
Round279's train_acc: 98.52%, train_loss: 0.9193
Round279's test_acc: 97.92%, test_loss: 0.9255
round 279's user 0 is done
round 280's user 0 is done
Round280's train_acc: 98.87%, train_loss: 0.9162
Round280's test_acc: 98.27%, test_loss: 0.9224
round 280's user 1 is done
round 281's user 1 is done
Round281's train_acc: 98.81%, train_loss: 0.9168
Round281's test_acc: 97.96%, test_loss: 0.9250
round 281's user 0 is done
round 282's user 0 is done
Round282's train_acc: 98.84%, train_loss: 0.9161
Round282's test_acc: 98.13%, test_loss: 0.9236
round 282's user 1 is done
round 283's user 0 is done
Round283's train_acc: 98.84%, train_loss: 0.9166
Round283's test_acc: 98.

round 332's user 1 is done
Round332's train_acc: 99.01%, train_loss: 0.9146
Round332's test_acc: 98.31%, test_loss: 0.9218
round 332's user 0 is done
round 333's user 1 is done
Round333's train_acc: 98.83%, train_loss: 0.9166
Round333's test_acc: 98.06%, test_loss: 0.9244
round 333's user 0 is done
round 334's user 1 is done
Round334's train_acc: 99.06%, train_loss: 0.9142
Round334's test_acc: 98.15%, test_loss: 0.9230
round 334's user 0 is done
round 335's user 0 is done
Round335's train_acc: 99.10%, train_loss: 0.9139
Round335's test_acc: 98.29%, test_loss: 0.9219
round 335's user 1 is done
round 336's user 0 is done
Round336's train_acc: 98.89%, train_loss: 0.9160
Round336's test_acc: 98.14%, test_loss: 0.9233
round 336's user 1 is done
round 337's user 1 is done
Round337's train_acc: 99.03%, train_loss: 0.9145
Round337's test_acc: 98.34%, test_loss: 0.9216
round 337's user 0 is done
round 338's user 1 is done
Round338's train_acc: 99.09%, train_loss: 0.9141
Round338's test_acc: 98.

round 387's user 1 is done
Round387's train_acc: 99.06%, train_loss: 0.9142
Round387's test_acc: 98.38%, test_loss: 0.9209
round 387's user 0 is done
round 388's user 0 is done
Round388's train_acc: 99.03%, train_loss: 0.9145
Round388's test_acc: 98.26%, test_loss: 0.9223
round 388's user 1 is done
round 389's user 0 is done
Round389's train_acc: 99.09%, train_loss: 0.9139
Round389's test_acc: 98.35%, test_loss: 0.9213
round 389's user 1 is done
round 390's user 1 is done
Round390's train_acc: 99.04%, train_loss: 0.9143
Round390's test_acc: 98.33%, test_loss: 0.9214
round 390's user 0 is done
round 391's user 0 is done
Round391's train_acc: 98.90%, train_loss: 0.9156
Round391's test_acc: 98.18%, test_loss: 0.9230
round 391's user 1 is done
round 392's user 1 is done
Round392's train_acc: 99.03%, train_loss: 0.9145
Round392's test_acc: 98.22%, test_loss: 0.9225
round 392's user 0 is done
round 393's user 1 is done
Round393's train_acc: 99.05%, train_loss: 0.9144
Round393's test_acc: 98.

## Print commmunication overheads 

In [26]:
print('---total_sendsize_list---')
total_size = 0
for size in total_sendsize_list:
#     print(size)
    total_size += size
print("total_sendsize size: {} bytes".format(total_size))
print("number of total_send: ", len(total_sendsize_list))
print('\n')

print('---total_receivesize_list---')
total_size = 0
for size in total_receivesize_list:
#     print(size)
    total_size += size
print("total receive sizes: {} bytes".format(total_size) )
print("number of total receive: ", len(total_receivesize_list) )
print('\n')

for i in range(users):
    print('---client_sendsize_list(user{})---'.format(i))
    total_size = 0
    for size in client_sendsize_list[i]:
#         print(size)
        total_size += size
    print("total client_sendsizes(user{}): {} bytes".format(i, total_size))
    print("number of client_send(user{}): ".format(i), len(client_sendsize_list[i]))
    print('\n')

    print('---client_receivesize_list(user{})---'.format(i))
    total_size = 0
    for size in client_receivesize_list[i]:
#         print(size)
        total_size += size
    print("total client_receive sizes(user{}): {} bytes".format(i, total_size))
    print("number of client_send(user{}): ".format(i), len(client_receivesize_list[i]))
    print('\n')

print('---train_sendsize_list---')
total_size = 0
for size in train_sendsize_list:
#     print(size)
    total_size += size
print("total train_sendsizes: {} bytes".format(total_size))
print("number of train_send: ", len(train_sendsize_list) )
print('\n')

print('---train_receivesize_list---')
total_size = 0
for size in train_receivesize_list:
#     print(size)
    total_size += size
print("total train_receivesizes: {} bytes".format(total_size))
print("number of train_receive: ", len(train_receivesize_list) )
print('\n')




---total_sendsize_list---
total_sendsize size: 22109196148 bytes
number of total_send:  166402


---total_receivesize_list---
total receive sizes: 22184537712 bytes
number of total receive:  166402


---client_sendsize_list(user0)---
total client_sendsizes(user0): 11054598074 bytes
number of client_send(user0):  83201


---client_receivesize_list(user0)---
total client_receive sizes(user0): 11092268856 bytes
number of client_send(user0):  83201


---client_sendsize_list(user1)---
total client_sendsizes(user1): 11054598074 bytes
number of client_send(user1):  83201


---client_receivesize_list(user1)---
total client_receive sizes(user1): 11092268856 bytes
number of client_send(user1):  83201


---train_sendsize_list---
total train_sendsizes: 22109196024 bytes
number of train_send:  166400


---train_receivesize_list---
total train_receivesizes: 22184537600 bytes
number of train_receive:  166400




# Validation after trainning

### acc of each acc 

In [27]:
ecg_client.load_state_dict(global_c_weights)
ecg_client.to(device)

ecg_server.load_state_dict(global_s_weights)
ecg_server.to(device)

ecg_client_dict = ecg_client.state_dict()
ecg_server_dict = ecg_server.state_dict()
ecg_original_dict = ecg_net.state_dict()

ecg_original_dict.update(ecg_client_dict)
ecg_original_dict.update(ecg_server_dict)

ecg_net.load_state_dict(ecg_original_dict)
ecg_net.eval()

Ecgnet(
  (conv1): Conv1d(1, 16, kernel_size=(7,), stride=(1,), padding=(3,))
  (relu1): LeakyReLU(negative_slope=0.01)
  (pool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(16, 16, kernel_size=(5,), stride=(1,), padding=(2,))
  (relu2): LeakyReLU(negative_slope=0.01)
  (conv3): Conv1d(16, 16, kernel_size=(5,), stride=(1,), padding=(2,))
  (relu3): LeakyReLU(negative_slope=0.01)
  (conv4): Conv1d(16, 16, kernel_size=(5,), stride=(1,), padding=(2,))
  (relu4): LeakyReLU(negative_slope=0.01)
  (pool4): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (linear5): Linear(in_features=512, out_features=128, bias=True)
  (relu5): LeakyReLU(negative_slope=0.01)
  (linear6): Linear(in_features=128, out_features=5, bias=True)
  (softmax6): Softmax(dim=1)
)

In [28]:
# train acc
with torch.no_grad():
    corr_num = 0
    total_num = 0
    train_loss = 0.0
    for j, trn in enumerate(train_loader):
        trn_x, trn_label = trn
        trn_x = trn_x.to(device)

        trn_label = trn_label.clone().detach().long().to(device)


        #trn_output = ecg_net(trn_x)
        trn_output = ecg_client(trn_x)
        trn_output = ecg_server(trn_output)
        
        loss = criterion(trn_output, trn_label)

        train_loss += loss.item()
        model_label = trn_output.argmax(dim=1)
        corr = trn_label[trn_label == model_label].size(0)
        corr_num += corr
        total_num += trn_label.size(0)
    train_accuracy = corr_num / total_num * 100
    r_train_loss = train_loss / len(train_loader)
    print("train_acc: {:.2f}%, train_loss: {:.4f}".format(train_accuracy, r_train_loss))

# test acc
with torch.no_grad():
    corr_num = 0
    total_num = 0
    val_loss = 0.0
    for j, val in enumerate(test_loader):
        val_x, val_label = val
        val_x = val_x.to(device)
        val_label = val_label.to(device)

        #val_output = ecg_net(val_x)
        val_output = ecg_client(val_x)
        val_output = ecg_server(val_output)
        
        val_label = val_label.long()
        loss = criterion(val_output, val_label)
        val_loss += loss.item()
        model_label = val_output.argmax(dim=1)
        corr = val_label[val_label == model_label].size(0)
        corr_num += corr
        total_num += val_label.size(0)
    test_accuracy = corr_num / total_num * 100
    test_loss = val_loss / len(test_loader)
    print("test_acc: {:.2f}%, test_loss: {:.4f}".format(test_accuracy, test_loss))

train_acc: 99.10%, train_loss: 0.9138
test_acc: 98.26%, test_loss: 0.9222


In [29]:
classes = ['N', 'L', 'R', 'A', 'V']

class_correct = list(0. for i in range(5))
class_total = list(0. for i in range(5))

with torch.no_grad():
    for data in test_loader:
        x, labels = data
        x = x.to(device)
        labels = labels.to(device)

#         outputs = ecg_net(x)
        outputs = ecg_client(x)
        outputs = ecg_server(outputs)

        labels = labels.long()
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(len(labels)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(5):
    print("Accuracy of %5s : %2d %%" % (classes[i], 100 * class_correct[i] / class_total[i]))
print('\n')

Accuracy of     N : 97 %
Accuracy of     L : 99 %
Accuracy of     R : 99 %
Accuracy of     A : 90 %
Accuracy of     V : 99 %




In [None]:
# Let's quickly save our trained model:
PATH = './ecg_hy_model.pth'
torch.save(ecg_net.state_dict(), PATH)

In [None]:
end_time = time.time()  # store end time
print("WorkingTime: {} sec".format(end_time - start_time))