In this notebook, we will be running a parameter sweep for the Random Forest model.

## Style Fix

In [1]:
%%html
<style>
table {float:left}
</style>

## Imports

In [2]:
import loadAndClean
import random
import numpy as np
from sklearn.metrics import mean_squared_error
from sklearn.ensemble import RandomForestRegressor

## Load Data

In [3]:
X = loadAndClean.loadAndClean()
X.head(3)

Unnamed: 0,DRG Definition,Provider Id,Provider Name,Provider Street Address,Provider City,Provider State,Provider Zip Code,Provider HRR,Total Discharges,Average Covered Charges,Average Total Payments,Average Medicare Payments,Average Covered Charges Num,Average Total Payments Num,Average Medicare Payments Num,DRG Code,Provider HRR Num,Latitude,Longitude
0,039 - EXTRACRANIAL PROCEDURES W/O CC/MCC,10001,SOUTHEAST ALABAMA MEDICAL CENTER,1108 ROSS CLARK CIRCLE,DOTHAN,AL,36301,AL - Dothan,98,$37988.98,$5872.19,$4838.32,37988.98,5872.19,4838.32,39,2,31.216725,-85.363068
1,039 - EXTRACRANIAL PROCEDURES W/O CC/MCC,10005,MARSHALL MEDICAL CENTER SOUTH,2505 U S HIGHWAY 431 NORTH,BOAZ,AL,35957,AL - Birmingham,34,$15554.88,$6053.29,$5255.65,15554.88,6053.29,5255.65,39,1,34.204811,-86.10186
2,039 - EXTRACRANIAL PROCEDURES W/O CC/MCC,10006,ELIZA COFFEE MEMORIAL HOSPITAL,205 MARENGO STREET,FLORENCE,AL,35631,AL - Birmingham,30,$40380.10,$5319.83,$4150.87,40380.1,5319.83,4150.87,39,1,34.793845,-87.683155


## Cross Val function

This function will cross vailidate the model and return the average RMSE. It splits the data by hospital so as to ensure that the model isn't baised based on the other procedures a hospital performs.

In [4]:
def crossVal(clf, X, predictors, cv=3):
    random.seed(6)
    scores = []
    for i in range(cv):
        while True:
            testIds = random.sample(X['Provider Id'].unique(),500)
            testData = X[X['Provider Id'].isin(testIds)]
            trainData = X[~X['Provider Id'].isin(testIds)]
            if len(testData['DRG Code'].unique()) == len(X['DRG Code'].unique()) and len(trainData['DRG Code'].unique()) == len(X['DRG Code'].unique()):
                break
        X_train = trainData[predictors]
        y_train = trainData['Average Medicare Payments Num']
        X_test = testData[predictors]
        y_test = testData['Average Medicare Payments Num']
        clf.fit(X_train, y_train)
        predictions = clf.predict(X_test)
        scores.append(mean_squared_error(y_test, predictions)**0.5)

    return np.mean(scores)

## Grid Search

This function accepts a list of n_estimators, max_depths, min_samples_splits and oob_scores. It then performs a grid search over all permutations. At the end it prints a table showing the results

In [5]:
def gridsearch(X, n_estimators, max_depths, min_samples_splits, oob_scores):
    predictors = ['Latitude','Longitude','DRG Code']
    best = [None, None, None, None,np.inf]
    line_str = '|  {: <8}  |  {: <3} |  {: <3} |  {: <3}  |  {}'
    print line_str.format('n_estimators', 'max_depth','min_samples_split','oob_score', 'RMSE')
    print line_str.format(':--', ':--', ':--',':--',':--')
    for n_estimator in n_estimators:
        for max_depth in max_depths:
            for min_samples_split in min_samples_splits:
                for oob_score in oob_scores:
                    alg =  RandomForestRegressor(n_estimators=n_estimator,max_depth = max_depth, 
                                                 min_samples_split =  min_samples_split, oob_score= oob_score, n_jobs=8)
                    score = crossVal(alg, X, predictors, cv=10)
                    if score < best[4]:
                        best = [n_estimator, max_depth, min_samples_split,oob_score, score]
                    print line_str.format(n_estimator, max_depth, min_samples_split,oob_score, score)

    print '\nBest:'
    print '{: <10} {: <5} {: <5} {: <5} ${:,.2f}'.format(*best)

First attempt grid search. The idea is be broad and then narrow in future searchs

In [None]:
n_estimators = [1, 10, 100, 500 ,1000]
max_depth = [1,2,4,8,16]
min_samples_split = [2,4,8,16]
oob_score = [True,False]
gridsearch(X, n_estimators, max_depth, min_samples_split, oob_score)

|  n_estimators  |  max_depth |  min_samples_split |  oob_score  |  RMSE
|  :--       |  :-- |  :-- |  :--  |  :--
|  1         |  1   |  2   |  1    |  7294.56778695
|  1         |  1   |  2   |  0    |  7294.73058545
|  1         |  1   |  4   |  1    |  7294.86541043
|  1         |  1   |  4   |  0    |  7295.14352671
|  1         |  1   |  8   |  1    |  7294.63353342
|  1         |  1   |  8   |  0    |  7294.7928194
|  1         |  1   |  16  |  1    |  7295.03235718
|  1         |  1   |  16  |  0    |  7294.75787242
|  1         |  2   |  2   |  1    |  6190.92170565
|  1         |  2   |  2   |  0    |  6190.87618149
|  1         |  2   |  4   |  1    |  6190.97712762
|  1         |  2   |  4   |  0    |  6190.85946699
|  1         |  2   |  8   |  1    |  6190.79227032
|  1         |  2   |  8   |  0    |  6190.9079704
|  1         |  2   |  16  |  1    |  6190.76542442
|  1         |  2   |  16  |  0    |  6190.7208444
|  1         |  4   |  2   |  1    |  5606.31133987
|  1         |  4   |  2   |  0    |  5601.89114103
|  1         |  4   |  4   |  1    |  5606.44176277
|  1         |  4   |  4   |  0    |  5608.20052036
|  1         |  4   |  8   |  1    |  5604.92284085
|  1         |  4   |  8   |  0    |  5605.6372499
|  1         |  4   |  16  |  1    |  5601.5562951
|  1         |  4   |  16  |  0    |  5606.07113797
|  1         |  8   |  2   |  1    |  3282.93749047
|  1         |  8   |  2   |  0    |  3273.44465419
|  1         |  8   |  4   |  1    |  3251.7401755
|  1         |  8   |  4   |  0    |  3312.60152948
|  1         |  8   |  8   |  1    |  3260.87044697
|  1         |  8   |  8   |  0    |  3291.37408416
|  1         |  8   |  16  |  1    |  3245.08230924
|  1         |  8   |  16  |  0    |  3242.38554034
|  1         |  16  |  2   |  1    |  2790.49283356
|  1         |  16  |  2   |  0    |  2756.0425743
|  1         |  16  |  4   |  1    |  2691.22157796
|  1         |  16  |  4   |  0    |  2658.17920559
|  1         |  16  |  8   |  1    |  2601.40401238
|  1         |  16  |  8   |  0    |  2597.23291154
|  1         |  16  |  16  |  1    |  2535.1921768
|  1         |  16  |  16  |  0    |  2509.91701011
|  10        |  1   |  2   |  1    |  7294.65265625
|  10        |  1   |  2   |  0    |  7294.55624734
|  10        |  1   |  4   |  1    |  7294.69562062
|  10        |  1   |  4   |  0    |  7294.64325994
|  10        |  1   |  8   |  1    |  7294.77125405
|  10        |  1   |  8   |  0    |  7294.57196966
|  10        |  1   |  16  |  1    |  7294.72286418
|  10        |  1   |  16  |  0    |  7294.65886418
|  10        |  2   |  2   |  1    |  6190.78543195
|  10        |  2   |  2   |  0    |  6190.782628
|  10        |  2   |  4   |  1    |  6190.7257814
|  10        |  2   |  4   |  0    |  6190.74201157
|  10        |  2   |  8   |  1    |  6190.78824212
|  10        |  2   |  8   |  0    |  6190.85155055
|  10        |  2   |  16  |  1    |  6190.74799024
|  10        |  2   |  16  |  0    |  6190.69197777
|  10        |  4   |  2   |  1    |  5592.96591593
|  10        |  4   |  2   |  0    |  5593.89623497
|  10        |  4   |  4   |  1    |  5595.36027626
|  10        |  4   |  4   |  0    |  5592.93708067
|  10        |  4   |  8   |  1    |  5593.32964418
|  10        |  4   |  8   |  0    |  5595.17531325
|  10        |  4   |  16  |  1    |  5594.81760658
|  10        |  4   |  16  |  0    |  5593.52094389
|  10        |  8   |  2   |  1    |  3163.7239878
|  10        |  8   |  2   |  0    |  3164.94729475
|  10        |  8   |  4   |  1    |  3170.03609828
|  10        |  8   |  4   |  0    |  3170.50626428
|  10        |  8   |  8   |  1    |  3173.54289922
|  10        |  8   |  8   |  0    |  3165.94260991
|  10        |  8   |  16  |  1    |  3170.1655599
|  10        |  8   |  16  |  0    |  3171.16585507
|  10        |  16  |  2   |  1    |  2283.31608229
|  10        |  16  |  2   |  0    |  2279.88443774
|  10        |  16  |  4   |  1    |  2256.32724671
|  10        |  16  |  4   |  0    |  2265.03290135
|  10        |  16  |  8   |  1    |  2240.54731069
|  10        |  16  |  8   |  0    |  2238.33601377
|  10        |  16  |  16  |  1    |  2222.99777649
|  10        |  16  |  16  |  0    |  2218.47085095
|  100       |  1   |  2   |  1    |  7294.63844841
|  100       |  1   |  2   |  0    |  7294.61115494
|  100       |  1   |  4   |  1    |  7294.63625999
|  100       |  1   |  4   |  0    |  7294.63641801
|  100       |  1   |  8   |  1    |  7294.61474149
|  100       |  1   |  8   |  0    |  7294.64691208
|  100       |  1   |  16  |  1    |  7294.65694572
|  100       |  1   |  16  |  0    |  7294.65927976
|  100       |  2   |  2   |  1    |  6190.81224069
|  100       |  2   |  2   |  0    |  6190.79557814
|  100       |  2   |  4   |  1    |  6190.79120535
|  100       |  2   |  4   |  0    |  6190.77996469
|  100       |  2   |  8   |  1    |  6190.74304852
|  100       |  2   |  8   |  0    |  6190.78378329
|  100       |  2   |  16  |  1    |  6190.78243485
|  100       |  2   |  16  |  0    |  6190.78039536
|  100       |  4   |  2   |  1    |  5594.90167058
|  100       |  4   |  2   |  0    |  5594.95014267
|  100       |  4   |  4   |  1    |  5594.4945627
|  100       |  4   |  4   |  0    |  5594.64089818
|  100       |  4   |  8   |  1    |  5594.42159931
|  100       |  4   |  8   |  0    |  5593.33782348
|  100       |  4   |  16  |  1    |  5593.87579515
|  100       |  4   |  16  |  0    |  5594.39863314
|  100       |  8   |  2   |  1    |  3159.36774604
|  100       |  8   |  2   |  0    |  3156.57289953
|  100       |  8   |  4   |  1    |  3158.03719858
|  100       |  8   |  4   |  0    |  3157.91614073
|  100       |  8   |  8   |  1    |  3157.08294284
|  100       |  8   |  8   |  0    |  3156.93468049
|  100       |  8   |  16  |  1    |  3157.99381201
|  100       |  8   |  16  |  0    |  3157.91088713
|  100       |  16  |  2   |  1    |  2226.09632067
|  100       |  16  |  2   |  0    |  2228.92880751
|  100       |  16  |  4   |  1    |  2214.41074009
|  100       |  16  |  4   |  0    |  2217.24925072
|  100       |  16  |  8   |  1    |  2199.05411541
|  100       |  16  |  8   |  0    |  2198.96254753
|  100       |  16  |  16  |  1    |  2189.5611402
|  100       |  16  |  16  |  0    |  2187.90212768
|  500       |  1   |  2   |  1    |  7294.64436675
|  500       |  1   |  2   |  0    |  7294.64063436
|  500       |  1   |  4   |  1    |  7294.6495678
|  500       |  1   |  4   |  0    |  7294.64525889
|  500       |  1   |  8   |  1    |  7294.6542858
|  500       |  1   |  8   |  0    |  7294.64919733
|  500       |  1   |  16  |  1    |  7294.65967917
|  500       |  1   |  16  |  0    |  7294.65388828
|  500       |  2   |  2   |  1    |  6190.77588032
|  500       |  2   |  2   |  0    |  6190.78875008
|  500       |  2   |  4   |  1    |  6190.77282069
|  500       |  2   |  4   |  0    |  6190.7764489
|  500       |  2   |  8   |  1    |  6190.78083427
|  500       |  2   |  8   |  0    |  6190.77588129
|  500       |  2   |  16  |  1    |  6190.76962268
|  500       |  2   |  16  |  0    |  6190.77193907
|  500       |  4   |  2   |  1    |  5594.188916
|  500       |  4   |  2   |  0    |  5594.1296475
|  500       |  4   |  4   |  1    |  5594.04666772
|  500       |  4   |  4   |  0    |  5594.13013365
|  500       |  4   |  8   |  1    |  5594.35602772
|  500       |  4   |  8   |  0    |  5594.41370767
|  500       |  4   |  16  |  1    |  5594.20155607
|  500       |  4   |  16  |  0    |  5594.10547991
|  500       |  8   |  2   |  1    |  3156.62691792
|  500       |  8   |  2   |  0    |  3157.58600206
|  500       |  8   |  4   |  1    |  3156.91971656
|  500       |  8   |  4   |  0    |  3157.07963417
|  500       |  8   |  8   |  1    |  3156.42383713
|  500       |  8   |  8   |  0    |  3157.34230044
|  500       |  8   |  16  |  1    |  3157.69680592
|  500       |  8   |  16  |  0    |  3157.33301579
|  500       |  16  |  2   |  1    |  2221.35552225
|  500       |  16  |  2   |  0    |  2223.12087011
|  500       |  16  |  4   |  1    |  2211.52226456
|  500       |  16  |  4   |  0    |  2211.36843591
|  500       |  16  |  8   |  1    |  2193.54280827
|  500       |  16  |  8   |  0    |  2194.41371636
|  500       |  16  |  16  |  1    |  2186.1105328
|  500       |  16  |  16  |  0    |  2185.45704271

In our broad search we noticed that there was very little difference between 100 and 500 esimators (on average a RMSE differnce of $2). Moving forward we decided to stick with a slightly lower number of trees (100 or 200) for performance reasons. We also noticed that high min_sample_split values produced better results (we'll investigate further soon). Similarly a higher max_depth produced better results. oob_score had almost no effect.

In [None]:
n_estimators = [100, 200]
max_depth = [16,24,32,64,128]
min_samples_split = [16,32]
oob_score = [False]
gridsearch(X, n_estimators, max_depth, min_samples_split, oob_score)


|  n_estimators  |  max_depth |  min_samples_split |  oob_score  |  RMSE
|  :--       |  :-- |  :-- |  :--  |  :--
|  100       |  16  |  16  |  0    |  2191.59840446
|  100       |  16  |  32  |  0    |  2202.13914397
|  100       |  24  |  16  |  0    |  2180.34652932
|  100       |  24  |  32  |  0    |  2194.07113896
|  100       |  32  |  16  |  0    |  2181.9755508
|  100       |  32  |  32  |  0    |  2193.01493908
|  100       |  64  |  16  |  0    |  2180.96588913
|  100       |  64  |  32  |  0    |  2191.5557443
|  100       |  128 |  16  |  0    |  2182.31301262
|  100       |  128 |  32  |  0    |  2193.89428541
|  200       |  16  |  16  |  0    |  2185.9926411
|  200       |  16  |  32  |  0    |  2202.41665627
|  200       |  24  |  16  |  0    |  2176.74482435
|  200       |  24  |  32  |  0    |  2193.14700294
|  200       |  32  |  16  |  0    |  2179.98323207
|  200       |  32  |  32  |  0    |  2191.45995399
|  200       |  64  |  16  |  0    |  2178.52994012

Now let's explore what happens when we don't set a max depth limit

In [None]:
n_estimators = [200]
max_depth = [None]
min_samples_split = [8,12,16]
oob_score = [False]
gridsearch(X, n_estimators, max_depth, min_samples_split, oob_score)


|  n_estimators  |  max_depth |  min_samples_split |  oob_score  |  RMSE
|  :--       |  :-- |  :-- |  :--  |  :--
|  200       |  None |  8   |  0    |  2195.51468835
|  200       |  None |  12  |  0    |  2181.57165245
|  200       |  None |  16  |  0    |  2180.11692899

We didn't see a huge change when compared to setting the max depth to a limit. However, the model did perform slightly worse. For example with 200 estimators a max_depth of 24 and min_sample_split of 16 the Random Forest regression had a mean squared error of ~\$2176. With the same parameters and no max_depth the model had a RMSE of ~\$2180.

In [None]:
n_estimators = [200]
max_depth = [20,22,24,26,28]
min_samples_split = [14,16,18]
oob_score = [False]
gridsearch(X, n_estimators, max_depth, min_samples_split, oob_score)


|  n_estimators  |  max_depth |  min_samples_split |  oob_score  |  RMSE
|  :--       |  :-- |  :-- |  :--  |  :--
|  200       |  20  |  14  |  0    |  2175.91813626
|  200       |  20  |  16  |  0    |  2176.18007223
|  200       |  20  |  18  |  0    |  2175.82631433
|  200       |  22  |  14  |  0    |  2176.20445214
|  200       |  22  |  16  |  0    |  2176.81794258
|  200       |  22  |  18  |  0    |  2174.96407534
|  200       |  24  |  14  |  0    |  2178.19860334
|  200       |  24  |  16  |  0    |  2177.83784662
|  200       |  24  |  18  |  0    |  2177.06219171
|  200       |  26  |  14  |  0    |  2180.00419949
|  200       |  26  |  16  |  0    |  2178.07664792
|  200       |  26  |  18  |  0    |  2178.74010201
|  200       |  28  |  14  |  0    |  2180.01389183
|  200       |  28  |  16  |  0    |  2177.03809102
|  200       |  28  |  18  |  0    |  2177.8590263

In the end the best combination of parameters we were able to produce was (n_estimators=200, max_depth=22, min_samples_split=18). The produced a RMSE of 2174.96407534