In [1]:
import snowflake.snowpark
from snowflake.snowpark import functions as F
from snowflake.snowpark.session import Session
from snowflake.snowpark import version as v
import json 

with open('connection.json') as f:
    data = json.load(f)
    USERNAME = data['user']
    PASSWORD = data['password']
    SF_ACCOUNT = data['account']
    SF_WH = data['warehouse']

CONNECTION_PARAMETERS = {
   "account": SF_ACCOUNT,
   "user": USERNAME,
   "password": PASSWORD,
}

session = Session.builder.configs(CONNECTION_PARAMETERS).create()

## Environment Setup

In [2]:
session.sql('''create database if not exists snowflake_sample_data from share sfc_samples.sample_data''').collect()

[Row(status='SNOWFLAKE_SAMPLE_DATA already exists, statement succeeded.')]

In [3]:
session.sql('CREATE DATABASE IF NOT EXISTS tpcds_xgboost').collect()
session.sql('CREATE SCHEMA IF NOT EXISTS tpcds_xgboost.demo').collect()
session.sql("create or replace warehouse FE_AND_INFERENCE_WH with warehouse_size='3X-LARGE'").collect()
session.sql("create or replace warehouse snowpark_opt_wh with warehouse_size = 'MEDIUM' warehouse_type = 'SNOWPARK-OPTIMIZED'").collect()
session.sql("alter warehouse snowpark_opt_wh set max_concurrency_level = 1").collect()
session.use_warehouse('FE_AND_INFERENCE_WH')

Select either 100 or 10 for the TPC-DS Dataset size to use below. See (https://docs.snowflake.com/en/user-guide/sample-data-tpcds.html)[here] for more information If you choose 100, I recommend >= 3XL warehouse. 

In [4]:
TPCDS_SIZE_PARAM = 100
SNOWFLAKE_SAMPLE_DB = 'SNOWFLAKE_SAMPLE_DATA' # Name of Snowflake Sample Database might be different...

if TPCDS_SIZE_PARAM == 100: 
    TPCDS_SCHEMA = 'TPCDS_SF100TCL'
elif TPCDS_SIZE_PARAM == 10:
    TPCDS_SCHEMA = 'TPCDS_SF10TCL'
else:
    raise ValueError("Invalid TPCDS_SIZE_PARAM selection")
    
store_sales = session.table(f'{SNOWFLAKE_SAMPLE_DB}.{TPCDS_SCHEMA}.store_sales')
catalog_sales = session.table(f'{SNOWFLAKE_SAMPLE_DB}.{TPCDS_SCHEMA}.catalog_sales') 
web_sales = session.table(f'{SNOWFLAKE_SAMPLE_DB}.{TPCDS_SCHEMA}.web_sales') 
date = session.table(f'{SNOWFLAKE_SAMPLE_DB}.{TPCDS_SCHEMA}.date_dim')
dim_stores = session.table(f'{SNOWFLAKE_SAMPLE_DB}.{TPCDS_SCHEMA}.store')
customer = session.table(f'{SNOWFLAKE_SAMPLE_DB}.{TPCDS_SCHEMA}.customer')
address = session.table(f'{SNOWFLAKE_SAMPLE_DB}.{TPCDS_SCHEMA}.customer_address')
demo = session.table(f'{SNOWFLAKE_SAMPLE_DB}.{TPCDS_SCHEMA}.customer_demographics')

## Feature Engineering
We will aggregate sales by customer across all channels(web, store, catalogue) and join that to customer demographic data. 

In [5]:
store_sales_agged = store_sales.group_by('ss_customer_sk').agg(F.sum('ss_sales_price').as_('total_sales'))
web_sales_agged = web_sales.group_by('ws_bill_customer_sk').agg(F.sum('ws_sales_price').as_('total_sales'))
catalog_sales_agged = catalog_sales.group_by('cs_bill_customer_sk').agg(F.sum('cs_sales_price').as_('total_sales'))
store_sales_agged = store_sales_agged.rename('ss_customer_sk', 'customer_sk')
web_sales_agged = web_sales_agged.rename('ws_bill_customer_sk', 'customer_sk')
catalog_sales_agged = catalog_sales_agged.rename('cs_bill_customer_sk', 'customer_sk')

In [6]:
total_sales = store_sales_agged.union_all(web_sales_agged)
total_sales = total_sales.union_all(catalog_sales_agged)

In [7]:
total_sales = total_sales.group_by('customer_sk').agg(F.sum('total_sales').as_('total_sales'))

In [8]:
customer = customer.select('c_customer_sk','c_current_hdemo_sk', 'c_current_addr_sk', 'c_customer_id', 'c_birth_year')

In [9]:
customer = customer.join(address.select('ca_address_sk', 'ca_zip'), customer['c_current_addr_sk'] == address['ca_address_sk'] )
customer = customer.join(demo.select('cd_demo_sk', 'cd_gender', 'cd_marital_status', 'cd_credit_rating', 'cd_education_status', 'cd_dep_count'),
                                customer['c_current_hdemo_sk'] == demo['cd_demo_sk'] )
customer = customer.rename('c_customer_sk', 'customer_sk')
customer.show()

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"CUSTOMER_SK"  |"C_CURRENT_HDEMO_SK"  |"C_CURRENT_ADDR_SK"  |"C_CUSTOMER_ID"   |"C_BIRTH_YEAR"  |"CA_ADDRESS_SK"  |"CA_ZIP"  |"CD_DEMO_SK"  |"CD_GENDER"  |"CD_MARITAL_STATUS"  |"CD_CREDIT_RATING"  |"CD_EDUCATION_STATUS"  |"CD_DEP_COUNT"  |
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|84886068       |4705                  |859249               |AAAAAAAAEDCEPAFA  |1945            |859249           |48784     |4705          |M            |D                    |Unknown             |Secondary              |0               |
|84886701       |5003               

In [10]:
final_df = total_sales.join(customer, on='customer_sk')

In [11]:
session.use_database('tpcds_xgboost')
session.use_schema('demo')
final_df.write.mode('overwrite').save_as_table('feature_store')

In [12]:
session.add_packages('snowflake-snowpark-python', 'scikit-learn', 'pandas', 'numpy', 'joblib', 'cachetools', 'xgboost', 'joblib')

The version of package xgboost in the local environment is 1.7.4, which does not fit the criteria for the requirement xgboost. Your UDF might not work when the package version is different between the server and your local environment


In [13]:
session.sql('CREATE OR REPLACE STAGE ml_models ').collect()

[Row(status='Stage area ML_MODELS successfully created.')]

In [14]:
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder, MinMaxScaler
from sklearn.metrics import mean_squared_error
from sklearn.compose import ColumnTransformer
from xgboost import XGBRegressor
import joblib
import os

def train_model(session: snowflake.snowpark.Session) -> float:
    snowdf = session.table("feature_store")
    snowdf = snowdf.drop(['CUSTOMER_SK', 'C_CURRENT_HDEMO_SK', 'C_CURRENT_ADDR_SK', 'C_CUSTOMER_ID', 'CA_ADDRESS_SK', 'CD_DEMO_SK'])
    snowdf_train, snowdf_test = snowdf.random_split([0.8, 0.2], seed=82) 

    # save the train and test sets as time stamped tables in Snowflake 
    snowdf_train.write.mode("overwrite").save_as_table("tpcds_xgboost.demo.tpc_TRAIN")
    snowdf_test.write.mode("overwrite").save_as_table("tpcds_xgboost.demo.tpc_TEST")
    train_x = snowdf_train.drop("TOTAL_SALES").to_pandas() # drop labels for training set
    train_y = snowdf_train.select("TOTAL_SALES").to_pandas()
    test_x = snowdf_test.drop("TOTAL_SALES").to_pandas()
    test_y = snowdf_test.select("TOTAL_SALES").to_pandas()
    cat_cols = ['CA_ZIP', 'CD_GENDER', 'CD_MARITAL_STATUS', 'CD_CREDIT_RATING', 'CD_EDUCATION_STATUS']
    num_cols = ['C_BIRTH_YEAR', 'CD_DEP_COUNT']

    num_pipeline = Pipeline([
            ('imputer', SimpleImputer(strategy="median")),
            ('std_scaler', StandardScaler()),
        ])

    preprocessor = ColumnTransformer(
    transformers=[('num', num_pipeline, num_cols),
                  ('encoder', OneHotEncoder(handle_unknown="ignore"), cat_cols) ])

    pipe = Pipeline([('preprocessor', preprocessor), 
                        ('xgboost', XGBRegressor())])
    pipe.fit(train_x, train_y)

    test_preds = pipe.predict(test_x)
    rmse = mean_squared_error(test_y, test_preds)
    model_file = os.path.join('/tmp', 'model.joblib')
    joblib.dump(pipe, model_file)
    session.file.put(model_file, "@ml_models",overwrite=True)
    return rmse

In [15]:
session.use_warehouse('snowpark_opt_wh')
train_model_sp = F.sproc(train_model, session=session, replace=True, is_permanent=True, name="xgboost_sproc", stage_location="@ml_models")
# Switch to Snowpark Optimized Warehouse for training and to run the stored proc
train_model_sp(session=session)

77617378801.55711

In [16]:
# Switch back to feature engineering/inference warehouse
session.use_warehouse('FE_AND_INFERENCE_WH')

In [17]:
import sys
import pandas as pd
import cachetools
import joblib
from snowflake.snowpark import types as T

session.add_import("@ml_models/model.joblib")  

features = [ 'C_BIRTH_YEAR', 'CA_ZIP', 'CD_GENDER', 'CD_MARITAL_STATUS', 'CD_CREDIT_RATING', 'CD_EDUCATION_STATUS', 'CD_DEP_COUNT']

@cachetools.cached(cache={})
def read_file(filename):
       import_dir = sys._xoptions.get("snowflake_import_directory")
       if import_dir:
              with open(os.path.join(import_dir, filename), 'rb') as file:
                     m = joblib.load(file)
                     return m

@F.pandas_udf(session=session, max_batch_size=10000, is_permanent=True, stage_location='@ml_models', replace=True, name="clv_xgboost_udf")
def predict(df:  T.PandasDataFrame[int, str, str, str, str, str, int]) -> T.PandasSeries[float]:
       m = read_file('model.joblib')       
       df.columns = features
       return m.predict(df)

In [18]:
inference_df = session.table('feature_store')
inference_df = inference_df.drop(['CUSTOMER_SK', 'C_CURRENT_HDEMO_SK', 'C_CURRENT_ADDR_SK', 'C_CUSTOMER_ID', 'CA_ADDRESS_SK', 'CD_DEMO_SK'])
inputs = inference_df.drop("TOTAL_SALES")
snowdf_results = inference_df.select(*inputs,
                    predict(*inputs).alias('PREDICTION'), 
                    (F.col('TOTAL_SALES')).alias('ACTUAL_SALES')
                    )
snowdf_results.write.mode('overwrite').save_as_table('predictions')

In [19]:
inference_df.count()

96500091

In [20]:
res=snowdf_results.to_pandas()

In [21]:
res.head()

Unnamed: 0,C_BIRTH_YEAR,CA_ZIP,CD_GENDER,CD_MARITAL_STATUS,CD_CREDIT_RATING,CD_EDUCATION_STATUS,CD_DEP_COUNT,PREDICTION,ACTUAL_SALES
0,1944.0,50150,M,W,Good,Advanced Degree,0,210252.609375,119471.31
1,1976.0,78883,M,W,Good,Advanced Degree,0,210252.609375,119043.41
2,1932.0,69310,M,W,Good,Advanced Degree,0,210187.984375,111926.28
3,1992.0,67683,M,W,Good,Advanced Degree,0,210252.609375,109641.6
4,1944.0,42293,M,W,Good,Advanced Degree,0,210252.609375,106267.27


In [38]:
res.CD_CREDIT_RATING.unique()
res.loc[res['C_BIRTH_YEAR']==1944.0]

Unnamed: 0,C_BIRTH_YEAR,CA_ZIP,CD_GENDER,CD_MARITAL_STATUS,CD_CREDIT_RATING,CD_EDUCATION_STATUS,CD_DEP_COUNT,PREDICTION,ACTUAL_SALES
0,1944.0,50150,M,W,Good,Advanced Degree,0,210252.609375,119471.31
4,1944.0,42293,M,W,Good,Advanced Degree,0,210252.609375,106267.27
68,1944.0,13394,M,M,Good,2 yr Degree,0,210197.687500,121289.28
134,1944.0,38371,F,U,Low Risk,Advanced Degree,0,210252.609375,113860.91
228,1944.0,59651,M,M,High Risk,Secondary,0,210252.609375,507302.42
...,...,...,...,...,...,...,...,...,...
96499805,1944.0,45258,F,M,Good,4 yr Degree,1,210252.609375,117063.28
96499871,1944.0,33683,M,M,Unknown,2 yr Degree,0,210197.687500,114533.88
96499913,1944.0,76614,F,U,Unknown,Secondary,0,210252.609375,117748.83
96499959,1944.0,78222,F,U,Unknown,Unknown,0,210252.609375,108388.53


In [39]:
import plotly.express as px

In [56]:
res.C_BIRTH_YEAR

0           1944.0
1           1976.0
2           1932.0
3           1992.0
4           1944.0
             ...  
96500086    1943.0
96500087    1951.0
96500088    1957.0
96500089    1965.0
96500090    1939.0
Name: C_BIRTH_YEAR, Length: 96500091, dtype: float64

In [57]:
res.C_BIRTH_YEAR.isnull().count()

96500091

In [49]:
res.astype({'C_BIRTH_YEAR':'int32'})

IntCastingNaNError: Cannot convert non-finite values (NA or inf) to integer

In [40]:
res.columns

Index(['C_BIRTH_YEAR', 'CA_ZIP', 'CD_GENDER', 'CD_MARITAL_STATUS',
       'CD_CREDIT_RATING', 'CD_EDUCATION_STATUS', 'CD_DEP_COUNT', 'PREDICTION',
       'ACTUAL_SALES'],
      dtype='object')

In [44]:
res.groupby(['C_BRITH_YEAR'])['ACTUAL_SALES'].sum()

KeyError: 'C_BRITH_YEAR'

In [43]:
fig = px.sunburst(res,
                    path=['C_BIRTH_YEAR', 'CA_ZIP', 'CD_GENDER', 'CD_MARITAL_STATUS'],values='ACXTUAL_SALES',color_discrete_sequence=px.colors.qualitative.Pastel)
fig.update_layout(width=800, 
                    height=800,
                    )


ValueError: Value of 'values' is not the name of a column in 'data_frame'. Expected one of ['C_BIRTH_YEAR', 'CA_ZIP', 'CD_GENDER', 'CD_MARITAL_STATUS', 'CD_CREDIT_RATING', 'CD_EDUCATION_STATUS', 'CD_DEP_COUNT', 'PREDICTION', 'ACTUAL_SALES'] but received: ACXTUAL_SALES