# 1 Lineare Regression vs. Neuronales Netz für die vorhersage der Qualität von Wein

## 1.1 Lineare Regression

In [2]:
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split

In [3]:
df = pd.read_csv("Data/winequality-white.csv", sep=";")
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4898 entries, 0 to 4897
Data columns (total 12 columns):
 #   Column                Non-Null Count  Dtype  
---  ------                --------------  -----  
 0   fixed acidity         4898 non-null   float64
 1   volatile acidity      4898 non-null   float64
 2   citric acid           4898 non-null   float64
 3   residual sugar        4898 non-null   float64
 4   chlorides             4898 non-null   float64
 5   free sulfur dioxide   4898 non-null   float64
 6   total sulfur dioxide  4898 non-null   float64
 7   density               4898 non-null   float64
 8   pH                    4898 non-null   float64
 9   sulphates             4898 non-null   float64
 10  alcohol               4898 non-null   float64
 11  quality               4898 non-null   int64  
dtypes: float64(11), int64(1)
memory usage: 459.3 KB


In [4]:
X = df.drop(columns="quality")
y = df["quality"]
X_train, X_test, y_train, y_test = train_test_split(X,y)
df.tail()

Unnamed: 0,fixed acidity,volatile acidity,citric acid,residual sugar,chlorides,free sulfur dioxide,total sulfur dioxide,density,pH,sulphates,alcohol,quality
4893,6.2,0.21,0.29,1.6,0.039,24.0,92.0,0.99114,3.27,0.5,11.2,6
4894,6.6,0.32,0.36,8.0,0.047,57.0,168.0,0.9949,3.15,0.46,9.6,5
4895,6.5,0.24,0.19,1.2,0.041,30.0,111.0,0.99254,2.99,0.46,9.4,6
4896,5.5,0.29,0.3,1.1,0.022,20.0,110.0,0.98869,3.34,0.38,12.8,7
4897,6.0,0.21,0.38,0.8,0.02,22.0,98.0,0.98941,3.26,0.32,11.8,6


In [5]:
reg = LinearRegression().fit(X_train,y_train)
reg.coef_

array([ 2.50416070e-02, -1.88654873e+00,  9.37410912e-04,  6.98869162e-02,
       -4.70356190e-01,  3.70745986e-03, -2.71448467e-04, -1.17274746e+02,
        5.36442950e-01,  5.56680377e-01,  2.33691115e-01])

In [6]:
y_reg = reg.predict(X_test)
rmse_reg = round(mean_squared_error(y_reg,y_test,squared=False),4)
print("Root mean squared error:",rmse_reg)

Root mean squared error: 0.7593


## 1.2 Neuronales Netz

In [7]:
from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import StandardScaler

In [8]:
X_train_scaled = StandardScaler().fit(X_train).transform(X_train)
X_test_scaled = StandardScaler().fit(X_test).transform(X_test)
X_scaled = StandardScaler().fit(X).transform(X)

nn = MLPRegressor(max_iter=1000,
                  hidden_layer_sizes=(10,6),
                  activation="relu",
                  solver="sgd",
                  learning_rate='adaptive',
                  learning_rate_init=0.01
                 ).fit(X_train_scaled,y_train)
                 
y_nn = nn.predict(X_test_scaled)
rmse_nn = round(mean_squared_error(y_nn,y_test,squared=False),4)
print("RMSE regression:\t",rmse_reg)
print("RMSE neural network:\t",rmse_nn)

RMSE regression:	 0.7593
RMSE neural network:	 0.7148


In [9]:
validation = pd.DataFrame({"quality": np.array(y_test),
              "prediction_nn": np.round(y_nn,4),
              "prediction_reg": np.round(y_reg,4)}, index=y_test.index)
pd.concat([X_test,validation],axis=1)

Unnamed: 0,fixed acidity,volatile acidity,citric acid,residual sugar,chlorides,free sulfur dioxide,total sulfur dioxide,density,pH,sulphates,alcohol,quality,prediction_nn,prediction_reg
2622,6.6,0.26,0.21,2.9,0.026,48.0,126.0,0.99089,3.22,0.38,11.3,7,6.4149,6.2300
3277,6.4,0.16,0.25,1.4,0.057,21.0,125.0,0.99091,3.23,0.44,11.1,7,6.0121,6.1841
4422,6.2,0.22,0.28,2.2,0.040,24.0,125.0,0.99170,3.19,0.48,10.5,6,5.9460,5.9089
2752,7.1,0.43,0.30,6.6,0.025,15.0,138.0,0.99126,3.18,0.46,12.6,6,6.5702,6.3388
1066,6.4,0.30,0.51,5.5,0.048,62.0,172.0,0.99420,3.08,0.45,9.1,6,5.2487,5.4221
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
435,6.9,0.20,0.37,6.2,0.027,24.0,97.0,0.99200,3.38,0.49,12.2,7,6.8184,6.7271
116,6.0,0.31,0.24,3.3,0.041,25.0,143.0,0.99140,3.31,0.44,11.3,6,6.2420,6.0736
4777,5.9,0.27,0.32,2.0,0.034,31.0,102.0,0.98952,3.16,0.56,12.3,6,6.5954,6.5329
1366,7.0,0.14,0.41,0.9,0.037,22.0,95.0,0.99140,3.25,0.43,10.9,6,6.1810,6.1243
