<img align="left" src="imgs/fonduer-logo.png" width="100px" style="margin-right:20px">

# Tutorial: Providing Supervision using Labeling Functions

## Running locally?

If you're running this tutorial interactively on your own machine, you'll need to create a new PostgreSQL database named `intro_supervision`.

If you already have the database `intro_supervision` in your postgresql, please uncomment the first line to drop it. Otherwise, download our database snapshots by executing `./download_data.sh` in the intro tutorial directory.

In [None]:
#! dropdb --if-exists intro_supervision
! createdb intro_supervision
! psql intro_supervision < data/intro_supervision.sql > /dev/null

## Providing Supervision by Writing Labeling Functions

In this tutorial, you will learn what a labeling function (LF) is and how to write them by leverage Fonduer's [data model utilities](https://fonduer.readthedocs.io/en/stable/user/data_model_utils.html).

At a high level, a labeling function is a simple Python function that takes a candidate (a part and numerical value, in these intro tutorials) as input, and returns a label for the input candidate. Labels can be one of these values: {-1, 0, 1}. -1 is a way to abstain from voting, a label of 0 signifies that a candidate is False, and +1 labels the candidate as True.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import sys
import logging

ATTRIBUTE = "intro_supervision"
conn_string = f'postgresql://localhost:5432/{ATTRIBUTE}'

from fonduer import Meta, init_logging

# Configure logging for Fonduer
init_logging(log_dir="logs")

session = Meta.init(conn_string).Session()

from fonduer.candidates.models import candidate_subclass, mention_subclass

Part = mention_subclass("Part")
Attr = mention_subclass("Attr")
PartAttr = candidate_subclass("PartAttr", [Part, Attr])

## I. Background

### Using a Development Set to Evaluate our Supervision
For convenience in error analysis and evaluation, we have already annotated the dev and test set for this tutorial, and we'll now load it using an externally-defined helper function. If you're interested in the example implementation details, please see the script we now load:

In [None]:
from fonduer.parser.models import Document
from fonduer.supervision.models import GoldLabel
from hardware_utils import gold

from fonduer.supervision import Labeler

docs = session.query(Document).order_by(Document.name).all()
labeler = Labeler(session, [PartAttr])
%time labeler.apply(docs=docs, lfs=[[gold]], table=GoldLabel, train=True)

### Loading Candidates

Next, we can get our train and development set candidates by issuing SQLAlchemy queries for the `Part_Attr` candidate we defined during candidate generation.

In [None]:
train_cands = sorted(session.query(PartAttr).all())

print(f"Number of training candidates: {len(train_cands)}")

## Writing Labeling Functions 

Supervisions can be in different sources such as patterns or heuristics. Fonduer uses labeling functions to encode these supervisions that can be used to distinguish whether or not a candidate is true or false. In this notebook, we will describe how to use Fonduer API to express supervision via different modal signals.

The full list of functions that you can use are documented here:

https://fonduer.readthedocs.io/en/stable/user/data_model_utils.html

In [None]:
from fonduer.utils.data_model_utils import *

### Recall: what's in a candidate:

In [None]:
cand = train_cands[0]

Let's take a look at part number first:

In [None]:
print(f"part object:                      {cand.part}")
print(f"part text:                        {cand.part.context.get_span()}")
print(f"part sentence object:             {cand.part.context.sentence}")
print(f"part sentence text:               {cand.part.context.sentence.text}")
print(f"check if part is in a table:      {cand.part.context.sentence.is_tabular()}")
print(f"check if part has in visual info: {cand.part.context.sentence.is_visual()}")

Then, we can look at the `attr`, which is the number representing the maximum collector-emitter voltage:

In [None]:
print(f"attr object:                      {cand.attr}")
print(f"attr text:                        {cand.attr.context.get_span()}")
print(f"attr sentence object:             {cand.attr.context.sentence}")
print(f"attr sentence text:               {cand.attr.context.sentence.text}")
print(f"check if attr is in a table:      {cand.attr.context.sentence.is_tabular()}")
print(f"check if attr has in visual info: {cand.attr.context.sentence.is_visual()}")

### Example 1: Write a labeling function to check if two mentions in one candidate are in the same page. 
If they are, label the candidate True, otherwise, label it False.

In [None]:
ABSTAIN = -1
FALSE = 0
TRUE = 1

In [None]:
from snorkel.labeling import labeling_function

@labeling_function()
def LF_same_page(c):
    return TRUE if same_page(c) else FALSE

In [None]:
# Sanity check: the previous labeling function should pass the follwoing test.
true_candidate = train_cands[81]
false_candidate = train_cands[10]

if (LF_same_page(true_candidate) == TRUE and LF_same_page(false_candidate) == FALSE):
    print("You passed!")
else:
    print("Try again.")

### Example 2: Write a labeling function based on your insight of the data.

For example, inspecting several documents may reveal that storage temperatures are typically listed inside a table where the row header contains the word "storage". This intuitive pattern can be directly expressed as a labeling function. Similarly, the word "temperature" is an obvious positive signal.


In [None]:
@labeling_function()
def LF_storage_row(c):
    return TRUE if 'storage' in get_row_ngrams(c.attr) else ABSTAIN

@labeling_function()
def LF_temperature_row(c):
    return TRUE if 'temperature' in get_row_ngrams(c.attr) else ABSTAIN

### Example 3: Write a labeling function based on alignment information.

In [None]:
@labeling_function()
def LF_collector_aligned(c):
    return FALSE if overlap(
        ['collector', 'collector-current', 'collector-base', 'collector-emitter'],
        list(get_aligned_ngrams(c.attr))) else ABSTAIN

@labeling_function()
def LF_current_aligned(c):
    ngrams = get_aligned_ngrams(c.attr)
    return FALSE if overlap(
        ['current', 'dc', 'ic'],
        list(get_aligned_ngrams(c.attr))) else ABSTAIN

We can then collect all of these labeling functions in a list which we will provide to Fonduer as supervision signals.

In [None]:
LFs = [
    LF_same_page,
    LF_storage_row,
    LF_temperature_row,
    LF_collector_aligned,
    LF_current_aligned
]

### Applying the Labeling Functions

Next, we need to actually run the LFs over all of our training candidates, producing a set of `Labels` and `LabelKeys` (just the names of the LFs) in the database. We'll do this using the `Labeler`. Note that this will delete any existing `Labels` and `LabelKeys` for this candidate set.

View the API provided by the `Labeler` on [ReadTheDocs](https://fonduer.readthedocs.io/en/stable/user/supervision.html#fonduer.supervision.Labeler).

In [None]:
from fonduer.supervision import Labeler

labeler = Labeler(session, [PartAttr])

%time labeler.apply(split=0, lfs=[LFs], train=True)
%time L_train = labeler.get_label_matrices([train_cands])

### Labeling Function Metrics

Next, we can view insights provided by Fonduer to better understand the quality and coverage of our labeling functions.

In order to view statistics about the resulting label matrix, we provide several metrics to evaluate labelding functions:
* **Coverage** is the fraction of candidates that the labeling function emits a non-abstain label for.
* **Overlaps** is the fraction candidates that the labeling function emits a non-abstain label for and that another labeling function emits a non-abstain label for.
* **Conflicts** is the fraction candidates that the labeling function emits a non-abstain label for and that another labeling function emits a conflicting non-abstain label for.
* **Correct** is the number of candidates that the labeling function labels correctly.
* **Incorrect** is the number of candidates that the labeling function labels incorrectly.
* **Empirical Accuracy** is the fraction of correctly labeled candidates.

In addition, because we have already loaded the gold labels, we can view the emperical accuracy of these labeling functions when compared to our gold labels:

In [None]:
L_gold_dev = labeler.get_gold_labels([train_cands], annotator='gold')

In [None]:
from snorkel.labeling import LFAnalysis

LFAnalysis(L=L_train[0], lfs=sorted(LFs, key=lambda lf: lf.name)).lf_summary(Y=L_gold_dev[0].reshape(-1))