# Library Imports

In [3]:
# Standard Libraries
import os

# External Libraries
import click
import pandas as pd
import altair as alt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix
import mlflow.sklearn
from mlflow.models.signature import infer_signature
import altair as alt
import altair_viewer

# Internal Libraries
import mlflow_vismod

# Constants

In [4]:
# MLflow
MLFLOW_TRACKING_URI = 'http://localhost:5000'
EXPERIMENT = 'iris'

# Sklearn Model
TEST_SIZE = 0.33
RANDOM_STATE = 42

  and should_run_async(code)


# Configurations

In [5]:
os.environ['MLFLOW_TRACKING_URI'] = MLFLOW_TRACKING_URI
mlflow.set_experiment(EXPERIMENT)

# Data Preparation

In [6]:
# Iris
iris = datasets.load_iris(as_frame=True, )
X_train, X_test, y_train, y_test = train_test_split(
    iris['data'],
    iris['target'],
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
)

In [17]:
# Diamonds
diamonds = pd.read_csv('./diamonds.csv')
diamonds

Unnamed: 0.1,Unnamed: 0,carat,cut,color,clarity,depth,table,price,x,y,z
0,1,0.23,Ideal,E,SI2,61.5,55.0,326,3.95,3.98,2.43
1,2,0.21,Premium,E,SI1,59.8,61.0,326,3.89,3.84,2.31
2,3,0.23,Good,E,VS1,56.9,65.0,327,4.05,4.07,2.31
3,4,0.29,Premium,I,VS2,62.4,58.0,334,4.20,4.23,2.63
4,5,0.31,Good,J,SI2,63.3,58.0,335,4.34,4.35,2.75
...,...,...,...,...,...,...,...,...,...,...,...
53935,53936,0.72,Ideal,D,SI1,60.8,57.0,2757,5.75,5.76,3.50
53936,53937,0.72,Good,D,SI1,63.1,55.0,2757,5.69,5.75,3.61
53937,53938,0.70,Very Good,D,SI1,62.8,60.0,2757,5.66,5.68,3.56
53938,53939,0.86,Premium,H,SI2,61.0,58.0,2757,6.15,6.12,3.74


# 3) Save/Load

In [33]:
with mlflow.start_run() as run:
    # Define Viz
    viz = alt.Chart(
        pd.concat([X_train, y_train], axis=1, sort=False)
    ).mark_circle(size=60).encode(
        x='sepal length (cm)',
        y='sepal width (cm)',
        color='target:N'
    ).interactive()
    
    # Log Model
    mlflow_vismod.log_model(
        model=viz, 
        artifact_path='viz',
        style='vegalite',
        signature=infer_signature(X_train, None),
        input_example=pd.concat([X_train, y_train], axis=1, sort=False),
    )

    
viz

  and should_run_async(code)


# 2) Change Data

# 2) Updating Viz

In [20]:
diamonds[['price', 'carat', 'color']]

  and should_run_async(code)


Unnamed: 0,price,carat,color
0,326,0.23,E
1,326,0.21,E
2,327,0.23,E
3,334,0.29,I
4,335,0.31,J
...,...,...,...
53935,2757,0.72,D
53936,2757,0.72,D
53937,2757,0.70,D
53938,2757,0.86,H


In [30]:
alt.data_transformers.disable_max_rows()
with mlflow.start_run() as run:
    # Define Viz
    viz = alt.Chart(
        diamonds[['price', 'carat', 'color']]
    ).mark_circle(size=60).encode(
        x=alt.X('carat', scale=alt.Scale(type='log')),
        y=alt.X('price', scale=alt.Scale(type='log')),
        color='color:N'
    )
    
    # Log Model
    mlflow_vismod.log_model(
        model=viz, 
        artifact_path='viz',
        style='vegalite',
        signature=infer_signature(X_train, None),
        input_example=pd.concat([X_train, y_train], axis=1, sort=False),
    )

    
viz

In [27]:
alt.data_transformers.disable_max_rows()
with mlflow.start_run() as run:
    # Define Viz
    viz = alt.Chart(
        diamonds[['price', 'carat', 'color']]
    ).mark_circle(size=60).encode(
        x='carat',
        y='price',
        color='color:N'
    )
    
    # Log Model
    mlflow_vismod.log_model(
        model=viz, 
        artifact_path='viz',
        style='vegalite',
        signature=infer_signature(X_train, None),
        input_example=pd.concat([X_train, y_train], axis=1, sort=False),
    )

    
viz

In [31]:
with mlflow.start_run() as run:
    # Define Viz
    viz = alt.Chart(
        pd.concat([X_train, y_train], axis=1, sort=False)
    ).mark_circle(size=60).encode(
        x='sepal length (cm)',
        y='sepal width (cm)',
        color='target:N'
    ).interactive()
    
    # Log Model
    mlflow_vismod.log_model(
        model=viz, 
        artifact_path='viz',
        style='vegalite',
        signature=infer_signature(X_train, None),
        input_example=pd.concat([X_train, y_train], axis=1, sort=False),
    )

    
viz