In [1]:
import socket
import pickle
import struct
import copy
from tqdm import tqdm

from sklearn.preprocessing import MinMaxScaler
import numpy as np

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

In [3]:
device = "cpu"

torch.manual_seed(777)
if device =="cuda:0":
    torch.cuda.manual_seed_all(777)

In [4]:
with (open("recsys_data/data.p", "rb")) as openfile:
    train_data = pickle.load(openfile)
with (open("recsys_data/lab_ratings.p", "rb")) as openfile:
    lab_ratings = pickle.load(openfile)
with (open("recsys_data/fb_data_train.p", "rb")) as openfile:
    fb_data_train = pickle.load(openfile)
with (open("recsys_data/fb_labels_train.p", "rb")) as openfile:
    fb_labels_train = pickle.load(openfile)
with (open("recsys_data/fb_data_test.p", 'rb')) as openfile:
    fb_data_test = pickle.load(openfile)
with (open("recsys_data/fb_labels_test.p", "rb")) as openfile:
    fb_labels_test = pickle.load(openfile)

In [5]:
EMBED_FILE = 'recsys_data/processed_full.embed'

In [6]:
def load_embeddings(embed_file):
    #also normalizes the embeddings
    W = []
    with open(embed_file) as ef:
        for line in ef:
            line = line.rstrip().split()
            vec = np.array(line[1:]).astype(np.float)
            vec = vec / float(np.linalg.norm(vec) + 1e-6)
            W.append(vec)
        #UNK embedding, gaussian randomly initialized 
        print("adding unk embedding")
        vec = np.random.randn(len(W[-1]))
        vec = vec / float(np.linalg.norm(vec) + 1e-6)
        W.append(vec)
    W = np.array(W)
    return W

In [7]:
embedding_matrix = load_embeddings(EMBED_FILE)
# embedding_matrix.shape()

adding unk embedding


In [8]:
import pandas as pd
vocab = pd.read_csv('recsys_data/vocab.csv', header=None)
len(set(sorted(vocab[0].tolist())))

51917

In [9]:
def load_vocab_dict(vocab_file):
    vocab_df = pd.read_csv(vocab_file, header=None)
    vocab = sorted(set(vocab_df[0].tolist()))
    ind2w = {i+1:w for i,w in enumerate(vocab)}
    w2ind = {w:i for i,w in ind2w.items()}
    return ind2w, w2ind

In [10]:
idx2w, w2idx = load_vocab_dict('recsys_data/vocab.csv')

In [11]:
def clean_text(text):
    s = text.replace('[', "")
    s = s.replace(']', "")
    s = s.replace("'", "")
    s = s.replace(",", "")
    s = s.split()
    return s

def encoding_disease(data, w2idx):
    idx_total = []
    for i in range(len(data['age'])):
        text = data['disease'][i]
        cleaned_text = clean_text(text)
        idx = []
        for st in cleaned_text:
            if st not in w2idx:
                idx.append(len(w2idx)+1)
            else:
                idx.append(w2idx[st])
        idx_total.append(idx)
    return np.array(idx_total)

In [12]:
def pad_sequences(sequences, max_seq_len: int = 0):
    max_seq_len = max(max_seq_len, max(len(sequence) for sequence in sequences))
    # Pad
    padded_sequences = np.zeros((len(sequences), max_seq_len))
    for i, sequence in enumerate(sequences):
        padded_sequences[i][: len(sequence)] = sequence
    return padded_sequences


In [13]:
train_encoded_disease = encoding_disease(train_data, w2idx)
train_padded_encoded_disease = pad_sequences(train_encoded_disease)
fb_train_encoded_disease = encoding_disease(fb_data_train, w2idx)
fb_train_padded_encoded_disease = pad_sequences(fb_train_encoded_disease)
fb_test_encoded_disease = encoding_disease(fb_data_test, w2idx)
fb_test_padded_encoded_disease = pad_sequences(fb_test_encoded_disease)

train_data['encoded_disease'] = train_padded_encoded_disease
fb_data_train['encoded_disease'] = fb_train_padded_encoded_disease
fb_data_test['encoded_disease'] = fb_test_padded_encoded_disease

  return np.array(idx_total)


In [14]:
EPOCHS = 50
users = 2

In [15]:
train_data['age'] = train_data['age'].reshape((-1, 1))
train_data['weight'] = train_data['weight'].reshape((-1, 1))

In [16]:
scaler_age = MinMaxScaler()
scaler_weight = MinMaxScaler()
train_data['age'] = scaler_age.fit_transform(train_data['age'])
train_data['weight'] = scaler_weight.fit_transform(train_data['weight'])

In [17]:
class InitialDataset(Dataset):
    def __init__(self):
        self.data = train_data
        self.ratings = lab_ratings
    def __getitem__(self, index):
        age = self.data['age'][index]
        weight = self.data['weight'][index]
        icd_codes = self.data['codes'][index]
        disease = self.data['encoded_disease'][index]
        target = self.ratings[index]
        return {
            'age': torch.tensor(age, dtype=float),
            'weight': torch.tensor(weight, dtype=float),
            'disease': torch.tensor(disease, dtype=torch.long),
            'icd_codes': torch.tensor(icd_codes, dtype=float),
            'target': torch.tensor(target, dtype=float)
        }
    def __len__(self):
        return len(self.data['hadm_id'])

In [18]:
fb_data_train['age'] = fb_data_train['age'].reshape((-1, 1))
fb_data_train['weight'] = fb_data_train['weight'].reshape((-1, 1))

scaler_age = MinMaxScaler()
scaler_weight = MinMaxScaler()
fb_data_train['age'] = scaler_age.fit_transform(fb_data_train['age'])
fb_data_train['weight'] = scaler_weight.fit_transform(fb_data_train['weight'])

In [19]:
fb_data_test['age'] = fb_data_test['age'].reshape((-1, 1))
fb_data_test['weight'] = fb_data_test['weight'].reshape((-1, 1))

scaler_age = MinMaxScaler()
scaler_weight = MinMaxScaler()
fb_data_test['age'] = scaler_age.fit_transform(fb_data_test['age'])
fb_data_test['weight'] = scaler_weight.fit_transform(fb_data_test['weight'])

In [20]:
class FeedbackDataset(Dataset):
    def __init__(self, train=False):
        if train:
            self.data = fb_data_train
            self.ratings = fb_labels_train
        else:
            self.data = fb_data_test
            self.ratings = fb_labels_test
            
    def __getitem__(self, index):
        age = self.data['age'][index]
        weight = self.data['weight'][index]
        icd_codes = self.data['codes'][index]
        disease = self.data['encoded_disease'][index]
        target = self.ratings[index]
        return {
            'age': torch.tensor(age, dtype=float),
            'weight': torch.tensor(weight, dtype=float),
            'disease': torch.tensor(disease, dtype=torch.long),
            'icd_codes': torch.tensor(icd_codes, dtype=float),
            'target': torch.tensor(target, dtype=float)
        }
    def __len__(self):
        return len(self.data['age'])

In [21]:
initial_dataset = InitialDataset()
train_loader = DataLoader(initial_dataset, batch_size=64, shuffle=True)

fb_train_dataset = FeedbackDataset(train=True)
fb_train_loader = DataLoader(fb_train_dataset, batch_size=64, shuffle=True)

fb_test_dataset = FeedbackDataset(train=False)
fb_test_loader = DataLoader(fb_test_dataset, batch_size=64)

In [22]:
class RecSysServer(nn.Module):
    def __init__(self):
        super().__init__()
#             self.fc1 = nn.Linear(52, 1024)
#             self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, 25)
    def forward(self, x):
        x = self.fc3(x)
        x = F.relu(x)
        x = self.fc4(x)
        return x
            

In [23]:
server = RecSysServer()
print(server)

RecSysServer(
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=25, bias=True)
)


In [24]:
class RecSysClient(nn.Module):
    def __init__(self):
        super().__init__()
#         W = torch.Tensor(load_embeddings(EMBED_FILE))
#         self.embed = nn.Embedding(W.size()[0], W.size()[1], padding_idx=0)
#         self.embed.weight.data = W.clone()
        W = torch.Tensor(embedding_matrix)
        self.embed = nn.Embedding(W.size()[0], W.size()[1], padding_idx=0)
        self.embed.weight.data = W.clone()
        
        self.fc1 = nn.Linear(152, 1024)
        self.fc2 = nn.Linear(1024, 512)
#         self.fc3 = nn.Linear(512, 256)
#         self.fc4 = nn.Linear(256, 25)
    def forward(self, age, weight, icd_codes, disease):
#         age = x['age']
#         weight = x['weight']
#         icd_codes = x['icd_codes']
        
        embedded = self.embed(disease)
        embedded = torch.mean(embedded, 1)
        
        x = torch.cat((embedded, icd_codes, age, weight), 1).float()
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        return x

In [25]:
client = RecSysClient()
print(client)

RecSysClient(
  (embed): Embedding(51919, 100, padding_idx=0)
  (fc1): Linear(in_features=152, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
)


In [26]:
criterion = nn.MSELoss()
lr = 0.001
optimizer_server = Adam(server.parameters(), lr=lr)

clientsoclist = []
train_total_batch = []
client_weights = copy.deepcopy(client.state_dict())

In [27]:
total_sendsize_list = []
total_receivesize_list = []

client_sendsize_list = [[] for i in range(users)]
client_receivesize_list = [[] for i in range(users)]

train_sendsize_list = [] 
train_receivesize_list = []

In [28]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = pickle.dumps(msg)
    l_send = len(msg)
    msg = struct.pack('>I', l_send) + msg
    sock.sendall(msg)
    return l_send

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    msg =  recvall(sock, msglen)
    msg = pickle.loads(msg)
    return msg, msglen

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

In [29]:
host = socket.gethostbyname(socket.gethostname())
port = 10080
print(host)

10.10.7.64


In [30]:
s = socket.socket()
s.bind((host, port))
s.listen(5)

In [31]:
for i in range(users):
    conn, addr = s.accept()
    print('Conntected with', addr)
    clientsoclist.append(conn)    # append client socket on list

    datasize = send_msg(conn, EPOCHS)    #send epoch
    total_sendsize_list.append(datasize)
    client_sendsize_list[i].append(datasize)

    total_batch, datasize = recv_msg(conn)    # get total_batch of train dataset
    total_receivesize_list.append(datasize)
    client_receivesize_list[i].append(datasize)

    train_total_batch.append(total_batch)    # append on list

Conntected with ('10.10.7.64', 62701)
Conntected with ('10.10.7.64', 62702)


In [32]:
for e in range(EPOCHS):

    # train client 0

    for user in range(users):

        datasize = send_msg(clientsoclist[user], client_weights)
        total_sendsize_list.append(datasize)
        client_sendsize_list[user].append(datasize)
        train_sendsize_list.append(datasize)

        for i in tqdm(range(train_total_batch[user]), ncols=100, desc='Epoch {} Client{} '.format(e+1, user)):
            optimizer_server.zero_grad()  # initialize all gradients to zero

            msg, datasize = recv_msg(clientsoclist[user])  # receive client message from socket
            total_receivesize_list.append(datasize)
            client_receivesize_list[user].append(datasize)
            train_receivesize_list.append(datasize)

            client_output_cpu = msg['client_output']  # client output tensor
            label = msg['label']  # label

            client_output = client_output_cpu.to(device)
            label = label.clone().detach().float().to(device)

            output = server(client_output)  # forward propagation
            loss = criterion(output, label)  # calculates cross-entropy loss
            loss.backward()  # backward propagation
            msg = client_output_cpu.grad.clone().detach()

            datasize = send_msg(clientsoclist[user], msg)
            total_sendsize_list.append(datasize)
            client_sendsize_list[user].append(datasize)
            train_sendsize_list.append(datasize)
            
            optimizer_server.step()
            
        client_weights, datasize = recv_msg(clientsoclist[user])
        total_receivesize_list.append(datasize)
        client_receivesize_list[user].append(datasize)
        train_receivesize_list.append(datasize)
        
        

    client.load_state_dict(client_weights)
    client.to(device)
    client.eval()


    # train loss
    with torch.no_grad():
        train_loss = 0.0
        for j, trn in enumerate(train_loader):
            trn_age, trn_weight, trn_icd_codes, trn_disease= trn['age'], trn['weight'], trn['icd_codes'], trn['disease']
            trn_target = trn['target']
            
            trn_age = trn_age.to(device)
            trn_weight = trn_weight.to(device)
            trn_icd_codes = trn_icd_codes.to(device)
            trn_disease = trn_disease.to(device)
            trn_target = trn_target.to(device)

            trn_output = client(trn_age, trn_weight, trn_icd_codes, trn_disease)
            trn_output = server(trn_output)
            trn_target = trn_target.float()
            loss = criterion(trn_output, trn_target)
            train_loss += loss.item()

        r_train_loss = train_loss / len(train_loader)
        print("train_loss: {:.4f}".format(r_train_loss))

Epoch 1 Client0 : 100%|███████████████████████████████████████████| 181/181 [00:07<00:00, 23.59it/s]
Epoch 1 Client1 : 100%|███████████████████████████████████████████| 181/181 [00:07<00:00, 22.87it/s]
Epoch 2 Client0 :   0%|                                                     | 0/181 [00:00<?, ?it/s]

train_loss: 5.7572


Epoch 2 Client0 : 100%|███████████████████████████████████████████| 181/181 [00:07<00:00, 25.12it/s]
Epoch 2 Client1 : 100%|███████████████████████████████████████████| 181/181 [00:07<00:00, 24.16it/s]
Epoch 3 Client0 :   0%|                                                     | 0/181 [00:00<?, ?it/s]

train_loss: 5.6990


Epoch 3 Client0 : 100%|███████████████████████████████████████████| 181/181 [00:07<00:00, 24.82it/s]
Epoch 3 Client1 : 100%|███████████████████████████████████████████| 181/181 [00:07<00:00, 24.94it/s]
Epoch 4 Client0 :   0%|                                                     | 0/181 [00:00<?, ?it/s]

train_loss: 5.6446


Epoch 4 Client0 : 100%|███████████████████████████████████████████| 181/181 [00:07<00:00, 24.96it/s]
Epoch 4 Client1 : 100%|███████████████████████████████████████████| 181/181 [00:07<00:00, 25.22it/s]
Epoch 5 Client0 :   0%|                                                     | 0/181 [00:00<?, ?it/s]

train_loss: 5.5273


Epoch 5 Client0 : 100%|███████████████████████████████████████████| 181/181 [00:07<00:00, 24.30it/s]
Epoch 5 Client1 : 100%|███████████████████████████████████████████| 181/181 [00:07<00:00, 25.69it/s]
Epoch 6 Client0 :   0%|                                                     | 0/181 [00:00<?, ?it/s]

train_loss: 5.4560


Epoch 6 Client0 : 100%|███████████████████████████████████████████| 181/181 [00:07<00:00, 24.14it/s]
Epoch 6 Client1 : 100%|███████████████████████████████████████████| 181/181 [00:07<00:00, 25.00it/s]
Epoch 7 Client0 :   0%|                                                     | 0/181 [00:00<?, ?it/s]

train_loss: 5.3689


Epoch 7 Client0 : 100%|███████████████████████████████████████████| 181/181 [00:07<00:00, 25.58it/s]
Epoch 7 Client1 : 100%|███████████████████████████████████████████| 181/181 [00:07<00:00, 25.24it/s]
Epoch 8 Client0 :   0%|                                                     | 0/181 [00:00<?, ?it/s]

train_loss: 5.2507


Epoch 8 Client0 : 100%|███████████████████████████████████████████| 181/181 [00:07<00:00, 24.65it/s]
Epoch 8 Client1 : 100%|███████████████████████████████████████████| 181/181 [00:07<00:00, 25.13it/s]
Epoch 9 Client0 :   0%|                                                     | 0/181 [00:00<?, ?it/s]

train_loss: 5.1312


Epoch 9 Client0 : 100%|███████████████████████████████████████████| 181/181 [00:07<00:00, 25.47it/s]
Epoch 9 Client1 : 100%|███████████████████████████████████████████| 181/181 [00:07<00:00, 23.92it/s]
Epoch 10 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 4.9576


Epoch 10 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.18it/s]
Epoch 10 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 25.94it/s]
Epoch 11 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 4.7966


Epoch 11 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.29it/s]
Epoch 11 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.41it/s]
Epoch 12 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 4.6468


Epoch 12 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.05it/s]
Epoch 12 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.17it/s]
Epoch 13 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 4.5568


Epoch 13 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 24.83it/s]
Epoch 13 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.76it/s]
Epoch 14 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 4.3445


Epoch 14 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.19it/s]
Epoch 14 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.71it/s]
Epoch 15 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 4.1541


Epoch 15 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.36it/s]
Epoch 15 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 25.90it/s]
Epoch 16 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 4.0283


Epoch 16 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 24.83it/s]
Epoch 16 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.35it/s]
Epoch 17 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 3.8022


Epoch 17 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 24.97it/s]
Epoch 17 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.21it/s]
Epoch 18 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 3.6182


Epoch 18 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.67it/s]
Epoch 18 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 26.00it/s]
Epoch 19 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 3.4554


Epoch 19 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.42it/s]
Epoch 19 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.58it/s]
Epoch 20 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 3.2509


Epoch 20 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 25.90it/s]
Epoch 20 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.61it/s]
Epoch 21 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 3.0561


Epoch 21 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 24.51it/s]
Epoch 21 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.26it/s]
Epoch 22 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 2.9401


Epoch 22 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.37it/s]
Epoch 22 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.69it/s]
Epoch 23 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 2.7077


Epoch 23 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.75it/s]
Epoch 23 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.77it/s]
Epoch 24 Client0 :   1%|▍                                           | 2/181 [00:00<00:10, 16.74it/s]

train_loss: 2.5448


Epoch 24 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.29it/s]
Epoch 24 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.65it/s]
Epoch 25 Client0 :   1%|▍                                           | 2/181 [00:00<00:10, 16.42it/s]

train_loss: 2.4047


Epoch 25 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.39it/s]
Epoch 25 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.80it/s]
Epoch 26 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 2.2260


Epoch 26 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 25.92it/s]
Epoch 26 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.77it/s]
Epoch 27 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 2.1528


Epoch 27 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 24.88it/s]
Epoch 27 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.16it/s]
Epoch 28 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 2.0204


Epoch 28 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.68it/s]
Epoch 28 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.75it/s]
Epoch 29 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 1.8851


Epoch 29 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 24.89it/s]
Epoch 29 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 25.88it/s]
Epoch 30 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 1.8175


Epoch 30 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 24.84it/s]
Epoch 30 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 26.08it/s]
Epoch 31 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 1.6323


Epoch 31 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.66it/s]
Epoch 31 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 23.79it/s]
Epoch 32 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 1.6301


Epoch 32 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.25it/s]
Epoch 32 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.35it/s]
Epoch 33 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 1.5202


Epoch 33 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.23it/s]
Epoch 33 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 26.36it/s]
Epoch 34 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 1.4196


Epoch 34 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.21it/s]
Epoch 34 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 26.06it/s]
Epoch 35 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 1.3729


Epoch 35 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.65it/s]
Epoch 35 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 25.95it/s]
Epoch 36 Client0 :   1%|▍                                           | 2/181 [00:00<00:10, 16.83it/s]

train_loss: 1.2933


Epoch 36 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 25.87it/s]
Epoch 36 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.85it/s]
Epoch 37 Client0 :   1%|▍                                           | 2/181 [00:00<00:10, 16.73it/s]

train_loss: 1.2151


Epoch 37 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.56it/s]
Epoch 37 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 26.09it/s]
Epoch 38 Client0 :   1%|▍                                           | 2/181 [00:00<00:10, 17.24it/s]

train_loss: 1.1627


Epoch 38 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 26.02it/s]
Epoch 38 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.12it/s]
Epoch 39 Client0 :   1%|▍                                           | 2/181 [00:00<00:10, 16.65it/s]

train_loss: 1.0938


Epoch 39 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.48it/s]
Epoch 39 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.46it/s]
Epoch 40 Client0 :   1%|▍                                           | 2/181 [00:00<00:10, 16.29it/s]

train_loss: 1.0624


Epoch 40 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.40it/s]
Epoch 40 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 25.86it/s]
Epoch 41 Client0 :   1%|▍                                           | 2/181 [00:00<00:10, 16.63it/s]

train_loss: 1.0550


Epoch 41 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 25.93it/s]
Epoch 41 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.60it/s]
Epoch 42 Client0 :   1%|▍                                           | 2/181 [00:00<00:10, 16.65it/s]

train_loss: 0.9816


Epoch 42 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.76it/s]
Epoch 42 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.75it/s]
Epoch 43 Client0 :   1%|▍                                           | 2/181 [00:00<00:10, 16.65it/s]

train_loss: 0.9720


Epoch 43 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 25.89it/s]
Epoch 43 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.52it/s]
Epoch 44 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 0.9201


Epoch 44 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.21it/s]
Epoch 44 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 26.26it/s]
Epoch 45 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 0.8577


Epoch 45 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.37it/s]
Epoch 45 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 26.05it/s]
Epoch 46 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 0.8609


Epoch 46 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 26.31it/s]
Epoch 46 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.58it/s]
Epoch 47 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 0.8092


Epoch 47 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.40it/s]
Epoch 47 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.71it/s]
Epoch 48 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 0.7574


Epoch 48 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 24.65it/s]
Epoch 48 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.16it/s]
Epoch 49 Client0 :   1%|▍                                           | 2/181 [00:00<00:11, 16.24it/s]

train_loss: 0.7425


Epoch 49 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.34it/s]
Epoch 49 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.00it/s]
Epoch 50 Client0 :   0%|                                                    | 0/181 [00:00<?, ?it/s]

train_loss: 0.7332


Epoch 50 Client0 : 100%|██████████████████████████████████████████| 181/181 [00:07<00:00, 25.06it/s]
Epoch 50 Client1 : 100%|██████████████████████████████████████████| 181/181 [00:06<00:00, 26.02it/s]


train_loss: 0.6945


In [33]:
# host = socket.gethostbyname(socket.gethostname())
# port = 10080
# print(host)

## Train Feedback dataset

In [34]:
train_total_batch = []

total_sendsize_list = []
total_receivesize_list = []

client_sendsize_list = [[] for i in range(users)]
client_receivesize_list = [[] for i in range(users)]

train_sendsize_list = [] 
train_receivesize_list = []

In [35]:
for i in range(users):

    datasize = send_msg(clientsoclist[i], EPOCHS)    #send epoch
    total_sendsize_list.append(datasize)
    client_sendsize_list[i].append(datasize)

    total_batch, datasize = recv_msg(clientsoclist[i])    # get total_batch of train dataset
    total_receivesize_list.append(datasize)
    client_receivesize_list[i].append(datasize)

    train_total_batch.append(total_batch)    # append on list

In [36]:
for e in range(EPOCHS):

    # train client 0

    for user in range(users):

        datasize = send_msg(clientsoclist[user], client_weights)
        total_sendsize_list.append(datasize)
        client_sendsize_list[user].append(datasize)
        train_sendsize_list.append(datasize)

        for i in tqdm(range(train_total_batch[user]), ncols=100, desc='Epoch {} Client{} '.format(e+1, user)):
            optimizer_server.zero_grad()  # initialize all gradients to zero

            msg, datasize = recv_msg(clientsoclist[user])  # receive client message from socket
            total_receivesize_list.append(datasize)
            client_receivesize_list[user].append(datasize)
            train_receivesize_list.append(datasize)

            client_output_cpu = msg['client_output']  # client output tensor
            label = msg['label']  # label

            client_output = client_output_cpu.to(device)
            label = label.clone().detach().float().to(device)

            output = server(client_output)  # forward propagation
            loss = criterion(output, label)  # calculates cross-entropy loss
            loss.backward()  # backward propagation
            msg = client_output_cpu.grad.clone().detach()

            datasize = send_msg(clientsoclist[user], msg)
            total_sendsize_list.append(datasize)
            client_sendsize_list[user].append(datasize)
            train_sendsize_list.append(datasize)
            
            optimizer_server.step()
            
        client_weights, datasize = recv_msg(clientsoclist[user])
        total_receivesize_list.append(datasize)
        client_receivesize_list[user].append(datasize)
        train_receivesize_list.append(datasize)
        
        

    client.load_state_dict(client_weights)
    client.to(device)
    client.eval()


    # train acc
    with torch.no_grad():
        train_loss = 0.0
        for j, trn in enumerate(fb_train_loader):
            trn_age, trn_weight, trn_icd_codes, trn_disease = trn['age'], trn['weight'], trn['icd_codes'], trn['disease']
            trn_target = trn['target']
            
            trn_age = trn_age.to(device)
            trn_weight = trn_weight.to(device)
            trn_icd_codes = trn_icd_codes.to(device)
            trn_disease = trn_disease.to(device)
            
            trn_target = trn_target.to(device)

            trn_output = client(trn_age, trn_weight, trn_icd_codes, trn_disease)
            trn_output = server(trn_output)
            trn_target = trn_target.float()
            loss = criterion(trn_output, trn_target)
            train_loss += loss.item()

        r_train_loss = train_loss / len(train_loader)
        print("train_loss: {:.4f}".format(r_train_loss))
    
    # test acc
    with torch.no_grad():
        val_loss = 0.0
        for j, val in enumerate(fb_test_loader):
            val_age, val_weight, val_icd_codes, val_disease = val['age'], val['weight'], val['icd_codes'], val['disease']
            val_target = val['target']
            
            val_age = val_age.to(device)
            val_weight = val_weight.to(device)
            val_icd_codes = val_icd_codes.to(device)
            val_disease = val_disease.to(device)
            
            val_target = val_target.to(device)

            val_output = client(val_age, val_weight, val_icd_codes, val_disease)
            val_output = server(val_output)
        
            val_label = val_target.float()
            loss = criterion(val_output, val_label)
            val_loss += loss.item()
            
        test_loss = val_loss / len(fb_test_loader)
        print("test_loss: {:.4f}".format(test_loss))

Epoch 1 Client0 : 100%|█████████████████████████████████████████████| 73/73 [00:02<00:00, 25.74it/s]
Epoch 1 Client1 : 100%|█████████████████████████████████████████████| 73/73 [00:02<00:00, 25.92it/s]


train_loss: 3.7204


Epoch 2 Client0 :   0%|                                                      | 0/73 [00:00<?, ?it/s]

test_loss: 4.7099


Epoch 2 Client0 : 100%|█████████████████████████████████████████████| 73/73 [00:02<00:00, 25.25it/s]
Epoch 2 Client1 : 100%|█████████████████████████████████████████████| 73/73 [00:02<00:00, 25.24it/s]


train_loss: 2.8704


Epoch 3 Client0 :   0%|                                                      | 0/73 [00:00<?, ?it/s]

test_loss: 3.7330


Epoch 3 Client0 : 100%|█████████████████████████████████████████████| 73/73 [00:02<00:00, 25.60it/s]
Epoch 3 Client1 : 100%|█████████████████████████████████████████████| 73/73 [00:02<00:00, 25.96it/s]


train_loss: 2.4114


Epoch 4 Client0 :   0%|                                                      | 0/73 [00:00<?, ?it/s]

test_loss: 3.2287


Epoch 4 Client0 : 100%|█████████████████████████████████████████████| 73/73 [00:02<00:00, 25.71it/s]
Epoch 4 Client1 : 100%|█████████████████████████████████████████████| 73/73 [00:02<00:00, 25.38it/s]


train_loss: 2.1356


Epoch 5 Client0 :   0%|                                                      | 0/73 [00:00<?, ?it/s]

test_loss: 2.9481


Epoch 5 Client0 : 100%|█████████████████████████████████████████████| 73/73 [00:02<00:00, 24.79it/s]
Epoch 5 Client1 : 100%|█████████████████████████████████████████████| 73/73 [00:02<00:00, 24.62it/s]


train_loss: 1.8729


Epoch 6 Client0 :   0%|                                                      | 0/73 [00:00<?, ?it/s]

test_loss: 2.6487


Epoch 6 Client0 : 100%|█████████████████████████████████████████████| 73/73 [00:02<00:00, 25.24it/s]
Epoch 6 Client1 : 100%|█████████████████████████████████████████████| 73/73 [00:02<00:00, 25.45it/s]


train_loss: 1.5278


Epoch 7 Client0 :   3%|█▎                                            | 2/73 [00:00<00:04, 15.99it/s]

test_loss: 2.2627


Epoch 7 Client0 : 100%|█████████████████████████████████████████████| 73/73 [00:02<00:00, 25.40it/s]
Epoch 7 Client1 : 100%|█████████████████████████████████████████████| 73/73 [00:02<00:00, 25.53it/s]


train_loss: 1.3156


Epoch 8 Client0 :   0%|                                                      | 0/73 [00:00<?, ?it/s]

test_loss: 2.0278


Epoch 8 Client0 : 100%|█████████████████████████████████████████████| 73/73 [00:03<00:00, 24.17it/s]
Epoch 8 Client1 : 100%|█████████████████████████████████████████████| 73/73 [00:02<00:00, 25.11it/s]


train_loss: 1.1231


Epoch 9 Client0 :   0%|                                                      | 0/73 [00:00<?, ?it/s]

test_loss: 1.8122


Epoch 9 Client0 : 100%|█████████████████████████████████████████████| 73/73 [00:02<00:00, 25.28it/s]
Epoch 9 Client1 : 100%|█████████████████████████████████████████████| 73/73 [00:03<00:00, 23.46it/s]


train_loss: 0.9576


Epoch 10 Client0 :   0%|                                                     | 0/73 [00:00<?, ?it/s]

test_loss: 1.6268


Epoch 10 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.51it/s]
Epoch 10 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.59it/s]


train_loss: 0.8018


Epoch 11 Client0 :   0%|                                                     | 0/73 [00:00<?, ?it/s]

test_loss: 1.4456


Epoch 11 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.54it/s]
Epoch 11 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.61it/s]


train_loss: 0.6611


Epoch 12 Client0 :   0%|                                                     | 0/73 [00:00<?, ?it/s]

test_loss: 1.2893


Epoch 12 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:03<00:00, 22.73it/s]
Epoch 12 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.44it/s]


train_loss: 0.5637


Epoch 13 Client0 :   0%|                                                     | 0/73 [00:00<?, ?it/s]

test_loss: 1.1852


Epoch 13 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.13it/s]
Epoch 13 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.61it/s]


train_loss: 0.4880


Epoch 14 Client0 :   0%|                                                     | 0/73 [00:00<?, ?it/s]

test_loss: 1.0983


Epoch 14 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 24.79it/s]
Epoch 14 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.57it/s]


train_loss: 0.4016


Epoch 15 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 16.80it/s]

test_loss: 0.9991


Epoch 15 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.26it/s]
Epoch 15 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.65it/s]


train_loss: 0.3213


Epoch 16 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 16.68it/s]

test_loss: 0.8972


Epoch 16 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 24.92it/s]
Epoch 16 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.02it/s]


train_loss: 0.3002


Epoch 17 Client0 :   0%|                                                     | 0/73 [00:00<?, ?it/s]

test_loss: 0.8750


Epoch 17 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:03<00:00, 24.28it/s]
Epoch 17 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:03<00:00, 21.98it/s]


train_loss: 0.2765


Epoch 18 Client0 :   0%|                                                     | 0/73 [00:00<?, ?it/s]

test_loss: 0.8547


Epoch 18 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.42it/s]
Epoch 18 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.16it/s]


train_loss: 0.2505


Epoch 19 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.46it/s]

test_loss: 0.8198


Epoch 19 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.27it/s]
Epoch 19 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.42it/s]


train_loss: 0.2202


Epoch 20 Client0 :   0%|                                                     | 0/73 [00:00<?, ?it/s]

test_loss: 0.7777


Epoch 20 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:03<00:00, 24.09it/s]
Epoch 20 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.71it/s]


train_loss: 0.2130


Epoch 21 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.01it/s]

test_loss: 0.7794


Epoch 21 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.68it/s]
Epoch 21 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.63it/s]


train_loss: 0.2005


Epoch 22 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.23it/s]

test_loss: 0.7628


Epoch 22 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.57it/s]
Epoch 22 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.88it/s]


train_loss: 0.1835


Epoch 23 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.13it/s]

test_loss: 0.7386


Epoch 23 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.40it/s]
Epoch 23 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.37it/s]


train_loss: 0.1750


Epoch 24 Client0 :   0%|                                                     | 0/73 [00:00<?, ?it/s]

test_loss: 0.7233


Epoch 24 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.02it/s]
Epoch 24 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.76it/s]


train_loss: 0.1701


Epoch 25 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.24it/s]

test_loss: 0.7211


Epoch 25 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:03<00:00, 23.20it/s]
Epoch 25 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:03<00:00, 21.76it/s]


train_loss: 0.1603


Epoch 26 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.52it/s]

test_loss: 0.7029


Epoch 26 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.49it/s]
Epoch 26 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.21it/s]


train_loss: 0.1526


Epoch 27 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.09it/s]

test_loss: 0.6874


Epoch 27 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.82it/s]
Epoch 27 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.08it/s]


train_loss: 0.1483


Epoch 28 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.29it/s]

test_loss: 0.6828


Epoch 28 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.62it/s]
Epoch 28 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.94it/s]


train_loss: 0.1445


Epoch 29 Client0 :   3%|█▏                                           | 2/73 [00:00<00:03, 17.81it/s]

test_loss: 0.6731


Epoch 29 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.93it/s]
Epoch 29 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.01it/s]


train_loss: 0.1427


Epoch 30 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.70it/s]

test_loss: 0.6750


Epoch 30 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.83it/s]
Epoch 30 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.92it/s]


train_loss: 0.1402


Epoch 31 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.42it/s]

test_loss: 0.6686


Epoch 31 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.05it/s]
Epoch 31 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.11it/s]


train_loss: 0.1397


Epoch 32 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.14it/s]

test_loss: 0.6691


Epoch 32 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.09it/s]
Epoch 32 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.21it/s]


train_loss: 0.1374


Epoch 33 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.39it/s]

test_loss: 0.6562


Epoch 33 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.71it/s]
Epoch 33 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.68it/s]


train_loss: 0.1331


Epoch 34 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 16.65it/s]

test_loss: 0.6594


Epoch 34 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.20it/s]
Epoch 34 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.00it/s]


train_loss: 0.1316


Epoch 35 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.36it/s]

test_loss: 0.6575


Epoch 35 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.96it/s]
Epoch 35 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.22it/s]


train_loss: 0.1328


Epoch 36 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.55it/s]

test_loss: 0.6526


Epoch 36 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.74it/s]
Epoch 36 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.82it/s]


train_loss: 0.1320


Epoch 37 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.64it/s]

test_loss: 0.6492


Epoch 37 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:03<00:00, 24.31it/s]
Epoch 37 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.91it/s]


train_loss: 0.1316


Epoch 38 Client0 :   3%|█▏                                           | 2/73 [00:00<00:03, 17.82it/s]

test_loss: 0.6482


Epoch 38 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.95it/s]
Epoch 38 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:03<00:00, 24.24it/s]


train_loss: 0.1315


Epoch 39 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.71it/s]

test_loss: 0.6491


Epoch 39 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.04it/s]
Epoch 39 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.70it/s]


train_loss: 0.1329


Epoch 40 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.48it/s]

test_loss: 0.6493


Epoch 40 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.25it/s]
Epoch 40 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.59it/s]


train_loss: 0.1314


Epoch 41 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.36it/s]

test_loss: 0.6489


Epoch 41 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.59it/s]
Epoch 41 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.11it/s]


train_loss: 0.1274


Epoch 42 Client0 :   3%|█▏                                           | 2/73 [00:00<00:03, 17.85it/s]

test_loss: 0.6418


Epoch 42 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.62it/s]
Epoch 42 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.06it/s]


train_loss: 0.1334


Epoch 43 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.33it/s]

test_loss: 0.6517


Epoch 43 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.03it/s]
Epoch 43 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.13it/s]


train_loss: 0.1296


Epoch 44 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.50it/s]

test_loss: 0.6564


Epoch 44 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.98it/s]
Epoch 44 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.31it/s]


train_loss: 0.1288


Epoch 45 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 16.89it/s]

test_loss: 0.6517


Epoch 45 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.95it/s]
Epoch 45 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.28it/s]


train_loss: 0.1272


Epoch 46 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.41it/s]

test_loss: 0.6434


Epoch 46 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.58it/s]
Epoch 46 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.01it/s]


train_loss: 0.1309


Epoch 47 Client0 :   3%|█▏                                           | 2/73 [00:00<00:03, 17.89it/s]

test_loss: 0.6451


Epoch 47 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.51it/s]
Epoch 47 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.59it/s]


train_loss: 0.1294


Epoch 48 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.32it/s]

test_loss: 0.6387


Epoch 48 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.83it/s]
Epoch 48 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.76it/s]


train_loss: 0.1308


Epoch 49 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.71it/s]

test_loss: 0.6416


Epoch 49 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.12it/s]
Epoch 49 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.12it/s]


train_loss: 0.1254


Epoch 50 Client0 :   3%|█▏                                           | 2/73 [00:00<00:04, 17.71it/s]

test_loss: 0.6366


Epoch 50 Client0 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 26.85it/s]
Epoch 50 Client1 : 100%|████████████████████████████████████████████| 73/73 [00:02<00:00, 25.93it/s]


train_loss: 0.1274
test_loss: 0.6309


In [37]:
def precision_at_k(y_true, y_pred, k):
    topk_true = np.flip(np.argsort(y_true), 1)[:, :k]
    topk_pred = np.flip(np.argsort(y_pred), 1)[:, :k]

    n_relevant = 0
    n_recommend = 0

    for t, p in zip(topk_true, topk_pred):
        # print(f"t:{t}")
        # print(f"p:{p}")
        n_relevant += len(np.intersect1d(t, p))
        # print(f"rev:{n_relevant}")
        n_recommend += len(p)
        # print(n_recommend)

    return float(n_relevant) / n_recommend

In [38]:
with torch.no_grad():
        val_loss = 0.0
        pred = torch.empty((0, 25))
        for j, val in enumerate(fb_test_loader):
            val_age, val_weight, val_icd_codes, val_disease = val['age'], val['weight'], val['icd_codes'], val['disease']
            val_target = val['target']
            
            val_age = val_age.to(device)
            val_weight = val_weight.to(device)
            val_icd_codes = val_icd_codes.to(device)
            val_disease = val_disease.to(device)
            
            val_target = val_target.to(device)

            val_output = client(val_age, val_weight, val_icd_codes, val_disease)
            val_output = server(val_output)
        
            val_label = val_target.float()
            loss = criterion(val_output, val_label)
            val_loss += loss.item()
            pred = torch.cat((pred, val_output))
            
        test_loss = val_loss / len(fb_test_loader)
        print("test_loss: {:.4f}".format(test_loss))

test_loss: 0.6309


In [39]:
pred = pred.numpy()
print(f"After training feedback dataset Precision@5: {precision_at_k(fb_labels_test, pred, 5)}")

After training feedback dataset Precision@5: 0.8190189328743546
