In [None]:
import numpy as np
from torch import nn
import torch
from Model import GATT
from loaders_FL import *
from torch_geometric.nn.conv import GCN2Conv
from utilsGHomo import *
import pandas as pd
import warnings
from collections import OrderedDict
import os
from MakeGraph import MakegraphH

warnings.filterwarnings('ignore')
    
random.seed(210)
torch.manual_seed(210)
np.random.seed(210)

In [None]:
## Get Data Loaders for CURIAL

import pandas as pd
A=pd.read_csv('./final_data/UHB.csv')
columns2 = [col for col in A.columns if 'Blood_Test' in col]
cols=columns2


files=['BH','OUH','PUH','UHB']
Loaders=get_loaders_structured(cols,files,path='./final_data/',batch=64,unstruct=1)

In [None]:
device = get_device()
print(f'Device: {device}')

n_features = Loaders[0][0].dataset[0][0].shape[0]


hidden_dim=128 
in_dim=100 ### fixed dimension across clients after augmentation


global_model= GATT(in_dim, hidden_dim,device)
global_model.to(device)
optimizer = torch.optim.Adam(params=global_model.parameters(), lr=0.001)


print(global_model)

In [None]:


#simulating client-side operation
def client_training(global_weights, Loaders, client_id, inner_epochs=25, device=device):

    ''' Function to train a client (by its client_id) "locally" given the global weights from the server '''
    
    train_loader = Loaders[client_id][0]

    # Build the model (architecture only) locally and initialize is with the global weights
    n_features = train_loader.dataset[0][0].shape[0]
    # net= GAT(in_dim, hidden_dim,8)
    # net = GCNT(in_dim, hidden_dim,device)
    # net=Splinconv(in_dim, hidden_dim,device)
    net = GATT(in_dim, hidden_dim,device)
    # net= APPNPT(in_dim,hidden_dim,device)
    net.to(device)
    net.load_state_dict(global_weights)      
    
    # Initialise the local optimizer and store the initial weights
    inner_opt = torch.optim.Adam(params=net.parameters(), lr=0.001)
    inner_state = OrderedDict({k: tensor.data for k, tensor in global_weights.items()})
    
    # Initialize the local loss function
    loss_fn = nn.BCELoss().to(device)
    # loss_fcn = torch.nn.CrossEntropyLoss().to(device)

    # Training loop for local training 
    train_loss = 0 
    
        
    # Iterate over the batches in the dataloader
    for i, batch in enumerate(train_loader):
      inputs,label=batch

      #************************************************************************
      mean_original = inputs.mean()
      std_dev_original = inputs.std()
      normalized_data = (inputs - mean_original) / std_dev_original


      indices = torch.randint(0, normalized_data.size(1), (in_dim-inputs.shape[1],))
      resized_data = normalized_data[:, indices]

      resized_data=torch.cat((inputs, resized_data),axis=1)
      # print(resized_data.shape)
      # # Transform back to original scale (optional)
      # resized_transformed_data = (resized_data * std_dev_original) + mean_original
      #************************************************************************


    #   start = time.process_time()
      g = MakegraphHT(resized_data)
      # feature_transform = FeatureTransform(input_dim=inputs.shape[1], output_dim=16)  # Assume varying input_dim
      
      
      # Move graph and features to device
      g = g.to(device)
      # transformed_features = feature_transform(g.x.float(),g.edge_index)      
      for epoch in range(0, inner_epochs):
            if i<len(train_loader)-1:
              # print(i,len(train_loader))
              # put model in train mode and reset the gradients to zero
              net.train()
              inner_opt.zero_grad()

              # Forward pass
              logits = net(g.edge_index, g.x.float())  # note that we pass edge_index and x (node features)
              

              # load data and labels
            #   resized_data=resized_data.to(torch.float32).to(device)
              label=label.to(torch.float32).to(device)


              # print(logits)
              # loss = loss_fn(logits.squeeze(1), label)
              loss = loss_fn(logits, label)

              # print(loss)
              # backward pass to get gradients for optimization
              loss.backward(retain_graph=True)


              inner_opt.step()


              train_loss += loss.detach().cpu().item() 


    # store gradients of client model after training
    final_state = net.state_dict()

    # calculate delta theta by subtracting the initial weight from the final weight
    delta_theta = OrderedDict({k: inner_state[k] - final_state[k] for k in global_weights.keys()})
    
    av_train_loss = train_loss/len(train_loader)
    return av_train_loss, delta_theta, net

In [None]:
models_dir = 'Ab'
if not os.path.exists(f'./trained_models/{models_dir}'):
    os.makedirs(f'./trained_models/{models_dir}')

if not os.path.exists(f'./Dynamic/{models_dir}'):
    os.makedirs(f'./Dynamic/{models_dir}')


In [None]:
nodes=4
# List for storing the metric dataframes for each client
DF = [0]*nodes

# List for best val auc at each client
Val_AUC = [0]*nodes

for h in range(0, nodes):
    DF[h] = pd.DataFrame(columns=['Train_Loss', 'Val_Loss', 'Val_AUC'])

loss_fn = nn.BCELoss().to(device)

In [None]:

### simulating server side

num_rounds = 100
for i in range(0, num_rounds):
    print(f'------------------------------ STARTING TRAINING ROUND: {i+1} ... ---------------------------------') 

    # List to store the average client training loss after a round of training
    TL = [0]*nodes
    
    # Put global model in training mode and get the weights which will be sent to the clients for local training
    global_model.train()
    optimizer.zero_grad()
    global_weights = global_model.state_dict()

    # List to store gradients from each client after a round of training 
    GRADS = [0]*nodes    
    
    # List to store the personalised models from each client
    NETS = [0]*nodes

    # Loop to iteratively train each client locally
    for j in range(0, nodes):
        TL[j], GRADS[j], NETS[j] = client_training(global_weights, Loaders, j, inner_epochs=1) 
        print(f'Node : {j}/{nodes} training complete...', end="\r") 

    # Combine the grads from each client after a round of local training across all clients to get the meta grad
    grad = combine_grads(GRADS) 

    # Manually update the gradients of the global model parameters using the meta grad 
    for name, par in global_model.named_parameters():
        if par.requires_grad:
              par.grad = grad[name]
    
    # Update the global model parameters by optimizing using the meta grad for SGD
    optimizer.step() 

    # Evaluate the global model (optimized using the meta gradient)
    for k in range(0, nodes): 
        DF[k], Val_AUC[k] = evaluate_modelsT(k, Loaders, NETS[k], TL, loss_fn, device, DF[k], Val_AUC, models_dir)
        print(f'Node : {k:.1f} || Val AUC {Val_AUC[k]:.4f}')
        

In [None]:
# Evaluating client-side personalised models 
for i in range(0, nodes): 
    # global_weights = global_model.load_state_dict(model)
    model = torch.load(f'./trained_models/{models_dir}/node{i}')
    # global_model.load_state_dict(model)

    _, test_auc = prediction_binaryT(model, Loaders[i][2], loss_fn,device) 
    DF[i].to_csv(f'./Dynamic/{models_dir}/node{i}.csv') 
    
    print(f'Node : {i:.1f} || Val AUC {Val_AUC[i]:.4f} || Test AUC {test_auc:.4f}') 