In [2]:
import snowflake.snowpark
from snowflake.snowpark import functions as F
from snowflake.snowpark.session import Session
from snowflake.snowpark.types import IntegerType, StringType, StructType, FloatType, StructField, DateType, Variant
from snowflake.snowpark.functions import udf, sum, col,array_construct,month,year,call_udf,lit,count
from snowflake.snowpark.version import VERSION
# Misc
import json
import pandas as pd
import numpy as np
import logging 
logger = logging.getLogger("snowflake.snowpark.session")
logger.setLevel(logging.ERROR)

In [3]:
# Create Snowflake Session object
connection_parameters = json.load(open('connection.json'))
session = Session.builder.configs(connection_parameters).create()
session.sql_simplifier_enabled = True

snowflake_environment = session.sql('select current_user(), current_role(), current_database(), current_schema(), current_version(), current_warehouse()').collect()
snowpark_version = VERSION

# # Current Environment Details
# print('User                        : {}'.format(snowflake_environment[0][0]))
# print('Role                        : {}'.format(snowflake_environment[0][1]))
# print('Database                    : {}'.format(snowflake_environment[0][2]))
# print('Schema                      : {}'.format(snowflake_environment[0][3]))
# print('Warehouse                   : {}'.format(snowflake_environment[0][5]))
# print('Snowflake version           : {}'.format(snowflake_environment[0][4]))
# print('Snowpark for Python version : {}.{}.{}'.format(snowpark_version[0],snowpark_version[1],snowpark_version[2]))

In [4]:
ca_zip = json.load(open('src/zip_json.json'))

In [5]:
%%writefile pages/CLV_prediction.py
import pandas as pd
import streamlit as st
import json
import numpy as np
import sys
import cachetools
import joblib
# from streamlit.report_thread import get_report_ctx
# from streamlit.server.server import Server
import snowflake.snowpark
from snowflake.snowpark import functions as F
from snowflake.snowpark.session import Session
from snowflake.snowpark.types import PandasDataFrame,PandasSeries
from snowflake.snowpark.functions import udf, sum, col,array_construct,month,year,call_udf,lit,count
from snowflake.snowpark.version import VERSION

st.set_page_config(page_title= "Customer Lifetime Value")

st.image(
    "src/1649251328-maximize-your-clv.webp",
    width = 600,
)

#st.title("XGBoost model to predict CLV")

col1, col2 = st.columns(2,gap = "medium")
with col1:
    st.markdown('#### Numeric Features')
    Cus_by = st.number_input('Customer Birth Year:', 
                    min_value=1924,
                    max_value=2020, 
                    help = 'Please type VALID birth Year!!(Range: 1924~2020)'
                            )

    Cs_zip = st.number_input( 'Customer Zip Code:', 
                    min_value= 601, 
                    max_value= 99981,
                    value  = 66668,
                    step = 1
                            )


with col2:
    st.markdown('#### Categorical Features ')
    Cus_gender = st.selectbox('CD_Gender',
                              ['M', 'F'], 
                              help= 'M: Male, F: Female'
     )

    Cus_marital = st.selectbox( 'CD_MARITAL_STATUS',
                 ['S','D','W', 'U', 'M'],
     )
    Cus_credit = st.selectbox('CD_CREDIT_RATING',
                 ['Low Risk','Unknown','Good','High Risk']
       )

    Cus_edu = st.selectbox( 'CD_EDUCATION_STATUS',
                 ['Advanced Degree','Secondary','2 yr Degree','4 yr Degree','Unknown','Primary','College']
                )
    Cus_dep = st.selectbox( 'CD_DEP_COUNT',
                 ['0', '1'],
      )



col3, col4 = st.columns([8,2],gap = "medium")

with col3:
    model_select = st.radio('Select the Model here:',
                           [ 'XGBoost','Linear Regression']
                            #[ 'XGBoost']
                           )
with col4:
    submit =  st.button('Submit')
    reset = st.button('Reset ')
@st.cache_resource
def initialize_SF():
    connection_parameters = json.load(open('connection.json'))
    session = Session.builder.configs(connection_parameters).create()
    session.sql_simplifier_enabled = True
    ca_zip = json.load(open('src/zip_json.json'))    
    # set up feature engineering/inference warehouse
    session.use_warehouse('FE_AND_INFERENCE_WH')
    session.use_database('tpcds_xgboost')
    session.use_schema('demo')
    session.add_packages('snowflake-snowpark-python', 'scikit-learn', 'pandas', 'numpy', 'joblib', 'cachetools', 'xgboost==1.5.0', 'joblib')
    session.add_import("@ml_models_10T/model.joblib")  
    session.add_import("@ml_models_LR_10T/model_LR.joblib")
    return session, ca_zip
session, ca_zip = initialize_SF()

# choose model here
if model_select == 'XGBoost':    
    stage_name = 'ml_models_10T'
    model_name = 'model.joblib'
    
else:
    stage_name = 'ml_models_LR_10T'
    model_name = 'model_LR.joblib'

features = [ 'C_BIRTH_YEAR', 'CA_ZIP', 'CD_GENDER', 'CD_MARITAL_STATUS', 'CD_CREDIT_RATING', 'CD_EDUCATION_STATUS', 'CD_DEP_COUNT']

@cachetools.cached(cache={})
def read_file(filename):
    import os, joblib
    import_dir = sys._xoptions.get("snowflake_import_directory")
    if import_dir:
        with open(os.path.join(import_dir, filename), 'rb') as file:
            m = joblib.load(file)
            return m

@F.pandas_udf(session=session, max_batch_size=10000, is_permanent=True, stage_location=f'@{stage_name}', replace=True, name="clv_xgboost_udf")
def predict(df:  PandasDataFrame[int, str, str, str, str, str, int]) -> PandasSeries[float]:
    m = read_file(model_name)       
    df.columns = features
    return m.predict(df) 

# if click submit
if submit:
    typed_input = [[Cus_by, ca_zip.get(str(Cs_zip), '66668'),Cus_gender,Cus_marital,Cus_credit,Cus_edu, int(Cus_dep)]]
    #st.write(typed_input)
    input_df = session.create_dataframe(typed_input, schema=features)
    typed_output = input_df.select(*input_df,
                    predict(*input_df).alias('PREDICTION'))
    output = pd.DataFrame(typed_output.collect()).T
    output.columns = ['']
    st.write(output)
# if click reset  
if reset:
    st.write('放弃reset？')
    

    


Overwriting pages/CLV_prediction.py


In [50]:
%%writefile Customer_Lifetime_Value.py
import pandas as pd
import numpy as np
import streamlit as st
import json


#--Alchemy--
from sqlalchemy import create_engine

engine = create_engine(
    'snowflake://{user}:{password}@{account}/d'.format(
        user='ESTPEGION',
        password='SnowFlake1234!',
        account='mg61873.ca-central-1.aws',
        database = 'SNOWFLAKE_SAMPLE_DATA',
        schema = 'TPCDS_SF10TCL',
        warehouse = 'COMPUTE_WH',
        role='accountadmin',
        numpy = True
    )
)

connection = engine.connect()
results = connection.execute('select current_version()').fetchone()


pd.read_sql_query('''USE SCHEMA SNOWFLAKE_SAMPLE_DATA.TPCDS_SF10TCL;''',engine)


#--Alchemy End--

#import var_store
var_json = json.load(open('var_store.json'))




# st.write([[1969, '66060','M','U','Low Risk','2 yr Degree', 1]])
st.set_page_config(page_title= "Sqlalchemy Query")

st.image(
    "src/data.jpeg",caption='Query the Data',
     width = 600,
)

col1, col2 = st.columns(2,gap = "medium")
with col1:
    Query_selection = st.selectbox('Select the Query here:',
#   Query_selection = st.radio('Select the Query here:',               
                           [ 'Q1','Q2','Q3','Q4','Q5','Q6','Q7','Q8']
                           )


#with col2:
    if Query_selection == 'Q1':
        st.markdown('#### Features')
        form = st.form(key='my-form')
        year_input = form.number_input('year',min_value=1900,
                    max_value=2100, 
                    help = 'Input value not in range.(Range: 1900~2100)')
        state_input = form.selectbox('State', var_json['q1_state']
                      )
        agg_input = form.selectbox('Aggreagation Column',var_json['q1_agg'])
#         Cus_credit = st.selectbox('CD_CREDIT_RATING',
#                  ['Low Risk','Unknown','Good','High Risk']
#        )

#         Cus_by = form.number_input('Customer Birth Year:', 
#                         min_value=1924,
#                         max_value=2020, 
#                         help = 'Please type VALID birth Year!!(Range: 1924~2020)'
#                                 )

        
        submit = form.form_submit_button('Submit')

        st.write('Press submit to have your name printed below')
        
        if submit:
            q1_state = state_input
            q1_year = year_input
            1

    elif Query_selection == 'Q2':
        Cs_zip = st.number_input( 'Customer Zip Code:', 
                    min_value= 601, 
                    max_value= 99981,
                    value  = 66668,
                    step = 1
                            )
    elif Query_selection == 'Q3':
        st.markdown('#### Categorical Features ')
        Cus_gender = st.selectbox('CD_Gender',
                                  ['M', 'F'], 
                                  help= 'M: Male, F: Female'
         )
        
    elif Query_selection == 'Q4':
        st.markdown('#### Categorical Features ')
        Cus_dep = st.selectbox( 'CD_DEP_COUNT',
                     ['0', '1'],
          )
    elif Query_selection == 'Q5':
        st.markdown('#### Categorical Features ')
        Cus_edu = st.selectbox( 'CD_EDUCATION_STATUS',
                     ['Advanced Degree','Secondary','2 yr Degree','4 yr Degree','Unknown','Primary','College']
                    )
        
    elif Query_selection == 'Q6':
        st.markdown('#### Categorical Features ')
        Cus_credit = st.selectbox('CD_CREDIT_RATING',
                     ['Low Risk','Unknown','Good','High Risk']
           )
        
    elif Query_selection == 'Q7':
        st.markdown('#### Categorical Features ')
        Cus_gender = st.selectbox('CD_Gender',
                              ['M', 'F'], 
                              help= 'M: Male, F: Female'
     )
    elif Query_selection == 'Q8':
        st.markdown('#### Categorical Features ')
        Cus_marital = st.selectbox( 'CD_MARITAL_STATUS',
                     ['S','D','W', 'U', 'M'],
         )
     
  
        
        
        
        
    else:
        st.markdown('Please Select your Query')

    
#     submit =  st.button('Submit')
#     reset = st.button('Reset ')
    
    
    
with col2:
    if Query_selection == 'Q1' and submit:
        q1 = '''with customer_total_return as
                (select sr_customer_sk as ctr_customer_sk
                ,sr_store_sk as ctr_store_sk
                ,sum(SR_RETURN_AMT) as ctr_total_return
                from store_returns
                ,date_dim
                where sr_returned_date_sk = d_date_sk
                and d_year = {year}
                group by sr_customer_sk
                ,sr_store_sk
                limit 500)
                 select  c_customer_id
                from customer_total_return ctr1
                ,store
                ,customer
                where ctr1.ctr_total_return > (select avg(ctr_total_return)*1.2
                from customer_total_return ctr2
                where ctr1.ctr_store_sk = ctr2.ctr_store_sk)
                and s_store_sk = ctr1.ctr_store_sk
                and s_state = \'{state}\'
                and ctr1.ctr_customer_sk = c_customer_sk
                order by c_customer_id
                 limit 5;'''.format(year = q1_year, state = q1_state)
        
        st.write(pd.read_sql_query(q1 ,engine))
    
    
    
    
#st.write('## Introduction')


Overwriting Customer_Lifetime_Value.py


In [25]:
var_store = {}
var_store['q1_state'] = ['TN',
 'NM',
 'KY',
 'IA',
 'NC',
 'OH',
 'GA',
 'IN',
 'AL',
 'SC',
 'MD',
 'LA',
 'SD',
 'MI',
 'FL',
 'WV',
 'VT',
 'TX',
 'MT',
 'CA',
 'NJ',
 'CO',
 'OK',
 'IL',
 'None',
 'NE',
 'OR',
 'NY',
 'MO',
 'WI',
 'VA',
 'PA',
 'WA',
 'KS',
 'MN']



{'q1_state': ['TN',
  'NM',
  'KY',
  'IA',
  'NC',
  'OH',
  'GA',
  'IN',
  'AL',
  'SC',
  'MD',
  'LA',
  'SD',
  'MI',
  'FL',
  'WV',
  'VT',
  'TX',
  'MT',
  'CA',
  'NJ',
  'CO',
  'OK',
  'IL',
  None,
  'NE',
  'OR',
  'NY',
  'MO',
  'WI',
  'VA',
  'PA',
  'WA',
  'KS',
  'MN']}

In [49]:
var_json = json.load(open('var_store.json'))
var_json['q1_agg']

['SR_RETURN_AMT',
 'SR_RETURN_AMT_INC_TAX',
 'SR_RETURN_QUANTITY',
 'SR_RETURN_TAX',
 'SR_RETURN_SHIP_COST']

In [41]:
print('\"state')

"state
