# Part 2 - WML Federated Learning with MNIST for Party 

### Learning Goals

When you complete the Part 2 - WML Federated Learning with MNIST for Party, you should know how to:

- Load the data that you intend to use in the Federated Learning experiment.
- Install IBM Federated Learning libraries.
- Define a data handler. For more details on data handlers, see <a href = "https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fl-cus-dh.html?audience=wdp&context=cpdaas" target="_blank" rel="noopener no referrer">Customizing the data handler</a>.
- Configure the party to train data with the aggregator.

<div class="alert alert-block alert-info">This notebook is intended to be run by the administrator or connecting party of the Federated Learning experiment.
</div>

## Table of Contents

- [1. Input Variables](#input-vars)<br>
- [2. Download the Data](#download-data)<br>
- [3. Install Federated Learning libraries](#install)<br>
    - [3.1 Install the IBM WML SDK](#install-sdk)
    - [3.2 Import IBM WML client](#import-sdk)
- [4. Define the Data Handler](#data-handler)<br>
- [5. Configure the party](#config)<br>
- [6. Train with Federated Learning](#train)<br>
    - [6.1 Create the Party](#create-party)
    - [6.2 Connect to the Aggregator](#connect)

<div class="alert alert-block alert-warning">Before you run this notebook, you must have already run <a href = "https://dataplatform.cloud.ibm.com/exchange/public/entry/view/029d77a73d72a4134c81383d6f020f6f?context=cpdaas">Part 1 - WML Federated Learning with MNIST for Admin</a>). If you have not, open the notebook and run through that notebook first.
</div>

In [None]:
import psutil

mem_recommended = 4
mem_total = round(psutil.virtual_memory().total / 1073741824, 2)

print("System has " + format(mem_total) + "GB of memory.")
if mem_total < mem_recommended:
	print("WARNING: Running this notebook with less than " + format(mem_recommended) + "GB of memory may cause unexpected errors.")

<a id = "input-vars"></a>
## 1. Input Variables

Paste in the ID credentials you got from the end of the Part 1 notebook. If you have not run through Part 1, open the notebook and run through it first.

In [None]:
CP4D_HOST =
WS_USER =
WS_PASSWORD =
PROJECT_ID =
RTS_ID =
TRAINING_ID =

<a id = "download-data"></a>
## 2. Download MNIST handwritten digits dataset

As the party, you must provide the dataset that you will use to train the Federated Learning model. In this tutorial, a dataset is provided by default, the MNIST handwritten digits dataset.

In [None]:
import requests

dataset_resp = requests.get("https://api.dataplatform.cloud.ibm.com/v2/gallery-assets/entries/903188bb984a30f38bb889102a1baae5/data",
                            allow_redirects=True)

f = open('MNIST-pkl.zip', 'wb')
f.write(dataset_resp.content)
f.close()

In [None]:
import zipfile

with zipfile.ZipFile("MNIST-pkl.zip","r") as file:
    file.extractall()
    
!ls -lh

<a id = "install"></a>
## 3. Install Federated Learning libraries

In this section, we will install the necessary libraries and other packages to call for Federated Learning with the Python client.

<a id = "install-sdk"></a>
### 3.1 Install the IBM WML SDK with FL

This installs the IBM Watson Machine Learning CLI along with the whole software development package with Federated Learning.

In [None]:
import sys
!{sys.executable} -m pip install --upgrade 'ibm-watsonx-ai[fl-rt24.1-py3.11]'

<a id = "import-sdk"></a>
### 3.2 Import the IBM Watson Machine Learning client

The following code imports the APIClient for the party, and ensures that it is loaded.

In [None]:
from ibm_watsonx_ai import APIClient

wml_credentials = {
        "username": WS_USER,
        "password": WS_PASSWORD,
        "instance_id" : "openshift",
        "url": "https://" + CP4D_HOST,
        "version": "5.0"
}

wml_client = APIClient(wml_credentials)
wml_client.set.default_project(PROJECT_ID)

<a id = "data-handler"></a>
## 4. Define a Data Handler

The party should run a data handler to ensure that their datasets are in compatible format and consistent. In this tutorial, an example data handler for the MNIST dataset is provided. 

For more details on data handlers, see <a href = "https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fl-cus-dh.html?audience=wdp&context=cpdaas" target="_blank" rel="noopener no referrer">Customizing the data handler</a>.



In [None]:
import logging
import pickle
import numpy as np

from ibm_watsonx_ai.federated_learning.data_handler import DataHandler

logger = logging.getLogger("DataHandler")


class MnistTFDataHandler(DataHandler):
    """
    Data handler for MNIST dataset.
    """

    def __init__(self, data_config=None, channels_first=False):
        super().__init__()
        self.file_name = None
        if data_config is not None:
            if 'train_file' in data_config:
                self.train_file_name = data_config['train_file']
            if 'test_file' in data_config:
                self.test_file_name = data_config['test_file']
        try:
            logger.info('Loaded training data from {}'.format(self.train_file_name))
            with open(self.train_file_name, 'rb') as f:
                (self.x_train, self.y_train)= pickle.load(f)
            logger.info('Loaded test data from {}'.format(self.test_file_name))
            with open(self.test_file_name, 'rb') as f:
                (self.x_test, self.y_test)= pickle.load(f)
            
            logger.info('Loaded {} train samples'.format(self.x_train.shape[0]))
            logger.info('Loaded {} test samples'.format(self.x_test.shape[0]))
            
            # Subset for limited memory
            nb_points = 500
            self.x_train = self.x_train[:nb_points] / 255.0
            self.y_train = self.y_train[:nb_points]
            self.x_test = self.x_test[:nb_points] / 255.0
            self.y_test = self.y_test[:nb_points]

            # Add a channels dimension
            import tensorflow as tf
            self.x_train = self.x_train[..., tf.newaxis]
            self.x_test = self.x_test[..., tf.newaxis]

            logger.info('Using {} train samples'.format(self.x_train.shape[0]))
            logger.info('Using {} test samples'.format(self.x_test.shape[0]))


        except Exception:
            raise IOError('Unable to load training data from path '
                            'provided in config file: ' +
                            self.train_file_name)

    def get_data(self, nb_points=500):
        """
        Gets pre-process mnist training and testing data. Because this method
        is for testing it takes as input the number of datapoints, nb_points,
        to be included in the training and testing set.

        :param: nb_points: Number of data points to be included in each set
        :type nb_points: `int`
        :return: training data
        :rtype: `tuple`
        """

        # This example returns the same data set each time and ignores
        # the nb_points parameter.  

        logger.info('Training on {} samples'.format(self.x_train.shape[0]))

        return (self.x_train, self.y_train), (self.x_test, self.y_test)

The party can test the data handler before training.

In [None]:
dh = MnistTFDataHandler(data_config = { "train_file": "./mnist-keras-train.pkl", "test_file": "./mnist-keras-test.pkl" })
((x_train,y_train),(x_test,y_test)) = dh.get_data() 
print('x_train shape: ',x_train.shape)
print('x_train[0]:',x_train[0])
print('y_train shape:',y_train.shape)
print('y_train[0]:',y_train[0])
dh = x_train = y_train = x_test = y_test = None


<a id = "config"></a>
## 5. Configure the party

Each party must run their party configuration file to call out to the aggregator. Here is an example of a party configuration.

Because you had already defined the training ID, RTS ID and data handler in the previous sections of this notebook, and the local training and protocol handler are all defined by the SDK, you will only need to define the information for the dataset file under `wml_client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER`. 

In this tutorial, the data path is already defined as we have loaded the examplar MNIST dataset from previous sections.

In [None]:


party_config = {
    wml_client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: {
        "info": {
            "train_file": "./mnist-keras-train.pkl",
            "test_file": "./mnist-keras-test.pkl"
        },
        "class": MnistTFDataHandler,
    }
}



<a id = "train"></a>
## 6. Connect and train with Federated Learning

Here you can finally connect to the aggregator to begin training.

<a id = "create-party"></a>
### 6.1 Create the party 

In [None]:
party = wml_client.remote_training_systems.create_party(RTS_ID, party_config)
party.monitor_logs()

<a id = "connect"></a>
### 6.2 Connect to the aggregator and start training

In [None]:
party.run(aggregator_id=TRAINING_ID, asynchronous=False, verify=False)

<a id = "summary"></a>
## Summary

Congratulations! You have learned to:

1. Start a Federated Learning experiment
2. Load a template model
3. Create an RTS and launch the experiment job
4. Load a dataset for training
5. Define the data handler
6. Configure the party
7. Connect to the aggregator
8. Train your Federated Learning model

### Learn more

- For more details about setting up Federated Learning, terminology, and running Federated Learning from the UI, see <a href = "https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fed-lea.html?audience=wdp" target="_blank" rel="noopener no referrer">Federated Learning documentation</a> for Cloud.
- For more information on a Keras model template, see their documentation <a href = "https://www.tensorflow.org/tutorials/quickstart/advanced" target="_blank" rel="noopener no referrer">here</a>.

# <hr>
Copyright © 2020-2024 IBM. This notebook and its source code are released under the terms of the MIT License.
<br>

<div style="background:#F5F7FA; height:110px; padding: 2em; font-size:14px;">
<span style="font-size:18px;color:#152935;">Love this notebook? </span>
<span style="font-size:15px;color:#152935;float:right;margin-right:40px;">Don't have an account yet?</span><br>
<span style="color:#5A6872;">Share it with your colleagues and help them discover the power of Watson Studio!</span>
<span style="border: 1px solid #3d70b2;padding:8px;float:right;margin-right:40px; color:#3d70b2;"><a href="https://ibm.co/wsnotebooks" target="_blank" style="color: #3d70b2;text-decoration: none;">Sign Up</a></span><br>
</div>