In [None]:
import snowflake.snowpark
from snowflake.snowpark import functions as F
from snowflake.snowpark.session import Session
from snowflake.snowpark import version as v
import json 

with open('creds.json') as f:
    data = json.load(f)
    USERNAME = data['user']
    PASSWORD = data['password']
    SF_ACCOUNT = data['account']
    SF_WH = data['warehouse']

CONNECTION_PARAMETERS = {
   "account": SF_ACCOUNT,
   "user": USERNAME,
   "password": PASSWORD,
}

session = Session.builder.configs(CONNECTION_PARAMETERS).create()

## Environment Setup

In [None]:
session.sql('CREATE DATABASE IF NOT EXISTS tpcds_xgboost').collect()
session.sql('CREATE SCHEMA IF NOT EXISTS tpcds_xgboost.demo').collect()
session.sql("create or replace warehouse FE_AND_INFERENCE_WH with warehouse_size='3X-LARGE'").collect()
session.sql("create or replace warehouse snowpark_opt_wh with warehouse_size = 'MEDIUM' warehouse_type = 'SNOWPARK-OPTIMIZED'").collect()
session.use_warehouse('FE_AND_INFERENCE_WH')

Select either 100 or 10 for the TPC-DS Dataset size to use below. See (https://docs.snowflake.com/en/user-guide/sample-data-tpcds.html)[here] for more information If you choose 100, I recommend >= 3XL warehouse. 

In [None]:
TPCDS_SIZE_PARAM = 100

if TPCDS_SIZE_PARAM == 100: 
    store_sales = session.table('sfc_samples_sample_data.TPCDS_SF100TCL.store_sales')
    catalog_sales = session.table('sfc_samples_sample_data.TPCDS_SF100TCL.catalog_sales') 
    web_sales = session.table('sfc_samples_sample_data.TPCDS_SF100TCL.web_sales') 
    date = session.table('sfc_samples_sample_data.TPCDS_SF100TCL.date_dim')
    dim_stores = session.table('sfc_samples_sample_data.TPCDS_SF100TCL.store')
    customer = session.table('sfc_samples_sample_data.TPCDS_SF100TCL.customer')
    address = session.table('sfc_samples_sample_data.TPCDS_SF100TCL.customer_address')
    demo = session.table('sfc_samples_sample_data.TPCDS_SF100TCL.customer_demographics')
elif TPCDS_SIZE_PARAM == 10:
    store_sales = session.table('sfc_samples_sample_data.TPCDS_SF10TCL.store_sales')
    catalog_sales = session.table('sfc_samples_sample_data.TPCDS_SF10TCL.catalog_sales') 
    web_sales = session.table('sfc_samples_sample_data.tpcds_sf10tcl.web_sales') 
    date = session.table('sfc_samples_sample_data.tpcds_sf10tcl.date_dim')
    dim_stores = session.table('sfc_samples_sample_data.tpcds_sf10tcl.store')
    customer = session.table('sfc_samples_sample_data.tpcds_sf10tcl.customer')
    address = session.table('sfc_samples_sample_data.tpcds_sf10tcl.customer_address')
    demo = session.table('sfc_samples_sample_data.tpcds_sf10tcl.customer_demographics')
else:
    raise ValueError("Invalid TPCDS_SIZE_PARAM selection")

## Feature Engineering
We will aggregate sales by customer across all channels(web, store, catalogue) and join that to customer demographic data. 

In [None]:
store_sales_agged = store_sales.group_by('ss_customer_sk').agg(F.sum('ss_sales_price').as_('total_sales'))
web_sales_agged = web_sales.group_by('ws_bill_customer_sk').agg(F.sum('ws_sales_price').as_('total_sales'))
catalog_sales_agged = catalog_sales.group_by('cs_bill_customer_sk').agg(F.sum('cs_sales_price').as_('total_sales'))
store_sales_agged = store_sales_agged.rename('ss_customer_sk', 'customer_sk')
web_sales_agged = web_sales_agged.rename('ws_bill_customer_sk', 'customer_sk')
catalog_sales_agged = catalog_sales_agged.rename('cs_bill_customer_sk', 'customer_sk')

In [None]:
total_sales = store_sales_agged.union_all(web_sales_agged)
total_sales = total_sales.union_all(catalog_sales_agged)

In [None]:
total_sales = total_sales.group_by('customer_sk').agg(F.sum('total_sales').as_('total_sales'))

In [None]:
customer = customer.select('c_customer_sk','c_current_hdemo_sk', 'c_current_addr_sk', 'c_customer_id', 'c_birth_year')

In [None]:
customer = customer.join(address.select('ca_address_sk', 'ca_zip'), customer['c_current_addr_sk'] == address['ca_address_sk'] )
customer = customer.join(demo.select('cd_demo_sk', 'cd_gender', 'cd_marital_status', 'cd_credit_rating', 'cd_education_status', 'cd_dep_count'),
                                customer['c_current_hdemo_sk'] == demo['cd_demo_sk'] )
customer = customer.rename('c_customer_sk', 'customer_sk')
customer.show()

In [None]:
final_df = total_sales.join(customer, on='customer_sk')

In [None]:
session.use_database('tpcds_xgboost')
session.use_schema('demo')
final_df.write.mode('overwrite').save_as_table('feature_store')

In [None]:
session.add_packages('snowflake-snowpark-python', 'scikit-learn', 'pandas', 'numpy', 'joblib', 'cachetools', 'xgboost')

In [None]:
session.sql('CREATE OR REPLACE STAGE ml_models ').collect()

In [None]:
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder, MinMaxScaler
from sklearn.metrics import mean_squared_error
from sklearn.compose import ColumnTransformer
from xgboost import XGBRegressor
import joblib
import os

def train_model(session: snowflake.snowpark.Session) -> float:
    snowdf = session.table("feature_store")
    snowdf = snowdf.drop(['CUSTOMER_SK', 'C_CURRENT_HDEMO_SK', 'C_CURRENT_ADDR_SK', 'C_CUSTOMER_ID', 'CA_ADDRESS_SK', 'CD_DEMO_SK'])
    snowdf_train, snowdf_test = snowdf.random_split([0.8, 0.2], seed=82) 

    # save the train and test sets as time stamped tables in Snowflake 
    snowdf_train.write.mode("overwrite").save_as_table("tpcds_xgboost.demo.tpc_TRAIN")
    snowdf_test.write.mode("overwrite").save_as_table("tpcds_xgboost.demo.tpc_TEST")
    train_x = snowdf_train.drop("TOTAL_SALES").to_pandas() # drop labels for training set
    train_y = snowdf_train.select("TOTAL_SALES").to_pandas()
    test_x = snowdf_test.drop("TOTAL_SALES").to_pandas()
    test_y = snowdf_test.select("TOTAL_SALES").to_pandas()
    cat_cols = ['CA_ZIP', 'CD_GENDER', 'CD_MARITAL_STATUS', 'CD_CREDIT_RATING', 'CD_EDUCATION_STATUS']
    num_cols = ['C_BIRTH_YEAR', 'CD_DEP_COUNT']

    num_pipeline = Pipeline([
            ('imputer', SimpleImputer(strategy="median")),
            ('std_scaler', StandardScaler()),
        ])

    preprocessor = ColumnTransformer(
    transformers=[('num', num_pipeline, num_cols),
                  ('encoder', OneHotEncoder(handle_unknown="ignore"), cat_cols) ])

    pipe = Pipeline([('preprocessor', preprocessor), 
                        ('xgboost', XGBRegressor())])
    pipe.fit(train_x, train_y)

    test_preds = pipe.predict(test_x)
    rmse = mean_squared_error(test_y, test_preds)
    model_file = os.path.join('/tmp', 'model.joblib')
    joblib.dump(pipe, model_file)
    session.file.put(model_file, "@ml_models",overwrite=True)
    return rmse

In [None]:
train_model_sp = F.sproc(train_model, session=session, replace=True)
# Switch to Snowpark Optimized Warehouse for training and to run the stored proc
session.use_warehouse('snowpark_opt_wh')
train_model_sp(session=session)

In [None]:
# Switch back to feature engineering/inference warehouse
session.use_warehouse('FE_AND_INFERENCE_WH')

In [None]:
import sys
import pandas as pd
import cachetools
import joblib
from snowflake.snowpark import types as T

session.add_import("@ml_models/model.joblib")  

features = [ 'C_BIRTH_YEAR', 'CA_ZIP', 'CD_GENDER', 'CD_MARITAL_STATUS', 'CD_CREDIT_RATING', 'CD_EDUCATION_STATUS', 'CD_DEP_COUNT']

@cachetools.cached(cache={})
def read_file(filename):
       import_dir = sys._xoptions.get("snowflake_import_directory")
       if import_dir:
              with open(os.path.join(import_dir, filename), 'rb') as file:
                     m = joblib.load(file)
                     return m

@F.pandas_udf(session=session, max_batch_size=10000)
def predict(df:  T.PandasDataFrame[int, str, str, str, str, str, int]) -> T.PandasSeries[float]:
       m = read_file('model.joblib')       
       df.columns = features
       return m.predict(df)

In [None]:
inference_df = session.table('feature_store')
inference_df = inference_df.drop(['CUSTOMER_SK', 'C_CURRENT_HDEMO_SK', 'C_CURRENT_ADDR_SK', 'C_CUSTOMER_ID', 'CA_ADDRESS_SK', 'CD_DEMO_SK'])
inputs = inference_df.drop("TOTAL_SALES")
snowdf_results = inference_df.select(*inputs,
                    predict(*inputs).alias('PREDICTION'), 
                    (F.col('TOTAL_SALES')).alias('ACTUAL_SALES')
                    )
snowdf_results.write.mode('overwrite').save_as_table('predictions')

In [None]:
inference_df.count()