# Sklearn

## sklearn.grid_search

документация: http://scikit-learn.org/stable/modules/grid_search.html

In [2]:
from sklearn import cross_validation, datasets, grid_search, linear_model, metrics

import numpy as np
import pandas as pd



### Генерация датасета

In [3]:
iris = datasets.load_iris()

In [4]:
train_data, test_data, train_labels, test_labels = cross_validation.train_test_split(iris.data, iris.target, 
                                                                                     test_size = 0.3,random_state = 0)

### Задание модели

In [5]:
classifier = linear_model.SGDClassifier(random_state = 0)

### Генерация сетки

In [6]:
classifier.get_params().keys()

['warm_start',
 'loss',
 'n_jobs',
 'eta0',
 'verbose',
 'shuffle',
 'fit_intercept',
 'epsilon',
 'average',
 'max_iter',
 'penalty',
 'power_t',
 'random_state',
 'tol',
 'l1_ratio',
 'n_iter',
 'alpha',
 'learning_rate',
 'class_weight']

In [7]:
parameters_grid = {
    'loss' : ['hinge', 'log', 'squared_hinge', 'squared_loss'],
    'penalty' : ['l1', 'l2'],
    'n_iter' : range(5,10),
    'alpha' : np.linspace(0.0001, 0.001, num = 5),
}

In [8]:
cv = cross_validation.StratifiedShuffleSplit(train_labels, n_iter = 10, test_size = 0.2, random_state = 0)

### Подбор параметров и оценка качества

#### Grid search

In [12]:
grid_cv = grid_search.GridSearchCV(classifier, parameters_grid, scoring = 'accuracy', cv = cv)

In [19]:
import warnings
?warnings.simplefilter

[0;31mDocstring:[0m

A simple filter matches all modules and messages.
'action' -- one of "error", "ignore", "always", "default", "module",
            or "once"
'append' -- if true, append to the list of filters
[0;31mType:[0m      function


In [None]:
%%time
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    grid_cv.fit(train_data, train_labels)

In [None]:
grid_cv.best_estimator_

In [16]:
print grid_cv.best_score_
print grid_cv.best_params_

0.895238095238
{'penalty': 'l1', 'alpha': 0.000325, 'n_iter': 9, 'loss': 'hinge'}


In [None]:
grid_cv.grid_scores_[:10]

#### Randomized grid search

In [None]:
randomized_grid_cv = grid_search.RandomizedSearchCV(classifier, parameters_grid, scoring = 'accuracy', cv = cv, n_iter = 20, 
                                                   random_state = 0)

In [None]:
%%time
randomized_grid_cv.fit(train_data, train_labels)

In [None]:
print randomized_grid_cv.best_score_
print randomized_grid_cv.best_params_