# Federated Learning basic pipeline using the FL API from the Azure ML SDK

This notebook:
1. reads a config file in yaml specifying the number of silos and their parameters,
2. reads the components from a given folder,
3. uses the Scatter-Gather API to build a flexible pipeline depending on the config, and to configure each step to read/write from the right storage account.

## General imports

Here, we import all the packages we'll need, except for the Azure packages. Further details below.

In [None]:
import os
import argparse
import random
import string
import datetime
import webbrowser
import time
import json
import sys

# to handle yaml config easily
from omegaconf import OmegaConf

## Activate Private Preview features

__This needs to happen *before* importing the Azure ML SDK.__

In [None]:
os.environ["AZURE_ML_CLI_PRIVATE_FEATURES_ENABLED"] = "True"

## Azure Imports

Now we can import the AzureML SDK, **after** the environment variable above has been set up.

In [None]:
# generic Azure ML sdk v2 imports
import azure
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
from azure.ai.ml import MLClient, Input, Output
from azure.ai.ml.constants import AssetTypes
from azure.ai.ml.dsl import pipeline
from azure.ai.ml import load_component

# FL-specific Azure ML sdk v2 imports
import azure.ai.ml.dsl._fl_scatter_gather_node as fl
from azure.ai.ml.entities._assets.federated_learning_silo import FederatedLearningSilo

## Configure the notebook

This part is for reading the parameters from a config file, and for stating where we will look for the components. This is fairly standard in our FL examples and has nothing to do with the Scatter-Gather API.

In [None]:
# choose the example to run (only 'MNIST' is supported in this sample)
example = "MNIST"

# load the config from a local yaml file
YAML_CONFIG = OmegaConf.load("./config.yaml")

# path to the components
COMPONENTS_FOLDER = os.path.join(
    ".", "..", "..", "components", example
)

# path to the shared components
SHARED_COMPONENTS_FOLDER = os.path.join(
    ".", "..", "..", "components", "utils"
)

## Connect to Azure ML

This part is for connecting to the AzureML workspace. This is fairly standard in our FL examples and has nothing to do with the Scatter-Gather API.

In [None]:
def connect_to_aml():
    try:
        credential = DefaultAzureCredential()
        # Check if given credential can get token successfully.
        credential.get_token("https://management.azure.com/.default")
    except Exception as ex:
        # Fall back to InteractiveBrowserCredential in case DefaultAzureCredential not work
        credential = InteractiveBrowserCredential()

    # Get a handle to workspace
    try:
        # tries to connect using cli args if provided else using config.yaml
        ML_CLIENT = MLClient(
            subscription_id=YAML_CONFIG.aml.subscription_id,
            resource_group_name=YAML_CONFIG.aml.resource_group_name,
            workspace_name=YAML_CONFIG.aml.workspace_name,
            credential=credential,
        )

    except Exception as ex:
        print("Could not find either cli args or config.yaml.")
        # tries to connect using local config.json
        ML_CLIENT = MLClient.from_config(credential=credential)

    return ML_CLIENT

print("Connecting to the Azure ML workspace...")
ML_CLIENT = connect_to_aml()
print("Connected!")

## Load the pipeline components

This part is for loading the components we will be using in the pipeline. This is fairly standard in our FL examples and has nothing to do with the Scatter-Gather API.

In [None]:
# Loading the component from their yaml specifications
preprocessing_component = load_component(
    source=os.path.join(COMPONENTS_FOLDER, "preprocessing", "spec.yaml")
)

training_component = load_component(
    source=os.path.join(COMPONENTS_FOLDER, "traininsilo", "spec.yaml")
)

aggregate_component = load_component(
    source=os.path.join(SHARED_COMPONENTS_FOLDER, "aggregatemodelweights_mltable", "spec.yaml")
)

## Create the silos

Here we create a list containing the information for all silos, which will later on be passed to the Scatter-Gather API. We read the values from a config file.

Note that using a config file is NOT mandatory. We usually find it more convenient to put all parameters in one file, but if you prefer you can also just create the list of silos directly in the notebook.

In [None]:
silo_list = [
    FederatedLearningSilo(
        compute=silo_config["computes"][0],
        datastore=silo_config["datastore"],
        inputs= {
            "silo_name": silo_config["name"],
            "raw_train_data": Input(**dict(silo_config["inputs"])["training_data"]),
            "raw_test_data": Input(**dict(silo_config["inputs"])["testing_data"]),
        },
    )
    for silo_config in YAML_CONFIG.strategy.horizontal
]
print("Silo list created")

## Create arguments mappings

Create mappings for arguments - you should not have to modify anything in this cell.

In [None]:
silo_to_aggregation_argument_map = {"model" : "from_silo_input"}
aggregation_to_silo_argument_map = {"aggregated_output" : "checkpoint"}

## Create kwargs inputs mappings

Create mappings for inputs - you should not have to modify anything in this cell.

In [None]:
silo_kwargs = dict(YAML_CONFIG.silo_training_parameters)
agg_kwargs = {}

## Create silo subgraph

Here we build the subgraph defining what job(s) will happen in each silo. This will be passed to the Scatter-Gather API later on.

In [None]:
@pipeline(
    name="Silo Federated Learning Subgraph",
    description="It includes all steps that needs to be executing in silo",
)
def silo_scatter_subgraph(
    # user defined inputs
    raw_train_data: Input,
    raw_test_data: Input,
    checkpoint: Input(optional=True),
    silo_name: str,
    # user defined training arguments
    lr: float = 0.01,
    epochs: int = 3,
    batch_size: int = 64,
    dp: bool = False,
    dp_target_epsilon: float = 50.0,
    dp_target_delta: float = 1e-5,
    dp_max_grad_norm: float = 1.0,
) -> dict:
    """Create scatter/silo subgraph.

    Args:
        raw_train_data (Input): raw train data
        raw_test_data (Input): raw test data
        checkpoint (Input): if not None, the checkpoint obtained from previous iteration
        silo_name (str): name of the silo
        lr (float, optional): Learning rate. Defaults to 0.01.
        epochs (int, optional): Number of epochs. Defaults to 3.
        batch_size (int, optional): Batch size. Defaults to 64.
        dp (bool, optional): Differential Privacy
        dp_target_epsilon (float, optional): DP target epsilon
        dp_target_delta (float, optional): DP target delta
        dp_max_grad_norm (float, optional): DP max gradient norm

    Returns:
        Dict[str, Outputs]: a map of the outputs
    """
    # we're using our own preprocessing component
    silo_pre_processing_step = preprocessing_component(
        # this consumes whatever user defined inputs
        raw_training_data=raw_train_data,
        raw_testing_data=raw_test_data,
        # here we're using the name of the silo compute as a metrics prefix
        metrics_prefix=silo_name,
    )

    # we're using our own training component
    silo_training_step = training_component(
        # with the train_data from the pre_processing step
        train_data=silo_pre_processing_step.outputs.processed_train_data,
        # with the test_data from the pre_processing step
        test_data=silo_pre_processing_step.outputs.processed_test_data,
        # and the checkpoint from previous iteration (or None if iteration == 1)
        checkpoint=checkpoint,
        # Learning rate for local training
        lr=lr,
        # Number of epochs
        epochs=epochs,
        # Dataloader batch size
        batch_size=batch_size,
        # Differential Privacy
        dp=dp,
        # DP target epsilon
        dp_target_epsilon=dp_target_epsilon,
        # DP target delta
        dp_target_delta=dp_target_delta,
        # DP max gradient norm
        dp_max_grad_norm=dp_max_grad_norm,
        # Silo name/identifier
        metrics_prefix=silo_name,
    )

    # IMPORTANT: we will assume that any output provided here can be exfiltrated into the orchestrator/gather
    return {
        # NOTE: The key you use is custom. A mapping function needs to be provided to map the name here to the expected input from gather.
        # This was already done, just above this function definition.

        "model": silo_training_step.outputs.model
    }

## Build the FL pipeline

This is where the magic happens. We just use the `fl_scatter_gather` API, which will build the whole FL graph for us, and anchor the silos' components appropriately. See how we use the list of silos and subgraph we created earlier, along with the mappings, and some additional arguments from the config file.

In [None]:
fl_node = fl.fl_scatter_gather(
    silo_configs=silo_list,
    silo_component=silo_scatter_subgraph,
    aggregation_component=aggregate_component,
    aggregation_compute=YAML_CONFIG.orchestrator.compute,
    aggregation_datastore=YAML_CONFIG.orchestrator.datastore,
    shared_silo_kwargs=silo_kwargs,
    aggregation_kwargs=agg_kwargs,
    silo_to_aggregation_argument_map=silo_to_aggregation_argument_map,
    aggregation_to_silo_argument_map=aggregation_to_silo_argument_map,
    max_iterations=YAML_CONFIG.general_training_parameters.num_of_iterations,
)

## Submit the job

This part is for submitting the job to AzureML. This is fairly standard in our FL examples and has nothing to do with the Scatter-Gather API.

In [None]:
print("Submitting the pipeline job to your AzureML workspace...")
pipeline_job = ML_CLIENT.jobs.create_or_update(
    fl_node.scatter_gather_graph, experiment_name="example_fl_pipeline_with_sdk_accelerator"
)

print("The url to see your live job running is returned by the sdk:")
print(pipeline_job.services["Studio"].endpoint)

webbrowser.open(pipeline_job.services["Studio"].endpoint)