In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split, GridSearchCV, KFold
from sklearn.metrics import (accuracy_score, precision_score, classification_report,
                             recall_score, f1_score, log_loss)

In [2]:
from pattern_clf import *
from datasets import *

In [3]:
X, y = get_breast_cancer()

In [4]:
clf = LazyPatternClassifier()

In [5]:
param_grid = {
    'tolerance': np.logspace(-8, -2, 4),
    'use_softmax': [True],
    'weights_strategy': ['from_objects', 'from_classifiers', 'uniform'],
    'weights_iters': [1, 5, 10],
    'weight_classifiers': [False, True]
}
grid = GridSearchCV(clf, param_grid,
                    scoring='f1',
                    cv=KFold(n_splits=5, shuffle=True, random_state=495),
                    verbose=2, return_train_score=True,
                    n_jobs=-1, refit=False)

In [6]:
grid.fit(X, y);

Fitting 5 folds for each of 72 candidates, totalling 360 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 2 concurrent workers.
[Parallel(n_jobs=-1)]: Done  37 tasks      | elapsed: 98.9min
[Parallel(n_jobs=-1)]: Done 158 tasks      | elapsed: 373.0min
[Parallel(n_jobs=-1)]: Done 360 out of 360 | elapsed: 1021.5min finished


In [12]:
df = pd.DataFrame(grid.cv_results_)

In [25]:
df2 = df[['mean_fit_time',
          'param_tolerance',
          'param_weight_classifiers',
          'param_weights_iters',
          'param_weights_strategy', 'mean_test_score']]
df2 = df2.sort_values('mean_test_score')
df2

Unnamed: 0,mean_fit_time,param_tolerance,param_weight_classifiers,param_weights_iters,param_weights_strategy,mean_test_score
19,58.919238,1e-06,False,1,from_classifiers,0.000000
34,1251.327284,1e-06,True,10,from_classifiers,0.000000
43,629.323045,0.0001,False,10,from_classifiers,0.000000
31,295.045695,1e-06,True,5,from_classifiers,0.000000
46,110.825578,0.0001,True,1,from_classifiers,0.000000
28,57.898499,1e-06,True,1,from_classifiers,0.000000
49,1728.021700,0.0001,True,5,from_classifiers,0.000000
25,582.711264,1e-06,False,10,from_classifiers,0.000000
52,605.908435,0.0001,True,10,from_classifiers,0.000000
22,291.772818,1e-06,False,5,from_classifiers,0.000000
