In [1]:
import numpy as np
from torch import nn
import torch
from Model import *
from loaders_FL import *
from utils import *
import pandas as pd
import warnings
from collections import OrderedDict
import os
import copy
import pickle

warnings.filterwarnings('ignore')    
random.seed(200)
torch.manual_seed(200)
np.random.seed(200)

In [2]:
import pandas as pd
A=pd.read_csv('./final_data/UHB.csv') # Files Not included here
cols = [col for col in A.columns if 'Blood_Test' in col] ## list column names iwth blood test features.. Required for consistency
files=['BH','OUH','PUH','UHB'] # Client names
Loaders=get_loaders_structured(cols,files,path='./final_data/',batch=64,unstruct=1) ### if unstruct = 1, it will add vital signs to some of the clients along with blood tests.

(1865, 21)
------------
(161955, 28)
------------
(38717, 21)
------------
(95236, 28)
------------


In [3]:
device = get_device()
print(f'Device: {device}')

### Define global prediction model
latent_dim=32 ## output dim after filtering module or dimensionality of knwoeldge vector
hidden=128    ## Hidden layer nodes in prediction model 
dnn=DNN(latent_dim,hidden,device) ## prediction model

### COVID_Pred adds Filtering module and Knowledge vector around the prediction model (dnn). Filtering module maps input to latent dim using knowledge vector. Here input is just a plcae holder as we dont do filtering at server 
global_model=COVID_Pred(len(cols),latent_dim,dnn,device)
global_model.to(device)
print(global_model)
optimizer = torch.optim.Adam(params=global_model.parameters(), lr=0.001)  ### Global optimiser

Device: cuda:0
COVID_Pred(
  (cca): CCA(
    (mlp_x): Sequential(
      (0): Flatten()
      (1): Linear(in_features=32, out_features=32, bias=True)
    )
    (mlp_g): Sequential(
      (0): Flatten()
      (1): Linear(in_features=21, out_features=32, bias=True)
    )
    (relu): ReLU(inplace=True)
  )
  (model): DNN(
    (fc1): Linear(in_features=32, out_features=128, bias=True)
    (activation1): ReLU()
    (drop): Dropout(p=0.25, inplace=False)
    (fc2): Linear(in_features=128, out_features=1, bias=True)
    (activation2): Sigmoid()
  )
)


In [4]:

### Creating dictionaries to store filtering modules at each client. Just a "helper" step/ 
A=[21,28,21,28] ## Input features at our four clients. 


## creating a folder to store model states

models_dir = 'temp'
if not os.path.exists(f'./temp'):
    os.makedirs(f'./temp')

state = OrderedDict({k: tensor.data for k, tensor in global_model.state_dict().items()})
nodes=4
for i in range(nodes):
    dnn=DNN(latent_dim,hidden,device)
    g_model=COVID_Pred(A[i],latent_dim,dnn,device)
    g_model.to(device)
    state = OrderedDict({k: tensor.data for k, tensor in g_model.state_dict().items()})
    name='./temp/'+str(i)+'_order.pickle'
    with open(name, 'wb') as handle:
         pickle.dump(state, handle, protocol=pickle.HIGHEST_PROTOCOL)
    del g_model,dnn     


In [5]:
### local training at each client

def client_training(global_weights, Loaders, client_id, inner_epochs=25, device=device):
    train_loader = Loaders[client_id][0]
    n_features = train_loader.dataset[0][0].shape[0]
    # print(n_features)
    
    dnn=DNN(latent_dim,hidden,device)
    net=COVID_Pred(n_features,latent_dim,dnn,device)
    net.to(device)
    
    ## loading the stored model state
    name='./temp/'+str(client_id)+'_order.pickle'
    with open(name, 'rb') as handle:
         old=pickle.load(handle)

    # find names of model parameters and not filtering module (CCA) 
    A=[key for key, value in state.items() if not 'cca' in key.lower()]
    
    # copy the weights recived from serevr
    global_w = copy.deepcopy(global_weights)
        
    # update the parameters in loadeed model state. Leave filtering module unaltered
    for key in A:
        old[key]=global_w[key]

    net.load_state_dict(old)      
    
    # Initialise the local optimizer and store the initial weights
    inner_opt = torch.optim.Adam(params=net.parameters(), lr=0.001)
    inner_state = OrderedDict({k: tensor.data for k, tensor in global_weights.items()})
    
    # Initialize the local loss function
    loss_fn = nn.BCELoss().to(device)

    # Training loop for local training 
    train_loss = 0 
    for epoch in range(0, inner_epochs):
        
        # Iterate over the batches in the dataloader
        for i, batch in enumerate(train_loader):
            
            # put model in train mode and reset the gradients to zero
            net.train()
            inner_opt.zero_grad()
            inputs,label=batch
            # load data and labels
            inputs=inputs.to(torch.float32).to(device)
            label=label.to(torch.float32).to(device)
            
            # forward pass and loss calculation (for batch)
            pred = net(inputs)[:,0]
            loss = loss_fn(pred, label)

            # backward pass to get gradients for optimization
            loss.backward()

            # parameter optimization using gradients calculated during backward pass
            inner_opt.step()

            # For each batch, add up the loss for all data points and take an average loss later
            train_loss += loss.detach().cpu().item() 

    # store gradients of client model after training
    final_state = net.state_dict()
    
    ## store the model state  
    with open(name, 'wb') as handle:
         pickle.dump(final_state, handle, protocol=pickle.HIGHEST_PROTOCOL)

    # calculate delta theta or gradients by subtracting the initial weight from the final weight
    delta_theta = OrderedDict({k: inner_state[k] - final_state[k] for k in A})
    
    B=[key for key, value in state.items() if 'cca' in key.lower()]
    
    ## nullify the filtering module gradienst as server doesnt have filtering module
    for key in B:
        delta_theta[key]=global_weights[key]
    
    av_train_loss = train_loss/len(train_loader)
    return av_train_loss, delta_theta, net

In [6]:
## crate folders to store models and training dynamics
models_dir = 'unstruct'
if not os.path.exists(f'./trained_models/{models_dir}'):
    os.makedirs(f'./trained_models/{models_dir}')

if not os.path.exists(f'./Dynamic/{models_dir}'):
    os.makedirs(f'./Dynamic/{models_dir}')

In [7]:
import time
start_time = time.time()

nodes=4 #clients

# List for storing the metric dataframes for each client
DF = [0]*nodes

# List for best val auc at each client
Val_AUC = [0]*nodes
Val_APR = [0]*nodes


for h in range(0, nodes):
    DF[h] = pd.DataFrame(columns=['Train_Loss', 'Val_Loss', 'Val_AUC','Val_APR'])

loss_fn = nn.BCELoss().to(device)

#### Loop for "rounds" of client training
num_rounds = 5

for i in range(num_rounds):
    print(f'------------------------------ STARTING TRAINING ROUND: {i+1} ... ---------------------------------') 

    # List to store the average client training loss after a round of training
    TL = [0]*nodes
    
    # Put global model in training mode and get the weights which will be sent to the clients for local training
    global_model.train()
    optimizer.zero_grad()
    global_weights = global_model.state_dict()

    # List to store gradients from each client after a round of training 
    GRADS = [0]*nodes    
    
    # List to store the personalised models from each client ## only for simulation and evaluation
    NETS = [0]*nodes

    # Loop to iteratively train each client locally
    for j in range(nodes):
        TL[j], GRADS[j], NETS[j] = client_training(global_weights, Loaders, j, inner_epochs=2) 
        print(f'Node : {j}/{nodes} training complete...', end="\r")
        

    # Combine the grads from each client after a round of local training across all clients to get the meta grad
    grad = combine_grads(GRADS) 

    # Manually update the gradients of the global model parameters using the meta grad 
    for name, par in global_model.named_parameters():
        if par.requires_grad:
              par.grad = grad[name]
    
    # Update the global model parameters by optimizing using the meta grad for SGD
    optimizer.step() 



    # Evaluate the global model (optimized using the meta gradient)
    for k in range(0, nodes): 
        DF[k], Val_AUC[k],cur,cur_apr = evaluate_models(k, Loaders, NETS[k], TL, loss_fn, device, DF[k], Val_AUC,Val_APR, models_dir)
        print(f'Node : {k:.1f} || Best Val AUC {Val_AUC[k]:.4f} || Current AUC {cur:.4f}|| Curr APR {cur_apr:.4f}')
        
    # print(f'\nNode training loss: {TL[0]:.4f}')


------------------------------ STARTING TRAINING ROUND: 1 ... ---------------------------------
Node : 0.0 || Best Val AUC 0.4309 || Current AUC 0.4309|| Curr APR 0.1456
Node : 1.0 || Best Val AUC 0.8273 || Current AUC 0.8273|| Curr APR 0.1897
Node : 2.0 || Best Val AUC 0.7990 || Current AUC 0.7990|| Curr APR 0.2331
Node : 3.0 || Best Val AUC 0.8092 || Current AUC 0.8092|| Curr APR 0.0360
------------------------------ STARTING TRAINING ROUND: 2 ... ---------------------------------
Node : 0.0 || Best Val AUC 0.5974 || Current AUC 0.5974|| Curr APR 0.1846
Node : 1.0 || Best Val AUC 0.8525 || Current AUC 0.8525|| Curr APR 0.2415
Node : 2.0 || Best Val AUC 0.8201 || Current AUC 0.8201|| Curr APR 0.2634
Node : 3.0 || Best Val AUC 0.8102 || Current AUC 0.8102|| Curr APR 0.0358
------------------------------ STARTING TRAINING ROUND: 3 ... ---------------------------------
Node : 0.0 || Best Val AUC 0.5974 || Current AUC 0.4629|| Curr APR 0.1560
Node : 1.0 || Best Val AUC 0.8565 || Current A

In [8]:
# Evaluate each client model on its val and test sets

for i in range(0, nodes): 

    model = torch.load(f'./trained_models/{models_dir}/node{i}')
    _, test_auc,apr = prediction_binary(model, Loaders[i][2], loss_fn,device) 
    DF[i].to_csv(f'./Dynamic/{models_dir}/node{i}.csv') 
    
    print(f'Node : {i:.1f} || Val AUC {Val_AUC[i]:.4f} || Test AUC {test_auc:.4f}|| Test APR {apr:.4f}') 

Node : 0.0 || Val AUC 0.6968 || Test AUC 0.7612|| Test APR 0.2679
Node : 1.0 || Val AUC 0.8608 || Test AUC 0.8504|| Test APR 0.3163
Node : 2.0 || Val AUC 0.8464 || Test AUC 0.8654|| Test APR 0.4137
Node : 3.0 || Val AUC 0.8396 || Test AUC 0.7963|| Test APR 0.0353
