# Doctor Right - Model building

In [1]:
import sys
sys.path.append("../modules")
from eda import EDAAnalyzer
from spark_session import SparkManager
from feature_engineering import FeatureEngineer
from ml_developer import XGBoostModelBuilder
from ml_developer import MLPModelBuilder

In [2]:
# Load autoreload extension
%load_ext autoreload
%autoreload 2

#### Constants and config

In [3]:
# mx_submits_path = "../data_sample/mx_submits_all/"
mx_submits_path = "../data_sample/mx_submits.parquet/"
mx_submits_line_path = "../data_sample/mx_submitsline.parquet/"
cohort_key="767ef4cac69e8a0c77384f6e1414364b"

sample_patient_id = "8aad41f612a7095449888c8050abaeb05fdee65643caa3033542610421d8bd1daaa2c4ce1757401003a1bbcd60948a7aa13eba507a676dea80e0cf76b77dbc95"
features_cols = [
'facility_provider_address_region',
'patient_gender',
'principal_diagnosis_body_part',
'principal_diagnosis_category',
'claim_all_diagnosis_codes',
'previous_diagnosis_ohe']
label_column = 'claim_total_charge_amount'
exclude_cols = ['patient_id']
most_repeated_diagnosis_list = [] 

In [4]:
mx_submits_spark_manager = SparkManager(mx_submits_path)

24/10/17 01:35:24 WARN Utils: Your hostname, Sureshs-MacBook-Air.local resolves to a loopback address: 127.0.0.1; using 172.20.9.214 instead (on interface en0)
24/10/17 01:35:24 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/10/17 01:35:54 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/10/17 01:35:55 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
                                                                                

### Feature Engineering

In [5]:
mx_submits_fe=FeatureEngineer(mx_submits_spark_manager)

In [6]:
mx_submits_fe.add_continuous_visit_years()
mx_submits_fe.get_unique_value_counts("continuous_visit_years")

                                                                                

+----------------------+------+--------------------+
|continuous_visit_years| count|          percentage|
+----------------------+------+--------------------+
|                     1|224597|   90.80716763566676|
|                     2| 19491|   7.880436979954232|
|                     3|  2506|  1.0132048161595253|
|                     4|   530| 0.21428513669774474|
|                     5|   115|0.046495831547623864|
|                     6|    35|0.014150905253624654|
|                     7|    16|0.006468985258799842|
|                     9|     8|0.003234492629399921|
|                     8|     8|0.003234492629399921|
|                    10|     6|0.002425869472049...|
|                    11|     5| 0.00202155789337495|
|                    12|     4|0.001617246314699...|
|                    13|     4|0.001617246314699...|
|                    14|     3|0.001212934736024...|
|                    15|     3|0.001212934736024...|
|                    18|     1|4.0431157867499

In [7]:
mx_submits_fe.filter_by_continuous_visit_years(1)



Dataframe post removing less than 1 continuous visits - Shape: 247334 rows, 131 columns


                                                                                

In [8]:
mx_submits_fe.add_comorbidities_with_exponential_decay_sparse_vector()

                                                                                

Unnamed: 0,previous_diagnosis_ohe
0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
1,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
2,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
3,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
4,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."


In [9]:
mx_submits_fe.retain_columns(features_cols+[label_column])

In [10]:
mx_submits_fe.convert_columns_to_float(["claim_total_charge_amount"])
preprocess_data = mx_submits_fe.preprocess_data(exclude_cols=exclude_cols)
preprocess_data

Casted claim_total_charge_amount to float
One-Hot Encoding applied successfully to column: facility_provider_address_region
One-Hot Encoding applied successfully to column: patient_gender
One-Hot Encoding applied successfully to column: principal_diagnosis_body_part
One-Hot Encoding applied successfully to column: principal_diagnosis_category
Assembling all features into a vector with 6 columns.
Preprocessing complete. Feature vector created.


In [11]:
model_feature_col = [
    'principal_diagnosis_category_Factors influencing health status and contact with health services',
# 'patient_location_residential_region_Northeast',
'principal_diagnosis_category_Diseases of the circulatory system',
'principal_diagnosis_category_Diseases of the respiratory system',
'principal_diagnosis_category_Diseases of the musculoskeletal system and connective tissue',
# 'principal_diagnosis_category_unknown',
# 'patient_location_residential_region_West',
'patient_gender_F',
'principal_diagnosis_category_Endocrine nutritional and metabolic diseases',
# 'principal_diagnosis_body_part_unknown',
'principal_diagnosis_body_part_Spine',
'principal_diagnosis_category_Diseases of the eye and adnexa',
'principal_diagnosis_category_Diseases of the genitourinary system',
'principal_diagnosis_category_Injury poisoning and certain other consequences of external causes',
# 'patient_location_residential_region_South',
# 'patient_location_residential_region_unknown',
# 'facility_provider_address_region_unknown',
'patient_gender_M',
'principal_diagnosis_category_Symptoms signs and abnormal clinical laboratory findings not elsewhere classified',
'principal_diagnosis_category_Mental Behavioral and Neurodevelopmental disorders',
# 'patient_gender_U',
# 'secondary_payer_state_unknown',
'principal_diagnosis_category_Diseases of the nervous system',
'Diagnosis_I10',
# 'patient_location_residential_region_Midwest',
'principal_diagnosis_body_part_Knee',
'principal_diagnosis_category_Diseases of the skin and subcutaneous tissue',
'principal_diagnosis_category_Diseases of the ear and mastoid process',
'facility_provider_address_region_Northeast',
'principal_diagnosis_body_part_Eye',
'principal_diagnosis_category_Neoplasms',
'facility_provider_address_region_South',
'principal_diagnosis_body_part_Heart',
'principal_diagnosis_body_part_Ear',
'principal_diagnosis_body_part_Shoulder',
'Diagnosis_E785',
'principal_diagnosis_category_External causes of morbidity',
'principal_diagnosis_category_Diseases of the digestive system',
'facility_provider_address_region_West',
'principal_diagnosis_body_part_Lung',
'facility_provider_address_region_Midwest',
'Diagnosis_N179',
'Diagnosis_E119',
'Diagnosis_R079',
'Diagnosis_Z23',
'principal_diagnosis_body_part_Hip',
'Diagnosis_F200',
'principal_diagnosis_category_Pregnancy childbirth and puerperium',
'Diagnosis_Z87891',
'principal_diagnosis_body_part_Foot',
'Diagnosis_I129',
'Diagnosis_F331',
'Diagnosis_M109',
'principal_diagnosis_category_Certain infections and parasitic diseases',
'Diagnosis_J90',
'principal_diagnosis_body_part_Leg non-joint',
'Diagnosis_R000',
'Diagnosis_R739',
'Diagnosis_K219',
'Diagnosis_Z951',
'Diagnosis_R32',
'principal_diagnosis_body_part_Foot and ankle',
'Diagnosis_I509',
'Diagnosis_E875',
'Diagnosis_N281',
'Diagnosis_S2242XA',
'Diagnosis_I130',
'principal_diagnosis_category_Diseases of the blood and blood-forming organs and certain disorders involving the immune mechanism',
'principal_diagnosis_body_part_Hand',
'Diagnosis_F17210',
'Diagnosis_I214',
'Diagnosis_Z931',
'Diagnosis_Q909',
'Diagnosis_I739',
'Diagnosis_Z743',
'Diagnosis_F418',
'Diagnosis_F329',
# 'secondary_payer_state_UT',
'principal_diagnosis_body_part_Wrist',
'Diagnosis_F17200',
'Diagnosis_F209',
'Diagnosis_M545',
# 'secondary_payer_state_KY',
'Diagnosis_E872',
# 'secondary_payer_state_TX',
'Diagnosis_A419',
'principal_diagnosis_body_part_Finger',
'Diagnosis_J189',
'Diagnosis_Z794',
'Diagnosis_I252',
'Diagnosis_R262',
'Diagnosis_D631',
'Diagnosis_I82411',
'Diagnosis_D638',
'Diagnosis_R918',
# 'secondary_payer_state_MO',
'Diagnosis_N189',
'Diagnosis_N186',
'principal_diagnosis_category_Congenital malformations deformations and chromosomal abnormalities',
'principal_diagnosis_body_part_Elbow',
'Diagnosis_Z955',
'Diagnosis_J810',
'Diagnosis_I69322',
'Diagnosis_I69351',
'Diagnosis_R278',
'Diagnosis_M479',
# 'secondary_payer_state_GA',
'Diagnosis_R279',
'Diagnosis_S0101XA',
'Diagnosis_S130XXA',
'Diagnosis_F10229',
'Diagnosis_I69959',
'Diagnosis_D509',
'Diagnosis_I361',
'Diagnosis_N184',
'Diagnosis_I110',
'Diagnosis_M542',
'Diagnosis_E669',
'Diagnosis_G894',
'Diagnosis_R578',
'Diagnosis_S01112A',
'Diagnosis_F840',
'Diagnosis_Z00129',
'Diagnosis_G309',
'Diagnosis_G319',
'Diagnosis_J441',
'Diagnosis_S14125A',
'Diagnosis_I712',
'Diagnosis_S12500A',
'Diagnosis_S240XXA',
'Diagnosis_S12400A',
'Diagnosis_S14123A',
'Diagnosis_R202',
'Diagnosis_K222',
'Diagnosis_D649',
'Diagnosis_Z452',
'Diagnosis_V784XXA',
'principal_diagnosis_category_Certain conditions originating in the perinatal period',
'Diagnosis_K743',
'Diagnosis_G904',
'Diagnosis_J449',
'Diagnosis_S0990XA',
'Diagnosis_R620',
'Diagnosis_Z789',
'Diagnosis_S0191XA',
'Diagnosis_E861',
'Diagnosis_Z992',
'Diagnosis_M549',
'Diagnosis_I469',
'Diagnosis_S1093XA',
'Diagnosis_Z713',
'Diagnosis_D72829',
'Diagnosis_D62',
'Diagnosis_M341',
'Diagnosis_Z20822',
'Diagnosis_R569',
'Diagnosis_Z113',
'Diagnosis_I447',
'Diagnosis_E871',
'Diagnosis_I480',
'Diagnosis_N390',
'principal_diagnosis_body_part_Arm non-joint',
'principal_diagnosis_body_part_Ankle',
'principal_diagnosis_body_part_Head',
'Diagnosis_E440',
'Diagnosis_R579',
'Diagnosis_Z79899',
'Diagnosis_M25551',
'Diagnosis_R64',
'Diagnosis_F251',
'Diagnosis_H524',
'principal_diagnosis_body_part_Toe',
'Diagnosis_R55',
'Diagnosis_Z993',
'Diagnosis_Z95810',
'Diagnosis_R634',
'principal_diagnosis_body_part_Stomach',
'Diagnosis_D508',
'Diagnosis_R531',
'principal_diagnosis_body_part_Various',
'Diagnosis_H903',
'Diagnosis_F39',
'Diagnosis_S2191XA',
'Diagnosis_X58XXXA',
'Diagnosis_I120',
'Diagnosis_M329',
'Diagnosis_R54',
'Diagnosis_Z139',
'Diagnosis_J431',
'Diagnosis_F250',
'Diagnosis_C50511',
'Diagnosis_S1091XA',
'Diagnosis_L89310',
'Diagnosis_F419',
'Diagnosis_I959',
'principal_diagnosis_body_part_Leg',
'Diagnosis_S31119A',
'Diagnosis_J309',
'Diagnosis_E11621',
'Diagnosis_N529',
'Diagnosis_R402432',
'Diagnosis_M25571',
'Diagnosis_I253',
'Diagnosis_N939',
'Diagnosis_S31020A',
'Diagnosis_N401',
'Diagnosis_R69',
'Diagnosis_Z95828',
# 'secondary_payer_state_MA',
'Diagnosis_R410',
'Diagnosis_R600',
'Diagnosis_E782',
'Diagnosis_R52',
'Diagnosis_M546',
'Diagnosis_Z888'
]

In [12]:
mx_submits_fe.expand_features(model_feature_col)

Created OHE column: facility_provider_address_region_Northeast (index: 1)
Created OHE column: facility_provider_address_region_South (index: 2)
Created OHE column: facility_provider_address_region_Midwest (index: 3)
Created OHE column: facility_provider_address_region_West (index: 4)
Created OHE column: patient_gender_F (index: 0)
Created OHE column: patient_gender_M (index: 1)
Created OHE column: principal_diagnosis_body_part_Spine (index: 1)
Created OHE column: principal_diagnosis_body_part_Knee (index: 2)
Created OHE column: principal_diagnosis_body_part_Heart (index: 3)
Created OHE column: principal_diagnosis_body_part_Shoulder (index: 4)
Created OHE column: principal_diagnosis_body_part_Eye (index: 5)
Created OHE column: principal_diagnosis_body_part_Hip (index: 6)
Created OHE column: principal_diagnosis_body_part_Ear (index: 7)
Created OHE column: principal_diagnosis_body_part_Leg non-joint (index: 8)
Created OHE column: principal_diagnosis_body_part_Foot (index: 9)
Created OHE c

In [13]:
mx_submits_fe.retain_columns(model_feature_col+[label_column])

In [14]:
mx_submits_fe.dataframe.columns

['principal_diagnosis_category_Factors influencing health status and contact with health services',
 'principal_diagnosis_category_Diseases of the circulatory system',
 'principal_diagnosis_category_Diseases of the respiratory system',
 'principal_diagnosis_category_Diseases of the musculoskeletal system and connective tissue',
 'patient_gender_F',
 'principal_diagnosis_category_Endocrine nutritional and metabolic diseases',
 'principal_diagnosis_body_part_Spine',
 'principal_diagnosis_category_Diseases of the eye and adnexa',
 'principal_diagnosis_category_Diseases of the genitourinary system',
 'principal_diagnosis_category_Injury poisoning and certain other consequences of external causes',
 'patient_gender_M',
 'principal_diagnosis_category_Symptoms signs and abnormal clinical laboratory findings not elsewhere classified',
 'principal_diagnosis_category_Mental Behavioral and Neurodevelopmental disorders',
 'principal_diagnosis_category_Diseases of the nervous system',
 'Diagnosis_I

In [15]:
model_data = mx_submits_fe.preprocess_features(model_feature_col, label_column)
model_data

DataFrame[features: vector, claim_total_charge_amount: float]

# Model Training

## XGB Model

In [16]:
xgb_model = XGBoostModelBuilder(model_data, model_feature_col, label_column)

In [17]:
train_df, test_df = xgb_model.split_data()

In [18]:
xgb_model.train_model()
xgb_model.save_model("../output/model/XGB_model_cont_1")

24/10/17 01:36:29 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
2024-10-17 01:36:29,825 INFO XGBoost-PySpark: _fit Running xgboost-2.0.3 on 1 workers with
	booster params: {'objective': 'reg:squarederror', 'device': 'cpu', 'max_depth': 3, 'eta': 0.1, 'num_round': 100, 'nthread': 1}
	train_call_kwargs_params: {'verbose_eval': True, 'num_boost_round': 100}
	dmatrix_kwargs: {'nthread': 1, 'missing': nan}
24/10/17 01:36:31 WARN DAGScheduler: Broadcasting large task binary with size 1672.3 KiB
[02:00:41] task 0 got new rank 0                                    (0 + 1) / 1]
Parameters: { "num_round" } are not used.

2024-10-17 02:00:46,286 INFO XGBoost-PySpark: _fit Finished xgboost training!   


Model 'XGB_model' saved to ../output/model/XGB_model_cont_1


In [19]:
# xgb_model = xgb_model.load_model(model_data, model_feature_col, label_column, xgb_model.model_name,path="../output/model/XGB_model")

In [20]:
xgb_model.evaluate_model(type="Train")

24/10/17 02:00:49 WARN DAGScheduler: Broadcasting large task binary with size 1803.7 KiB
2024-10-17 02:00:57,814 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:00:58,049 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:00:58,288 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:00:58,416 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:00:58,612 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:00:58,862 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:00:58,921 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:00:58,993 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:04:10,599 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:04:39,279 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:06:51,181 INFO XGBoost-PySp

10090.130545448183

In [21]:
xgb_model.evaluate_model(type="Test")

24/10/17 02:24:28 WARN DAGScheduler: Broadcasting large task binary with size 1803.7 KiB
2024-10-17 02:24:29,547 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:24:37,712 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:24:37,914 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:24:37,972 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:24:38,085 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:24:38,121 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:24:38,128 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:24:38,188 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:27:59,888 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:28:35,008 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:30:50,194 INFO XGBoost-PySp

9199.769218021689

In [22]:
xgb_model.calculate_mape(type="Train")

24/10/17 02:46:07 WARN DAGScheduler: Broadcasting large task binary with size 1798.5 KiB
2024-10-17 02:46:16,245 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:46:16,270 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:46:16,358 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:46:16,564 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:46:16,732 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:46:16,735 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:46:16,876 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:46:16,883 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:49:37,983 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:50:08,069 INFO XGBoost-PySpark: predict_udf Do the inference on the CPUs
2024-10-17 02:52:11,595 INFO XGBoost-PySp

24259.49783369624

In [23]:
xgb_model.feature_importance()

[('principal_diagnosis_category_Certain conditions originating in the perinatal period',
  56.0),
 ('principal_diagnosis_category_Diseases of the genitourinary system', 53.0),
 ('principal_diagnosis_category_Certain infections and parasitic diseases',
  49.0),
 ('patient_gender_F', 46.0),
 ('Diagnosis_F418', 46.0),
 ('principal_diagnosis_body_part_Various', 43.0),
 ('facility_provider_address_region_Northeast', 33.0),
 ('principal_diagnosis_category_Neoplasms', 27.0),
 ('facility_provider_address_region_South', 26.0),
 ('facility_provider_address_region_Midwest', 25.0),
 ('Diagnosis_I130', 23.0),
 ('patient_gender_M', 20.0),
 ('principal_diagnosis_category_Factors influencing health status and contact with health services',
  17.0),
 ('principal_diagnosis_category_Injury poisoning and certain other consequences of external causes',
  17.0),
 ('principal_diagnosis_category_Diseases of the circulatory system', 15.0),
 ('facility_provider_address_region_West', 15.0),
 ('Diagnosis_K743', 1

## "?MLP Classifier

In [14]:
mlp_builder = MLPModelBuilder(model_data, model_feature_col, label_column)

In [15]:
model_feature_col

['principal_diagnosis_category_Factors influencing health status and contact with health services',
 'patient_location_residential_region_Northeast',
 'principal_diagnosis_category_Diseases of the circulatory system',
 'principal_diagnosis_category_Diseases of the respiratory system',
 'principal_diagnosis_category_Diseases of the musculoskeletal system and connective tissue',
 'principal_diagnosis_category_unknown',
 'patient_location_residential_region_West',
 'patient_gender_F',
 'principal_diagnosis_category_Endocrine nutritional and metabolic diseases',
 'principal_diagnosis_body_part_unknown',
 'principal_diagnosis_body_part_Spine',
 'principal_diagnosis_category_Diseases of the eye and adnexa',
 'principal_diagnosis_category_Diseases of the genitourinary system',
 'principal_diagnosis_category_Injury poisoning and certain other consequences of external causes',
 'patient_location_residential_region_South',
 'patient_location_residential_region_unknown',
 'facility_provider_addre

In [23]:
mlp_builder.train_df, mlp_builder.test_df = mlp_builder.split_data()

In [17]:
mlp_builder.bin_labels(num_bins=3)

                                                                                

In [18]:
# layers = [len(mlp_builder.feature_columns), 5, 4, 3] 
# mlp_builder.train_model(layers)

24/10/15 01:25:50 WARN DAGScheduler: Broadcasting large task binary with size 1733.4 KiB
24/10/15 01:50:03 WARN DAGScheduler: Broadcasting large task binary with size 1736.4 KiB
24/10/15 01:50:04 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
24/10/15 01:50:04 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.VectorBLAS
24/10/15 01:50:04 WARN DAGScheduler: Broadcasting large task binary with size 1737.4 KiB
24/10/15 01:50:05 WARN DAGScheduler: Broadcasting large task binary with size 1736.4 KiB
24/10/15 01:50:05 WARN DAGScheduler: Broadcasting large task binary with size 1737.4 KiB
24/10/15 01:50:05 WARN DAGScheduler: Broadcasting large task binary with size 1736.4 KiB
24/10/15 01:50:05 WARN DAGScheduler: Broadcasting large task binary with size 1737.4 KiB
24/10/15 01:50:06 WARN DAGScheduler: Broadcasting large task binary with size 1736.4 KiB
24/10/15 01:50:06 WARN DAGScheduler: Broadcasting large task binary wit

MultilayerPerceptronClassificationModel: uid=MultilayerPerceptronClassifier_1d369a285971, numLayers=4, numClasses=3, numFeatures=215

In [20]:
# mlp_builder.save_model("../output/model/MLPModel")

In [20]:
mlp_builder = mlp_builder.load_model(model_data, model_feature_col, label_column,path = "../output/model/MLPModel")

In [21]:
mlp_builder.evaluate_model(type="Train")

24/10/15 05:11:58 WARN DAGScheduler: Broadcasting large task binary with size 1732.4 KiB
                                                                                

0.4183189785586058

In [22]:
mlp_builder.evaluate_model(type="Test")

24/10/15 05:33:57 WARN DAGScheduler: Broadcasting large task binary with size 1732.4 KiB
                                                                                

0.41645249459470485

In [23]:
avg_claim_train = mlp_builder.average_claim_by_bin_train()
avg_claim_train.show()

24/10/15 05:57:04 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
24/10/15 05:57:05 WARN DAGScheduler: Broadcasting large task binary with size 1707.3 KiB
24/10/15 06:22:22 WARN DAGScheduler: Broadcasting large task binary with size 1675.8 KiB
                                                                                

+------------+--------------------+
|label_binned|average_claim_amount|
+------------+--------------------+
|         1.0|  242.79556904862184|
|         0.0|    83.5966601299861|
|         2.0|  3583.0215044044608|
+------------+--------------------+



In [25]:
avg_claim_test = mlp_builder.average_claim_by_bin_test()
avg_claim_test.show()

24/10/15 09:31:15 WARN DAGScheduler: Broadcasting large task binary with size 1707.3 KiB

+------------+--------------------+
|label_binned|average_claim_amount|
+------------+--------------------+
|         1.0|  242.14374536013332|
|         0.0|    84.0827847817373|
|         2.0|  3487.9150747244967|
+------------+--------------------+



24/10/15 09:57:26 WARN DAGScheduler: Broadcasting large task binary with size 1675.8 KiB
                                                                                

In [27]:
avg_claim_pred_train = mlp_builder.average_claim_by_predicted_bin_train(num_bins=3)
avg_claim_pred_train.show()

24/10/15 14:16:42 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
24/10/15 14:16:45 WARN DAGScheduler: Broadcasting large task binary with size 1757.2 KiB
24/10/15 14:22:45 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
24/10/15 14:22:46 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.VectorBLAS
24/10/15 14:45:45 WARN DAGScheduler: Broadcasting large task binary with size 1723.9 KiB
                                                                                

+-------------+--------------------+
|predicted_bin|average_claim_amount|
+-------------+--------------------+
|            0|  1307.2307887574923|
+-------------+--------------------+



In [28]:
avg_claim_pred_test = mlp_builder.average_claim_by_predicted_bin_test(num_bins=3)
avg_claim_pred_test.show()

24/10/15 14:45:51 WARN DAGScheduler: Broadcasting large task binary with size 1756.9 KiB
24/10/15 15:09:33 WARN DAGScheduler: Broadcasting large task binary with size 1722.1 KiB
                                                                                

+-------------+--------------------+
|predicted_bin|average_claim_amount|
+-------------+--------------------+
|            0|  1272.9569344277793|
+-------------+--------------------+

