## 17. Лабораторная работа «Разработка модели машинного обучения для выбранной предметной области» 

Dataset: [Car Evaluation](https://archive.ics.uci.edu/ml/datasets/Car+Evaluation)

In [17]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import LabelEncoder
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.metrics import mean_squared_error, mean_absolute_error, confusion_matrix
import os
import requests

%matplotlib inline
pd.options.display.max_columns = None

In [18]:
header = ["buying", "maint", "doors", "persons", "lug_boot", "safety", "class"]
data = pd.read_csv("car.data", names=header)
data = data.astype({"buying": "category", "maint": "category", "doors": "category", "persons": "category", "lug_boot": "category", "safety": "category", "class": "category"})

display(data.dtypes)

buying      category
maint       category
doors       category
persons     category
lug_boot    category
safety      category
class       category
dtype: object

По данным UCI пропуски в данных отсутствуют - охотно верим.

In [19]:
data.describe()

Unnamed: 0,buying,maint,doors,persons,lug_boot,safety,class
count,1728,1728,1728,1728,1728,1728,1728
unique,4,4,4,3,3,3,4
top,vhigh,vhigh,5more,more,small,med,unacc
freq,432,432,432,576,576,576,1210


In [20]:
data.sample(5)

Unnamed: 0,buying,maint,doors,persons,lug_boot,safety,class
1570,low,med,4,2,med,med,unacc
770,high,low,2,4,med,high,acc
509,high,vhigh,4,more,med,high,unacc
11,vhigh,vhigh,2,4,small,high,unacc
183,vhigh,high,4,more,med,low,unacc


In [21]:
data.duplicated().any()

False

Дублей нет

In [22]:
le = LabelEncoder()
for col in data.columns:
    data[col] = le.fit_transform(data[col])

data.sample(5)


Unnamed: 0,buying,maint,doors,persons,lug_boot,safety,class
638,0,0,3,1,0,0,0
1267,2,1,2,2,0,2,1
1377,1,3,3,0,2,1,2
501,0,3,2,1,0,1,2
968,2,3,3,2,1,0,0


In [23]:
X = data.drop(columns=["class"]).copy()
y = data["class"].copy()

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=1337)

In [26]:
mlp_clf = MLPClassifier(hidden_layer_sizes=(5),max_iter=10000, random_state=1337, shuffle=True, verbose=False)

mlp_clf.fit(X_train, y_train)
y_pred = mlp_clf.predict(X_test)

In [27]:
y_score = cross_val_score(mlp_clf, X_train, y_train, cv=10)

metrics = [
    ["Mean squared error (MSE)", mean_squared_error(y_test, y_pred)],
    ["Mean absolute error (MAE)", mean_absolute_error(y_test, y_pred)],
    ["Accuracy", mlp_clf.score(X_test, y_test)],
    ["Cross validation Accuracy", y_score.mean()]
]
pd.DataFrame(data=metrics, columns=["Metric", "Score"])

Unnamed: 0,Metric,Score
0,Mean squared error (MSE),0.973988
1,Mean absolute error (MAE),0.482659
2,Accuracy,0.751445
3,Cross validation Accuracy,0.725764


In [28]:
params = {
    'activation': ['logistic', 'tanh', 'relu'],
    'solver': ['lbfgs', 'adam', 'sgd'],
    'alpha': 10.0 ** -np.arange(1, 3),
    'hidden_layer_sizes': [(3), (4), (5), (12), (18), (24) ]
    }

mlp_clf_cv = MLPClassifier(random_state=1337)
gscv = GridSearchCV(mlp_clf_cv, params, cv=10, n_jobs=10)
gscv_pred = gscv.fit(X_train, y_train).predict(X_test)
gscv.best_params_

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
  self.n_iter_ = _check_optimize_result("lbfgs", opt_res, self.max_iter)


{'activation': 'logistic',
 'alpha': 0.1,
 'hidden_layer_sizes': 18,
 'solver': 'lbfgs'}

In [29]:
y_score = cross_val_score(mlp_clf, X_train, y_train, cv=10)

metrics = [
    ["Mean squared error (MSE)", mean_squared_error(y_test, gscv_pred)],
    ["Mean absolute error (MAE)", mean_absolute_error(y_test, gscv_pred)],
    ["Accuracy", gscv.score(X_test, y_test)],
    ["Cross validation Accuracy", gscv.best_score_]
]
pd.DataFrame(data=metrics, columns=["Metric", "Score"])

Unnamed: 0,Metric,Score
0,Mean squared error (MSE),0.130058
1,Mean absolute error (MAE),0.066474
2,Accuracy,0.962428
3,Cross validation Accuracy,0.989157


Стало получше