In [1]:
# Libraries
%matplotlib inline

from sklearn import datasets
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn import metrics
from snorkel.labeling import labeling_function
from snorkel.labeling import PandasLFApplier
from snorkel.labeling import LFAnalysis

In [2]:
wine = datasets.load_wine()
x = wine.data
y = wine.target
col_names = wine.feature_names
class_names = wine.target_names
print(class_names)
print(col_names)

['class_0' 'class_1' 'class_2']
['alcohol', 'malic_acid', 'ash', 'alcalinity_of_ash', 'magnesium', 'total_phenols', 'flavanoids', 'nonflavanoid_phenols', 'proanthocyanins', 'color_intensity', 'hue', 'od280/od315_of_diluted_wines', 'proline']


In [3]:
print(wine.DESCR)

.. _wine_dataset:

Wine recognition dataset
------------------------

**Data Set Characteristics:**

    :Number of Instances: 178 (50 in each of three classes)
    :Number of Attributes: 13 numeric, predictive attributes and the class
    :Attribute Information:
 		- Alcohol
 		- Malic acid
 		- Ash
		- Alcalinity of ash  
 		- Magnesium
		- Total phenols
 		- Flavanoids
 		- Nonflavanoid phenols
 		- Proanthocyanins
		- Color intensity
 		- Hue
 		- OD280/OD315 of diluted wines
 		- Proline

    - class:
            - class_0
            - class_1
            - class_2
		
    :Summary Statistics:
    
                                   Min   Max   Mean     SD
    Alcohol:                      11.0  14.8    13.0   0.8
    Malic Acid:                   0.74  5.80    2.34  1.12
    Ash:                          1.36  3.23    2.36  0.27
    Alcalinity of Ash:            10.6  30.0    19.5   3.3
    Magnesium:                    70.0 162.0    99.7  14.3
    Total Phenols:                0

In [4]:
wine_df_nolbl = pd.DataFrame(data=x,columns=col_names)
wine_df_lbl = wine_df_nolbl.copy()
wine_df_lbl['label'] = y
print(wine_df_lbl.sample(7))

     alcohol  malic_acid   ash  alcalinity_of_ash  magnesium  total_phenols  \
28     13.87        1.90  2.80               19.4      107.0           2.95   
145    13.16        3.57  2.15               21.0      102.0           1.50   
37     13.05        1.65  2.55               18.0       98.0           2.45   
80     12.00        0.92  2.00               19.0       86.0           2.42   
141    13.36        2.56  2.35               20.0       89.0           1.40   
10     14.10        2.16  2.30               18.0      105.0           2.95   
142    13.52        3.17  2.72               23.5       97.0           1.55   

     flavanoids  nonflavanoid_phenols  proanthocyanins  color_intensity   hue  \
28         2.97                  0.37             1.76             4.50  1.25   
145        0.55                  0.43             1.30             4.00  0.60   
37         2.43                  0.29             1.44             4.25  1.12   
80         2.26                  0.30      

In [23]:
train_df, _ = train_test_split(wine_df_nolbl, test_size = 0.3, random_state = 42)
_, test_df = train_test_split(wine_df_lbl, test_size = 0.3, random_state = 42)
test_df, valid_df = train_test_split(test_df, test_size = 0.3, random_state = 42)
train_df.sample(10)

Unnamed: 0,alcohol,malic_acid,ash,alcalinity_of_ash,magnesium,total_phenols,flavanoids,nonflavanoid_phenols,proanthocyanins,color_intensity,hue,od280/od315_of_diluted_wines,proline
6,14.39,1.87,2.45,14.6,96.0,2.5,2.52,0.3,1.98,5.25,1.02,3.58,1290.0
161,13.69,3.26,2.54,20.0,107.0,1.83,0.56,0.5,0.8,5.88,0.96,1.82,680.0
92,12.69,1.53,2.26,20.7,80.0,1.38,1.46,0.58,1.62,3.05,0.96,2.06,495.0
144,12.25,3.88,2.2,18.5,112.0,1.38,0.78,0.29,1.14,8.21,0.65,2.0,855.0
170,12.2,3.03,2.32,19.0,96.0,1.25,0.49,0.4,0.73,5.5,0.66,1.83,510.0
14,14.38,1.87,2.38,12.0,102.0,3.3,3.64,0.29,2.96,7.5,1.2,3.0,1547.0
95,12.47,1.52,2.2,19.0,162.0,2.5,2.27,0.32,3.28,2.6,1.16,2.63,937.0
93,12.29,2.83,2.22,18.0,88.0,2.45,2.25,0.25,1.99,2.15,1.15,3.3,290.0
77,11.84,2.89,2.23,18.0,112.0,1.72,1.32,0.43,0.95,2.65,0.96,2.52,500.0
4,13.24,2.59,2.87,21.0,118.0,2.8,2.69,0.39,1.82,4.32,1.04,2.93,735.0


In [6]:
test_df.sample(5)

Unnamed: 0,alcohol,malic_acid,ash,alcalinity_of_ash,magnesium,total_phenols,flavanoids,nonflavanoid_phenols,proanthocyanins,color_intensity,hue,od280/od315_of_diluted_wines,proline,label
24,13.5,1.81,2.61,20.0,96.0,2.53,2.61,0.28,1.66,3.52,1.12,3.82,845.0,0
60,12.33,1.1,2.28,16.0,101.0,2.05,1.09,0.63,0.41,3.27,1.25,1.67,680.0,1
128,12.37,1.63,2.3,24.5,88.0,2.22,2.45,0.4,1.9,2.12,0.89,2.78,342.0,1
56,14.22,1.7,2.3,16.3,118.0,3.2,3.0,0.26,2.03,6.38,0.94,3.31,970.0,0
2,13.16,2.36,2.67,18.6,101.0,2.8,3.24,0.3,2.81,5.68,1.03,3.17,1185.0,0


In [7]:
valid_df.sample(5)

Unnamed: 0,alcohol,malic_acid,ash,alcalinity_of_ash,magnesium,total_phenols,flavanoids,nonflavanoid_phenols,proanthocyanins,color_intensity,hue,od280/od315_of_diluted_wines,proline,label
12,13.75,1.73,2.41,16.0,89.0,2.6,2.76,0.29,1.81,5.6,1.15,2.9,1320.0,0
26,13.39,1.77,2.62,16.1,93.0,2.85,2.94,0.34,1.45,4.8,0.92,3.22,1195.0,0
109,11.61,1.35,2.7,20.0,94.0,2.74,2.92,0.29,2.49,2.65,0.96,3.26,680.0,1
143,13.62,4.95,2.35,20.0,92.0,2.0,0.8,0.47,1.02,4.4,0.91,2.05,550.0,2
154,12.58,1.29,2.1,20.0,103.0,1.48,0.58,0.53,1.4,7.6,0.58,1.55,640.0,2


In [8]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 124 entries, 138 to 102
Data columns (total 13 columns):
alcohol                         124 non-null float64
malic_acid                      124 non-null float64
ash                             124 non-null float64
alcalinity_of_ash               124 non-null float64
magnesium                       124 non-null float64
total_phenols                   124 non-null float64
flavanoids                      124 non-null float64
nonflavanoid_phenols            124 non-null float64
proanthocyanins                 124 non-null float64
color_intensity                 124 non-null float64
hue                             124 non-null float64
od280/od315_of_diluted_wines    124 non-null float64
proline                         124 non-null float64
dtypes: float64(13)
memory usage: 13.6 KB


In [9]:
test_df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 37 entries, 171 to 98
Data columns (total 14 columns):
alcohol                         37 non-null float64
malic_acid                      37 non-null float64
ash                             37 non-null float64
alcalinity_of_ash               37 non-null float64
magnesium                       37 non-null float64
total_phenols                   37 non-null float64
flavanoids                      37 non-null float64
nonflavanoid_phenols            37 non-null float64
proanthocyanins                 37 non-null float64
color_intensity                 37 non-null float64
hue                             37 non-null float64
od280/od315_of_diluted_wines    37 non-null float64
proline                         37 non-null float64
label                           37 non-null int32
dtypes: float64(13), int32(1)
memory usage: 4.2 KB


In [10]:
valid_df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 17 entries, 114 to 137
Data columns (total 14 columns):
alcohol                         17 non-null float64
malic_acid                      17 non-null float64
ash                             17 non-null float64
alcalinity_of_ash               17 non-null float64
magnesium                       17 non-null float64
total_phenols                   17 non-null float64
flavanoids                      17 non-null float64
nonflavanoid_phenols            17 non-null float64
proanthocyanins                 17 non-null float64
color_intensity                 17 non-null float64
hue                             17 non-null float64
od280/od315_of_diluted_wines    17 non-null float64
proline                         17 non-null float64
label                           17 non-null int32
dtypes: float64(13), int32(1)
memory usage: 1.9 KB


In [11]:
# For clarity, we define constants to represent the class labels for spam, ham, and abstaining.
ABSTAIN = -1
WINE_0 = 0
WINE_1 = 1
WINE_2 = 2 

In [12]:
@labeling_function()
def dt_rules(x):
    if x["color_intensity"] <= 3.63:
        return WINE_1
    else:
        if x["flavanoids"] < 1.58:
            if x["ash"] <= 2.06:
                return WINE_1
            else:
                return WINE_2
        else:
            if x["proline"] <= 697.5:
                return WINE_1
            else:
                return WINE_0

In [13]:
lfs = [dt_rules]

applier = PandasLFApplier(lfs=lfs)
L_train = applier.apply(df=train_df)
L_test = applier.apply(df=test_df)

  from pandas import Panel
100%|█████████████████████████████████████████████████████████████████████████████| 124/124 [00:00<00:00, 15541.42it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 37/37 [00:00<00:00, 12366.66it/s]


In [14]:
L_train

array([[2],
       [1],
       [1],
       [0],
       [1],
       [0],
       [2],
       [1],
       [1],
       [2],
       [0],
       [0],
       [0],
       [2],
       [0],
       [0],
       [1],
       [2],
       [1],
       [0],
       [2],
       [1],
       [0],
       [2],
       [1],
       [1],
       [0],
       [1],
       [0],
       [0],
       [1],
       [1],
       [0],
       [2],
       [0],
       [1],
       [1],
       [0],
       [1],
       [1],
       [1],
       [2],
       [2],
       [0],
       [1],
       [2],
       [2],
       [1],
       [1],
       [0],
       [1],
       [2],
       [2],
       [1],
       [2],
       [1],
       [1],
       [1],
       [0],
       [0],
       [2],
       [0],
       [2],
       [0],
       [0],
       [1],
       [1],
       [0],
       [0],
       [0],
       [1],
       [1],
       [1],
       [2],
       [1],
       [1],
       [1],
       [2],
       [2],
       [1],
       [0],
       [0],
       [1],
    

In [15]:
L_test

array([[2],
       [0],
       [0],
       [1],
       [1],
       [2],
       [1],
       [0],
       [2],
       [2],
       [2],
       [0],
       [1],
       [1],
       [1],
       [0],
       [0],
       [1],
       [0],
       [1],
       [2],
       [1],
       [2],
       [0],
       [2],
       [1],
       [1],
       [1],
       [0],
       [2],
       [0],
       [2],
       [1],
       [0],
       [1],
       [0],
       [1]])

In [16]:
LFAnalysis(L=L_train, lfs=lfs).lf_summary()

Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts
dt_rules,0,"[0, 1, 2]",1.0,0.0,0.0


In [17]:
test_df["label"]

171    2
2      0
31     0
76     1
111    1
141    2
113    1
29     0
158    2
164    2
150    2
19     0
122    1
65     1
128    1
38     0
55     0
100    1
45     0
66     1
140    2
108    1
159    2
42     0
169    2
68     1
24     0
60     1
9      0
153    2
18     0
174    2
85     1
15     0
90     1
56     0
98     1
Name: label, dtype: int32

In [18]:
print(L_test)

[[2]
 [0]
 [0]
 [1]
 [1]
 [2]
 [1]
 [0]
 [2]
 [2]
 [2]
 [0]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [0]
 [1]
 [2]
 [1]
 [2]
 [0]
 [2]
 [1]
 [1]
 [1]
 [0]
 [2]
 [0]
 [2]
 [1]
 [0]
 [1]
 [0]
 [1]]


In [19]:
metrics.accuracy_score(L_test,test_df["label"])

0.972972972972973

In [20]:
@labeling_function()
def dt_rules2(x):
    if x["color_intensity"] <= 3.63:
        return WINE_1
    else:
        if x["flavanoids"] < 1.58:
            if x["ash"] <= 2.06:
                return WINE_1
            else:
                return WINE_2
        else:
            if x["proline"] <= 697.5:
                return WINE_1
            else:
                return WINE_0

In [21]:
lfs = [dt_rules, dt_rules2]

applier = PandasLFApplier(lfs=lfs)
L_train = applier.apply(df=train_df)
L_test = applier.apply(df=test_df)

100%|█████████████████████████████████████████████████████████████████████████████| 124/124 [00:00<00:00, 17770.05it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 37/37 [00:00<00:00, 9271.12it/s]


In [22]:
LFAnalysis(L=L_train, lfs=lfs).lf_summary()

Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts
dt_rules,0,"[0, 1, 2]",1.0,1.0,0.0
dt_rules2,1,"[0, 1, 2]",1.0,1.0,0.0
