In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
pd.set_option('display.max_columns', None)

In [2]:
### Load CSV ###

df = pd.read_csv("mushrooms.csv")
df = df.drop('veil-type', axis=1) # They are all the same according df.describe()

In [3]:
### Prepare Data ###

dfCategorical = pd.get_dummies(df)

# We need n-1 (where n = categories number) columns to guess a category, because if there are zeros in every columns
# the last one in obviously one, and if there is a one elsewhere, the last column could only be zero.
# Moreover, data redundancy can be bad for some models.

# Drop useless columns
dfCategorical = dfCategorical.drop(['class_p', 'cap-shape_b', 'cap-surface_f', 'cap-color_n', 'bruises_t', 'odor_a', 'gill-attachment_a', 'gill-spacing_c', 'gill-size_b', 'gill-color_k', 'stalk-shape_e', 'stalk-root_b', 'stalk-surface-above-ring_f', 'stalk-surface-below-ring_f', 'stalk-color-above-ring_n', 'stalk-color-below-ring_n', 'veil-color_n', 'ring-number_n', 'ring-type_e', 'spore-print-color_k', 'population_a', 'habitat_g'], axis=1)

y = dfCategorical['class_e']
X = dfCategorical.drop('class_e', axis=1)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

In [4]:
### Model ###

from sklearn.neighbors import KNeighborsClassifier

model = KNeighborsClassifier(n_neighbors=2)

%time model.fit(X_train, y_train) # Fit model to data

y_pred = model.predict(X_test) # Predict unseen data

CPU times: user 52.5 ms, sys: 1.28 ms, total: 53.7 ms
Wall time: 54.2 ms


In [5]:
### Metrics ###

from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.model_selection import cross_val_score

print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))
print(cross_val_score(model, X_train, y_train, cv=5))

[[1181    0]
 [   0 1257]]
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      1181
           1       1.00      1.00      1.00      1257

    accuracy                           1.00      2438
   macro avg       1.00      1.00      1.00      2438
weighted avg       1.00      1.00      1.00      2438

[1. 1. 1. 1. 1.]
