# Model training

In this notebook you will find the training process only, no explorations are done in this file.

The target variable is `quality`. We are considering each note is a category so the problem is not a regression but a multi-classification.

In [1]:
import joblib
import pandas as pd
import numpy as np
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

In [3]:
df = pd.read_csv("API/src/winequality.csv")
pd.set_option("display.max_rows", None, "display.max_columns", None)
print(df.shape)

(6497, 13)


In [4]:
df.head(5)

Unnamed: 0,type,fixed acidity,volatile acidity,citric acid,residual sugar,chlorides,free sulfur dioxide,total sulfur dioxide,density,pH,sulphates,alcohol,quality
0,white,7.0,0.27,0.36,20.7,0.045,45.0,170.0,1.001,3.0,0.45,8.8,6
1,white,6.3,0.3,0.34,1.6,0.049,14.0,132.0,0.994,3.3,0.49,9.5,6
2,white,8.1,0.28,0.4,6.9,0.05,30.0,97.0,0.9951,3.26,0.44,10.1,6
3,white,7.2,0.23,0.32,8.5,0.058,47.0,186.0,0.9956,3.19,0.4,9.9,6
4,white,7.2,0.23,0.32,8.5,0.058,47.0,186.0,0.9956,3.19,0.4,9.9,6


In [5]:
y = df["quality"]
X = df.drop(["quality", "type"], axis=1)

In [6]:
type(X)

pandas.core.frame.DataFrame

In [7]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [8]:
type(X_train)

pandas.core.frame.DataFrame

In [9]:
model = Pipeline(steps=[("imputer", SimpleImputer(strategy="mean")),
                        ("scaler", StandardScaler()),
                        ("classifier", RandomForestClassifier())])

In [10]:
model.fit(X_train, y_train)

print("Accuracy: {:.2f}".format(model.score(X_test, y_test)))

Accuracy: 0.69


In [11]:
display(X_train.head(5))
display(X_test.head(5))
display(y_train.head(5))
display(y_test.head(5))

Unnamed: 0,fixed acidity,volatile acidity,citric acid,residual sugar,chlorides,free sulfur dioxide,total sulfur dioxide,density,pH,sulphates,alcohol
1916,6.6,0.25,0.36,8.1,0.045,54.0,180.0,0.9958,3.08,0.42,9.2
947,8.5,0.16,0.35,1.6,0.039,24.0,147.0,0.9935,2.96,0.36,10.0
877,6.0,0.28,0.34,1.6,0.119,33.0,104.0,0.9921,3.19,0.38,10.2
2927,7.0,0.31,0.31,9.1,0.036,45.0,140.0,0.99216,2.98,0.31,12.0
6063,8.5,0.44,0.5,1.9,0.369,15.0,38.0,0.99634,3.01,1.1,9.4


Unnamed: 0,fixed acidity,volatile acidity,citric acid,residual sugar,chlorides,free sulfur dioxide,total sulfur dioxide,density,pH,sulphates,alcohol
3103,7.0,0.25,0.45,2.3,0.045,40.0,118.0,0.99064,3.16,0.48,11.9
1419,7.6,0.14,0.74,1.6,0.04,27.0,103.0,0.9916,3.07,0.4,10.8
4761,6.2,0.15,0.27,11.0,0.035,46.0,116.0,0.99602,3.12,0.38,9.1
4690,6.7,0.16,0.32,12.5,0.035,18.0,156.0,0.99666,2.88,0.36,9.0
4032,6.8,0.27,0.22,17.8,0.034,16.0,116.0,0.9989,3.07,0.53,9.2


1916    5
947     5
877     6
2927    7
6063    5
Name: quality, dtype: int64

3103    7
1419    7
4761    6
4690    6
4032    5
Name: quality, dtype: int64

In [12]:
y_test

3103    7
1419    7
4761    6
4690    6
4032    5
1297    7
1773    6
5584    5
561     5
5946    6
1891    5
2264    6
6485    6
217     5
230     4
2168    7
1400    7
4355    6
4697    7
4295    5
4660    6
5417    5
3270    5
6294    5
2996    8
4147    7
2876    5
2955    5
496     4
828     7
1397    7
2127    5
1263    6
1963    6
706     5
5464    6
1375    7
585     6
4787    8
3326    5
1608    5
96      6
4453    6
3946    6
31      6
2807    5
2104    5
491     7
401     5
5009    5
712     6
1406    8
2957    6
764     6
4161    7
3966    6
538     5
3258    6
3254    5
3100    7
3250    6
5990    6
239     6
2617    7
3413    6
4672    5
23      5
1616    6
6126    7
3733    6
3176    6
1330    5
3289    5
5061    5
5035    5
5972    5
4606    5
4921    5
3443    6
4562    7
503     5
5661    5
1261    6
4608    7
1501    5
3407    5
747     6
544     6
296     5
6020    6
1383    6
4165    7
5027    5
4420    6
435     7
3244    7
1321    6
2585    7
2981    7
132     5


In [9]:
joblib.dump(model, "model.joblib")

['model.joblib']