## Many Model Inference in Snowflake

This notebook accompanies the [Many Model Inference in Snowflake](https://quickstarts.snowflake.com/guide/many-model-inference-in-snowflake/index.html?index=..%2F..index#0) quickstart. In the notebook, we will show how you can use pretrained models to define a [partitioned custom model](https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/partitioned-custom-models) in Snowflake. The model will run inference using a model based on the value in the partition column -- in our case the Station ID. 

We will start with imports and defining the session and constants. Please add `snowflake-ml-python` and `cloudpickle==2.2.1` from the packages dropdown before starting. 

In [None]:
from snowflake.snowpark import Session
from snowflake.ml.model import custom_model
from snowflake.ml.registry import registry

from typing import Optional
import warnings
import pandas as pd

from snowflake.snowpark.context import get_active_session
session = get_active_session()

# Add a query tag to the session.This helps with performance monitoring and troubleshooting
session.query_tag = {"origin":"sf_sit-is", 
                     "name":"partitioned_models_stateful", 
                     "version":{"major":1, "minor":0},
                     "attributes":{"is_quickstart":1, "source":"notebook"}}

In [None]:
DATABASE = session.get_current_database()
SCHEMA = session.get_current_schema()

_INPUT_COLS = ['PASSENGER_COUNT', 'TRIP_DISTANCE', 'FARE_AMOUNT', 
               'PAYMENT_TYPE_1', 'PAYMENT_TYPE_2', 'PAYMENT_TYPE_3',
               'PAYMENT_TYPE_4', 'PAYMENT_TYPE_5']

### Define the Partitioned Model

We will now define the custom model. The partitoned custom model class inherits from `snowflake.ml.model.custom_model.CustomModel`, and inference methods are declared with the `@custom_model.partitioned_inference_api` decorator. Writing the model in this way allows it to run in parallel for each partition.

In [None]:
import pickle

class TaxiForecastingModelPickleInput(custom_model.CustomModel):
    def __init__(self, context: Optional[custom_model.ModelContext] = None) -> None:
        super().__init__(context)
        self.partition_id = None
        self.model = None

    @custom_model.partitioned_inference_api
    def predict(self, input: pd.DataFrame) -> pd.DataFrame:
        input_cols = _INPUT_COLS

        if self.partition_id != input['PULOCATIONID'][0]:
            self.partition_id = input['PULOCATIONID'][0]
            self.model = pickle.loads(input['MODEL_PICKLE_BYTES'][0])

        model_output = self.model.predict(input[input_cols])
        res = pd.DataFrame(model_output, columns=["TOTAL_AMOUNT"])
        res['PULOCATIONID_OUT'] = input['PULOCATIONID']
        return res

In [None]:
m = TaxiForecastingModelPickleInput()

In [None]:
m

### Log Model to Model Registry

Next we will log the model to Snowflake Model Registry. We will first define the signature for our prediction method, then define the registry, and finally log the model.

In [None]:
from snowflake.ml.model.model_signature import FeatureSpec, DataType, ModelSignature

input_signature= [
  FeatureSpec(dtype=DataType.INT64, name='PULOCATIONID', nullable=True),
  FeatureSpec(dtype=DataType.INT64, name='PASSENGER_COUNT', nullable=True),
  FeatureSpec(dtype=DataType.DOUBLE, name='TRIP_DISTANCE', nullable=True),
  FeatureSpec(dtype=DataType.DOUBLE, name='FARE_AMOUNT', nullable=True),
  FeatureSpec(dtype=DataType.BOOL, name='PAYMENT_TYPE_1', nullable=True),
  FeatureSpec(dtype=DataType.BOOL, name='PAYMENT_TYPE_2', nullable=True),
  FeatureSpec(dtype=DataType.BOOL, name='PAYMENT_TYPE_3', nullable=True),
  FeatureSpec(dtype=DataType.BOOL, name='PAYMENT_TYPE_4', nullable=True),
  FeatureSpec(dtype=DataType.BOOL, name='PAYMENT_TYPE_5', nullable=True),
  FeatureSpec(dtype=DataType.BYTES, name='MODEL_PICKLE_BYTES', nullable=True),
]

output_signature = [
    FeatureSpec(dtype=DataType.FLOAT, name='TOTAL_AMOUNT'),
    FeatureSpec(dtype=DataType.STRING, name='PULOCATIONID_OUT'),
]

signature = ModelSignature(
    inputs=input_signature,
    outputs=output_signature,
)

In [None]:
# Log model
reg = registry.Registry(session=session, 
                        database_name=DATABASE, 
                        schema_name=SCHEMA)

In [None]:
reg.show_models()

In [None]:
options = {
    "function_type": "TABLE_FUNCTION",
    "relax_version": False
}

mv = reg.log_model(
    m,
    model_name="taxi_total_amount_forecast_model",
    version_name="v1",
    options=options,
    conda_dependencies=["pandas", "xgboost", "cloudpickle==2.2.1"], # cloudpickle version should be greater than 2.0.0 in notebook as well
    signatures={"predict": signature}
)

### Run Inference

Finally, we will run inference using our custom partitioned model. We will pull the input data we defined in the setup notebook, then run inference, and save the results to a table in Snowflake. 

In [None]:
input_df = session.table(f"{DATABASE}.{SCHEMA}.INPUT_DATA")

model_bytes_table = session.table(f"{DATABASE}.{SCHEMA}.MODELS_TABLE")
input_df = input_df.join(model_bytes_table, on="PULOCATIONID", type="left")
input_df.show()

Let's see how many distinct stations there are in the test data. Each station corresponds to a different model in the logged partitioned model. When we run inference, each station will run with the relevant model in parallel, speeding up inference and ensuring accurate results with a model trained only on relevant station data. 

In [None]:
input_df.select("PULOCATIONID").distinct().count()

We will now run inference for the entire input dataframe. Because we built the model as partitioned, we will split the data into partitions based on pick up location and run inference with the relevant model that we pulled from the models table. 

In [None]:
result = mv.run(input_df, partition_column="PULOCATIONID")

Finally, we will save the results to a table and view them.

In [None]:
result.write.mode("overwrite").save_as_table("RESULTS")

In [None]:
result_df = session.table("RESULTS")
result_df.show()