In [1]:
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split

In [2]:
def split_wine(df):
    '''
    This function performs split on WineQT data, stratify quality.
    Returns train, validate, and test dfs.
    '''
    train_validate, test = train_test_split(df, test_size=.2, 
                                        random_state=123, 
                                        stratify=df.quality)
    train, validate = train_test_split(train_validate, test_size=.3, 
                                   random_state=123, 
                                   stratify=train_validate.quality)
    return train, validate, test

In [3]:
df = pd.read_csv('WineQT.csv')

In [4]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1143 entries, 0 to 1142
Data columns (total 13 columns):
 #   Column                Non-Null Count  Dtype  
---  ------                --------------  -----  
 0   fixed acidity         1143 non-null   float64
 1   volatile acidity      1143 non-null   float64
 2   citric acid           1143 non-null   float64
 3   residual sugar        1143 non-null   float64
 4   chlorides             1143 non-null   float64
 5   free sulfur dioxide   1143 non-null   float64
 6   total sulfur dioxide  1143 non-null   float64
 7   density               1143 non-null   float64
 8   pH                    1143 non-null   float64
 9   sulphates             1143 non-null   float64
 10  alcohol               1143 non-null   float64
 11  quality               1143 non-null   int64  
 12  Id                    1143 non-null   int64  
dtypes: float64(11), int64(2)
memory usage: 116.2 KB


In [5]:
train, validate, test = split_wine(df)

In [6]:
train

Unnamed: 0,fixed acidity,volatile acidity,citric acid,residual sugar,chlorides,free sulfur dioxide,total sulfur dioxide,density,pH,sulphates,alcohol,quality,Id
240,12.5,0.280,0.54,2.3,0.082,12.0,29.0,0.99970,3.11,1.36,9.800000,7,339
818,9.6,0.420,0.35,2.1,0.083,17.0,38.0,0.99622,3.23,0.66,11.100000,6,1153
1026,7.3,0.670,0.02,2.2,0.072,31.0,92.0,0.99566,3.32,0.68,11.066667,6,1439
692,7.4,0.580,0.00,2.0,0.064,7.0,11.0,0.99562,3.45,0.58,11.300000,6,985
159,8.6,0.645,0.25,2.0,0.083,8.0,28.0,0.99815,3.28,0.60,10.000000,6,223
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1011,7.7,0.640,0.21,2.2,0.077,32.0,133.0,0.99560,3.27,0.45,9.900000,5,1419
783,10.8,0.470,0.43,2.1,0.171,27.0,66.0,0.99820,3.17,0.76,10.800000,6,1109
558,8.3,0.430,0.30,3.4,0.079,7.0,34.0,0.99788,3.36,0.61,10.500000,5,778
976,8.2,0.885,0.20,1.4,0.086,7.0,31.0,0.99460,3.11,0.46,10.000000,5,1376


In [7]:
validate

Unnamed: 0,fixed acidity,volatile acidity,citric acid,residual sugar,chlorides,free sulfur dioxide,total sulfur dioxide,density,pH,sulphates,alcohol,quality,Id
239,7.8,0.430,0.32,2.8,0.080,29.0,58.0,0.99740,3.31,0.64,10.3,5,337
800,7.2,0.480,0.07,5.5,0.089,10.0,18.0,0.99684,3.37,0.68,11.2,7,1133
184,7.9,0.330,0.23,1.7,0.077,18.0,45.0,0.99625,3.29,0.65,9.3,5,260
871,7.8,0.815,0.01,2.6,0.074,48.0,90.0,0.99621,3.38,0.62,10.8,5,1231
411,9.9,0.500,0.24,2.3,0.103,6.0,14.0,0.99780,3.34,0.52,10.0,4,576
...,...,...,...,...,...,...,...,...,...,...,...,...,...
78,8.4,0.620,0.09,2.2,0.084,11.0,108.0,0.99640,3.15,0.66,9.8,5,111
460,6.7,0.420,0.27,8.6,0.068,24.0,148.0,0.99480,3.16,0.57,11.3,6,649
1090,7.9,0.290,0.49,2.2,0.096,21.0,59.0,0.99714,3.31,0.67,10.1,6,1528
731,6.9,0.490,0.19,1.7,0.079,13.0,26.0,0.99547,3.38,0.64,9.8,6,1041


In [8]:
test

Unnamed: 0,fixed acidity,volatile acidity,citric acid,residual sugar,chlorides,free sulfur dioxide,total sulfur dioxide,density,pH,sulphates,alcohol,quality,Id
905,7.0,0.690,0.00,1.9,0.114,3.0,10.0,0.99636,3.35,0.60,9.7,6,1277
88,8.1,1.330,0.00,1.8,0.082,3.0,12.0,0.99640,3.54,0.48,10.9,5,127
541,8.1,0.870,0.00,2.2,0.084,10.0,31.0,0.99656,3.25,0.50,9.8,5,757
714,8.0,0.180,0.37,0.9,0.049,36.0,109.0,0.99007,2.89,0.44,12.7,6,1018
259,12.8,0.615,0.66,5.8,0.083,7.0,42.0,1.00220,3.07,0.73,10.0,7,364
...,...,...,...,...,...,...,...,...,...,...,...,...,...
300,9.1,0.520,0.33,1.3,0.070,9.0,30.0,0.99780,3.24,0.60,9.3,5,428
656,8.4,0.670,0.19,2.2,0.093,11.0,75.0,0.99736,3.20,0.59,9.2,4,927
723,6.4,0.795,0.00,2.2,0.065,28.0,52.0,0.99378,3.49,0.52,11.6,5,1027
1138,6.3,0.510,0.13,2.3,0.076,29.0,40.0,0.99574,3.42,0.75,11.0,6,1592


In [11]:
df['is_viable'] = df['quality']>5

In [12]:
df

Unnamed: 0,fixed acidity,volatile acidity,citric acid,residual sugar,chlorides,free sulfur dioxide,total sulfur dioxide,density,pH,sulphates,alcohol,quality,Id,is_viable
0,7.4,0.700,0.00,1.9,0.076,11.0,34.0,0.99780,3.51,0.56,9.4,5,0,False
1,7.8,0.880,0.00,2.6,0.098,25.0,67.0,0.99680,3.20,0.68,9.8,5,1,False
2,7.8,0.760,0.04,2.3,0.092,15.0,54.0,0.99700,3.26,0.65,9.8,5,2,False
3,11.2,0.280,0.56,1.9,0.075,17.0,60.0,0.99800,3.16,0.58,9.8,6,3,True
4,7.4,0.700,0.00,1.9,0.076,11.0,34.0,0.99780,3.51,0.56,9.4,5,4,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1138,6.3,0.510,0.13,2.3,0.076,29.0,40.0,0.99574,3.42,0.75,11.0,6,1592,True
1139,6.8,0.620,0.08,1.9,0.068,28.0,38.0,0.99651,3.42,0.82,9.5,6,1593,True
1140,6.2,0.600,0.08,2.0,0.090,32.0,44.0,0.99490,3.45,0.58,10.5,5,1594,False
1141,5.9,0.550,0.10,2.2,0.062,39.0,51.0,0.99512,3.52,0.76,11.2,6,1595,True
