# ECG Federated 1D-CNN Server Side
This code is the server part of ECG federated 1D-CNN model for **multi** client and a server.

## Setting variables

In [1]:
rounds = 400
local_epoch = 1
users = 2 # number of clients

In [2]:
import os
import h5py

import socket
import struct
import pickle
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from threading import Thread
from threading import Lock


import time

from tqdm import tqdm

import copy

## Cuda

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


## Pytorch layer modules for *Conv1D* Network



### `Conv1d` layer
- `torch.nn.Conv1d(in_channels, out_channels, kernel_size)`

### `MaxPool1d` layer
- `torch.nn.MaxPool1d(kernel_size, stride=None)`
- Parameter `stride` follows `kernel_size`.

### `ReLU` layer
- `torch.nn.ReLU()`

### `Linear` layer
- `torch.nn.Linear(in_features, out_features, bias=True)`

### `Softmax` layer
- `torch.nn.Softmax(dim=None)`
- Parameter `dim` is usually set to `1`.

## Construct 1D-CNN ECG classification model

In [4]:
class EcgConv1d(nn.Module):
    def __init__(self):
        super(EcgConv1d, self).__init__()        
        self.conv1 = nn.Conv1d(1, 16, 7)  # 124 x 16        
        self.relu1 = nn.LeakyReLU()
        self.pool1 = nn.MaxPool1d(2)  # 62 x 16
        self.conv2 = nn.Conv1d(16, 16, 5)  # 58 x 16
        self.relu2 = nn.LeakyReLU()        
        self.conv3 = nn.Conv1d(16, 16, 5)  # 54 x 16
        self.relu3 = nn.LeakyReLU()        
        self.conv4 = nn.Conv1d(16, 16, 5)  # 50 x 16
        self.relu4 = nn.LeakyReLU()
        self.pool4 = nn.MaxPool1d(2)  # 25 x 16
        self.linear5 = nn.Linear(25 * 16, 128)
        self.relu5 = nn.LeakyReLU()        
        self.linear6 = nn.Linear(128, 5)
        self.softmax6 = nn.Softmax(dim=1)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)        
        x = self.conv2(x)
        x = self.relu2(x)        
        x = self.conv3(x)
        x = self.relu3(x)        
        x = self.conv4(x)
        x = self.relu4(x)
        x = self.pool4(x)
        x = x.view(-1, 25 * 16)
        x = self.linear5(x)
        x = self.relu5(x)        
        x = self.linear6(x)
        x = self.softmax6(x)
        return x        

In [5]:
ecg_net = EcgConv1d()
ecg_net.to('cpu')

EcgConv1d(
  (conv1): Conv1d(1, 16, kernel_size=(7,), stride=(1,))
  (relu1): LeakyReLU(negative_slope=0.01)
  (pool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(16, 16, kernel_size=(5,), stride=(1,))
  (relu2): LeakyReLU(negative_slope=0.01)
  (conv3): Conv1d(16, 16, kernel_size=(5,), stride=(1,))
  (relu3): LeakyReLU(negative_slope=0.01)
  (conv4): Conv1d(16, 16, kernel_size=(5,), stride=(1,))
  (relu4): LeakyReLU(negative_slope=0.01)
  (pool4): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (linear5): Linear(in_features=400, out_features=128, bias=True)
  (relu5): LeakyReLU(negative_slope=0.01)
  (linear6): Linear(in_features=128, out_features=5, bias=True)
  (softmax6): Softmax(dim=1)
)

## variables

In [6]:
clientsoclist = [0]*users

start_time = 0
weight_count = 0

global_weights = copy.deepcopy(ecg_net.state_dict())

datasetsize = [0]*users
weights_list = [0]*users

lock = Lock()

## Comunication overhead

In [7]:
total_sendsize_list = []
total_receivesize_list = []

client_sendsize_list = [[] for i in range(users)]
client_receivesize_list = [[] for i in range(users)]

train_sendsize_list = [] 
train_receivesize_list = []

## Socket initialization
### Set host address and port number

### Required socket functions

In [8]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = pickle.dumps(msg)
    l_send = len(msg)
    msg = struct.pack('>I', l_send) + msg
    sock.sendall(msg)
    return l_send

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    msg =  recvall(sock, msglen)
    msg = pickle.loads(msg)
    return msg, msglen

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

In [9]:
import copy

def average_weights(w, datasize):
    """
    Returns the average of the weights.
    """
        
    for i, data in enumerate(datasize):
        for key in w[i].keys():
            w[i][key] *= float(data)
    
    w_avg = copy.deepcopy(w[0])
    
    

# when client use only one kinds of device

    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], float(sum(datasize)))

# when client use various devices (cpu, gpu) you need to use it instead
#
#     for key, val in w_avg.items():
#         common_device = val.device
#         break
#     for key in w_avg.keys():
#         for i in range(1, len(w)):
#             if common_device == 'cpu':
#                 w_avg[key] += w[i][key].cpu()
#             else:
#                 w_avg[key] += w[i][key].cuda()
#         w_avg[key] = torch.div(w_avg[key], float(sum(datasize)))

    return w_avg

## Thread define

## Receive users before training

In [10]:
def run_thread(func, num_user):
    global clientsoclist
    global start_time
    
    thrs = []
    for i in range(num_user):
        conn, addr = s.accept()
        print('Conntected with', addr)
        # append client socket on list
        clientsoclist[i] = conn
        args = (i, num_user, conn)
        thread = Thread(target=func, args=args)
        thrs.append(thread)
        thread.start()
    print("timmer start!")
    start_time = time.time()    # store start time
    for thread in thrs:
        thread.join()
    end_time = time.time()  # store end time
    print("TrainingTime: {} sec".format(end_time - start_time))

In [11]:
def receive(userid, num_users, conn): #thread for receive clients
    global weight_count
    
    global datasetsize


    msg = {
        'rounds': rounds,
        'client_id': userid,
        'local_epoch': local_epoch
    }

    datasize = send_msg(conn, msg)    #send epoch
    total_sendsize_list.append(datasize)
    client_sendsize_list[userid].append(datasize)

    train_dataset_size, datasize = recv_msg(conn)    # get total_batch of train dataset
    total_receivesize_list.append(datasize)
    client_receivesize_list[userid].append(datasize)
    
    
    with lock:
        datasetsize[userid] = train_dataset_size
        weight_count += 1
    
    train(userid, train_dataset_size, num_users, conn)

## Train

In [12]:
def train(userid, train_dataset_size, num_users, client_conn):
    global weights_list
    global global_weights
    global weight_count
    global ecg_net
    global val_acc
    
    for r in range(rounds):
        with lock:
            if weight_count == num_users:
                for i, conn in enumerate(clientsoclist):
                    datasize = send_msg(conn, global_weights)
                    total_sendsize_list.append(datasize)
                    client_sendsize_list[i].append(datasize)
                    train_sendsize_list.append(datasize)
                    weight_count = 0

        client_weights, datasize = recv_msg(client_conn)
        total_receivesize_list.append(datasize)
        client_receivesize_list[userid].append(datasize)
        train_receivesize_list.append(datasize)

        weights_list[userid] = client_weights
        print("User" + str(userid) + "'s Round " + str(r + 1) +  " is done")
        with lock:
            weight_count += 1
            if weight_count == num_users:
                #average
                global_weights = average_weights(weights_list, datasetsize)
                
        
    

In [13]:
host = socket.gethostbyname(socket.gethostname())
port = 10080
print(host)

192.168.83.1


In [14]:
s = socket.socket()
s.bind((host, port))
s.listen(5)

### Open the server socket

In [15]:
run_thread(receive, users)

Conntected with ('192.168.83.1', 5455)
Conntected with ('192.168.83.1', 5555)
timmer start!




User0's Round 1 is done
User1's Round 1 is done
User1's Round 2 is done
User0's Round 2 is done
User0's Round 3 is done
User1's Round 3 is done
User0's Round 4 is done
User1's Round 4 is done
User0's Round 5 is done
User1's Round 5 is done
User1's Round 6 is done
User0's Round 6 is done
User1's Round 7 is done
User0's Round 7 is done
User1's Round 8 is done
User0's Round 8 is done
User1's Round 9 is done
User0's Round 9 is done
User1's Round 10 is done
User0's Round 10 is done
User0's Round 11 is done
User1's Round 11 is done
User1's Round 12 is done
User0's Round 12 is done
User1's Round 13 is done
User0's Round 13 is done
User0's Round 14 is done
User1's Round 14 is done
User0's Round 15 is done
User1's Round 15 is done
User1's Round 16 is done
User0's Round 16 is done
User1's Round 17 is done
User0's Round 17 is done
User1's Round 18 is done
User0's Round 18 is done
User1's Round 19 is done
User0's Round 19 is done
User1's Round 20 is done
User0's Round 20 is done
User1's Round 21 i

In [16]:
end_time = time.time()  # store end time
print("TrainingTime: {} sec".format(end_time - start_time))

TrainingTime: 351.4171006679535 sec


## Print all of communication overhead

In [17]:
# print('val_acc list')
# for acc in val_acc:
#     print(acc)

print('\n')
print('---total_sendsize_list---')
total_size = 0
for size in total_sendsize_list:
#     print(size)
    total_size += size
print("total_sendsize size: {} bytes".format(total_size))
print('\n')

print('---total_receivesize_list---')
total_size = 0
for size in total_receivesize_list:
#     print(size)
    total_size += size
print("total receive sizes: {} bytes".format(total_size) )
print('\n')

for i in range(users):
    print('---client_sendsize_list(user{})---'.format(i))
    total_size = 0
    for size in client_sendsize_list[i]:
#         print(size)
        total_size += size
    print("total client_sendsizes(user{}): {} bytes".format(i, total_size))
    print('\n')

    print('---client_receivesize_list(user{})---'.format(i))
    total_size = 0
    for size in client_receivesize_list[i]:
#         print(size)
        total_size += size
    print("total client_receive sizes(user{}): {} bytes".format(i, total_size))
    print('\n')

print('---train_sendsize_list---')
total_size = 0
for size in train_sendsize_list:
#     print(size)
    total_size += size
print("total train_sendsizes: {} bytes".format(total_size))
print('\n')

print('---train_receivesize_list---')
total_size = 0
for size in train_receivesize_list:
#     print(size)
    total_size += size
print("total train_receivesizes: {} bytes".format(total_size))
print('\n')




---total_sendsize_list---
total_sendsize size: 45635122 bytes


---total_receivesize_list---
total receive sizes: 45635012 bytes


---client_sendsize_list(user0)---
total client_sendsizes(user0): 22817561 bytes


---client_receivesize_list(user0)---
total client_receive sizes(user0): 22817506 bytes


---client_sendsize_list(user1)---
total client_sendsizes(user1): 22817561 bytes


---client_receivesize_list(user1)---
total client_receive sizes(user1): 22817506 bytes


---train_sendsize_list---
total train_sendsizes: 45635000 bytes


---train_receivesize_list---
total train_receivesizes: 45635000 bytes




In [18]:
root_path = '../../models/'

## Defining `ECG` Dataset Class


In [19]:
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

In [20]:
class ECG(Dataset):
    def __init__(self, train=True):
        if train:
            with h5py.File(os.path.join(root_path, 'ecg_data', 'train_ecg.hdf5'), 'r') as hdf:
                self.x = hdf['x_train'][:]
                self.y = hdf['y_train'][:]
        else:
            with h5py.File(os.path.join(root_path, 'ecg_data', 'test_ecg.hdf5'), 'r') as hdf:
                self.x = hdf['x_test'][:]
                self.y = hdf['y_test'][:]
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return torch.tensor(self.x[idx], dtype=torch.float), torch.tensor(self.y[idx])

## Making Batch Generator

In [21]:
batch_size = 32

### `DataLoader` for batch generating
`torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)`

In [22]:
train_dataset = ECG(train=True)
test_dataset = ECG(train=False)
trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(test_dataset, batch_size=batch_size)

### Number of total batches

In [23]:
train_total_batch = len(trainloader)
print(train_total_batch)
test_batch = len(testloader)
print(test_batch)

414
414


In [24]:
lr = 0.001
optimizer = Adam(ecg_net.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

## Accuracy of train and each of classes

In [25]:
ecg_net.load_state_dict(global_weights)
ecg_net.eval()
ecg_net = ecg_net.to(device)

# train acc
with torch.no_grad():
    corr_num = 0
    total_num = 0
    train_loss = 0.0
    for j, trn in enumerate(trainloader):
        trn_x, trn_label = trn
        trn_x = trn_x.to(device)
        trn_label = trn_label.clone().detach().long().to(device)

        trn_output = ecg_net(trn_x)
        loss = criterion(trn_output, trn_label)
        train_loss += loss.item()
        model_label = trn_output.argmax(dim=1)
        corr = trn_label[trn_label == model_label].size(0)
        corr_num += corr
        total_num += trn_label.size(0)
    print("train_acc: {:.2f}%, train_loss: {:.4f}".format(corr_num / total_num * 100, train_loss / len(trainloader)))


# test acc
with torch.no_grad():
    corr_num = 0
    total_num = 0
    val_loss = 0.0
    for j, val in enumerate(testloader):
        val_x, val_label = val
        val_x = val_x.to(device)
        val_label = val_label.clone().detach().long().to(device)

        val_output = ecg_net(val_x)
        loss = criterion(val_output, val_label)
        val_loss += loss.item()
        model_label = val_output.argmax(dim=1)
        corr = val_label[val_label == model_label].size(0)
        corr_num += corr
        total_num += val_label.size(0)
        accuracy = corr_num / total_num * 100
        test_loss = val_loss / len(testloader)
    print("test_acc: {:.2f}%, test_loss: {:.4f}".format( accuracy, test_loss))

# acc of each acc    
class_correct = list(0. for i in range(5))
class_total = list(0. for i in range(5))
classes = ['N', 'L', 'R', 'A', 'V']

with torch.no_grad():
    for data in testloader:
        x, labels = data
        x = x.to(device)
        labels = labels.to(device)

        outputs = ecg_net(x)
        labels = labels.long()
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(len(labels)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(5):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

# Let's quickly save our trained model:
PATH = './ecg_fd.pth'
torch.save(ecg_net.state_dict(), PATH)

end_time = time.time()  # store end time
print("WorkingTime: {} sec".format(end_time - start_time))
#     sys.exit(0)

train_acc: 98.14%, train_loss: 0.9237
test_acc: 97.35%, test_loss: 0.9317
Accuracy of     N : 97 %
Accuracy of     L : 99 %
Accuracy of     R : 99 %
Accuracy of     A : 82 %
Accuracy of     V : 98 %
WorkingTime: 365.4895222187042 sec
