In [5]:
from pathlib import Path

from matplotlib.colors import ListedColormap
from sklearn import datasets
import matplotlib.pyplot as plt
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.linear_model import LinearRegression, LogisticRegression
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline

In [6]:
# Set the folder path to save the figures to
IMG_FOLDER = Path('img')
plt.style.use('dark_background')
plt.ion()

<contextlib.ExitStack at 0x75b802baa0d0>

In [7]:
X, labels = datasets.load_wine(return_X_y=True, as_frame=True)

alcohol = X.alcohol
color_intensity = X.color_intensity

In [23]:
for a in (0, 0.03, 0.06, 0.12, 0.15, 0.18, 0.21, 0.24, 0.27, 0.3):
    for b in (0, 3, 6, 9, 12, 15):
        print(f'{a};{b}')
        alcohol_pred = a * color_intensity + b

        fig, ax = plt.subplots()
        fig.set_size_inches(16, 9)

        # plot alcohol against color_intensity
        ax.set_title('Wine samples - Alcohol vs Color Intensity')
        ax.set_xlabel('Color Intensity')
        ax.set_ylabel('Alcohol (Vol. %)')
        ax.scatter(color_intensity, alcohol, label='Samples')
        ax.plot(color_intensity, alcohol_pred, color='red',
                label=f'y = {a} * color_intensity + {b}')

        # plot the distance between the points and the regression line
        for i in range(len(color_intensity)):
            ax.plot([color_intensity[i], color_intensity[i]], [alcohol[i], alcohol_pred[i]], color='gray', linestyle='-')

        loss = np.abs(alcohol - alcohol_pred).sum()
        fig.suptitle(f'Loss: {loss:.2f}', fontsize=18)

        ax.legend()
        plt.tight_layout()
        plt.savefig(IMG_FOLDER / 'regression_examples' / f'wine_linear_regression__a_{a}_b_{b}.png')
        plt.close()

0;0
0;3
0;6
0;9
0;12
0;15
0.03;0
0.03;3
0.03;6
0.03;9
0.03;12
0.03;15
0.06;0
0.06;3
0.06;6
0.06;9
0.06;12
0.06;15
0.12;0
0.12;3
0.12;6
0.12;9
0.12;12
0.12;15
0.15;0
0.15;3
0.15;6
0.15;9
0.15;12
0.15;15
0.18;0
0.18;3
0.18;6
0.18;9
0.18;12
0.18;15
0.21;0
0.21;3
0.21;6
0.21;9
0.21;12
0.21;15
0.24;0
0.24;3
0.24;6
0.24;9
0.24;12
0.24;15
0.27;0
0.27;3
0.27;6
0.27;9
0.27;12
0.27;15
0.3;0
0.3;3
0.3;6
0.3;9
0.3;12
0.3;15
