# MODEL TRAINING
This script does training with Spark MLLib of a Random Forest Classification model for the customer churn prediction experiment-</br>
Uses BigQuery as a source, and writes test results, model metrics and 
feature importance scores to BigQuery

Copyright 2022 Google LLC

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.

In [None]:
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.sql.types import FloatType
import pyspark.sql.functions as F
from pyspark.ml.feature import StringIndexer
import pandas as pd
import sys, logging, argparse, random, tempfile, json
from pyspark.sql.functions import col, udf
from pyspark.sql.functions import round as spark_round
from pyspark.sql.types import StructType, DoubleType, StringType
from pyspark.sql.functions import lit
from pathlib import Path as path
from google.cloud import storage
from urllib.parse import urlparse, urljoin
from datetime import datetime
import random

In [None]:
spark

In [None]:
# 1a. Arguments
pipelineID = random.randint(1, 10000)
projectNbr = "YOUR_PROJECT_NBR"
projectID = "YOUR_PROJECT_ID"
displayPrintStatements = True

In [None]:
# 1b. Variables 
appBaseName = "customer-churn-model"
appNameSuffix = "training"
appName = f"{appBaseName}-{appNameSuffix}"
modelBaseNm = appBaseName
modelVersion = pipelineID
bqDatasetNm = f"{projectID}.customer_churn_ds"
operation = appNameSuffix
bigQuerySourceTableFQN = f"{bqDatasetNm}.training_data"
bigQueryModelTestResultsTableFQN = f"{bqDatasetNm}.test_predictions"
bigQueryModelMetricsTableFQN = f"{bqDatasetNm}.model_metrics"
bigQueryFeatureImportanceTableFQN = f"{bqDatasetNm}.model_feature_importance_scores"
modelBucketUri = f"gs://s8s_model_bucket-{projectNbr}/{modelBaseNm}/{operation}/{modelVersion}"
metricsBucketUri = f"gs://s8s_metrics_bucket-{projectNbr}/{modelBaseNm}/{operation}/{modelVersion}"
scratchBucketUri = f"s8s-spark-bucket-{projectNbr}/{appBaseName}/pipelineId-{pipelineID}/{appNameSuffix}/"
pipelineExecutionDt = datetime.now().strftime("%Y%m%d%H%M%S")

In [None]:
# Other variables, constants
SPLIT_SEED = 6
SPLIT_SPECS = [0.8, 0.2]

In [None]:
# 1c. Display input and output
if displayPrintStatements:
    print("Starting model training for *Customer Churn* experiment")
    print(".....................................................")
    print(f"The datetime now is - {pipelineExecutionDt}")
    print(" ")
    print("INPUT PARAMETERS")
    print(f"....pipelineID={pipelineID}")
    print(f"....projectID={projectID}")
    print(f"....projectNbr={projectNbr}")
    print(f"....displayPrintStatements={displayPrintStatements}")
    print(" ")
    print("EXPECTED SETUP")  
    print(f"....BQ Dataset={bqDatasetNm}")
    print(f"....Model Training Source Data in BigQuery={bigQuerySourceTableFQN}")
    print(f"....Scratch Bucket for BQ connector=gs://s8s-spark-bucket-{projectNbr}") 
    print(f"....Model Bucket=gs://s8s-model-bucket-{projectNbr}")  
    print(f"....Metrics Bucket=gs://s8s-metrics-bucket-{projectNbr}") 
    print(" ")
    print("OUTPUT")
    print(f"....Model in GCS={modelBucketUri}")
    print(f"....Model metrics in GCS={metricsBucketUri}")  
    print(f"....Model metrics in BigQuery={bigQueryModelMetricsTableFQN}")      
    print(f"....Model feature importance scores in BigQuery={bigQueryFeatureImportanceTableFQN}") 
    print(f"....Model test results in BigQuery={bigQueryModelTestResultsTableFQN}") 

In [None]:
# 2. Spark Session creation
print('....Initializing spark & spark configs')
spark = SparkSession.builder.appName(appName).getOrCreate()

# Spark configuration setting for writes to BigQuery
spark.conf.set("parentProject", projectID)
spark.conf.set("temporaryGcsBucket", scratchBucketUri)
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

# Add Python modules
sc.addPyFile(f"gs://s8s_code_bucket-{projectNbr}/pyspark/common_utils.py")
import common_utils

### TRAINING DATA - READ, SPLIT

In [None]:
# 3. Read training data
print('....Read the training dataset into a dataframe')
inputDF = spark.read \
    .format('bigquery') \
    .load(bigQuerySourceTableFQN)

inputDF.printSchema()

if displayPrintStatements:
    print(f"inputDF count={inputDF.count()}")

In [None]:
# Typecast some columns to the right datatype
inputDF = inputDF.withColumn("partner", inputDF.partner.cast('string')) \
    .withColumn("dependents", inputDF.dependents.cast('string')) \
    .withColumn("phone_service", inputDF.phone_service.cast('string')) \
    .withColumn("paperless_billing", inputDF.paperless_billing.cast('string')) \
    .withColumn("churn", inputDF.churn.cast('string')) \
    .withColumn("monthly_charges", inputDF.monthly_charges.cast('float')) \
    .withColumn("total_charges", inputDF.total_charges.cast('float'))

In [None]:
# 4. Split to training and test datasets
print('....Split the dataset')
trainDF, testDF = inputDF.randomSplit(SPLIT_SPECS, seed=SPLIT_SEED)

### PREPROCESSING & FEATURE ENGINEERING 

In [None]:
# 5. Pre-process training data
print('....Data pre-procesing')
dataPreprocessingStagesList = []
# 5a. Create and append to pipeline stages - string indexing and one hot encoding
for eachCategoricalColumn in common_utils.CATEGORICAL_COLUMN_LIST:
    # Category indexing with StringIndexer
    stringIndexer = StringIndexer(inputCol=eachCategoricalColumn, outputCol=eachCategoricalColumn + "Index")
    # Use OneHotEncoder to convert categorical variables into binary SparseVectors
    encoder = OneHotEncoder(inputCols=[stringIndexer.getOutputCol()], outputCols=[eachCategoricalColumn + "classVec"])
    # Add stages.  This is a lazy operation
    dataPreprocessingStagesList += [stringIndexer, encoder]

# 5b. Convert label into label indices using the StringIndexer and append to pipeline stages
labelStringIndexer = StringIndexer(inputCol="churn", outputCol="label")
dataPreprocessingStagesList += [labelStringIndexer]


In [None]:
# 6. Feature engineering
print('....Feature engineering')
featureEngineeringStageList = []
assemblerInputs = common_utils.NUMERIC_COLUMN_LIST + [c + "classVec" for c in common_utils.CATEGORICAL_COLUMN_LIST]
featuresVectorAssembler = VectorAssembler(inputCols=assemblerInputs, outputCol="features")
featureEngineeringStageList += [featuresVectorAssembler]

### MODEL TRAINING

In [None]:
# 5. Model training
print('....Model training')
modelTrainingStageList = []
rfClassifier = RandomForestClassifier(labelCol="label", featuresCol="features")
modelTrainingStageList += [rfClassifier]

In [None]:
# 6. Create a model training pipeline for stages defined
print('....Instantiating pipeline model')
pipeline = Pipeline(stages=dataPreprocessingStagesList + featureEngineeringStageList + modelTrainingStageList) 


In [None]:
# 9. Fit the model
print('....Fit the model')
pipelineModel = pipeline.fit(trainDF)

[Stage 3:>                                                          (0 + 1) / 1]

### MODEL TESTING

In [None]:
# 10. Test the model with the test dataset
print('....Test the model')
predictionsDF = pipelineModel.transform(testDF)
predictionsDF.show(2)

In [None]:
# 11. Persist model to GCS
print('....Persist the model to GCS')
pipelineModel.write().overwrite().save(modelBucketUri)

[Stage 69:>                                                         (0 + 1) / 1]

In [None]:
# 12. Persist model testing results to BigQuery
persistPredictionsDF = predictionsDF.withColumn("pipeline_id", lit(pipelineID).cast("string")) \
                                   .withColumn("model_version", lit(pipelineID).cast("string")) \
                                   .withColumn("pipeline_execution_dt", lit(pipelineExecutionDt)) \
                                   .withColumn("operation", lit(appNameSuffix)) 

persistPredictionsDF.write.format('bigquery') \
.mode("append")\
.option('table', bigQueryModelTestResultsTableFQN) \
.save()

### MODEL EXPLAINABILITY

In [None]:
# 13a. Model explainability - feature importance
pipelineModel.stages[-1].featureImportances

In [None]:
# 13b. Function to parse feature importance
def fnExtractFeatureImportance(featureImportanceSparseVector, predictionsDataframe, featureColumnListing):
    featureColumnMetadataList = []
    for i in predictionsDataframe.schema[featureColumnListing].metadata["ml_attr"]["attrs"]:
        featureColumnMetadataList = featureColumnMetadataList + predictionsDataframe.schema[featureColumnListing].metadata["ml_attr"]["attrs"][i]
        
    featureColumnMetadataPDF = pd.DataFrame(featureColumnMetadataList)
    featureColumnMetadataPDF['importance_score'] = featureColumnMetadataPDF['idx'].apply(lambda x: featureImportanceSparseVector[x])
    return(featureColumnMetadataPDF.sort_values('importance_score', ascending = False))


In [None]:
# 13c. Print feature importance
fnExtractFeatureImportance(pipelineModel.stages[-1].featureImportances, predictionsDF, "features")


In [None]:
# 13d. Capture into a Pandas DF
featureImportantcePDF = fnExtractFeatureImportance(pipelineModel.stages[-1].featureImportances, predictionsDF, "features")


In [None]:
# 13e. Persist feature importance scores to BigQuery
# Convert Pandas to Spark DF & use Spark to persist
featureImportantceDF = spark.createDataFrame(featureImportantcePDF).toDF("feature_index","feature_nm","importance_score")

persistFeatureImportanceDF = featureImportantceDF.withColumn("pipeline_id", lit(pipelineID).cast("string")) \
                                   .withColumn("model_version", lit(pipelineID).cast("string")) \
                                   .withColumn("pipeline_execution_dt", lit(pipelineExecutionDt)) \
                                   .withColumn("operation", lit(operation)) 

persistFeatureImportanceDF.show(2)

persistFeatureImportanceDF.write.format('bigquery') \
.mode("append")\
.option('table', bigQueryFeatureImportanceTableFQN) \
.save()

[Stage 212:>                                                        (0 + 8) / 8]

### MODEL EVALUATION

In [None]:
# 14a. Metrics parsing function
def fnParseModelMetrics(predictionsDF, labelColumn, operation, boolSubsetOnly):
    """
    Get model metrics
    Args:
        predictions: predictions
        labelColumn: target column
        operation: train or test
        boolSubsetOnly: boolean for partial(without true, score, prediction) or full metrics 
    Returns:
        metrics: metrics
        
    Anagha TODO: This function if called from common_utils fails; Need to researchy why
    """
    
    metricLabels = ['area_roc', 'area_prc', 'accuracy', 'f1', 'precision', 'recall']
    metricColumns = ['true', 'score', 'prediction']
    metricKeys = [f'{operation}_{ml}' for ml in metricLabels] + metricColumns

    # Instantiate evaluators
    bcEvaluator = BinaryClassificationEvaluator(labelCol=labelColumn)
    mcEvaluator = MulticlassClassificationEvaluator(labelCol=labelColumn)

    # Capture metrics -> areas, acc, f1, prec, rec
    area_roc = round(bcEvaluator.evaluate(predictionsDF, {bcEvaluator.metricName: 'areaUnderROC'}), 5)
    area_prc = round(bcEvaluator.evaluate(predictionsDF, {bcEvaluator.metricName: 'areaUnderPR'}), 5)
    acc = round(mcEvaluator.evaluate(predictionsDF, {mcEvaluator.metricName: "accuracy"}), 5)
    f1 = round(mcEvaluator.evaluate(predictionsDF, {mcEvaluator.metricName: "f1"}), 5)
    prec = round(mcEvaluator.evaluate(predictionsDF, {mcEvaluator.metricName: "weightedPrecision"}), 5)
    rec = round(mcEvaluator.evaluate(predictionsDF, {mcEvaluator.metricName: "weightedRecall"}), 5)

    # Get the true, score, prediction off of the test results dataframe
    rocDictionary = common_utils.fnGetTrueScoreAndPrediction(predictionsDF, labelColumn)
    true = rocDictionary['true']
    score = rocDictionary['score']
    prediction = rocDictionary['prediction']

    # Create a metric values array
    metricValuesArray = []
    if boolSubsetOnly:
        metricValuesArray.extend((area_roc, area_prc, acc, f1, prec, rec))
    else:
        metricValuesArray.extend((area_roc, area_prc, acc, f1, prec, rec, true, score, prediction))
    
    # Zip the keys and values into a dictionary  
    metricsDictionary = dict(zip(metricKeys, metricValuesArray))

    return metricsDictionary


In [None]:
# 14b. Capture & display metrics
modelMetrics = fnParseModelMetrics(predictionsDF, "label", "test", True)
for m, v in modelMetrics.items():
    print(f'{m}: {v}')
    

In [None]:
# 14c. Persist metrics subset to GCS
blobName = f"{modelBaseNm}/{operation}/{modelVersion}/subset/metrics.json"
common_utils.fnPersistMetrics(urlparse(metricsBucketUri).netloc, modelMetrics, blobName)


In [None]:
# 14d. Persist metrics in full to GCS
# (The version persisted to BQ does not have True, Score and Prediction needed for Confusion Matrix
# This version below has the True, Score and Prediction additionally) 

# 14d.1. Capture
modelMetricsWithTSP = fnParseModelMetrics(predictionsDF, "label", "test", False)

# 14d.2. Persist
blobName = f"{modelBaseNm}/{operation}/{modelVersion}/full/metrics.json"
print(blobName)
common_utils.fnPersistMetrics(urlparse(metricsBucketUri).netloc, modelMetricsWithTSP, blobName)

# 14d.3. Print
for m, v in modelMetricsWithTSP.items():
    print(f'{m}: {v}')
    

In [None]:
# 14e. Persist metrics subset to BigQuery
metricsDF = spark.createDataFrame(modelMetrics.items(), ["metric_nm", "metric_value"]) 
metricsWithPipelineIdDF = metricsDF.withColumn("pipeline_id", lit(pipelineID).cast("string")) \
                                   .withColumn("model_version", lit(pipelineID).cast("string")) \
                                   .withColumn("pipeline_execution_dt", lit(pipelineExecutionDt)) \
                                   .withColumn("operation", lit(operation)) 

metricsWithPipelineIdDF.show()

metricsWithPipelineIdDF.write.format('bigquery') \
.mode("append")\
.option('table', bigQueryModelMetricsTableFQN) \
.save()


