# Notebook #4: Model Training on Federated Data

#### Import the Rhino Health Python library & Authenticate to the Rhino Cloud
We'll again import any necessary functions from the `rhino_health` library and authenticate to the Rhino Cloud. Please refer to Notebook #1 for an explanation of the `session` interface for interacting with various endpoints in the Rhino Health ecosystem. In addition, you can always find more information about the Rhino SDK on our <a target="_blank" href="https://rhinohealth.github.io/rhino_sdk_docs/html/autoapi/index.html">Official SDK Documentation</a> and on our <a target="_blank" href="https://pypi.org/project/rhino-health/">PyPI Repository Page</a>

In [None]:
import pandas as pd
import numpy as np
import matplotlib.axes
import matplotlib.figure
import matplotlib.pyplot as plt
from PIL import Image
import os
import sys
import getpass
import json
import io
import base64
import rhino_health as rh
from rhino_health.lib.metrics import RocAuc, RocAucWithCI

my_username = "FCP_LOGIN_EMAIL" # Replace this with the email you use to log into Rhino Health
session = rh.login(username=my_username, password=getpass.getpass())

#### Load the Evaluation Results Generated in Notebook #4
In the previous notebook, we passed a string to the `validation_dataset_inference_suffix` argument. This had the effect of assigning a name to the dataset that contains the results of our model. We'll retrieve that dataset now so that we can use the data to examine the results of our model validation.  

In [None]:
project = session.project.get_project_by_name("YOUR_PROJECT_NAME")  # Replace with your project name
results_datasets = session.dataset.search_for_datasets_by_name('DATASET_SUFFIX') # Change it with your suffix
[dataset.name for dataset in results_datasetes]

This **Code Run** object encapsulates vital information pertaining to our specific model run within the Rhino Health FCP. Code Runs serve as the cornerstone of informed data analysis within Rhino Health FCP. By encapsulating run configurations, runtime insights, logs, and reporting capabilities, the Code Run object empowers you to derive meaningful insights, troubleshoot effectively, and collaborate seamlessly with others. It's your gateway to unlocking the potential of your model executions.



#### Generate a Receiver Operating Characteristic (ROC) curve
An **ROC curve** (receiver operating characteristic curve) is a graph showing the performance of a classification model at all classification thresholds. An ROC curve plots true positive ratio vs. false positive ratio at different classification thresholds. Lowering the classification threshold classifies more items as positive, thus increasing both False Positives and True Positives. 

In [None]:
# function to plot ROC

def plot_roc(results,datasets):
    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    linestyle_cycle = ['-', '--']
    fig, ax = plt.subplots(figsize=[6, 4], dpi=200)
    linestyle = linestyle_cycle[0]
    
    for i,result in enumerate (results):
        roc_metrics = result.output
        color = colors[0]
        ax.plot(roc_metrics['fpr'], roc_metrics['tpr'], color=colors[i], 
                linestyle=linestyle, label=datasets[i])
        ax.legend(loc='lower right')

    ax.title.set_text('ROC per Site')
    ax.set_xlabel('1 - Specificity')
    ax.set_ylabel('Sensitivity')
    ax.grid(True)
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1])
    fig.canvas.draw()
    return fig

In [None]:
results = []
datasets = []
report_data = []
report_data.append({"type": "Title", "data": "ROC Analysis"})

for result in results_datasets:
    dataset = session.dataset.get_dataset(result.uid)
    datasets.append(dataset.name.split('-')[0])
    metric_configuration = RocAuc(y_true_variable="Pneumonia",
                                  y_pred_variable="Model_Score")
    results.append(dataset.get_metric(metric_configuration))
fig = plot_roc(results, datasets)
image_to_store = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
image_to_store.save("ROC_per_site.png", format='png', optimize=True, quality=100)

with open("ROC_per_site.png", "rb") as temp_image:
    base_64_image = base64.b64encode(temp_image.read()).decode("utf-8")
    report_data.append(
              {
                 "type": "Image",
                 "data": {
                     "image_filename": "ROC per site",
                     "image_base64": base_64_image,
                 },
                 "width": 100
              }
    )

### Upload the visualizations to the Rhino Health Platform
Users have the flexibility to generate reports related to the code run and make them accessible via the Code Run object. This feature aids in sharing insights and outcomes with collaborators or stakeholders.

In the below code block we'll upload our ROC curve visualization to the cloud so that it can be viewed by our collaborators. 

In [None]:
code_run_uid = "XXXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX" # Paste the UID of the Code Run object for your NVF
result = session.post(f"code_runs/{code_run_uid}/set_report/", data={"report_data": json.dumps(report_data)})