In [4]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, make_scorer
from sklearn.model_selection import GridSearchCV
from sklearn.tree import DecisionTreeClassifier
from pactools.grid_search import GridSearchCVProgressBar

import warnings
warnings.filterwarnings("ignore")

In [5]:
data = pd.read_pickle('data.pickle')

In [6]:
X = data['data']
y = data['labels']

In [7]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [14]:
param_grid = {
    'criterion': ['gini', 'entropy'],              # Function to measure the quality of a split
    'splitter': ['best', 'random'],                # Strategy used to split at each node
    'max_depth': [None, 10, 20, 30, 40],           # The maximum depth of the tree
    'min_samples_split': [2, 5, 10, 20],           # The minimum number of samples required to split an internal node
    'min_samples_leaf': [1, 2, 5, 10],             # The minimum number of samples required to be at a leaf node
    'max_features': [None, 'auto', 'sqrt', 'log2'], # The number of features to consider when looking for the best split
    'max_leaf_nodes': [None, 10, 20, 30],          # The maximum number of leaf nodes in the tree
    'class_weight': [None, 'balanced'],            # Weighting of classes (None or 'balanced')
    'ccp_alpha': [0.0, 0.01, 0.1, 1.0],           # Complexity parameter used for pruning
    'random_state': [42]                           # Set a fixed seed for reproducibility
}

# Create the MLPClassifier
mlp = DecisionTreeClassifier()

# Set up GridSearchCV with the MLP model and the parameter grid
grid_search = GridSearchCV(mlp, param_grid, cv=3, n_jobs=12, verbose=2)

grid_search.fit(X, y)
results = grid_search.cv_results_

Fitting 3 folds for each of 40960 candidates, totalling 122880 fits


In [15]:
df = pd.DataFrame(results)
df.sort_values(by=['rank_test_score'], axis=0, ascending=True, inplace=True)
df.to_csv('resultsDT.csv')