## Many Model Inference in Snowflake -- Setup

This notebook sets up all models required for the [Many Model Inference in Snowflake](https://quickstarts.snowflake.com/guide/many-model-inference-in-snowflake/index.html?index=..%2F..index#0) quickstart. In this notebook, we will distinct XGBoost models for each station in the Citibike network. These models will be used in part 2 of the quickstart where we will run inference based on the station id. 

We will start with imports and defining the session. Please add `snowflake-ml-python` fron the packages dropdown before starting. 

In [1]:
from snowflake.snowpark import Session
import warnings
import pandas as pd
from snowflake.snowpark.context import get_active_session
session = get_active_session()

# Add a query tag to the session.This helps with performance monitoring and troubleshooting
session.query_tag = {"origin":"sf_sit-is", 
                     "name":"partitioned_models_stateful", 
                     "version":{"major":1, "minor":0},
                     "attributes":{"is_quickstart":1, "source":"notebook"}}

SnowflakeLoginOptions() is in private preview since 0.2.0. Do not use it in production. 


### Feature Engineering

We define the source data from Snowflake ML Citibike example data. Then we will prepare the data for training.

In [4]:
from snowflake.ml.feature_store.examples.example_helper import ExampleHelper

example_helper = ExampleHelper(session, 
                               session.get_current_database(), 
                               session.get_current_schema())
# print(type(example_helper))
source_tables = example_helper.load_example('citibike_trip_features')
# print(type(source_tables))
source_tables

INFO:snowflake.ml.feature_store.examples.example_helper:MANYMODEL_DB.MANYMODEL_SCHEMA.citibike_trips has been created successfully.


['MANYMODEL_DB.MANYMODEL_SCHEMA.citibike_trips']

In [5]:
from snowflake.snowpark import functions as F

citibike_trips = session.table(source_tables[0])
stations = citibike_trips.select(F.col("START_STATION_ID")).distinct().to_pandas()

In [None]:
def prepare_input_data(input_df):
    # Use get_dummies for categorical features
    input_df['HOUR'] = input_df['STARTTIME'].dt.hour
    input_df['WEEKDAY'] = input_df['STARTTIME'].dt.weekday
    
    # Converting features to categories for get_dummies
    input_df['WEEKDAY'] = input_df['WEEKDAY'].astype(pd.CategoricalDtype(categories=(w for w in range(7))))
    input_df['HOUR'] = input_df['HOUR'].astype(pd.CategoricalDtype(categories=(w for w in range(24))))
    input_df['USERTYPE'] = input_df['USERTYPE'].astype(pd.CategoricalDtype(categories=("Subscriber", "Customer")))
    input_df['GENDER'] = input_df['GENDER'].astype(pd.CategoricalDtype(categories=(0, 1, 2)))

    output_df = pd.get_dummies(data=input_df, columns=['WEEKDAY', 'HOUR', 'USERTYPE', 'GENDER'])
    output_df = output_df.drop(['STARTTIME', 'STOPTIME', 'START_STATION_NAME',
        'START_STATION_LATITUDE', 'START_STATION_LONGITUDE', 'END_STATION_ID',
        'END_STATION_NAME', 'END_STATION_LATITUDE', 'END_STATION_LONGITUDE',
        'TRIP_ID', 'MEMBERSHIP_TYPE', 'BIRTH_YEAR', 'BIKEID'], axis=1)

    return output_df

In [8]:
_INPUT_COLS = ['WEEKDAY_0', 'WEEKDAY_1', 'WEEKDAY_2', 'WEEKDAY_3', 'WEEKDAY_4',
       'WEEKDAY_5', 'WEEKDAY_6', 'HOUR_0', 'HOUR_1', 'HOUR_2', 'HOUR_3',
       'HOUR_4', 'HOUR_5', 'HOUR_6', 'HOUR_7', 'HOUR_8', 'HOUR_9',
       'HOUR_10', 'HOUR_11', 'HOUR_12', 'HOUR_13', 'HOUR_14', 'HOUR_15',
       'HOUR_16', 'HOUR_17', 'HOUR_18', 'HOUR_19', 'HOUR_20', 'HOUR_21',
       'HOUR_22', 'HOUR_23', 'USERTYPE_Customer', 'USERTYPE_Subscriber', 'GENDER_1', 'GENDER_2', 'GENDER_0']

### Model Training

Next we will train models corresponding to each station in the Citibike network. We will then serialize the models and save the data to a table in Snowflake

In [9]:
import xgboost
from datetime import timedelta
#from tqdm import tqdm

# do the train & forecast split
forecast_start = pd.to_datetime('2013-12-15')
train_start = pd.to_datetime('2013-11-15')
train_end = forecast_start - timedelta(days=1)

models = {}

for station_id in stations["START_STATION_ID"]:
    
    station_trips_pandas = citibike_trips.filter(F.col("START_STATION_ID") == station_id).to_pandas()
    training_data_all = prepare_input_data(station_trips_pandas)

    train = training_data_all[(station_trips_pandas["STARTTIME"] >= train_start) & (station_trips_pandas["STARTTIME"] <= pd.to_datetime(train_end))]
    X_train = train[_INPUT_COLS]
    y_train = train['TRIPDURATION']

    # Train an XGBoost regression model for every station
    model = xgboost.XGBRegressor(n_estimators=50, n_jobs=1)
    model.fit(X_train, y_train, verbose=False)
    models[station_id] = model

100%|██████████| 329/329 [02:53<00:00,  1.89it/s]


In [10]:
import pickle
models_bytes_df = pd.DataFrame([{"START_STATION_ID": id, "MODEL_PICKLE_BYTES": pickle.dumps(m)} for id, m in models.items()])

In [None]:
models_bytes_df["START_STATION_ID"] = models_bytes_df["START_STATION_ID"].astype(str)
models_df = session.create_dataframe(models_bytes_df)

In [None]:
models_df.write.mode("overwrite").save_as_table("MODELS_TABLE")

### Create Example Data

Last, we will define input data that we will use in part two of this quickstart. 

In [23]:
# Get the input data
input_df_pandas = citibike_trips.filter(F.col("STARTTIME") >= '2013-12-15').to_pandas()
input_df_pandas = prepare_input_data(input_df_pandas)
input_df_pandas = input_df_pandas.drop(['TRIPDURATION'], axis=1)

In [24]:
input_df_pandas

Unnamed: 0,START_STATION_ID,WEEKDAY_0,WEEKDAY_1,WEEKDAY_2,WEEKDAY_3,WEEKDAY_4,WEEKDAY_5,WEEKDAY_6,HOUR_0,HOUR_1,...,HOUR_19,HOUR_20,HOUR_21,HOUR_22,HOUR_23,USERTYPE_Subscriber,USERTYPE_Customer,GENDER_0,GENDER_1,GENDER_2
0,382,False,False,False,False,False,False,True,True,False,...,False,False,False,False,False,True,False,False,True,False
1,428,False,False,False,False,False,False,True,True,False,...,False,False,False,False,False,True,False,False,False,True
2,432,False,False,False,False,False,False,True,True,False,...,False,False,False,False,False,True,False,False,True,False
3,450,False,False,False,False,False,False,True,True,False,...,False,False,False,False,False,True,False,False,True,False
4,509,False,False,False,False,False,False,True,True,False,...,False,False,False,False,False,True,False,False,False,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
201827,350,False,True,False,False,False,False,False,False,False,...,False,False,False,False,True,True,False,False,True,False
201828,538,False,True,False,False,False,False,False,False,False,...,False,False,False,False,True,True,False,False,True,False
201829,281,False,True,False,False,False,False,False,False,False,...,False,False,False,False,True,True,False,False,True,False
201830,492,False,True,False,False,False,False,False,False,False,...,False,False,False,False,True,True,False,False,True,False


In [25]:
# Write a table with the initial dataset
INPUT_TABLE_NAME = "INPUT_DATA"
input_df = session.create_dataframe(input_df_pandas)
input_df.write.mode("overwrite").save_as_table(INPUT_TABLE_NAME)