# 4. Preliminary Modelling


In [1]:
import pandas as pd
import numpy as np
import xgboost as xgb
import sklearn

## Create Train-Test Splits

In [2]:
feat_df = pd.read_pickle('../data/features_df.pkl')
meta_df = pd.read_csv('../data/speechdetails.csv')

In [3]:
#Removes all the text that was in before
feat_df = feat_df.select_dtypes(exclude=['object'])

In [4]:
X = feat_df.values
y = meta_df['IC'].values

In [5]:
from sklearn.model_selection import train_test_split

In [6]:
X_train, X_test, y_train, y_test = train_test_split(X,y, 
                                                    test_size=0.2,
                                                    shuffle=True)

In [7]:
#Read in the dataframes
feat_df = pd.read_pickle('../data/features_df.pkl')
meta_df = pd.read_csv('../data/speechdetails.csv')

#Removes all the text that was in before
feat_df = feat_df.select_dtypes(exclude=['object'])

#Isolate predictive (X) and target (y) variables
X = feat_df.values
y = meta_df['IC'].values

#Implement train test split
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X,y, 
                                                    test_size=0.2,
                                                    random_state=11)

## XGBOOST

In [8]:
xg_reg = xgb.XGBRegressor(objective ='reg:linear', colsample_bytree = 0.3, learning_rate = 0.1,
                max_depth = 5, alpha = 10, n_estimators = 10)

In [9]:
xg_reg.fit(X_train, y_train)

XGBRegressor(alpha=10, base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bytree=0.3, gamma=0, importance_type='gain',
       learning_rate=0.1, max_delta_step=0, max_depth=5,
       min_child_weight=1, missing=None, n_estimators=10, n_jobs=1,
       nthread=None, objective='reg:linear', random_state=0, reg_alpha=0,
       reg_lambda=1, scale_pos_weight=1, seed=None, silent=True,
       subsample=1)

In [10]:
preds = xg_reg.predict(X_test)

In [12]:
from sklearn import metrics

print('Root Mean Squared Error:', np.sqrt(metrics.mean_squared_error(y_test, preds)))

Mean Absolute Error: 0.4548896045049032
Mean Squared Error: 0.2840524020486156
Root Mean Squared Error: 0.532965666857271


### Try with K-fold Cross Validation

In [13]:
data_dmatrix = xgb.DMatrix(data=X,label=y)

In [14]:
params = {"objective":"reg:linear",'colsample_bytree': 0.3,'learning_rate': 0.1,
                'max_depth': 5, 'alpha': 10}

cv_results = xgb.cv(dtrain=data_dmatrix, params=params, nfold=3,
                    num_boost_round=50,early_stopping_rounds=10,metrics="rmse", as_pandas=True, seed=123)

In [15]:
cv_results

Unnamed: 0,train-rmse-mean,train-rmse-std,test-rmse-mean,test-rmse-std
0,1.222118,0.039675,1.220663,0.087276
1,1.124399,0.036571,1.124137,0.089193
2,1.037226,0.033902,1.037717,0.093476
3,0.9597,0.031666,0.96071,0.097018
4,0.89093,0.029815,0.892631,0.1
5,0.829778,0.028143,0.832865,0.100951
6,0.775398,0.027006,0.780085,0.102865
7,0.727083,0.025794,0.733415,0.104681
8,0.685052,0.025115,0.692368,0.10585
9,0.647245,0.02485,0.655065,0.105256


In [16]:
print((cv_results["test-rmse-mean"]).tail(1))

49    0.417169
Name: test-rmse-mean, dtype: float64


### Visualize Boosting Trees and Feature Importance
We can visualize individual trees from the fully boosted model that XGBoost creates using the entire dataset.

In [17]:
import matplotlib.pyplot as plt

In [18]:
model = xgb.XGBRegressor()
model.fit(X, y)

XGBRegressor(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bytree=1, gamma=0, importance_type='gain',
       learning_rate=0.1, max_delta_step=0, max_depth=3,
       min_child_weight=1, missing=None, n_estimators=100, n_jobs=1,
       nthread=None, objective='reg:linear', random_state=0, reg_alpha=0,
       reg_lambda=1, scale_pos_weight=1, seed=None, silent=True,
       subsample=1)

In [19]:
print(model.feature_importances_)

[0.0187495  0.05512465 0.04033196 0.03925942 0.         0.
 0.04000152 0.05356915 0.10524248 0.05023061 0.03193491 0.02034778
 0.06616806 0.03187322 0.03114282 0.03238795 0.         0.
 0.         0.         0.         0.         0.027601   0.11083944
 0.04850224 0.07841585 0.04581972 0.07245771]


## Random Forest

In [20]:
from sklearn.preprocessing import StandardScaler

sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)

In [21]:
from sklearn.ensemble import RandomForestRegressor

regressor = RandomForestRegressor(n_estimators=20, random_state=0)
regressor.fit(X_train, y_train)
y_pred = regressor.predict(X_test)

In [22]:
from sklearn import metrics

print('Mean Absolute Error:', metrics.mean_absolute_error(y_test, y_pred))
print('Mean Squared Error:', metrics.mean_squared_error(y_test, y_pred))
print('Root Mean Squared Error:', np.sqrt(metrics.mean_squared_error(y_test, y_pred)))

Mean Absolute Error: 0.30050277799999997
Mean Squared Error: 0.12828849560959255
Root Mean Squared Error: 0.3581738343452695


In [23]:
import matplotlib.pyplot as plt

## Learning Curve for Neural Networks

In [24]:
loss = [47249372.881355934, 38306795.59322034, 30496030.101694915, 23914019.050847456, 18367448.06779661, 13821742.508474575, 10158476.677966101, 7278298.016949153, 5052215.610169492, 3419284.194915254, 2233431.1016949154, 1389939.5, 807155.3262711865, 436049.52966101695, 212049.6435381356, 88569.68855932204, 29089.809586864405, 6906.328853283899, 3025.3364878509005, 6082.789608712924, 9978.95224774894, 12225.243991657839, 12408.116227489407, 11126.867915783898, 9153.776698225636, 7117.212360963983, 5422.773627846928, 4214.0295699814615, 3469.132022146451, 3099.2443516618114, 2955.493681309587, 2935.1299903998943, 2952.4478945974574, 2972.384749073093, 2972.9484242584745, 2971.580227092161, 2961.7901135460806, 2950.2542869438557, 2939.7535255561443, 2934.0513564287608, 2928.2417654263772, 2926.0384211136125, 2925.271095405191, 2924.6971414856994, 2924.844267247087, 2924.3919988082625, 2923.322307004767, 2922.009223550053, 2921.1656432070977, 2919.449545650159, 2918.35692117982, 2917.122463420286, 2915.9873833090573, 2914.9199715307204, 2914.1341904462392, 2913.5039145259534, 2912.1346456236756, 2911.264892578125, 2910.3287332825744, 2908.845558295816, 2907.8761172537074, 2906.9536339711335, 2905.7791272179556, 2904.850424556409, 2903.6636776681676, 2902.4249329647773, 2901.3584315413136, 2900.3485666048728, 2899.1073018736756, 2897.9701817399364, 2896.853354243909, 2895.7259873212392, 2895.1455285023835, 2893.8194766287074, 2892.2333653336864, 2891.1339587195444, 2889.922143968485, 2888.6610955707097, 2887.4551650225108, 2886.1924283302437, 2885.7139747748943, 2884.0647593352755, 2884.0575840836864, 2882.8662440413136, 2881.9705955376057, 2879.6670997748943, 2878.0169698424256, 2876.0599965240995, 2875.013485666049, 2874.6502995895125, 2872.7287680415784, 2871.9047065346927, 2870.430448887712, 2869.4886412539727, 2867.7016560182733, 2868.1429298530193, 2865.1268372616523, 2863.311391022246, 2861.6662556276483, 2860.353557004767, 2859.1530099642478, 2857.852108712924, 2856.4526822364937, 2854.7940818657307, 2853.2739175052966, 2852.1541437698625, 2850.4583388506358, 2850.2251473119704, 2848.92722126589, 2847.28126241393, 2846.395425052966, 2844.1924490201272, 2842.4428793697034, 2840.470417604608, 2838.7862197100108, 2838.107310149629, 2836.4529305150954, 2835.1059611692267, 2833.1013845670022, 2831.3252615201272, 2829.886147709216, 2828.2281307931676, 2826.61036397643, 2825.3891518802966, 2824.457738844015, 2823.3062806210273, 2821.6281779661017, 2819.391398801642, 2817.8105799788136, 2816.414811473782, 2814.453803628178, 2812.5274306475108, 2810.8874718617585, 2809.2938170352227, 2807.9369496490995, 2807.5009848384534, 2805.895780918962, 2802.7984308792375, 2801.1825385659427, 2799.378028998941, 2798.059897212659, 2796.82423116393, 2795.104521153337, 2792.85939982786, 2790.9799887447034, 2789.649145094015, 2787.55809305482, 2786.0087021649892, 2784.315818657309, 2782.4409965903073, 2781.2707850569386, 2778.9955599510063, 2777.305001986229, 2775.49022196107, 2773.68361857786, 2771.938865532309, 2770.1473574880824, 2768.522854045286, 2766.5631579382944, 2764.7797768802966, 2763.0565951072563, 2761.188691737288, 2759.4117617849574, 2757.8123013771187, 2755.83102985964, 2754.2630304886125, 2752.062483448093, 2750.3391982256358, 2748.6478995630296, 2748.6040949417375, 2745.174589512712, 2743.1366980601165, 2740.9369786149364, 2739.0353672868114, 2737.6360897775426, 2735.3613777807204, 2733.36618610964, 2732.0523329912608, 2730.0429232322563, 2728.3402409957625, 2726.392809851695, 2726.003078654661, 2722.14210225768, 2720.0756049721927, 2718.100192829714, 2717.6180978548728, 2714.3309181342693, 2712.2337708554023, 2710.3550839181676, 2708.2619297868114, 2707.2809313757944, 2704.93895242982, 2703.7861410884534, 2700.4201991194386, 2698.231838420286, 2698.6197571835273, 2697.2058560646187, 2693.0906796212926, 2691.7652443061443, 2688.240681276483]
[54778575.118644066, 44252988.779661015, 35667868.81355932, 27857307.627118643, 21748981.694915254, 16684007.661016949, 12526525.949152542, 9159840.847457627, 6626329.824152542, 4503556.813559322, 2972132.686440678, 1909849.309322034, 1196887.3559322034, 668439.1663135593, 349411.0995762712, 164143.29157838982, 67845.00986493644, 19401.686093087923, 5181.098448672537, 3668.745506157309, 7456.424953654661, 10905.938228283898, 12425.25627317267, 12066.577926377118, 10475.606660487289, 8446.452338784427, 6495.164426641949, 4947.04001423464, 3966.537614208157, 3315.4402765823625, 3099.3629998675847, 2957.2897576800847, 2936.9808825476694, 2967.433461334746, 2972.8956133309057, 2974.2733878442796, 2968.0256678694386, 2956.604305978549, 2946.327351198358, 2933.4033285884534, 2933.1307228217693, 2927.221108646716, 2926.891531216896, 2924.348777641684, 2925.479939088983, 2925.7107968087926, 2925.2032532772773, 2924.1513713254767, 2924.403576867055, 2922.100180415784, 2919.5367866128177, 2917.841267213983, 2916.809773073358, 2916.7017677436443, 2914.4097548662608, 2915.0591151350636, 2915.2972391419494, 2912.8686378608318, 2911.4070444915255, 2913.6356014962926, 2909.4253012447034, 2910.176368842691, 2913.3366492319915, 2907.623911712129, 2904.6826833951272, 2904.901296841896, 2903.171804654396, 2903.338825807733, 2906.545298430879, 2899.1236758474574, 2904.318980071504, 2896.7380040055614, 2900.733303264036, 2895.022556110964, 2892.9094196901483, 2895.4504601430085, 2891.0649331302966, 2890.3052957825744, 2888.441923497087, 2890.8206600900426, 2888.5954962261653, 2890.6817647643006, 2887.5025820974574, 2888.37790485964, 2885.100991459216, 2880.9495290982522, 2880.667927370233, 2877.4114390227755, 2874.7575600834216, 2876.3175193657307, 2878.231577727754, 2872.256662142479, 2872.872244107521, 2869.9652285818324, 2871.5891568458687, 2876.5531192068324, 2871.6969842425847, 2867.489204018803, 2862.5594482421875, 2865.90821553893, 2860.184516518803, 2860.8724799721927, 2858.035210043697, 2856.301530223782, 2862.8868097854875, 2855.82225321107, 2858.530331369174, 2854.0599799721927, 2850.41063211732, 2849.7594883805614, 2850.676199185646, 2846.096936242055, 2844.0413963188557, 2841.663133524232, 2842.5225271451272, 2842.89259053893, 2837.9969461731994, 2837.1630859375, 2834.4566629700744, 2834.2769175384005, 2836.283745199947, 2829.9900109242585, 2827.571615962659, 2827.418891518803, 2827.57615946107, 2825.4654354806676, 2825.253989009534, 2821.4355675648835, 2832.6900696835273, 2817.7743809586864, 2816.7789658368642, 2815.8832800913665, 2812.862416412871, 2810.4845156912074, 2810.9278378244176, 2810.9664616988875, 2808.3749503442796, 2809.2256893869176, 2803.3547032243114, 2801.1389077396716, 2798.829387082892, 2803.2093816207625, 2797.8240201271187, 2795.5670476363875, 2793.6638638771187, 2796.0808767545022, 2795.5929803363347, 2792.7794913599046, 2788.036614886785, 2783.8990292306676, 2793.3722689353813, 2780.834605071504, 2781.2209141618114, 2781.9483663268006, 2777.4311192399364, 2774.400895458157, 2775.0660545219807, 2771.8136420815677, 2772.5188319319386, 2768.067680746822, 2764.2325253244176, 2765.64451883607, 2762.1100163863875, 2760.5604972192796, 2758.3770027807204, 2760.020590572034, 2753.8865904727227, 2751.5130553164727, 2751.2973177635063, 2759.395114704714, 2752.5941472457625, 2745.6635080111228, 2743.9746714446505, 2740.8155331369176, 2743.9698920815677, 2736.7224990068858, 2736.504771087129, 2735.3884360103284, 2731.817651780985, 2735.9570643538136, 2745.5763104972193, 2738.3999768273306, 2727.548012943591, 2722.2658732786017, 2720.460192664195, 2730.908203125, 2718.9443028336864, 2713.7873990333687, 2712.9565802105403, 2715.9335027145125, 2711.438608977754, 2707.381736626059, 2710.509391138109, 2703.8245663400426, 2702.186453091896, 2705.2426012976694, 2698.958384368379, 2701.193702827066, 2703.3316836599574, 2696.681495795816]

[54778575.118644066,
 44252988.779661015,
 35667868.81355932,
 27857307.627118643,
 21748981.694915254,
 16684007.661016949,
 12526525.949152542,
 9159840.847457627,
 6626329.824152542,
 4503556.813559322,
 2972132.686440678,
 1909849.309322034,
 1196887.3559322034,
 668439.1663135593,
 349411.0995762712,
 164143.29157838982,
 67845.00986493644,
 19401.686093087923,
 5181.098448672537,
 3668.745506157309,
 7456.424953654661,
 10905.938228283898,
 12425.25627317267,
 12066.577926377118,
 10475.606660487289,
 8446.452338784427,
 6495.164426641949,
 4947.04001423464,
 3966.537614208157,
 3315.4402765823625,
 3099.3629998675847,
 2957.2897576800847,
 2936.9808825476694,
 2967.433461334746,
 2972.8956133309057,
 2974.2733878442796,
 2968.0256678694386,
 2956.604305978549,
 2946.327351198358,
 2933.4033285884534,
 2933.1307228217693,
 2927.221108646716,
 2926.891531216896,
 2924.348777641684,
 2925.479939088983,
 2925.7107968087926,
 2925.2032532772773,
 2924.1513713254767,
 2924.40357686705