In [None]:
import numpy as np
from torch import nn
import torch
from Model_IHM import *
from data_loaders import *
from utils import *
import pandas as pd
import warnings
from Client import *
from FL_Server import find_common_intersection,aggregate_models
import random

warnings.filterwarnings('ignore')

random.seed(200)
torch.manual_seed(200)
np.random.seed(200)

In [None]:
# Load data from sites and create data loaders
num_clients = 5
Loaders=create_non_iid_loaders('./physio_data/Data.npy', './physio_data/Labels.npy', num_clients)

In [None]:
device = get_device()
print(f'Device: {device}')

n_features = Loaders[0][0].dataset[0][0].shape[1]
hidden=64
global_model =  LSTMClassifier(n_features, hidden, device)
global_model.to(device)


criterion = nn.BCELoss().to(device)
best=0
print(global_model)
global_model_params = [param.data.clone() for param in global_model.parameters()]


In [None]:
num_rounds = 10
DF = [0]*num_clients

# List for best val auc at each client
Val_AUC = [0]*num_clients
Val_APR = [0]*num_clients


for h in range(0, num_clients):
    DF[h] = pd.DataFrame(columns=['Train_Loss', 'Val_Loss', 'Val_AUC','Val_APR'])

In [None]:
    best=0
    from tqdm import tqdm
    for round_num in tqdm(range(num_rounds)):
        # print(f"Round {round_num + 1}/{num_rounds}")

        client_paths = []
        client_losses = []

        LOSS=[]
        client_models=[]
        for client_id in range(num_clients):
            # print('----------------------------------------------------')
            print(f"Client {client_id + 1}/{num_clients}")

            low_loss_path, losses,TL,local_model=client_update(global_model_params, Loaders[client_id], criterion)

            LOSS.append(TL)
            client_paths.append(low_loss_path)
            client_losses.append(losses)
            client_models.append(local_model)


        intersection_point = find_common_intersection(client_paths, client_losses,global_model_params)
        global_model = aggregate_models(global_model, intersection_point)
        global_model_params = [param.data.clone() for param in global_model.parameters()]


        for k in range(0, num_clients):
            local_model =client_models[k]
            DF[k], Val_AUC[k],cur,cur_apr = evaluate_models(k, Loaders, local_model, criterion, device, DF[k], Val_AUC, LOSS[k],'FedMode')
            print(f'Node : {k:.1f} || Train Loss {LOSS[k]:.3f} || Best Val AUC {Val_AUC[k]:.3f} || Current AUC {cur:.3f}|| Curr APR {cur_apr:.3f}')
        print('=======================================')

        AUC=0
        for k in range(0, num_clients):
            Vloss,cur,cur_apr = evaluate_models_test(k, Loaders, global_model, criterion, device)
            AUC=AUC+cur

        G=AUC/num_clients
        if G>best:
           torch.save(global_model, './trained_models/FedMode/global_model')
           best=G



In [None]:
from utils import *
sum_auc=0
sum_apr=0

for k in range(0,num_clients):
    local_model=torch.load('./trained_models/FedMode/node'+str(k))
    local_model.to(device)
    val_loss, val_auc, val_apr = prediction_binary(local_model, Loaders[k][2], criterion, device)
    sum_auc=sum_auc+val_auc
    sum_apr=sum_apr+val_apr
    print(f'Node : {k:.1f} || AUC {val_auc:.4f}|| APR {val_apr:.4f}')
    print('=======================================')

print(sum_auc/num_clients)
print(sum_apr/num_clients)

In [None]:
from utils import *
sum_auc=0
sum_apr=0
global_model=torch.load('./trained_models/FedMode/global_model')
for k in range(0,num_clients):
    val_loss, val_auc, val_apr = prediction_binary(global_model, Loaders[k][2], criterion, device)
    sum_auc=sum_auc+val_auc
    sum_apr=sum_apr+val_apr
    print(f'Node : {k:.1f} || AUC {val_auc:.4f}|| APR {val_apr:.4f}')
    print('=======================================')

print(sum_auc/num_clients)
print(sum_apr/num_clients)