## Many Model Inference in Snowflake

This notebook accompanies the [Many Model Inference in Snowflake](https://quickstarts.snowflake.com/guide/many-model-inference-in-snowflake/index.html?index=..%2F..index#0) quickstart. In the notebook, we will show how you can use pretrained models to define a [partitioned custom model](https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/partitioned-custom-models) in Snowflake. The model will run inference using a model based on the value in the partition column -- in our case the Station ID. 

We will start with imports and defining the session and constants. Please add `snowflake-ml-python` fron the packages dropdown before starting. 

In [None]:
from snowflake.snowpark import Session
from snowflake.ml.model import custom_model
from snowflake.ml.registry import registry

from typing import Optional
import warnings
import pandas as pd

from snowflake.snowpark.context import get_active_session
session = get_active_session()

# Add a query tag to the session.This helps with performance monitoring and troubleshooting
session.query_tag = {"origin":"sf_sit-is", 
                     "name":"partitioned_models_stateful", 
                     "version":{"major":1, "minor":0},
                     "attributes":{"is_quickstart":1, "source":"notebook"}}

In [None]:
DATABASE = session.get_current_database()
SCHEMA = session.get_current_schema()

_INPUT_COLS = ['WEEKDAY_0', 'WEEKDAY_1', 'WEEKDAY_2', 'WEEKDAY_3', 'WEEKDAY_4',
               'WEEKDAY_5', 'WEEKDAY_6', 'HOUR_0', 'HOUR_1', 'HOUR_2', 'HOUR_3',
               'HOUR_4', 'HOUR_5', 'HOUR_6', 'HOUR_7', 'HOUR_8', 'HOUR_9',
               'HOUR_10', 'HOUR_11', 'HOUR_12', 'HOUR_13', 'HOUR_14', 'HOUR_15',
               'HOUR_16', 'HOUR_17', 'HOUR_18', 'HOUR_19', 'HOUR_20', 'HOUR_21',
               'HOUR_22', 'HOUR_23', 'USERTYPE_Customer', 'USERTYPE_Subscriber', 
               'GENDER_1', 'GENDER_2', 'GENDER_0']

### Define the Partitioned Model

We will now define the custom model. The partitoned custom model class inherits from `snowflake.ml.model.custom_model.CustomModel`, and inference methods are declared with the `@custom_model.partitioned_inference_api` decorator

In [None]:
import pickle

class BikeTripDurationForecastingModelPickleInput(custom_model.CustomModel):
    def __init__(self, context: Optional[custom_model.ModelContext] = None) -> None:
        super().__init__(context)
        self.partition_id = None
        self.model = None

    @custom_model.partitioned_inference_api
    def predict(self, input: pd.DataFrame) -> pd.DataFrame:
        input_cols = _INPUT_COLS

        if self.partition_id != input['START_STATION_ID'][0]:
            self.partition_id = input['START_STATION_ID'][0]
            self.model = pickle.loads(input['MODEL_PICKLE_BYTES'][0])

        model_output = self.model.predict(input[input_cols])
        res = pd.DataFrame(model_output, columns=["DURATION"])
        res['START_STATION_ID_OUT'] = input['START_STATION_ID']
        return res

In [None]:
m = BikeTripDurationForecastingModelPickleInput()

In [None]:
m

### Log Model to Model Registry

Next we will log the model to Snowflake Model Registry. We will first define the signature for our prediction method, then define the registry, and finally log the model.

In [None]:
from snowflake.ml.model.model_signature import FeatureSpec, DataType, ModelSignature


input_signature = [
    FeatureSpec(dtype=DataType.BOOL, name=n) for n in _INPUT_COLS
]
input_signature.append(
    FeatureSpec(dtype=DataType.BYTES, name='MODEL_PICKLE_BYTES')
)
input_signature.append(
    FeatureSpec(dtype=DataType.INT64, name='START_STATION_ID')
)

output_signature = [
    FeatureSpec(dtype=DataType.FLOAT, name='DURATION'),
    FeatureSpec(dtype=DataType.STRING, name='START_STATION_ID_OUT'),
]

signature = ModelSignature(
    inputs=input_signature,
    outputs=output_signature,
)

In [None]:
# Log model
reg = registry.Registry(session=session, 
                        database_name=DATABASE, 
                        schema_name=SCHEMA)

In [None]:
reg.show_models()

In [None]:
options = {
    "function_type": "TABLE_FUNCTION",
    "relax_version": False
}

mv = reg.log_model(
    m,
    model_name="biketrip_duration_forecast_model",
    version_name="v1",
    options=options,
    conda_dependencies=["pandas", "xgboost"],
    signatures={"predict": signature}
)

### Run Inference

Finally, we will run inference using our custom partitioned model. We will pull the input data we defined in the setup notebook, then run inference, and save the results to a table in Snowflake. 

In [None]:
input_df = session.table(f"{DATABASE}.{SCHEMA}.INPUT_DATA")

model_bytes_table = session.table(f"{DATABASE}.{SCHEMA}.MODELS_TABLE")
input_df = input_df.join(model_bytes_table, on="START_STATION_ID", type="left")
input_df.show()

In [None]:
result = mv.run(input_df, partition_column="START_STATION_ID")

In [None]:
result.write.mode("overwrite").save_as_table("RESULTS")