In [1]:
import pyspark.sql.functions as F
import pyspark.sql.types as T
import os, sys
from source.functions import SparkMethods, DataLoader

In [2]:
df = DataLoader.load_data("data/adult.test")
df = df.where(F.col('age').isNotNull())
df.show()

+---+----------------+------+------------+-------------+------------------+-----------------+-------------+------------------+------+------------+------------+--------------+--------------+------+
|age|       workclass|fnlwgt|   education|education-num|    marital-status|       occupation| relationship|              race|   sex|capital-gain|capital-loss|hours-per-week|native-country|income|
+---+----------------+------+------------+-------------+------------------+-----------------+-------------+------------------+------+------------+------------+--------------+--------------+------+
| 25|         Private|226802|        11th|            7|     Never-married|Machine-op-inspct|    Own-child|             Black|  Male|         0.0|         0.0|          40.0| United-States|<=50K.|
| 38|         Private| 89814|     HS-grad|            9|Married-civ-spouse|  Farming-fishing|      Husband|             White|  Male|         0.0|         0.0|          50.0| United-States|<=50K.|
| 28|       Loc

In [3]:
# Load vectorizer and model
import mlflow.spark
vectorizer = mlflow.spark.load_model('mlruns/1/9373860b1fcf4654b1e6799c0c49859c/artifacts/vectorizer')
model = mlflow.spark.load_model('mlruns/1/4e8b49a2efb04bf599488a944250e5f6/artifacts/best_GBTClassifier') 

2019/12/04 05:23:24 INFO mlflow.spark: File 'mlruns/1/9373860b1fcf4654b1e6799c0c49859c/artifacts/vectorizer/sparkml' is already on DFS, copy is not necessary.
2019/12/04 05:23:29 INFO mlflow.spark: File 'mlruns/1/4e8b49a2efb04bf599488a944250e5f6/artifacts/best_GBTClassifier/sparkml' is already on DFS, copy is not necessary.


In [4]:
# vectorize and predict
transformed_df = vectorizer.transform(df)
transformed_df = model.transform(transformed_df)
transformed_df.show(20)

+---+----------------+------+------------+-------------+------------------+-----------------+-------------+------------------+------+------------+------------+--------------+--------------+------+---------------+---------------+-------------------+--------------------+----------------+------------------+----------+---------+--------------------+--------------------+--------------------+----------------+-------------+---------------+------------------+-----------------+-------------+---------------+-------------+------------------+--------------------+--------------------+--------------------+---------------+
|age|       workclass|fnlwgt|   education|education-num|    marital-status|       occupation| relationship|              race|   sex|capital-gain|capital-loss|hours-per-week|native-country|income|workclass_index|education_index|education-num_index|marital-status_index|occupation_index|relationship_index|race_index|sex_index|native-country_index|    scaling_features|     scaled_feat

In [5]:
# Clean income column in test data that contains a period after the income
transformed_df = transformed_df.withColumn('income', F.regexp_replace(F.col('income'), r'\.', ''))
transformed_df.show()

+---+----------------+------+------------+-------------+------------------+-----------------+-------------+------------------+------+------------+------------+--------------+--------------+------+---------------+---------------+-------------------+--------------------+----------------+------------------+----------+---------+--------------------+--------------------+--------------------+----------------+-------------+---------------+------------------+-----------------+-------------+---------------+-------------+------------------+--------------------+--------------------+--------------------+---------------+
|age|       workclass|fnlwgt|   education|education-num|    marital-status|       occupation| relationship|              race|   sex|capital-gain|capital-loss|hours-per-week|native-country|income|workclass_index|education_index|education-num_index|marital-status_index|occupation_index|relationship_index|race_index|sex_index|native-country_index|    scaling_features|     scaled_feat

In [7]:
# Index label (income)
label_vectorizer = mlflow.spark.load_model('mlruns/1/9373860b1fcf4654b1e6799c0c49859c/artifacts/label_vectorizer') 
transformed_df = label_vectorizer.transform(transformed_df)
transformed_df.show()

2019/12/04 05:24:07 INFO mlflow.spark: File 'mlruns/1/9373860b1fcf4654b1e6799c0c49859c/artifacts/label_vectorizer/sparkml' is already on DFS, copy is not necessary.
+---+----------------+------+------------+-------------+------------------+-----------------+-------------+------------------+------+------------+------------+--------------+--------------+------+---------------+---------------+-------------------+--------------------+----------------+------------------+----------+---------+--------------------+--------------------+--------------------+----------------+-------------+---------------+------------------+-----------------+-------------+---------------+-------------+------------------+--------------------+--------------------+--------------------+---------------+-----+
|age|       workclass|fnlwgt|   education|education-num|    marital-status|       occupation| relationship|              race|   sex|capital-gain|capital-loss|hours-per-week|native-country|income|workclass_index|e

In [8]:
# Check accuracy
SparkMethods.get_MultiClassMetrics(transformed_df, data_type='val')



{'valPrecision': 0.863276211534918,
 'valRecall': 0.863276211534918,
 'valF1Score': 0.863276211534918,
 'valWeightedRecall': 0.863276211534918,
 'valWeightedPrecision': 0.8576987984474875,
 'valWeightedF1Score': 0.8584499358859526,
 'valWeightedF0.5Score': 0.8574541310057604,
 'valWeightedFalsePositiveRate': 0.30341416445954883,
 'valRecall0.0': 0.9379171692802574,
 'valPrecision0.0': 0.889151482808569,
 'valF1Score0.0': 0.9128835316217909,
 'valRecall1.0': 0.6219448777951118,
 'valPrecision1.0': 0.7560050568900126,
 'valF1Score1.0': 0.682453637660485}

In [10]:
## get data for testing API:
# test_data = df.limit(2).toPandas().to_json(orient='split')
# print(test_data)