In [1]:
# Standard Libraries
import os

# External Libraries
import click
import pandas as pd
import altair as alt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix
import mlflow.sklearn
import altair as alt
import altair_viewer

# Internal Libraries
import mlflow_vismod


# Remove Max Rows Limit in Altair
alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

In [2]:
# MLflow
MLFLOW_TRACKING_URI = os.getenv('MLFLOW_TRACKING_URI', 'http://localhost:5000')
EXPERIMENT = '/Users/james.hibbard@seattlechildrens.org/iris'

# Sklearn Model
TEST_SIZE = 0.33
RANDOM_STATE = 42

  and should_run_async(code)


In [3]:
os.environ['MLFLOW_TRACKING_URI'] = MLFLOW_TRACKING_URI
mlflow.set_experiment(EXPERIMENT)

In [4]:
# Preprocess Example Iris Dataset
iris = datasets.load_iris(as_frame=True, )
X_train_iris, X_test_iris, y_train_iris, y_test_iris = train_test_split(
    iris['data'],
    iris['target'],
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
)

# Preprocessing Example 
diamonds = pd.read_csv('diamonds.csv')
X_train_diamond, X_test_diamond, y_train_diamond, y_test_diamond = train_test_split(
    diamonds[['carat', 'cut', 'color', 'clarity', 'depth', 'table', 'x', 'y', 'z', ]],
    diamonds[['price', ]],
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
)

# Viz Model Flavor
#### Use Cases and API

- The concept of a model extends far beyond machine learning encompassing any approximation of a system or process we wish to understand
- Machine learning models fit into a category of mathematically rigorous models that can be developed and used with the scientific method (hypothesis testing etc.) but are by no means the beginning or end of the model space continuum
- Visualizations can be used as models themselves or a means of visually communicating the results, quality etc. of machine learning or other models, to reduce cognitive load on the end user.  
- Visualizations require code/concept development and design like any other model this process needs to be captured to be re-usable and reproduceable
- Re-usability and deployability can be enhanced by including the visualization code in a registry, ready to accept compliant data to visualize, much like a machine learning model can be serialized, and deployed, ready to accept new instances of data for inference
- This pattern could very effective at orchestrating de-centralized data science workflows, typically found in Academia, healthcare, and small startups.  It allows developers to push and pull visualizations from a central registry (ie. MLflow deployed with Databricks or as part of a service stack on cloud VMs).

## General Viz Model Flavor API

The Viz Model flavor API follows the general MLflow API, conceptually.  The flavor is the top level API for visualization models, which can accept user submitted styles, for extension and customization of display and other functionalty specific to the base visualization libary and/or use case.

In [5]:
# Example MLflow Model Run:
# - sklearn flavor
with mlflow.start_run() as run:
    # Define and Fit Model
    clf = RandomForestClassifier(max_depth=7, random_state=RANDOM_STATE)
    clf.fit(X_train_iris, y_train_iris)
    
    # Log Accuracy
    mlflow.log_metric('accuracy', value=clf.score(X_test_iris, y_test_iris))
    
    # Log Model
    mlflow.sklearn.log_model(
        sk_model=clf,
        artifact_path='model',
    )

#### Preprocess Iris and Diamond Datasets
- merge dataframes for features and labels
- subset
- stash column names for chart labeling

TODO: handle chart relabeling automatically as part of the flavor

In [6]:
# stash original column names for chart
column_map_iris = {
    'x': 'sepal length (cm)',
    'y': 'sepal width (cm)',
    'z': 'target'
}

# prepare iris dataset, mapping data-specific column names to generic column names
df_iris = (
    # create dataframe
    pd.concat(
        [X_train_iris, y_train_iris], 
        axis=1, sort=False
    )
    # subset dataset
    [[*column_map_iris.values()]]
    # rename columns to x,y,z for consistency across datasets
    .rename(
        columns={v: k for k,v in column_map_iris.items()}
    )
)

df_iris.head()

  and should_run_async(code)


Unnamed: 0,x,y,z
96,5.7,2.9,1
105,7.6,3.0,2
66,5.6,3.0,1
0,5.1,3.5,0
122,7.7,2.8,2


In [7]:
column_map_diamond = {
    'x': 'carat',
    'y': 'clarity',
    'z': 'price'
}

df_diamond = (
    # create dataframe
    pd.concat(
        [X_train_diamond, y_train_diamond],
        axis=1,
        sort=False,
    )
    # subset dataset
    [[*column_map_diamond.values()]]
    # rename columns to x,y,z for consistency across datasets 
    .rename(
        columns={v: k for k,v in column_map_diamond.items()}
    )
)

df_diamond.head()

  and should_run_async(code)


Unnamed: 0,x,y,z
241,1.01,I1,2788
17398,0.32,SI1,612
36608,0.34,SI2,477
44731,0.56,VS1,1616
18104,1.02,VS1,7324


#### Log visualization model and artifacts:
- seralize visualization code
- log artifacts at an artifact path (typically relevant json, html)

In [8]:
# Example MLflow Model Run
# - vizmod flavor
# - vegalite `style`
with mlflow.start_run() as run_iris:
    # Define Viz
    viz_iris = alt.Chart(
        df_iris
    ).mark_circle(size=60).encode(
        x='x',
        y='y',
        color='z:N',
    ).properties(
        height=375,
        width=575,
    ).interactive()
    
    # Log Model
    mlflow_vismod.log_model(
        model=viz_iris, 
        artifact_path='viz',
        style='vegalite',
        input_example=df_iris,
    )
    
    # Optional: log artifact
    viz_iris.save('./example.html')
    mlflow.log_artifact(local_path='./example.html', artifact_path='charts')



# Label Chart
#
# We're replacing the generic x,y,z labels with their dataset-specific counterparts
viz_iris.title = 'Iris Classifications'
viz_iris.encoding.x.title = column_map_iris['x']
viz_iris.encoding.y.title = column_map_iris['y']
viz_iris.encoding.color.title = column_map_iris['z']
viz_iris

  and should_run_async(code)


#### Load and Reuse Iris Visualization for Diamonds

In [9]:
model_uri = os.path.join(run_iris.to_dictionary()['info']['artifact_uri'], 'viz')
loaded_viz_iris = mlflow_vismod.load_model(
    model_uri=model_uri,
    style='vegalite'
)

#### Update vizualization data object:

Just like an ML model must be able to accept new instances of data for inference and that instance must be compliant with the model in terms of features, so must a visualization be able to accept and render new instances of compliant data

In [10]:
# Example MLflow Model Run
# - vizmod flavor
# - vegalite `style`
with mlflow.start_run() as run_diamond:    
    # Reuse Iris Viz
    viz_diamond = loaded_viz_iris.display(df_diamond)


    # Label Chart
    #
    # We're replacing the generic x,y,z labels with their dataset-specific counterparts
    viz_diamond.title = 'Diamond Classifications'
    viz_diamond.encoding.x.title = column_map_diamond['x']
    viz_diamond.encoding.y.title = column_map_diamond['y']
    viz_diamond.encoding.color.title = column_map_diamond['z']
    
    # Log Model
    mlflow_vismod.log_model(
        model=viz_diamond, 
        artifact_path='viz',
        style='vegalite',
        input_example=df_diamond,
    )
    
    # Optional: log artifact
    viz_diamond.save('./example.html')
    mlflow.log_artifact(local_path='./example.html', artifact_path='charts')


viz_diamond

#### Register visualization model and artifacts:

Send logged visualizations to the MLflow registry for use in dev, orchestration, and deployment

In [11]:
viz_iris | viz_diamond

## Typical Model Styles:

- **ggplot:** an R library based on the grammar of graphics, which is particularly effective for leveraging R dataframes.
- **Altair:** is a declarative python library which serves as an API to vegalite, an declarative JS based visualization libary based on the grammer of graphics. This library is useful for simple to complex visualizations leveraging pandas, for display in python environments or export to HTML for display on the web
- **matplotlib/seaborn:** popular python viz libraries

## Example Use Cases 
 
 - Visualize Performance Metrics and Register With ML Model:
     + create a scatter plot predicted vs ground truth and color by correctness
     + use the diplay function to wrap an example in html
     + register the visualization and html artifact along with the model
     + another developer or user can see the diagnostic in the registry and pull down the model plus visualization for re-use and extension

In [12]:
# Classifier Predictions on Iris Dataset
df_iris_predictions = pd.concat(
    [X_test_iris, clf.predict(X_test_iris) == y_test_iris],
    axis=1,
    sort=False
)[[*column_map_iris.values()]].rename(
    columns={v: k for k,v in column_map_iris.items()}
)

df_iris_predictions.head()

Unnamed: 0,x,y,z
73,6.1,2.8,True
18,5.7,3.8,True
118,7.7,2.6,True
78,6.0,2.9,True
76,6.8,2.8,True


In [13]:
viz_iris_predictions = loaded_viz_iris.display(df_iris_predictions)

viz_iris_predictions.title = 'Iris Classification Predictions'
viz_iris_predictions.encoding.x.title = column_map_iris['x']
viz_iris_predictions.encoding.y.title = column_map_iris['y']
viz_iris_predictions.encoding.color.title = 'predicted correctly'
viz_iris_predictions

  and should_run_async(code)


- Visualize a key message of a model for end user consumption:
    + Pull a generalized model from a registry (ie. a scatter plot that accepts x, y, and color by variable)
    + Point the model at a compliant instance of data or entire data set to visualize
    + deploy rendered model to the end user
    
    Note for scenarios like, the data set could be dynamic (new data points coming in) and the code associated with the visualization could be an entire transformation pipeline, jsut as sklearn or other models can be associated with data transformations as part of sklearn pipeline objects.  This is particularly effective with altair models, which can include Pandas and related pipelines, out of the box
​

- Deploy an interactive web visualization: This use case is particuarly effective with Vegalite, since models are encoded as JSON objects, which accept pointers to data.  This means you can store the JSON object as an artifact and use the visualization API to point it at new data.
    + Encode a visualization (ie. a geo map for ICU bed utilization and covid cases for each county in the USA with size/color encodings and tooltips which help the end user make decisions on where patients or resources should be sent)
    + register the visualization
    + pull the visualization and point it at a compliant data set of interest
    + deploy to a website (the website would already have panel or some other mechanism of accpeting the registered visualization)
    + alter the visualization as needed, using registry tags to point the new version of the model

In [14]:
import altair_viewer


altair_viewer.display(viz_iris_predictions)