# Introduction

A demo on explainable models in model registry, adapted from [Quickstart Guide](https://quickstarts.snowflake.com/guide/intro_to_machine_learning_with,_snowpark_ml_for_python/#5), where we trained a model to predict price of a diamond based on its size, color, cut.

## Setup Variables

In [2]:
from snowflake.snowpark import Session
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions

session = Session.builder.configs({
    **SnowflakeLoginOptions(connection_name="<connection_name>"),
    "role": "<role>",
    "database": "<database>",
    "schema": "<schema>",
    "warehouse": "<warehouse>",
}).create()

SnowflakeLoginOptions() is in private preview since 0.2.0. Do not use it in production. 


In [3]:
TABLE = "DIAMONDS_DATA"

## Getting the data

In [4]:
train_df = session.table(TABLE)
train_pdf = train_df.to_pandas()

In [5]:
CATEGORICAL_COLUMNS = ["CUT", "COLOR", "CLARITY"]
NUMERICAL_COLUMNS = ["CARAT", "DEPTH", "TABLE_PCT", "X", "Y", "Z"]

LABEL_COLUMNS = ["PRICE"]
OUTPUT_COLUMNS = ["PREDICTED_PRICE"]

FEATURES = CATEGORICAL_COLUMNS + NUMERICAL_COLUMNS

In [7]:
xs = train_pdf[FEATURES]
ys = train_pdf[LABEL_COLUMNS]

## Training

In [10]:
from catboost import CatBoostRegressor
catboost_model = CatBoostRegressor(iterations=10, cat_features=CATEGORICAL_COLUMNS)

catboost_model.fit(xs, ys)

Learning rate set to 0.5
0:	learn: 2382.0261721	total: 61.4ms	remaining: 553ms
1:	learn: 1718.2563272	total: 65.4ms	remaining: 262ms
2:	learn: 1491.7641987	total: 69.4ms	remaining: 162ms
3:	learn: 1204.1808573	total: 72.8ms	remaining: 109ms
4:	learn: 989.8381674	total: 76.3ms	remaining: 76.3ms
5:	learn: 861.3793338	total: 78.9ms	remaining: 52.6ms
6:	learn: 790.6427205	total: 81.8ms	remaining: 35ms
7:	learn: 738.7260860	total: 84.3ms	remaining: 21.1ms
8:	learn: 708.9761048	total: 87.7ms	remaining: 9.74ms
9:	learn: 687.4288655	total: 90.4ms	remaining: 0us


<catboost.core.CatBoostRegressor at 0x12ff0aed0>

## Adding to Model Registry
When using `log_model` set the `enable_explainability` option to `True`.

In [11]:
from snowflake.ml.registry import Registry

reg = Registry(
    session=session, 
    database_name=session.get_current_database(),
    schema_name=session.get_current_schema()
)

In [12]:
mv = reg.log_model(
    catboost_model,
    model_name="diamond_catboost_explain_enabled",
    version_name="explain_v0",
    conda_dependencies=["snowflake-ml-python"],
    sample_input_data = xs,
    options={"enable_explainability": True}
)

  return next(self.gen)


New function available called `EXPLAIN` along with the inference methods.

In [13]:
mv.show_functions()

[{'name': 'EXPLAIN',
  'target_method': 'explain',
  'target_method_function_type': 'TABLE_FUNCTION',
  'signature': ModelSignature(
                      inputs=[
                          FeatureSpec(dtype=DataType.STRING, name='CUT'),
  		FeatureSpec(dtype=DataType.STRING, name='COLOR'),
  		FeatureSpec(dtype=DataType.STRING, name='CLARITY'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='CARAT'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='DEPTH'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='TABLE_PCT'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='X'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='Y'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='Z')
                      ],
                      outputs=[
                          FeatureSpec(dtype=DataType.DOUBLE, name='CUT_explanation'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='COLOR_explanation'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='CLARITY_explanation'),
  		FeatureSpec(dtype=DataType.DOUBLE, name='CARAT_explanati

In [14]:
mv.run(xs, function_name="predict")

Unnamed: 0,output_feature_0
0,596.916301
1,638.708848
2,779.811544
3,555.369615
4,481.385219
...,...
53935,2973.209622
53936,2650.645147
53937,2602.661167
53938,2193.898369


## Calling `EXPLAIN`

In [15]:
mv.run(xs, function_name="explain")

Unnamed: 0,CUT_explanation,COLOR_explanation,CLARITY_explanation,CARAT_explanation,DEPTH_explanation,TABLE_PCT_explanation,X_explanation,Y_explanation,Z_explanation
0,0.0,166.760619,-449.627312,-1206.348313,1.619296,-18.104913,-802.440106,-292.247123,-735.327481
1,0.0,168.778358,-201.601291,-1338.039216,1.755842,10.650137,-789.192082,-384.523327,-761.751208
2,0.0,209.698547,249.776097,-1442.396141,1.755842,10.650137,-844.634118,-384.544114,-953.126340
3,0.0,-370.175535,134.006603,-1378.787863,1.755842,10.650137,-774.617378,-384.700362,-615.393463
4,0.0,-470.804986,-370.984831,-1077.890915,1.755842,10.650137,-786.953160,-169.379700,-587.638802
...,...,...,...,...,...,...,...,...,...
53935,0.0,330.942653,-271.085590,-1156.525957,4.640223,3.257916,857.326257,-470.480515,-257.497000
53936,0.0,330.942653,-271.085590,-343.077142,4.671432,-5.538360,-504.289171,-420.764317,-72.845993
53937,0.0,330.942653,-271.085590,-343.077142,4.640223,3.257916,-504.289171,-422.475924,-127.883434
53938,0.0,-255.983618,-665.091902,-940.822052,4.640223,3.257916,849.681084,-496.931863,-237.483054
