In [69]:
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.model_selection import StratifiedKFold
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from scipy.stats import randint
from sklearn.metrics import accuracy_score
import numpy as np

In [70]:
wine = load_wine()
X, y = wine.data, wine.target

In [71]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [72]:
param_dist = {
    'criterion': ['gini', 'entropy'],
    'splitter': ['best', 'random'],
    'max_depth': randint(1, 20),
    'min_samples_split': randint(2, 20),
    'min_samples_leaf': randint(1, 20),
}

In [73]:
tree_classifier = DecisionTreeClassifier(random_state=42)
random_search = RandomizedSearchCV(tree_classifier, param_distributions=param_dist, n_iter=100, cv=5, scoring='accuracy', random_state=42)
random_search.fit(X_train, y_train)


In [74]:
best_tree_params = random_search.best_params_
best_tree_params

{'criterion': 'gini',
 'max_depth': 13,
 'min_samples_leaf': 1,
 'min_samples_split': 8,
 'splitter': 'best'}

In [75]:
best_tree_classifier = DecisionTreeClassifier(**best_tree_params, random_state=42)
best_tree_classifier.fit(X_train, y_train)

In [76]:
y_pred_tree = best_tree_classifier.predict(X_test)
accuracy_tree = accuracy_score(y_test, y_pred_tree)
print(f"Decision Tree Accuracy: {accuracy_tree * 100:.2f}%")

Decision Tree Accuracy: 94.44%


In [77]:
skf = StratifiedKFold(n_splits=10,shuffle=True, random_state=2)

In [78]:
idx = skf.split(X_train,y_train)


In [79]:
acc=[]
for train_idx,_ in idx:
    subset_tree = DecisionTreeClassifier(**best_tree_params, random_state=2)
    subset_tree.fit(X_train[train_idx],y_train[train_idx])
    pred = subset_tree.predict(X_test)
    accuracy = accuracy_score(y_test,pred)
    acc.append(accuracy)
    print(accuracy)


0.9444444444444444
0.9444444444444444
0.9444444444444444
0.9444444444444444
0.9166666666666666
0.9722222222222222
0.8888888888888888
0.9444444444444444
0.9166666666666666
0.9444444444444444


In [81]:

print(f'average accuracy is {np.mean(acc)}')

average accuracy is 0.9361111111111111
