# Investigating the # of trees with RF vs. train/test error

**Methods:**
>1. Load subset data
>2. Create a validation curve with the number of estimators as the x-axis
>3. Plot this

**Conclusions:
* 

In [1]:
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import sys
import time

import sklearn.svm as skl_svm
import sklearn.linear_model as skl_lm
import sklearn.ensemble as skl_ensemble
import sklearn.grid_search as skl_gs
import sklearn.model_selection as skl_model_selection

import data_processor as mdp

data_processor = mdp.MNIST_data_processor()



## 1. Load subset data

In [3]:
X, y = data_processor.load_full_data()

## 2. Create a validation curve with the number of estimators as the x-axis

In [None]:
train_scores, test_scores = skl_model_selection.validation_curve(
                                skl_ensemble.RandomForestClassifier(), 
                                X, y, 
                                param_name="n_estimators", 
                                param_range=[10,100,1000,5000],
                                cv=5, scoring="accuracy", n_jobs=5)

train_scores_mean = np.mean(train_scores, axis=1)
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)

plt.title("Validation Curve with RF")
plt.xlabel("Number of Trees")
plt.ylabel("Accuracy")
plt.ylim(0.0, 1.1)
lw = 2
plt.semilogx(param_range, train_scores_mean, label="Training score",
             color="darkorange", lw=lw)
plt.fill_between(param_range, train_scores_mean - train_scores_std,
                 train_scores_mean + train_scores_std, alpha=0.2,
                 color="darkorange", lw=lw)
plt.semilogx(param_range, test_scores_mean, label="Cross-validation score",
             color="navy", lw=lw)
plt.fill_between(param_range, test_scores_mean - test_scores_std,
                 test_scores_mean + test_scores_std, alpha=0.2,
                 color="navy", lw=lw)
plt.legend(loc="best")

## 3. Plot this