In [1]:
import pandas as pd
import altair as alt
import matplotlib.pyplot as plt

# Download and import the data

In [2]:
# Download the data into the /data folder as raw.csv
!python download_data.py data raw.csv

In [3]:
df = pd.read_csv('data/raw.csv', index_col=0)

In [4]:
df

Unnamed: 0,fixed acidity,volatile acidity,citric acid,residual sugar,chlorides,free sulfur dioxide,total sulfur dioxide,density,pH,sulphates,alcohol,quality
0,7.0,0.27,0.36,20.7,0.045,45.0,170.0,1.00100,3.00,0.45,8.8,6
1,6.3,0.30,0.34,1.6,0.049,14.0,132.0,0.99400,3.30,0.49,9.5,6
2,8.1,0.28,0.40,6.9,0.050,30.0,97.0,0.99510,3.26,0.44,10.1,6
3,7.2,0.23,0.32,8.5,0.058,47.0,186.0,0.99560,3.19,0.40,9.9,6
4,7.2,0.23,0.32,8.5,0.058,47.0,186.0,0.99560,3.19,0.40,9.9,6
...,...,...,...,...,...,...,...,...,...,...,...,...
4893,6.2,0.21,0.29,1.6,0.039,24.0,92.0,0.99114,3.27,0.50,11.2,6
4894,6.6,0.32,0.36,8.0,0.047,57.0,168.0,0.99490,3.15,0.46,9.6,5
4895,6.5,0.24,0.19,1.2,0.041,30.0,111.0,0.99254,2.99,0.46,9.4,6
4896,5.5,0.29,0.30,1.1,0.022,20.0,110.0,0.98869,3.34,0.38,12.8,7


# Check correlation between features

In [5]:
cor_data = (df.corr().stack()
              .reset_index()
              .rename(columns={0: 'correlation', 'level_0': 'variable', 'level_1': 'variable2'}))
cor_data['correlation_label'] = cor_data['correlation'].map('{:.2f}'.format)  # Round to 2 decimal
cor_data

Unnamed: 0,variable,variable2,correlation,correlation_label
0,fixed acidity,fixed acidity,1.000000,1.00
1,fixed acidity,volatile acidity,-0.022697,-0.02
2,fixed acidity,citric acid,0.289181,0.29
3,fixed acidity,residual sugar,0.089021,0.09
4,fixed acidity,chlorides,0.023086,0.02
...,...,...,...,...
139,quality,density,-0.307123,-0.31
140,quality,pH,0.099427,0.10
141,quality,sulphates,0.053678,0.05
142,quality,alcohol,0.435575,0.44


In [6]:
base = alt.Chart(cor_data).encode(
    x='variable2:O',
    y='variable:O'    
)

# Text layer with correlation labels
# Colors are for easier readability
text = base.mark_text().encode(
    text='correlation_label',
    color=alt.condition(
        alt.datum.correlation > 0.5, 
        alt.value('white'),
        alt.value('black')
    )
)

# The correlation heatmap itself
cor_plot = base.mark_rect().encode(
    color='correlation:Q'
)

(cor_plot + text).properties(height=600, width = 600)

From the chart we can see that the feature which is correlated most strongly with wine quality is "alcohol". Other features including "density", "chlorides" and "volatile acidity" were also found to be weakly negatively correlated to quality.

# Look at distributions of features

In [7]:
numeric_cols = list(df.select_dtypes(include=['float64','int64']).drop(columns=['quality']).columns)

alt.Chart(df).mark_bar().encode(
    alt.X(alt.repeat("repeat"), type='quantitative', bin=alt.Bin(maxbins=100)),
    alt.Y('count()')
).properties(height=100
).repeat(repeat = numeric_cols, columns=2)