In [1]:
import snowflake.snowpark
from snowflake.snowpark import functions as F
from snowflake.snowpark.session import Session
from snowflake.snowpark.types import IntegerType, StringType, StructType, FloatType, StructField, DateType, Variant
from snowflake.snowpark.functions import udf, sum, col,array_construct,month,year,call_udf,lit,count
from snowflake.snowpark.version import VERSION
# Misc
import json
import pandas as pd
import numpy as np
import logging 
logger = logging.getLogger("snowflake.snowpark.session")
logger.setLevel(logging.ERROR)

In [2]:
# Create Snowflake Session object
connection_parameters = json.load(open('connection.json'))
session = Session.builder.configs(connection_parameters).create()
session.sql_simplifier_enabled = True

snowflake_environment = session.sql('select current_user(), current_role(), current_database(), current_schema(), current_version(), current_warehouse()').collect()
snowpark_version = VERSION

# # Current Environment Details
# print('User                        : {}'.format(snowflake_environment[0][0]))
# print('Role                        : {}'.format(snowflake_environment[0][1]))
# print('Database                    : {}'.format(snowflake_environment[0][2]))
# print('Schema                      : {}'.format(snowflake_environment[0][3]))
# print('Warehouse                   : {}'.format(snowflake_environment[0][5]))
# print('Snowflake version           : {}'.format(snowflake_environment[0][4]))
# print('Snowpark for Python version : {}.{}.{}'.format(snowpark_version[0],snowpark_version[1],snowpark_version[2]))

In [6]:
ca_zip = json.load(open('src/zip_json.json'))

In [34]:
%%writefile pages/CLV_prediction.py
import pandas as pd
import streamlit as st
import json
import numpy as np
import sys
import cachetools
import joblib
# from streamlit.report_thread import get_report_ctx
# from streamlit.server.server import Server
import snowflake.snowpark
from snowflake.snowpark import functions as F
from snowflake.snowpark.session import Session
from snowflake.snowpark.types import PandasDataFrame,PandasSeries
from snowflake.snowpark.functions import udf, sum, col,array_construct,month,year,call_udf,lit,count
from snowflake.snowpark.version import VERSION

st.set_page_config(page_title= "Customer Lifetime Value")

st.image(
    "src/1649251328-maximize-your-clv.webp",
    width = 600,
)

#st.title("XGBoost model to predict CLV")

col1, col2 = st.columns(2,gap = "medium")
with col1:
    st.markdown('#### Numeric Features')
    Cus_by = st.number_input('Customer Birth Year:', 
                    min_value=1924,
                    max_value=2020, 
                    help = 'Please type VALID birth Year!!(Range: 1924~2020)'
                            )

    Cs_zip = st.number_input( 'Customer Zip Code:', 
                    min_value= 601, 
                    max_value= 99981,
                    value  = 66668,
                    step = 1
                            )


with col2:
    st.markdown('#### Categorical Features ')
    Cus_gender = st.selectbox('CD_Gender',
                              ['M', 'F'], 
                              help= 'M: Male, F: Female'
     )

    Cus_marital = st.selectbox( 'CD_MARITAL_STATUS',
                 ['S','D','W', 'U', 'M'],
     )
    Cus_credit = st.selectbox('CD_CREDIT_RATING',
                 ['Low Risk','Unknown','Good','High Risk']
       )

    Cus_edu = st.selectbox( 'CD_EDUCATION_STATUS',
                 ['Advanced Degree','Secondary','2 yr Degree','4 yr Degree','Unknown','Primary','College']
                )
    Cus_dep = st.selectbox( 'CD_DEP_COUNT',
                 ['0', '1'],
      )



col3, col4 = st.columns([8,2],gap = "medium")

with col3:
    model_select = st.radio('Select the Model here:',
                           [ 'XGBoost','Linear Regression']
                            #[ 'XGBoost']
                           )
with col4:
    submit =  st.button('Submit')
    reset = st.button('Reset ')
@st.cache_resource
def initialize_SF():
    connection_parameters = json.load(open('connection.json'))
    session = Session.builder.configs(connection_parameters).create()
    session.sql_simplifier_enabled = True
    ca_zip = json.load(open('src/zip_json.json'))    
    # set up feature engineering/inference warehouse
    session.use_warehouse('FE_AND_INFERENCE_WH')
    session.use_database('tpcds_xgboost')
    session.use_schema('demo')
    session.add_packages('snowflake-snowpark-python', 'scikit-learn', 'pandas', 'numpy', 'joblib', 'cachetools', 'xgboost', 'joblib')
    session.add_import("@ml_models_10T/model.joblib")  
    session.add_import("@ml_models_LR_10T/model_LR.joblib")
    return session, ca_zip
session, ca_zip = initialize_SF()

# choose model here
if model_select == 'XGBoost':    
    stage_name = 'ml_models_10T'
    model_name = 'model.joblib'
    
else:
    stage_name = 'ml_models_LR_10T'
    model_name = 'model_LR.joblib'

features = [ 'C_BIRTH_YEAR', 'CA_ZIP', 'CD_GENDER', 'CD_MARITAL_STATUS', 'CD_CREDIT_RATING', 'CD_EDUCATION_STATUS', 'CD_DEP_COUNT']

@cachetools.cached(cache={})
def read_file(filename):
    import os, joblib
    import_dir = sys._xoptions.get("snowflake_import_directory")
    if import_dir:
        with open(os.path.join(import_dir, filename), 'rb') as file:
            m = joblib.load(file)
            return m

@F.pandas_udf(session=session, max_batch_size=10000, is_permanent=True, stage_location=f'@{stage_name}', replace=True, name="clv_xgboost_udf")
def predict(df:  PandasDataFrame[int, str, str, str, str, str, int]) -> PandasSeries[float]:
    m = read_file(model_name)       
    df.columns = features
    return m.predict(df) 

# if click submit
if submit:
    typed_input = [[Cus_by, ca_zip.get(str(Cs_zip), '66668'),Cus_gender,Cus_marital,Cus_credit,Cus_edu, int(Cus_dep)]]
    #st.write(typed_input)
    input_df = session.create_dataframe(typed_input, schema=features)
    typed_output = input_df.select(*input_df,
                    predict(*input_df).alias('PREDICTION'))
    output = pd.DataFrame(typed_output.collect()).T
    output.columns = ['']
    st.write(output)
# if click reset  
if reset:
    st.write('xuyao reset ma? hao ma fan')
    

    


Overwriting pages/CLV_prediction.py


In [30]:
%%writefile Customer_Lifetime_Value.py
import pandas as pd
import streamlit as st

st.write('## Introduction')


Overwriting Customer_Lifetime_Value.py
