# Model Fusion

Model Fusion is a framework for federated learning in edge environments. Federated Learning (FL) is a distributed machine learning process in which each participant node (or party) retains its data locally and interacts with the other participants via a learning protocol. The main driver behind FL is need to not share data with others, mainly driven by privacy and confidentially concerns.

# Import libraries for calling APIs and data serialization

In [None]:
import requests
import pickle
import json
import os

# Model Fusion Requirements

FL requires that all parties utilize the same model architecture for model fusion. For this example, we will define a Pytorch CNN for all parties to utilize. 

In [None]:
import torch
from torch import nn

model = nn.Sequential(nn.Conv2d(1, 32, 3, 1),
                      nn.ReLU(),
                      nn.Conv2d(32, 64, 3, 1),
                      nn.ReLU(),
                      nn.MaxPool2d(2, 2),
                      nn.Dropout2d(p=0.25),
                      nn.Flatten(),
                      nn.Linear(9216, 128),
                      nn.ReLU(),
                      nn.Dropout2d(p=0.5),
                      nn.Linear(128, 10),
                      nn.LogSoftmax(dim=1))

fname = 'torch_mnist_cnn.pt'
torch.save(model, fname)

# Prepare Parameters for Aggregator 

The Aggregator is in charge of running the Fusion Algorithm. A fusion algorithm queries the registered parties to carry out the federated learning process. The queries sent vary according to the model/algorithm type.  In return, parties send their reply as a model update object, and these model updates are then aggregated according to the specified Fusion Algorithm, specified via a Fusion Handler class. 

## Config components
The follow are a list of parameters to set the Aggregator config:  
        
- **fusion_algorithm**: The name of the fusion algorithm for the federation learning process. Options includes 'fedavg', 'iter_avg', and 'doc2vec' 
  
  
- **model_type**: The type of model used for fusion. Options includes 'keras', 'pytorch', and 'doc2vec'
 
 
- **model_file**: The saved initial model to distribute to parties to train in isolation
  
  
- **num_parties** : The number of nodes participating in the fusion


- **rounds** : The number of fusion rounds to complete


- **epochs** : The number of epochs to to train for each fusion round


- **learning_rate** : The learnig rate for the parties to use for training


- **optimizer** : The name of the optimizer used for training (not applicable for doc2vec). Should be the name used by the keras or pytorch libraries (ex: optim.Adam for pytorch) 

In [None]:
fusion = 'fedavg'
model_type = 'pytorch'
model_file = 'torch_mnist_cnn.pt'
num_parties = 2
rounds = 5
epochs = 3
lr = 1
optimizer = 'optim.Adadelta'

# Start the aggregator 

With the aggregator config file defined, we can start the aggregator from the edgeai_model_fusion service. 
If successful, the service will return an ID for the parties to use to register with the aggregator.

In [None]:
url = 'URL/start-aggregator' # your url may differ. change as neccessary

payload = {'fusion_algorithm': fusion, 'model_type': model_type, 'num_parties': num_parties, 'rounds': rounds, 'learning_rate': lr, 'epochs': epochs, 'optimizer': optimizer}

files = {'model_file': open(model_file, 'rb'), }


r = requests.post(url,  data=payload, files=files, verify=False)

In [None]:
r.text

In [None]:
agg_id = r.json()['aggregator_id']

# Prepare local data for each party 

Since each party retains its own dataset to train a model in insolation, we will obtain a subset of the MNIST dataset for each party and split uniformly

In [None]:
def print_statistics(i, x_test_pi, x_train_pi, nb_labels, y_train_pi):
    print('Party_', i)
    print('nb_x_train: ', np.shape(x_train_pi),
          'nb_x_test: ', np.shape(x_test_pi))
    for l in range(nb_labels):
        print('* Label ', l, ' samples: ', (y_train_pi == l).sum())

In [None]:
import numpy as np
from keras import utils
from keras.datasets import mnist

img_rows, img_cols = 28, 28
nb_dp_per_party = [500,500]

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32').reshape(x_train.shape[0], 1, img_rows, img_cols) / 255.
y_train = y_train.astype('int64')
x_test = x_test.astype('float32').reshape(x_test.shape[0], 1, img_rows, img_cols) / 255.
y_test = y_test.astype('int64')

labels, train_counts = np.unique(y_train, return_counts=True)
te_labels, test_counts = np.unique(y_test, return_counts=True)
    
if np.all(np.isin(labels, te_labels)):
    print("Warning: test set and train set contain different labels")

num_train = np.shape(y_train)[0]
num_test = np.shape(y_test)[0]
num_labels = np.shape(np.unique(y_test))[0]
nb_parties = len(nb_dp_per_party)


train_probs = {label: 1.0 / len(labels) for label in labels}
test_probs = {label: 1.0 / len(te_labels) for label in te_labels}

for idx, dp in enumerate(nb_dp_per_party):
    train_p = np.array([train_probs[y_train[idx]] for idx in range(num_train)])
    train_p /= np.sum(train_p)
    train_indices = np.random.choice(num_train, dp, p=train_p)
    test_p = np.array([test_probs[y_test[idx]] for idx in range(num_test)])
    test_p /= np.sum(test_p)

    # Split test evenly
    test_indices = np.random.choice(num_test, int(num_test / nb_parties), p=test_p)
    x_train_pi = x_train[train_indices]
    y_train_pi = y_train[train_indices]
    x_test_pi = x_test[test_indices]
    y_test_pi = y_test[test_indices]
    
    # Now put it all in an npz
    name_file = 'data_party' + str(idx) + '.npz'
    print(name_file)
    np.savez(name_file, x_train=x_train_pi, y_train=y_train_pi,
             x_test=x_test_pi, y_test=y_test_pi)
 
    print_statistics(idx, x_test_pi, x_train_pi, num_labels, y_train_pi)

# Start Party Nodes

We will start each party to join the federation. The following are parameters required to register with an aggregator. 

- **aggregator_id** : The ID of the aggregator to connect to for fusion 


- **data** : The path to the data file to train with 


- **data_handler_class_name** : The name of the data handler to train with. Model fusion requires a data_handler to preprocess data. The apis includes handlers for popular datasets such as 'cifar10_keras' and 'mnist_keras' for keras model types; 'mnist_pytorch' for pytorch model types; and '20_newsgroup' and 'wikipedia' for doc2vec. If a custom datahandler is provided, the name of the class should be listed. 


- **custom_data_handler** (optional): The path to the user created data handler python module for training. For information on how to create a customer data handler, see [here.](https://w3.ibm.com/w3publisher/ffl/ffl-tutorials/prepare-your-data) 

If successful, the service will notify how many remaining parties are left to register with the service.

In [None]:
data_handler = 'mnist_pytorch'

url = 'URL/start-party'

payload = {'aggregator_id': agg_id, 'data_handler_class_name': data_handler}

files = {'data': open('data_party0.npz', 'rb'), }
                      
r = requests.post(url,  data=payload, files=files, verify=False)  

In [None]:
r.text

In [None]:
url = 'URL/start-party'

payload = {'aggregator_id': agg_id, 'data_handler_class_name': 'mnist_pytorch'}

files = {'data': open('data_party1.npz', 'rb'), }
                      
r = requests.post(url,  data=payload, files=files, verify=False)  

In [None]:
r.text

# Initiate Federated Training (Model Fusion)

After both Parties register successfully with the aggregator, the federated learning process can begin. We will issue a **train** command to the model_fusion service to initiate the training. 

The aggregator_id parameter is required to initiate the correct aggregator

Upon successful training, the service should return the model weights of the global model acquired through fusion

In [None]:
url = 'URL/train'

payload = {'aggregator_id': agg_id}
                      
r = requests.post(url,  json=payload, verify=False) 

In [None]:
r

# Get global model (parameters) from model training process that allow parties to reconstruct the global model

In [None]:
global_model = r.json()['global_model'] # can be saved and provided to pytorch model as weights

# End Federation 

To end the federation, a **stop** command should be issued to the aggregator with the corresponding ID. 

In [None]:
url = 'URL/stop'

payload = {'aggregator_id': agg_id}
                      
r = requests.post(url,  json=payload, verify=False)

In [None]:
r.text