imports

In [1]:
from fonduer.supervision import Labeler
from fonduer.supervision.models import GoldLabel
from fonduer.features import Featurizer
from fonduer.candidates.models import Candidate

from snorkel.labeling import LFAnalysis
from snorkel.labeling.model import LabelModel
from fonduer.supervision.models import LabelKey

from MeMoKBC.pipeline.utils import get_session, load_candidates, match_label_matrix
from MeMoKBC.definitions.candidates import NameFullAbbr, NameAbbrTask
from MeMoKBC.pipeline.lfs.name_short_long_lfs import short_long_lfs
from MeMoKBC.pipeline.lfs.name_short_task_lfs import name_abbr_task_lfs
from MeMoKBC.gold_label_matcher import match_gold_label
from importlib import reload
import csv
import re

Get session object

In [2]:
session = get_session(db_name="pipeline6")

[2023-07-06 20:22:15,436][INFO] fonduer.meta:49 - Setting logging directory to: /tmp/2023-07-06_20-22-15


Define candidates and Labeler object

In [3]:
candidates = [NameFullAbbr, NameAbbrTask]

Load Goldlabels

In [4]:
# Load goldlabels from json file and compare to candidates in database
gold_labels = match_gold_label(
    "pipeline6",
    "/data/Goldlabel_biomedRxiv/goldlabel1_docs801-840_laura/goldlabel_authorlong_short_task_medRxiv.json",
    [NameAbbrTask, NameFullAbbr]
)

# filter potential goldlabels after candidate class
nat_cands = []
nfa_cands = []
for cand in gold_labels:
    if type(cand) == NameAbbrTask:
        # remove candidates where short and long name are not in the same sentence
        if cand[0].context.sentence.id == cand[1].context.sentence.id:
            # append the id of the candidate to the list
            nat_cands.append(cand.id)
    elif type(cand) == NameFullAbbr:
        # append the id of the candidate to the list
        nfa_cands.append(cand.id)

[2023-07-06 20:22:15,662][INFO] fonduer.meta:134 - Connecting user:postgres to fonduer-postgres-dev:5432/pipeline6
[2023-07-06 20:22:15,664][INFO] fonduer.meta:162 - Initializing the storage schema
[2023-07-06 20:22:15,965][INFO] root:88 - Found relations for 22 documents
[2023-07-06 20:22:16,323][INFO] root:93 - Found 6820 candidates for <class 'fonduer.candidates.models.candidate.NameAbbrTask'>
[2023-07-06 20:22:16,886][INFO] root:93 - Found 28105 candidates for <class 'fonduer.candidates.models.candidate.NameFullAbbr'>
[2023-07-06 20:22:17,300][INFO] root:102 - Found candidates for 40 documents


In [5]:
# create labeler object
labeler = Labeler(session, candidates)

# write function that returns gold label for a candidate
def gold(c: Candidate) -> int:
    if type(c) == NameAbbrTask:

        # check if the candidate id is inside the list of goldlabel candidate id's
        if c.id in nat_cands:
            return 1
        else:
            return 0

    elif type(c) == NameFullAbbr:
        
        # check if the candidate id is inside the list of goldlabel candidate id's
        if c.id in nfa_cands:
            return 1
        else:
            return 0
    # if the candidate id is not inside the list of goldlabel candidate id's return FALSE
    return 0

# Apply the gold label function for each candidate class
labeler.apply(lfs=[[gold], [gold]], table=GoldLabel, split=0, clear=True)
labeler.apply(lfs=[[gold], [gold]], table=GoldLabel, split=1)
labeler.apply(lfs=[[gold], [gold]], table=GoldLabel, split=2)

  .filter(candidate_class.id.in_(sub_query))
[2023-07-06 20:22:31,295][INFO] fonduer.supervision.labeler:330 - Clearing Labels (split 0)
  query = self.session.query(table).filter(table.candidate_id.in_(sub_query))
[2023-07-06 20:22:31,335][INFO] fonduer.utils.udf:67 - Running UDF...


  0%|          | 0/14 [00:00<?, ?it/s]

[2023-07-06 20:22:36,396][INFO] fonduer.supervision.labeler:330 - Clearing Labels (split 1)
[2023-07-06 20:22:36,444][INFO] fonduer.utils.udf:67 - Running UDF...


  0%|          | 0/13 [00:00<?, ?it/s]

[2023-07-06 20:22:43,418][INFO] fonduer.supervision.labeler:330 - Clearing Labels (split 2)
[2023-07-06 20:22:43,482][INFO] fonduer.utils.udf:67 - Running UDF...


  0%|          | 0/13 [00:00<?, ?it/s]

In [6]:
# load candidates
train_cands = load_candidates(session, 0, candidates)

# match the candidates with the outcome of the labeling functions to generate input for the label model
L_train_NFA, L_train_NAT = match_label_matrix(session, candidates, 0) 

 # load gold labels list
L_gold_train_NFA, L_gold_train_NAT = labeler.get_gold_labels(train_cands)

## LF analysis

NameFull + Abrreviation

In [7]:
LFAnalysis(
    L_train_NFA,
    lfs=sorted(short_long_lfs, key=lambda lf: lf.name)
).lf_summary(Y=L_gold_train_NFA.reshape(-1))

Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts,Correct,Incorrect,Emp. Acc.
check_all_abbr_letters_in_long,0,"[0, 1]",1.0,1.0,0.140079,4511,579,0.886248
check_all_uppercase_letters,1,[0],0.145187,0.145187,0.140079,605,134,0.818674
check_horizont_abr_short,2,[0],0.001768,0.001768,0.0,9,0,1.0
check_long_name_not_upper,3,[0],0.092731,0.092731,0.021611,472,0,1.0
check_uppercase_letters,4,[0],0.132417,0.132417,0.131041,540,134,0.801187
check_uppercase_letters_short_in_long,5,[0],0.132417,0.132417,0.131041,540,134,0.801187
name_full_in_top_percentile_sentence_wise,6,[0],1.0,1.0,0.140079,4956,134,0.973674
name_short_outside_half_percentile_sentence_wise,7,[0],0.456974,0.456974,0.077996,2251,75,0.967756
small_letter_count,8,[0],0.893713,0.893713,0.140079,4415,134,0.970543
word_count,9,[0],0.97721,0.97721,0.140079,4840,134,0.97306


NameAbbr + Task

In [8]:
LFAnalysis(
    L_train_NAT,
    lfs=sorted(name_abbr_task_lfs, key=lambda lf: lf.name)
).lf_summary(Y=L_gold_train_NAT.reshape(-1))

Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts,Correct,Incorrect,Emp. Acc.
abbr_is_complete,0,[0],0.190668,0.190668,0.006336,331,0,1.0
common_verbs_following_abbr,1,[1],0.025922,0.025922,0.025922,13,32,0.288889
is_medical_abbreviation,2,[0],1.0,0.459101,0.253456,1627,109,0.937212
lf_length_more_than_three_words,3,[0],0.020737,0.020737,0.010369,36,0,1.0
lf_name_short_in_first_words,4,[1],0.152074,0.152074,0.152074,35,229,0.132576
sentence_beginning,5,[1],0.120392,0.120392,0.120392,44,165,0.210526
verbs_ending_with_past,6,[1],0.047811,0.047811,0.047811,25,58,0.301205
word_before_abbr,7,[0],0.013249,0.013249,0.000576,23,0,1.0
