<a href="https://colab.research.google.com/github/Priyanka-Sachan/Complaint-Identification-using-FL/blob/master/Code_with_FL_N_IID.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf

device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

In [None]:
!pip install transformers
!pip install sentencepiece
!pip install datasets

In [3]:
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
from fastai.text import *
from google.colab import files

import transformers
transformers.logging.set_verbosity_error()
from transformers import AdamW
from transformers import XLNetModel, XLNetTokenizer, XLNetForSequenceClassification
from transformers import get_scheduler
from datasets import load_metric

from tqdm.auto import tqdm
import math
import pandas as pd
import numpy as np

In [4]:
# For Reproducibility
SEED=9
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [5]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
torch.cuda.get_device_name(0)

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [8]:
def create_dataloader(input_ids,masks,labels):

  input_ids=torch.tensor(input_ids)
  masks=torch.tensor(masks)
  labels=torch.tensor(labels)

  data = TensorDataset(input_ids,masks,labels)
  sampler = SequentialSampler(data)
  dataloader = DataLoader(data, sampler=sampler, batch_size=16) 

  return dataloader

In [9]:
# To train main model intially with amazon reviews
def train_main_model(model,optimizer):

  path = untar_data(URLs.AMAZON_REVIEWS, dest = "Data")
  df = pd.read_csv(path/'train.csv', header=None, names=['rating', 'title', 'review'],  nrows=10000)

  label=[]
  for rating in df.rating.values:
    if(rating>2):
      label.append(0)
    else:
      label.append(1)
  df['label']=label

  sentences=df.review.values
  tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased', do_lower_case=True)
  MAX_LEN=49
  tokens= [tokenizer(sentence, padding='max_length',truncation="only_first", max_length=MAX_LEN) for sentence in sentences]

  input_ids=np.asarray([np.asarray(token['input_ids']) for token in tokens])
  attention_masks=np.asarray([np.asarray(token['attention_mask']) for token in tokens])
  labels=df.label.values
 
  train_inputs, valid_inputs, train_masks, valid_masks,train_labels, valid_labels = train_test_split(input_ids,attention_masks, labels,random_state=42, test_size=0.2)

  train_dataloader=create_dataloader(train_inputs,train_masks,train_labels)
  validation_dataloader=create_dataloader(valid_inputs,valid_masks,valid_labels)

  train_accuracy,train_loss,valid_accuracy,valid_loss=train_and_validate_model(model,optimizer,train_dataloader,validation_dataloader)
 
  torch.save(model.state_dict(), F"/content/gdrive/My Drive/Innovation_Lab/main_model.pt" )
  torch.save(optimizer.state_dict(), F"/content/gdrive/My Drive/Innovation_Lab/main_optimizer.pt" )

In [10]:
# To extract features from complaints data
def get_features():

  uploaded = files.upload()
  df = pd.read_csv("complaints-data.csv", header=None, names=['id', 'tweet', 'y', 'industry'])

  industry_code=[]
  for industry in df.industry.values:
    if industry=='apparel':
      industry_code.append(0)
    elif industry=='cars':
      industry_code.append(1)
    elif industry=='electronics':
      industry_code.append(2)
    elif industry=='food':
      industry_code.append(3)
    elif industry=='retail':
      industry_code.append(4)
    elif industry=='services':
      industry_code.append(5)
    elif industry=='software':
      industry_code.append(6)
    elif industry=='transport':
      industry_code.append(7)
    elif industry=='other':
      industry_code.append(8)
    else :
      industry_code.append(9)
  df['industry_code']=industry_code

  sentences=df.tweet.values
  tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased', do_lower_case=True)
  MAX_LEN=49
  tokens= [tokenizer(sentence, padding='max_length',truncation="only_first", max_length=MAX_LEN) for sentence in sentences]

  input_ids=np.asarray([np.asarray(token['input_ids']) for token in tokens])
  attention_masks=np.asarray([np.asarray(token['attention_mask']) for token in tokens])
  labels=df.y.values
  industry_code=df.industry_code.values

  return input_ids,attention_masks,labels,industry_code

In [11]:
# Shuffling and dividing the data into 9 parts based on industry
def split_and_shuffle_samples(industry_codes, seed):

    samples=pd.DataFrame(industry_codes,columns=["samples"])
    samples["i"]=np.arange(len(samples))
    sample_dict = dict()

    for i in range(9):
        var_name="sample" + str(i)
        sample_info=samples[samples["samples"]==i]
        np.random.seed(seed)
        sample_info=np.random.permutation(sample_info)
        sample_info=pd.DataFrame(sample_info, columns=["samples","i"])
        sample_dict.update({var_name: sample_info })
        
    return sample_dict

In [12]:
# Distributes input ids, attention masks and labels to nodes in dictionary
def create_niid_subsamples(sample_dict, input_ids,attention_masks,labels):
  
    input_ids_dict= dict()
    attention_masks_dict=dict()
    labels_dict= dict()
    
    for i in range(len(sample_dict)):  

        input_id_name= "input_id"+str(i)
        attention_mask_name="attention_mask"+str(i)
        label_name= "label"+str(i)
        sample_name="sample"+str(i)
        
        indices=np.sort(np.array(sample_dict[sample_name]["i"]))
        
        input_ids_info= input_ids[indices]
        input_ids_dict.update({input_id_name : input_ids_info})

        attention_masks_info= attention_masks[indices]
        attention_masks_dict.update({attention_mask_name : attention_masks_info})
        
        labels_info= labels[indices]
        labels_dict.update({label_name : labels_info})
        
    return input_ids_dict,attention_masks_dict, labels_dict

In [13]:
def train_valid_test_split(input_ids_dict, attention_masks_dict, labels_dict):
    
    for i in range(number_of_samples):

        input_id_name= "input_id"+str(i)
        attention_mask_name="attention_mask"+str(i)
        label_name= "label"+str(i)

        input_ids=input_ids_dict[input_id_name]
        attention_masks=attention_masks_dict[attention_mask_name]
        labels=labels_dict[label_name]

        train_inputs, val_test_inputs, train_masks, val_test_masks,train_labels, val_test_labels = train_test_split(input_ids,attention_masks, labels,random_state=42, test_size=0.2)
        validation_inputs, test_inputs,validation_masks, test_masks, validation_labels, test_labels = train_test_split(val_test_inputs,val_test_masks, val_test_labels,random_state=42, test_size=0.5)
        
        input_ids_info=dict()
        input_ids_info.update({'all':input_ids, 'train':train_inputs,'valid':validation_inputs,'test':test_inputs})
        input_ids_dict.update({input_id_name : input_ids_info})

        attention_masks_info= dict()
        attention_masks_info.update({'all':attention_masks,'train':train_masks,'valid':validation_masks,'test':test_masks})
        attention_masks_dict.update({attention_mask_name : attention_masks_info})
        
        labels_info= dict()
        labels_info.update({'all':labels,'train':train_labels,'valid':validation_labels,'test':test_labels})
        labels_dict.update({label_name : labels_info})

    return input_ids_dict, attention_masks_dict, labels_dict

In [14]:
# To train and validate local models
def train_and_validate_model(model,optimizer,train_dataloader,validation_dataloader): 
 
    model.cuda()

    num_epochs = 4
    num_training_steps = num_epochs * len(train_dataloader)
    lr_scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps
    )
    early_stopping = EarlyStopping(patience=3, verbose=False)
 
    train_accuracy_metric=load_metric("accuracy")
    valid_accuracy_metric=load_metric("accuracy")
 
    train_loss,valid_loss=0,0
    pr_train_loss,pr_valid_loss=0,0
 
    progress_bar = tqdm(range(num_training_steps))
    for epoch in range(num_epochs):

      train_losses = []
      valid_losses = []
 
      model.train()
      for batch in train_dataloader:
          batch = tuple(t.to(device) for t in batch)
          b_input_ids, b_input_mask, b_labels = batch
          b_input_ids.cuda()
          b_input_mask.cuda()
          b_labels.cuda()
          outputs = model(input_ids=b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
          loss = outputs.loss
          logits = outputs.logits
          loss.backward()
          optimizer.step()
          lr_scheduler.step()
 
          train_losses.append(loss.item())
          predictions = torch.argmax(logits, dim=-1)
          train_accuracy_metric.add_batch(predictions=predictions,references=b_labels)
 
          optimizer.zero_grad()
          progress_bar.update(1)
 
      model.eval()
      for batch in validation_dataloader:
          batch = tuple(t.to(device) for t in batch)
          b_input_ids, b_input_mask, b_labels = batch
          b_input_ids.cuda()
          b_input_mask.cuda()
          b_labels.cuda()
          with torch.no_grad():
              outputs = model(input_ids=b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
          loss = outputs.loss
          logits = outputs.logits
 
          valid_losses.append(loss.item())
          predictions = torch.argmax(logits, dim=-1)
          valid_accuracy_metric.add_batch(predictions=predictions,references=b_labels)
 
      pr_train_loss=train_loss
      train_loss = np.average(train_losses)
      pr_valid_loss=valid_loss
      valid_loss = np.average(valid_losses)
 
      early_stopping(valid_loss, model)
          
      if early_stopping.early_stop:
          print("Early stopping")
          valid_loss=pr_valid_loss
          train_loss=pr_train_loss
          break

      train_accuracy=train_accuracy_metric.compute()['accuracy']
      valid_accuracy=valid_accuracy_metric.compute()['accuracy']

      print("EPOCH: {}".format(epoch+1),
            "| Train accuracy: {:7.5f}".format(train_accuracy),
            "| Train loss: {:7.5f}".format(train_loss),
            "| Validation accuracy: {:7.5f}".format(valid_accuracy),
            "| Validation loss: {:7.5f}".format(valid_loss))

    model.cpu()
    torch.cuda.empty_cache() 
 
    return train_accuracy,train_loss,valid_accuracy,valid_loss

In [15]:
# Evaluate models
def test_model(model,test_dataloader):

    model.cuda()

    test_losses=[]
    test_accuracy_metric=load_metric("accuracy")
    test_precision_metric=load_metric("precision")
    test_recall_metric=load_metric("recall")
    test_f1_metric=load_metric("f1")
    
    model.eval()
    for batch in test_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        with torch.no_grad():
            outputs = model(input_ids=b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
        loss = outputs.loss
        logits = outputs.logits
 
        test_losses.append(loss.item())
        predictions = torch.argmax(logits, dim=-1)
        test_accuracy_metric.add_batch(predictions=predictions,references=b_labels)
        test_precision_metric.add_batch(predictions=predictions,references=b_labels)
        test_recall_metric.add_batch(predictions=predictions,references=b_labels)
        test_f1_metric.add_batch(predictions=predictions,references=b_labels)
 
    test_loss = np.average(test_losses)
    test_accuracy=test_accuracy_metric.compute()['accuracy']
    test_precision=test_precision_metric.compute()['precision']
    test_recall=test_recall_metric.compute()['recall']
    test_f1=test_f1_metric.compute()['f1']

    model.cpu()
 
    return test_accuracy,test_loss,test_precision,test_recall,test_f1

In [16]:
# Creates a model and optimizer for each node
def create_model_optimizer_dict(number_of_samples):

    model_dict = dict()
    optimizer_dict= dict()
    
    for i in range(number_of_samples):

        model_name="model"+str(i)
        model_info=XLNetForSequenceClassification.from_pretrained(pretrained_model_name_or_path="xlnet-base-cased", num_labels=2)
        model_dict.update({model_name : model_info })

        optimizer_name="optimizer"+str(i)
        optimizer_info = AdamW(model_dict[model_name].parameters(),lr=1e-5)
        optimizer_dict.update({optimizer_name : optimizer_info })
        
    return model_dict, optimizer_dict 

In [17]:
# Load all trained local models
def load_all_models(number_of_samples):

    model_dict = dict()
    optimizer_dict= dict()
    
    for i in range(number_of_samples):

        model=XLNetForSequenceClassification.from_pretrained(pretrained_model_name_or_path="xlnet-base-cased", num_labels=2)
        optimizer=AdamW(model.parameters(),lr=1e-5)

        model_name="model"+str(i)
        model.load_state_dict(torch.load( F"/content/gdrive/My Drive/Innovation_Lab/NIID/client"+str(i)+".pt"))
        model_dict.update({model_name : model })

        optimizer_name="optimizer"+str(i)
        optimizer_dict.update({optimizer_name : optimizer })
        
    return model_dict, optimizer_dict 

In [18]:
# Trains individual local models in nodes
def start_train_end_node_process(clients_to_be_trained):
  
    for i in clients_to_be_trained: 

        model=model_dict[name_of_models[i]]
        optimizer=optimizer_dict[name_of_optimizers[i]]

        input_ids=input_ids_dict[name_of_input_ids_sets[i]]
        attention_masks=attention_masks_dict[name_of_attention_masks_sets[i]]
        labels=labels_dict[name_of_labels_sets[i]]

        train_dataloader=create_dataloader(input_ids['train'],attention_masks['train'],labels['train'])
        validation_dataloader=create_dataloader(input_ids['valid'],attention_masks['valid'],labels['valid'])
        test_dataloader=create_dataloader(input_ids['test'],attention_masks['test'],labels['test'])

        train_accuracy,train_loss,valid_accuracy,valid_loss=train_and_validate_model(model,optimizer,train_dataloader,validation_dataloader)
        
        print("CLIENT: {}".format(i+1) + 
              " | Train accuracy: {:7.5f}".format(train_accuracy)+ 
              " | Train loss: {:7.5f}".format(train_loss) +
              " | Validation accuracy: {:7.5f}".format(valid_accuracy)+
              " | Validation loss: {:7.5f}".format(valid_loss))
        
        test_accuracy,test_loss,test_precision,test_recall,test_f1=test_model(model,test_dataloader)

        print("CLIENT: {}".format(i+1) + 
              " | Test accuracy: {:7.5f}".format(test_accuracy)+ 
              " | Test loss: {:7.5f}".format(test_loss) +
              " | Test precision: {:7.5f}".format(test_precision)+
              " | Test recall: {:7.5f}".format(test_recall)+
              " | Test f1: {:7.5f}".format(test_f1))
        
        torch.save(model.state_dict(), F"/content/gdrive/My Drive/Innovation_Lab/NIID/client"+str(i)+".pt" )

In [19]:
# Sends the averaged weights of individual nodes to the main model and sets them as the new weights of the main model
def get_averaged_weights_and_update_main_model(model_dict, clients_to_be_trained):

    with torch.no_grad():
        main_model.cuda()

        f=0
        for i in clients_to_be_trained:
            model_dict["model"+str(i)].cuda()

            for client_parameters,main_model_parameters in zip(model_dict["model"+str(i)].parameters(),main_model.parameters()):

                if f==0:
                    main_model_parameters.data=client_parameters.data.clone().detach()
                    main_model_parameters.data/=clients_to_be_trained.__len__()
                else:
                    main_model_parameters.data+=client_parameters.data/clients_to_be_trained.__len__()

            f=1

            model_dict["model"+str(i)].cpu()

        main_model.cpu()

In [20]:
def get_test_data(clients_to_be_tested,group):

  test_input_ids,test_attention_masks,test_labels=[],[],[]

  for i in clients_to_be_tested:

    test_input_ids.extend(input_ids_dict[name_of_input_ids_sets[i]][group])
    test_attention_masks.extend(attention_masks_dict[name_of_attention_masks_sets[i]][group])
    test_labels.extend(labels_dict[name_of_labels_sets[i]][group])

  test_dataloader=create_dataloader(test_input_ids,test_attention_masks,test_labels)

  return test_dataloader

In [22]:
# Compares the accuracy of the main model and the local model running on each node
def compare_local_and_merged_model_performance(clients_to_be_trained,clients_to_be_tested,group):

    test_dataloader=get_test_data(clients_to_be_tested,group)

    accuracy_table=pd.DataFrame(data=np.zeros((clients_to_be_trained.__len__()+1,5)), columns=["Model", "Accuracy", "Precision","Recall","F1"])
    
    accuracy,loss,precision,recall,f1 = test_model(main_model, test_dataloader)
    
    accuracy_table.loc[0, "Model"]="Main Model"
    accuracy_table.loc[0, "Accuracy"] = accuracy
    accuracy_table.loc[0, "Precision"] = precision
    accuracy_table.loc[0, "Recall"] = recall
    accuracy_table.loc[0, "F1"] = f1

    industry=['apparel','cars','electronics','food','retail','services','software','transport','other']

    j=1
    for i in clients_to_be_trained:

        if group=='test':

            client=[]
            client.append(i);
            test_dataloader=get_test_data(client,group)

        model=model_dict[name_of_models[i]]
        optimizer=optimizer_dict[name_of_optimizers[i]]

        accuracy,loss,precision,recall,f1 = test_model(model, test_dataloader)
    
        accuracy_table.loc[j, "Model"]=industry[i]
        accuracy_table.loc[j, "Accuracy"] = accuracy
        accuracy_table.loc[j, "Precision"] = precision
        accuracy_table.loc[j, "Recall"] = recall
        accuracy_table.loc[j, "F1"] = f1

        j=j+1

    return accuracy_table

In [23]:
# To send the merged parameters of the main model to local models
def send_main_model_to_nodes_and_update_model_dict(main_model, model_dict, number_of_samples):

    with torch.no_grad():
        main_model.cuda()

        for i in range(number_of_samples):
            model_dict["model"+str(i)].cuda()

            for client_parameters,main_model_parameters in zip(model_dict["model"+str(i)].parameters(),main_model.parameters()):
                client_parameters.data=main_model_parameters.data.clone().detach()

            model_dict["model"+str(i)].cpu()
            
        main_model.cpu()  

###EXPERIMENTATION

In [None]:
# Load main model - XLNEtForSequenceClassification, the pretrained XLNet model with a single linear classification layer on top. 
main_model = XLNetForSequenceClassification.from_pretrained("xlnet-base-cased", num_labels=2)
main_optimizer = AdamW(params=main_model.parameters(),lr=1e-5)

In [25]:
# Initially, training model for once.
# train_main_model(main_model,main_optimizer)
# Or loading weights from drive
main_model.load_state_dict(torch.load( F"/content/gdrive/My Drive/Innovation_Lab/main_model.pt"))
main_optimizer.load_state_dict(torch.load( F"/content/gdrive/My Drive/Innovation_Lab/main_optimizer.pt"))

In [None]:
# Getting input ids, attention masks, labels and industry codes from complaints
input_ids,attention_masks,labels,industry_codes=get_features()

In [27]:
number_of_samples=9

In [28]:
# Creating clients with train, validation and test dataset
id_dict=split_and_shuffle_samples(industry_codes, seed=1) 
input_ids_dict,attention_masks_dict,labels_dict = create_niid_subsamples(id_dict,input_ids,attention_masks,labels)
input_ids_dict,attention_masks_dict,labels_dict = train_valid_test_split(input_ids_dict, attention_masks_dict, labels_dict)

In [29]:
# Keys of dicts are being made iterable

name_of_input_ids_sets=list(input_ids_dict.keys())
name_of_attention_masks_sets=list(attention_masks_dict.keys())
name_of_labels_sets=list(labels_dict.keys())

###Experiment 1
Train local domain-specific models, fedAvg it, then test main model across all domain and domain-specific tweets.

In [30]:
clients_to_be_trained=[0,1,2,3,4,5,6,7,8]
clients_to_be_tested=[0,1,2,3,4,5,6,7,8]

In [31]:
# Models and optimizers functions in nodes are defined
model_dict, optimizer_dict = create_model_optimizer_dict(number_of_samples)
name_of_models=list(model_dict.keys())
name_of_optimizers=list(optimizer_dict.keys())

In [None]:
# Updating client models
send_main_model_to_nodes_and_update_model_dict(main_model, model_dict, number_of_samples)

In [None]:
# Models in the nodes are trained
start_train_end_node_process(clients_to_be_trained)

In [None]:
# Or load all trained local models
model_dict, optimizer_dict = load_all_models(number_of_samples)
name_of_models=list(model_dict.keys())
name_of_optimizers=list(optimizer_dict.keys())

In [None]:
# Update all client models to the federated average
get_averaged_weights_and_update_main_model(model_dict, clients_to_be_trained)

In [None]:
# Compares the accuracy of the main model and the local model running on each node
compare_local_and_merged_model_performance(clients_to_be_trained,clients_to_be_tested,'test')

In [None]:
# Updating client models
send_main_model_to_nodes_and_update_model_dict(main_model, model_dict, number_of_samples)

RESULT:
```
   Test data        Accuracy   Precision     Recall       F1
0	All	        0.810945	0.827068	0.880000	0.852713
1	Apparel	    0.962963	1.000000	0.933333	0.965517
2	Cars	       0.800000	0.888889	0.888889	0.888889
3	Electronics	0.862069	0.944444	0.850000	0.894737
4	Food	       0.538462	0.875000	0.583333	0.700000
5	Retail	     0.750000	0.764706	0.928571	0.838710
6	Services	   0.794118	0.789474	0.833333	0.810811
7	Software	   0.900000	0.947368	0.900000	0.923077
8	Transport	  0.880000	0.875000	0.777778	0.823529
9	Other	      0.846154	0.800000	1.000000	0.888889
```

###Experiment 2
Train local domain models over some domains ( eg. 1, 2, 3 ), fedAvg it, then test local and main models across all other ( eg. 4,5,6 ) domain tweets.

In [None]:
clients_to_be_trained=[5,6,7,8]
clients_to_be_tested=[0,1,2,3,4]

In [None]:
# Models and optimizers functions in nodes are defined
model_dict, optimizer_dict = create_model_optimizer_dict(number_of_samples)
name_of_models=list(model_dict.keys())
name_of_optimizers=list(optimizer_dict.keys())

In [None]:
# Updating client models
send_main_model_to_nodes_and_update_model_dict(main_model, model_dict, number_of_samples)

In [None]:
# Models in the nodes are trained
start_train_end_node_process(clients_to_be_trained)

In [None]:
# # Or load all trained local models
# model_dict, optimizer_dict = load_all_models(number_of_samples)
# name_of_models=list(model_dict.keys())
# name_of_optimizers=list(optimizer_dict.keys())

In [None]:
# Update all client models to the federated average
get_averaged_weights_and_update_main_model(model_dict, clients_to_be_trained)

In [None]:
# Compares the accuracy of the main model and the local model running on each node
compare_local_and_merged_model_performance(clients_to_be_trained,clients_to_be_tested,'all')

In [None]:
# Updating client models
send_main_model_to_nodes_and_update_model_dict(main_model, model_dict, number_of_samples)