# Customer churn analysis


## Machine Learning Pipeline

In the following notebooks, we will go through the implementation of each one of the steps in the Machine Learning Pipeline. 

We will discuss:

1. Data Preparation and Analysis
2. Model Training
3. Obtaining Predictions / Scoring

<img src="arch.jpg"/>

In [1]:
from snowflake.snowpark.session import Session
from snowflake.snowpark import functions as F
from snowflake.snowpark.types import *
import snowflake.snowpark

import pandas as pd


import matplotlib.pyplot as plt

%matplotlib inline
import datetime as dt
import numpy as np
import seaborn as sns

#Snowflake connection info is saved in config.py
from config import snowflake_conn_prop


# lets import some tranformations functions
from snowflake.snowpark.functions import udf, col, lit, translate, is_null, iff

## Session Creation

In [43]:
from snowflake.snowpark import version
print(version.VERSION)
#session.close()
session = Session.builder.configs(snowflake_conn_prop).create()

(0, 7, 0)


## Data Preparation and Analysis

In [44]:
dfR = session.table("RAW_PARQUET_DATA")

In [45]:
dfR.to_pandas().head(10)

Unnamed: 0,COUNTRY,CITY,PHONE SERVICE,MULTIPLE LINES,LATITUDE,ONLINE SECURITY,SENIOR CITIZEN,MONTHLY CHARGES,STREAMING MOVIES,PAYMENT METHOD,...,CHURN SCORE,GENDER,LONGITUDE,ONLINE BACKUP,TOTAL CHARGES,CLTV,CHURN REASON,DEVICE PROTECTION,STATE,ZIP CODE
0,United States,Los Angeles,Yes,No,34.059281,No,False,70.7,No,Electronic check,...,1,Female,-118.30742,No,151.65,2701,Moved,No,California,90005
1,United States,Los Angeles,Yes,Yes,34.048013,No,False,99.65,Yes,Electronic check,...,1,Female,-118.293953,No,820.5,5372,Moved,Yes,California,90006
2,United States,Los Angeles,Yes,Yes,34.108833,No,True,95.45,Yes,Electronic check,...,1,Male,-118.229715,No,1752.55,3179,Competitor made better offer,No,California,90065
3,United States,La Habra,Yes,Yes,33.940619,No,False,74.4,No,Electronic check,...,1,Male,-117.9513,No,229.55,4415,Product dissatisfaction,No,California,90631
4,United States,Glendale,Yes,No,34.162515,No,False,79.25,No,Electronic check,...,1,Female,-118.203869,Yes,1111.65,5142,Price too high,Yes,California,91206
5,United States,Burbank,Yes,No,34.213049,Yes,False,84.6,Yes,Mailed check,...,1,Male,-118.317651,No,84.6,2484,Poor expertise of phone support,No,California,91504
6,United States,Ontario,Yes,No,34.057256,No internet service,False,19.35,No internet service,Mailed check,...,1,Male,-117.667677,No internet service,1099.6,5084,Price too high,No internet service,California,91762
7,United States,Alpine,Yes,No,32.827184,No,True,74.5,No,Electronic check,...,1,Male,-116.703729,Yes,606.55,4345,Poor expertise of online support,No,California,91901
8,United States,Borrego Springs,Yes,No,33.200369,No,False,80.6,No,Electronic check,...,1,Male,-116.192313,No,415.55,5715,Network reliability,No,California,92004
9,United States,Del Mar,Yes,No,32.948262,No,True,93.15,Yes,Electronic check,...,1,Male,-117.256086,No,2231.05,2212,Lack of affordable download/upload speed,Yes,California,92014


In [46]:
##Easy way to fetch stats
dfR.describe().to_pandas()

Unnamed: 0,SUMMARY,COUNTRY,CITY,PHONE SERVICE,MULTIPLE LINES,LATITUDE,ONLINE SECURITY,MONTHLY CHARGES,STREAMING MOVIES,PAYMENT METHOD,...,CHURN SCORE,GENDER,LONGITUDE,ONLINE BACKUP,TOTAL CHARGES,CLTV,CHURN REASON,DEVICE PROTECTION,STATE,ZIP CODE
0,count,100000,100000,100000,100000,100000.0,100000,100000.0,100000,100000,...,100000.0,100000,100000.0,100000,100000.0,100000.0,100000,100000,100000,100000.0
1,mean,,,,,,,65.601117,,,...,0.32226,,,,2222.032082,4378.27414,,,,
2,stddev,,,,,,,29.808447,,,...,0.467345,,,,2248.711822,1186.175499,,,,
3,min,United States,Acampo,No,No,32.555828,No,18.25,No,Bank transfer (automatic),...,0.0,Female,-114.192901,No,0.0,2003.0,Attitude of service provider,No,California,90001.0
4,max,United States,Zenia,Yes,Yes,41.962127,Yes,118.75,Yes,Mailed check,...,1.0,Male,-124.301372,Yes,8684.8,6500.0,do not know,Yes,California,96161.0


In [59]:
## Compute the numbers of customer churn by city
dfR.filter((col("CHURN SCORE") == 1)).group_by("CITY").count().sort(col("COUNT"), ascending=False).show()

---------------------------
|"CITY"         |"COUNT"  |
---------------------------
|Los Angeles    |1593     |
|San Diego      |876      |
|San Francisco  |522      |
|San Jose       |494      |
|Sacramento     |439      |
|Fresno         |279      |
|Long Beach     |247      |
|Glendale       |243      |
|Oakland        |216      |
|Modesto        |203      |
---------------------------



In [6]:
dfDemographics = dfR.select(col("CUSTOMERID"),
                             col("COUNT").alias("COUNT"),
                             translate(col("GENDER"),lit("NULL"),lit("Male")).alias("GENDER"),
                             col("SENIOR CITIZEN").alias("SENIORCITIZEN"),
                             col("PARTNER"),
                             col("DEPENDENTS")          
                            )


dfDemographics.write.mode('overwrite').saveAsTable('DEMOGRAPHICS')
dfDemographics.show()

----------------------------------------------------------------------------------
|"CUSTOMERID"  |"COUNT"  |"GENDER"  |"SENIORCITIZEN"  |"PARTNER"  |"DEPENDENTS"  |
----------------------------------------------------------------------------------
|7090-ZyCMx    |1        |Female    |False            |False      |True          |
|1364-wJXMS    |1        |Female    |False            |False      |True          |
|6564-sLgIC    |1        |Male      |True             |False      |True          |
|7853-2xheR    |1        |Male      |False            |False      |True          |
|8457-E9FuW    |1        |Female    |False            |False      |True          |
|5718-ykxBT    |1        |Male      |False            |False      |True          |
|7092-gCJX5    |1        |Male      |False            |False      |False         |
|8249-GOs7s    |1        |Male      |True             |False      |False         |
|9445-kPPEc    |1        |Male      |False            |False      |False         |
|158

In [7]:
dfLocation = dfR.select(col("CUSTOMERID"),
                         col("COUNTRY").name("COUNTRY"),
                         col("STATE").name("STATE"),
                         col("CITY").name("CITY"),
                         translate(col("ZIP CODE"),lit("NULL"),lit(0)).name("ZIPCODE"),
                         col("LAT LONG").name("LATLONG"),
                         col("LATITUDE").name("LATITUDE"),
                         col("LONGITUDE").name("LONGITUDE")       
                        )

dfLocation.write.mode('overwrite').saveAsTable('LOCATION')
dfLocation.show()

-------------------------------------------------------------------------------------------------------------------------------
|"CUSTOMERID"  |"COUNTRY"      |"STATE"     |"CITY"           |"ZIPCODE"  |"LATLONG"               |"LATITUDE"  |"LONGITUDE"  |
-------------------------------------------------------------------------------------------------------------------------------
|7090-ZyCMx    |United States  |California  |Los Angeles      |90005      |34.059281, -118.30742   |34.059281   |-118.307420  |
|1364-wJXMS    |United States  |California  |Los Angeles      |90006      |34.048013, -118.293953  |34.048013   |-118.293953  |
|6564-sLgIC    |United States  |California  |Los Angeles      |90065      |34.108833, -118.229715  |34.108833   |-118.229715  |
|7853-2xheR    |United States  |California  |La Habra         |90631      |33.940619, -117.9513    |33.940619   |-117.951300  |
|8457-E9FuW    |United States  |California  |Glendale         |91206      |34.162515, -118.203869  |34.1

In [8]:
dfServices = dfR.select(col("CUSTOMERID"),
                       col("TENURE MONTHS").name("TENUREMONTHS"),
                       iff(is_null(col("PHONE SERVICE")),lit('N'),col("PHONE SERVICE")).name("PHONESERVICE"),
                       iff(is_null(col("MULTIPLE LINES")),lit("No"),col("MULTIPLE LINES")).name("MULTIPLELINES"),
                       iff(is_null(col("INTERNET SERVICE")),lit("No"),col("INTERNET SERVICE")).name("INTERNETSERVICE"),
                       iff(is_null(col("ONLINE SECURITY")),lit("No"),col("ONLINE SECURITY")).name("ONLINESECURITY"),
                       iff(is_null(col("ONLINE BACKUP")),lit("No"),col("ONLINE BACKUP")).name("ONLINEBACKUP"),
                       iff(is_null(col("DEVICE PROTECTION")),lit("No"),col("DEVICE PROTECTION")).name("DEVICEPROTECTION"),
                       iff(is_null(col("TECH SUPPORT")),lit('N'),col("TECH SUPPORT")).name("TECHSUPPORT"),
                       iff(is_null(col("STREAMING TV")),lit("No"),col("STREAMING TV")).name("STREAMINGTV"),
                       iff(is_null(col("STREAMING MOVIES")),lit("No"),col("STREAMING MOVIES")).name("STREAMINGMOVIES"),
                       iff(is_null(col("CONTRACT")),lit("Month-to-month"),col("CONTRACT")).name("CONTRACT"),
                       iff(is_null(col("PAPERLESS BILLING")),lit('Y'),col("PAPERLESS BILLING")).name("PAPERLESSBILLING"),
                       iff(is_null(col("PAYMENT METHOD")),lit("Mailed check"),col("PAYMENT METHOD")).name("PAYMENTMETHOD"),
                       col("MONTHLY CHARGES").name("MONTHLYCHARGES"),
                       col("TOTAL CHARGES").name("TOTALCHARGES"),
                       col("CHURN VALUE").name("CHURNVALUE")        

                      )

dfServices.write.mode('overwrite').saveAsTable('SERVICES')
dfServices.to_pandas().head(10)

Unnamed: 0,CUSTOMERID,TENUREMONTHS,PHONESERVICE,MULTIPLELINES,INTERNETSERVICE,ONLINESECURITY,ONLINEBACKUP,DEVICEPROTECTION,TECHSUPPORT,STREAMINGTV,STREAMINGMOVIES,CONTRACT,PAPERLESSBILLING,PAYMENTMETHOD,MONTHLYCHARGES,TOTALCHARGES,CHURNVALUE
0,7090-ZyCMx,2,Yes,No,Fiber optic,No,No,No,No,No,No,Month-to-month,True,Electronic check,70.7,151.65,1.0
1,1364-wJXMS,8,Yes,Yes,Fiber optic,No,No,Yes,No,Yes,Yes,Month-to-month,True,Electronic check,99.65,820.5,1.0
2,6564-sLgIC,18,Yes,Yes,Fiber optic,No,No,No,No,Yes,Yes,Month-to-month,True,Electronic check,95.45,1752.55,1.0
3,7853-2xheR,3,Yes,Yes,Fiber optic,No,No,No,No,No,No,Month-to-month,True,Electronic check,74.4,229.55,1.0
4,8457-E9FuW,13,Yes,No,Fiber optic,No,Yes,Yes,No,No,No,Month-to-month,True,Electronic check,79.25,1111.65,1.0
5,5718-ykxBT,1,Yes,No,Fiber optic,Yes,No,No,No,No,Yes,Month-to-month,True,Mailed check,84.6,84.6,1.0
6,7092-gCJX5,59,Yes,No,No,No internet service,No internet service,No internet service,No internet service,No internet service,No internet service,Two year,False,Mailed check,19.35,1099.6,1.0
7,8249-GOs7s,8,Yes,No,Fiber optic,No,Yes,No,No,No,No,Month-to-month,True,Electronic check,74.5,606.55,1.0
8,9445-kPPEc,5,Yes,No,Fiber optic,No,No,No,No,Yes,No,Month-to-month,True,Electronic check,80.6,415.55,1.0
9,1581-8yNji,24,Yes,No,Fiber optic,No,No,Yes,No,Yes,Yes,Month-to-month,True,Electronic check,93.15,2231.05,1.0


In [9]:
dfJ = dfDemographics.join(dfServices, using_columns='CUSTOMERID', join_type = 'left')
dfJ.select(col('GENDER'),
              col('SENIORCITIZEN'),
              col('PARTNER'),
              col('DEPENDENTS'),
              col('MULTIPLELINES'),
              col('INTERNETSERVICE'),
              col('ONLINESECURITY'),
              col('ONLINEBACKUP'),
              col('DEVICEPROTECTION'),
              col('TECHSUPPORT'),
              col('STREAMINGTV'),
              col('STREAMINGMOVIES'),
              col('CONTRACT'),
              col('PAPERLESSBILLING'),
              col('PAYMENTMETHOD'),
              col('TENUREMONTHS'),
              col('MONTHLYCHARGES'),
              col('TOTALCHARGES'),
              col('CHURNVALUE'))
dfJ.create_or_replace_view('TRAIN_DATASET')

[Row(status='View TRAIN_DATASET successfully created.')]

## ML training push down to Snowflake

In [21]:
#Let's first create a stage to store the artifacts
print(session.sql('create stage if not exists MODELSTAGE').collect())

[Row(status='Stage area MODELSTAGE successfully created.')]


In [22]:
##pachages needed for the training & inference 
session.add_packages('snowflake-snowpark-python', 'scikit-learn', 'pandas', 'numpy','cloudpickle')

def train_model(session: snowflake.snowpark.Session) -> float:
    
    #transformations
    from sklearn.preprocessing import OrdinalEncoder
    from sklearn.impute import SimpleImputer
    from sklearn.preprocessing import MinMaxScaler
    from sklearn.preprocessing import FunctionTransformer

    #Classifier
    from sklearn.ensemble import RandomForestClassifier

    #Pipeline
    from sklearn.pipeline import make_pipeline
    from sklearn.model_selection import train_test_split

    #Model Accuracy
    from sklearn.metrics import balanced_accuracy_score
    
    #get training dataset
    raw = session.table('TRAIN_DATASET').sample(n = 20000)
    data = raw.toPandas()
    
    # split the train and test set
    X_train, X_test, y_train, y_test = train_test_split(
    data.drop(columns=['CHURNVALUE','CUSTOMERID'], axis=1), # predictive variables
    data['CHURNVALUE'], # target
    test_size=0.2, # portion of dataset to allocate to test set
    random_state=0, # we are setting the seed here
    )

    
        # Model Pipeline
    ord_pipe = make_pipeline(
        FunctionTransformer(lambda x: x.astype(str)) ,
        OrdinalEncoder(handle_unknown='use_encoded_value', unknown_value=-1)
        )

    num_pipe = make_pipeline(
        SimpleImputer(missing_values=np.nan, strategy='constant', fill_value=0),
        MinMaxScaler()
        )

    clf = make_pipeline(RandomForestClassifier(random_state=0, n_jobs=-1))

    model = make_pipeline(ord_pipe, num_pipe, clf)

    # fit the model
    model.fit(X_train, y_train)

    # save the full pipeline including the model inside a stage 
    filename = "model.pkl"
    model_file_path = "tmp/" + filename
    cloudpickle.dump(model, open(model_file_path,"wb")) ## save the model locally 
    session.file.put(model_file_path , "@MODELSTAGE",overwrite=True, auto_compress = False) ##Push the file on our stage
    
    #check accuracy 
    y_pred = model.predict_proba(X_test)[:,1]
    predictions = [round(value) for value in y_pred]
    balanced_accuracy = balanced_accuracy_score(y_test, predictions)
    print("Model testing completed.\n   - Model Balanced Accuracy: %.2f%%" % (balanced_accuracy * 100.0))

    return balanced_accuracy


In [23]:
# Create an instance of StoredProcedure using the sproc() function
train_model_sp = F.sproc(train_model, replace=True)

In [24]:
train_model_sp()

0.9810823006847813

In [25]:
##Verification de l'enregistrement du model 
session.sql("list @MODELSTAGE").show()

-----------------------------------------------------------------------------------------------------
|"name"                |"size"    |"md5"                             |"last_modified"               |
-----------------------------------------------------------------------------------------------------
|modelstage/model.pkl  |26163200  |23bff3f5f180fe125830825161f1d83e  |Wed, 6 Jul 2022 13:32:14 GMT  |
-----------------------------------------------------------------------------------------------------



## Prepare for model deployment in Snowflake using Snowpark Python UDF

### We will define a snowpark python UDF to help us score the live data using the model that we built earlier. 

Since we have fitted the model with the scikit learn pipeline, our UDF will do the transformations and also score the new data

In [26]:
## Import the model for usage inside the UDF
session.add_import("@MODELSTAGE/model.pkl")

In [32]:
%%time

##keep only the features trained on to do the inference
features = list(dfJ.drop('CHURNVALUE','CUSTOMERID').columns)


@udf(name='predict_churn',is_permanent = True, stage_location = '@MODELSTAGE', replace=True)
def predict_churn(args: list ) -> float:
    #fetch the immported model
    import sys
    IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
    import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
    
    #load the model in memory 
    model = cloudpickle.load(open(import_dir +"model.pkl" , 'rb'))
    row = pd.DataFrame([args], columns=features)
    return model.predict(row)


CPU times: user 22.3 ms, sys: 4.77 ms, total: 27 ms
Wall time: 23.3 s


## Use the UDF for inference

In [33]:
new_df = session.table('TRAIN_DATASET').sample(n = 400)

In [34]:
##call the udf 
new_df.select(new_df.CUSTOMERID,new_df.CHURNVALUE, \
              F.call_udf("predict_churn", F.array_construct(*features)).alias('PREDICTED_CHURN')) \
        .write.mode('overwrite').saveAsTable('churn_detection')

In [35]:
session.table('churn_detection').toPandas()

Unnamed: 0,CUSTOMERID,CHURNVALUE,PREDICTED_CHURN
0,2368-m9jM6,0.0,0.0
1,3758-MDgfC,0.0,0.0
2,1669-n5QY6,0.0,0.0
3,4743-ZFAVL,1.0,1.0
4,2781-7nMyb,0.0,0.0
...,...,...,...
395,9085-DgzRL,1.0,1.0
396,4525-rm2Lv,0.0,0.0
397,2608-4cetv,0.0,0.0
398,6471-cgUPV,0.0,0.0


In [43]:
##another way to register an udf 
%%time
session.add_packages("scikit-learn==1.0.2", "pandas", "numpy")
features = list(X_train.columns)
predict_churn_model = session.udf.register(lambda *args: 
                                    model.predict(pd.DataFrame(args, columns=features)),
                                    name="predict_churn_model",
                                    stage_location="@MODELSTAGE",
                                    return_type=FloatType(),
                                    is_permanent=True,
                                    replace=True,
                                    input_types=[ArrayType()])

CPU times: user 1.27 s, sys: 1.1 s, total: 2.38 s
Wall time: 25.9 s


In [None]:
## Execution from SQL
%%time
session.sql(' select customerid,churnvalue, \
            predict_churn(ARRAY_CONSTRUCT( \
                                    GENDER, \
                                    COUNT, \
                                    SENIORCITIZEN, \
                                    PARTNER, \
                                    DEPENDENTS, \
                                    PHONESERVICE, \
                                    MULTIPLELINES,  \
                                    INTERNETSERVICE,  \
                                    ONLINESECURITY,  \
                                    ONLINEBACKUP, \
                                    DEVICEPROTECTION,  \
                                    TECHSUPPORT,  \
                                    STREAMINGTV,  \
                                    STREAMINGMOVIES, \
                                    CONTRACT,  \
                                    PAPERLESSBILLING,  \
                                    PAYMENTMETHOD,  \
                                    TENUREMONTHS, \
                                    MONTHLYCHARGES,  \
                                    TOTALCHARGES)) as Churn_prediction \
                                    from train_dataset sample (10 rows)').show()

In [36]:
session.close()