In [47]:
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib as mpl
import numpy as np
import seaborn as sns
from typing import List

In [27]:
red_wines = pd.read_csv("dataset/winequality-red.csv", sep=";")
white_wines = pd.read_csv("dataset/winequality-white.csv", sep=";")

In [28]:
red_wines["wine_type"] = "red"
white_wines["wine_type"] = "white"

In [29]:
def encode_quality(quality):
    if quality < 5:
        return "low"
    elif quality < 7:
        return "medium"
    return "high"


red_wines["quality_label"] = pd.Categorical(red_wines["quality"].apply(encode_quality), categories=["low", "medium", "high"], ordered=True)
white_wines["quality_label"] = pd.Categorical(white_wines["quality"].apply(encode_quality), categories=["low", "medium", "high"], ordered=True)

In [30]:
wines = pd.concat([red_wines, white_wines], ignore_index=True)
wines = wines.sample(frac=1, random_state=42)
wines.head()

Unnamed: 0,fixed acidity,volatile acidity,citric acid,residual sugar,chlorides,free sulfur dioxide,total sulfur dioxide,density,pH,sulphates,alcohol,quality,wine_type,quality_label
3103,7.0,0.17,0.74,12.8,0.045,24.0,126.0,0.9942,3.26,0.38,12.2,8,white,high
1419,7.7,0.64,0.21,2.2,0.077,32.0,133.0,0.9956,3.27,0.45,9.9,5,red,medium
4761,6.8,0.39,0.34,7.4,0.02,38.0,133.0,0.99212,3.18,0.44,12.0,7,white,high
4690,6.3,0.28,0.47,11.2,0.04,61.0,183.0,0.99592,3.12,0.51,9.5,6,white,medium
4032,7.4,0.35,0.2,13.9,0.054,63.0,229.0,0.99888,3.11,0.5,8.9,6,white,medium


In [132]:
def concat_described(data_frames: List[pd.DataFrame], subset_attributes : List[str], keys: List[str], drop_count = True):
    assert len(data_frames) == len(keys), "Number of data frames and keys should be equal"
    
    described_dfs = []
    for df in data_frames:
        described_dfs.append(df[subset_attributes].describe().drop("count") if drop_count else df[subset_attributes].describe())
        
    
    multi_level_data = {}
    
    for attr in subset_attributes:
        attr_data = pd.DataFrame()
        for i, key in enumerate(keys):
            attr_data[key] = described_dfs[i][attr]
            
        multi_level_data[attr] = attr_data
        
    return pd.concat(multi_level_data, axis=1)

In [130]:
concat_described(data_frames = [wines.where(wines["wine_type"] == "red"), 
                                wines.where(wines["wine_type"] == "white")], 
                 subset_attributes=["residual sugar", "total sulfur dioxide", "sulphates", "alcohol", "volatile acidity", "quality"],
                 keys=["red", "white"],
                 drop_count=True).style.set_sticky()

Unnamed: 0_level_0,residual sugar,residual sugar,total sulfur dioxide,total sulfur dioxide,sulphates,sulphates,alcohol,alcohol,volatile acidity,volatile acidity,quality,quality
Unnamed: 0_level_1,red,white,red,white,red,white,red,white,red,white,red,white
mean,2.538806,6.391415,46.467792,138.360657,0.658149,0.489847,10.422983,10.514267,0.527821,0.278241,5.636023,5.877909
std,1.409928,5.072058,32.895324,42.498065,0.169507,0.114126,1.065668,1.230621,0.17906,0.100795,0.807569,0.885639
min,0.9,0.6,6.0,9.0,0.33,0.22,8.4,8.0,0.12,0.08,3.0,3.0
25%,1.9,1.7,22.0,108.0,0.55,0.41,9.5,9.5,0.39,0.21,5.0,5.0
50%,2.2,5.2,38.0,134.0,0.62,0.47,10.2,10.4,0.52,0.26,6.0,6.0
75%,2.6,9.9,62.0,167.0,0.73,0.55,11.1,11.4,0.64,0.32,6.0,6.0
max,15.5,65.8,289.0,440.0,2.0,1.08,14.9,14.2,1.58,1.1,8.0,9.0


In [131]:
concat_described(data_frames = [wines.where(wines["quality_label"] == "low"), 
                                wines.where(wines["quality_label"] == "medium"),
                                wines.where(wines["quality_label"] == "high")], 
                 subset_attributes=["alcohol", "volatile acidity", "pH", "quality"],
                 keys=["quality low", "quality medium", "quality high"],
                 drop_count=False)

Unnamed: 0_level_0,alcohol,alcohol,alcohol,volatile acidity,volatile acidity,volatile acidity,pH,pH,pH,quality,quality,quality
Unnamed: 0_level_1,quality low,quality medium,quality high,quality low,quality medium,quality high,quality low,quality medium,quality high,quality low,quality medium,quality high
count,246.0,4974.0,1277.0,246.0,4974.0,1277.0,246.0,4974.0,1277.0,246.0,4974.0,1277.0
mean,10.1843,10.2653,11.4334,0.4652,0.3464,0.2892,3.2348,3.2153,3.2277,3.878,5.5702,7.159
std,0.999,1.0706,1.2156,0.2457,0.1657,0.117,0.1913,0.1595,0.1591,0.3279,0.4951,0.3763
min,8.0,8.0,8.5,0.11,0.08,0.08,2.74,2.72,2.84,3.0,5.0,7.0
25%,9.4,9.4,10.7,0.28,0.23,0.2,3.09,3.11,3.12,4.0,5.0,7.0
50%,10.05,10.0,11.5,0.38,0.3,0.27,3.225,3.2,3.22,4.0,6.0,7.0
75%,10.9,11.0,12.4,0.61,0.42,0.34,3.36,3.32,3.34,4.0,6.0,7.0
max,13.5,14.9,14.2,1.58,1.33,0.915,3.9,4.01,3.82,4.0,6.0,9.0
