In [33]:
import numpy as np
import pandas as pd

In [34]:
""" A function that can read MNIST's idx file format into numpy arrays.
    The MNIST data files can be downloaded from here:
    
    http://yann.lecun.com/exdb/mnist/
    This relies on the fact that the MNIST dataset consistently uses
    unsigned char types with their data segments.
"""
import gzip

import struct

import numpy as np

def read_idx(filename):
    with gzip.open(filename) as f:
        zero, data_type, dims = struct.unpack('>HBB', f.read(4))
        shape = tuple(struct.unpack('>I', f.read(4))[0] for d in range(dims))
        return np.fromstring(f.read(), dtype=np.uint8).reshape(shape)


In [35]:
X_train = read_idx('X_train.gz')
y_train = read_idx('y_train.gz')
X_test = read_idx('X_test.gz')
y_test = read_idx('y_test.gz')



In [36]:
num_split = 60000

X_train, X_test, y_train, y_test = X_train[:num_split], X_test[num_split:], y_train[:num_split], y_test[num_split:]

shuffle_index = np.random.permutation(num_split)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]

X_train = X_train.reshape(60000,-1)

In [37]:
y_train_0 = (y_train == 0)
y_test_0 = (y_test == 0)


In [41]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_curve, confusion_matrix
from sklearn.model_selection import cross_val_predict
f_clf = RandomForestClassifier(random_state=0)
y_probas_forest = cross_val_predict(f_clf, X_train, y_train_0, cv=3, method='predict_proba')
y_scores_forest = y_probas_forest[:, 1]
fpr_forest, tpr_forest, threshold_forest = roc_curve(y_train_0, y_scores_forest)




In [42]:
y_train_pred = cross_val_predict(f_clf, X_train, y_train_0, cv=3)
y_scores = cross_val_predict(f_clf, X_train, y_train_0, cv=3)
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score
print(f1_score(y_train_0, y_train_pred))
print(roc_auc_score(y_train_0, y_scores))



0.9591498998345092
0.9643174482309012


In [None]:
from sklearn.model_selection import RandomizedSearchCV

from sklearn.model_selection import RandomizedSearchCV
# Number of trees in random forest
n_estimators = [int(x) for x in np.linspace(start = 200, stop = 2000, num = 10)]
# Number of features to consider at every split
max_features = ['auto', 'sqrt']
# Maximum number of levels in tree
max_depth = [int(x) for x in np.linspace(10, 110, num = 11)]
max_depth.append(None)
# Minimum number of samples required to split a node
min_samples_split = [2, 5, 10]
# Minimum number of samples required at each leaf node
min_samples_leaf = [1, 2, 4]
# Method of selecting samples for training each tree
bootstrap = [True, False]
# Create the random grid
random_grid = {'n_estimators': n_estimators,
               'max_features': max_features,
               'max_depth': max_depth,
               'min_samples_split': min_samples_split,
               'min_samples_leaf': min_samples_leaf,
               'bootstrap': bootstrap}

rf = RandomForestClassifier()
rf_random = RandomizedSearchCV(estimator = rf, param_distributions = random_grid, n_iter = 100, cv = 3, verbose=2, random_state=42, n_jobs = -1)
# Fit the random search model
rf_random.fit(X_train, y_train)
rf_random.best_params_

Fitting 3 folds for each of 100 candidates, totalling 300 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.


In [43]:
f_clf = RandomForestClassifier(random_state=0)
y_probas_forest = cross_val_predict(f_clf, X_train, y_train, cv=3, method='predict_proba')
y_scores_forest = y_probas_forest[:, 1]



In [44]:
y_train_pred = cross_val_predict(f_clf, X_train, y_train, cv=3)
y_scores = cross_val_predict(f_clf, X_train, y_train, cv=3)
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score
print(f1_score(y_train, y_train_pred, average='weighted'))
# we usually don't use roc for mutliclass problems print(roc_auc_score(y_train, y_scores))



0.9405416368372377


In [46]:
confusion_matrix(y_train, y_train_pred, labels=None, sample_weight=None)

array([[5813,    1,   17,    6,    7,   16,   32,    4,   23,    4],
       [   2, 6626,   41,   15,    9,    7,    9,   11,   13,    9],
       [  41,   33, 5644,   54,   36,    9,   29,   44,   59,    9],
       [  28,   19,  158, 5633,   15,  112,    8,   49,   76,   33],
       [  21,   15,   28,   13, 5540,    8,   29,   13,   26,  149],
       [  48,   17,   24,  207,   24, 4947,   61,    7,   54,   32],
       [  54,   14,   27,    6,   38,   57, 5694,    0,   24,    4],
       [  15,   38,   90,   38,   61,    5,    4, 5909,   25,   80],
       [  33,   65,  100,  134,   53,  102,   40,   20, 5225,   79],
       [  36,   16,   47,   83,  162,   42,    9,   90,   55, 5409]],
      dtype=int64)