## Many Model Inference in Snowflake -- Setup

This notebook sets up all models required for the [Many Model Inference in Snowflake](https://quickstarts.snowflake.com/guide/many-model-inference-in-snowflake/index.html?index=..%2F..index#0) quickstart. In this notebook, we will distinct XGBoost models for each pickup zone in the New York taxi network. These models will be used in part 2 of the quickstart where we will run inference based on the pickup zone. 

We will start with imports and defining the session. Please add `snowflake-ml-python` fron the packages dropdown before starting. 

In [None]:
from snowflake.snowpark import Session
import warnings
import pandas as pd
from snowflake.snowpark.context import get_active_session
session = get_active_session()

# Add a query tag to the session.This helps with performance monitoring and troubleshooting
session.query_tag = {"origin":"sf_sit-is", 
                     "name":"partitioned_models_stateful", 
                     "version":{"major":1, "minor":0},
                     "attributes":{"is_quickstart":1, "source":"notebook"}}

SnowflakeLoginOptions() is in private preview since 0.2.0. Do not use it in production. 


### Feature Engineering

We define the source data from Snowflake ML New York Taxi example data. Then we will prepare the data for training.

In [4]:
from snowflake.ml.feature_store.examples.example_helper import ExampleHelper

example_helper = ExampleHelper(session, 
                               session.get_current_database(), 
                               session.get_current_schema())
source_tables = example_helper.load_example('new_york_taxi_features')
source_tables

INFO:snowflake.ml.feature_store.examples.example_helper:MANYMODEL_DB.MANYMODEL_SCHEMA.citibike_trips has been created successfully.


['MANYMODEL_DB.MANYMODEL_SCHEMA.citibike_trips']

In [None]:
from snowflake.snowpark import functions as F

taxi_trips = session.table(source_tables[0])


# first filter out zones with less than 100 pickups
location_counts = (
    taxi_trips.group_by("PULocationID")
    .agg(F.count("*").alias("row_count"))
    .filter(F.col("row_count") >= 100)  # Only keep counts >= 100
)
valid_locations = location_counts.select("PULocationID").distinct()
taxi_trips = taxi_trips.join(valid_locations, on="PULocationID", how="inner")

# Let's see how many distinct pickup zones there are in the data.
# We will train a model for each zone and run inference on the partitioned data
pickup_zones = taxi_trips.select(F.col("PULOCATIONID")).distinct().to_pandas()
len(pickup_zones)

In [None]:
_INPUT_COLS = ['PASSENGER_COUNT', 'TRIP_DISTANCE', 'FARE_AMOUNT', 
               'PAYMENT_TYPE_1', 'PAYMENT_TYPE_2', 'PAYMENT_TYPE_3',
               'PAYMENT_TYPE_4', 'PAYMENT_TYPE_5']

### Model Training

Next we will train models corresponding to each pickup location in the New York taxi network. We will then serialize the models and save the data to a table in Snowflake

In [None]:
# takes about 4 minutes to train all models
import xgboost
from datetime import timedelta

# do the train & forecast split
forecast_start = pd.to_datetime('2016-01-31')
train_start = pd.to_datetime('2016-01-01')
train_end = forecast_start - timedelta(days=1)
models = {}

for pickup_zone in pickup_zones["PULOCATIONID"]:
    
    taxi_trips_pandas = taxi_trips.filter(F.col("PULOCATIONID") == pickup_zone).to_pandas()
    taxi_trips_pandas["TPEP_PICKUP_DATETIME"] = pd.to_datetime(taxi_trips_pandas["TPEP_PICKUP_DATETIME"])
    
    # Converting payment_type to categories for get_dummies
    taxi_trips_pandas['PAYMENT_TYPE'] = taxi_trips_pandas['PAYMENT_TYPE'].astype(pd.CategoricalDtype(categories=(1, 2, 3, 4, 5), ordered=False))
    taxi_trips_pandas = pd.get_dummies(data=taxi_trips_pandas, columns=['PAYMENT_TYPE'])
    
    # filter by training date and select input and target columns
    train = taxi_trips_pandas[(taxi_trips_pandas["TPEP_PICKUP_DATETIME"] >= train_start) & (taxi_trips_pandas["TPEP_PICKUP_DATETIME"] <= pd.to_datetime(train_end))]
    X_train = train[_INPUT_COLS]
    y_train = train['TOTAL_AMOUNT']

    # Train an XGBoost regression model for every pickup location
    model = xgboost.XGBRegressor(n_estimators=50, n_jobs=1, random_state=42)
    model.fit(X_train, y_train, verbose=False)
    models[pickup_zone] = model

Now that we have trained the models, we will save the serialized model data to a table in Snowflake. In the next notebook, we will read from this table during inference to run the model relevant to the pick up location partition. 

In [10]:
import pickle
models_bytes_df = pd.DataFrame([{"PULOCATIONID": id, "MODEL_PICKLE_BYTES": pickle.dumps(m)} for id, m in models.items()])

In [None]:
models_bytes_df["PULOCATIONID"] = models_bytes_df["PULOCATIONID"].astype(str)
models_df = session.create_dataframe(models_bytes_df)

In [None]:
models_df.write.mode("overwrite").save_as_table("MODELS_TABLE")

### Create Example Data

Last, we will define input data that we will use in part two of this quickstart. 

In [23]:
# Get the input data
input_df_pandas = taxi_trips.filter(F.col("TPEP_PICKUP_DATETIME") >= '2016-01-31').to_pandas()
input_df_pandas['PAYMENT_TYPE'] = input_df_pandas['PAYMENT_TYPE'].astype(pd.CategoricalDtype(categories=(1, 2, 3, 4, 5), ordered=False))
input_df_pandas = pd.get_dummies(data=input_df_pandas, columns=['PAYMENT_TYPE'])
input_df_pandas = input_df_pandas[_INPUT_COLS + ["PULOCATIONID"]]

In [24]:
input_df_pandas

Unnamed: 0,START_STATION_ID,WEEKDAY_0,WEEKDAY_1,WEEKDAY_2,WEEKDAY_3,WEEKDAY_4,WEEKDAY_5,WEEKDAY_6,HOUR_0,HOUR_1,...,HOUR_19,HOUR_20,HOUR_21,HOUR_22,HOUR_23,USERTYPE_Subscriber,USERTYPE_Customer,GENDER_0,GENDER_1,GENDER_2
0,382,False,False,False,False,False,False,True,True,False,...,False,False,False,False,False,True,False,False,True,False
1,428,False,False,False,False,False,False,True,True,False,...,False,False,False,False,False,True,False,False,False,True
2,432,False,False,False,False,False,False,True,True,False,...,False,False,False,False,False,True,False,False,True,False
3,450,False,False,False,False,False,False,True,True,False,...,False,False,False,False,False,True,False,False,True,False
4,509,False,False,False,False,False,False,True,True,False,...,False,False,False,False,False,True,False,False,False,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
201827,350,False,True,False,False,False,False,False,False,False,...,False,False,False,False,True,True,False,False,True,False
201828,538,False,True,False,False,False,False,False,False,False,...,False,False,False,False,True,True,False,False,True,False
201829,281,False,True,False,False,False,False,False,False,False,...,False,False,False,False,True,True,False,False,True,False
201830,492,False,True,False,False,False,False,False,False,False,...,False,False,False,False,True,True,False,False,True,False


In [25]:
# Write a table with the initial dataset
INPUT_TABLE_NAME = "INPUT_DATA"
input_df = session.create_dataframe(input_df_pandas)
input_df.write.mode("overwrite").save_as_table(INPUT_TABLE_NAME)