# 🐭 Weakly supervised NER with `skweak`

### Important Note: This is a draft. The description below is outdated and will be simplified

This tutorial will walk you through the process of using Rubrix to improve weak supervision and data programming workflows with the [skweak library](https://github.com/NorskRegnesentral/skweak).

- Using _skweak_ and _spaCy_, we define heuristic labeling functions for the [CoNLL 2003](https://arxiv.org/abs/cs/0306050v1) dataset.
- We combine those with a pretrained NER model and use an aggregation model from skweak to obtain noisy NER annotations.
- We then log the documents to Rubrix and visualize the results via its web app.
- With the noisy labels, we fine-tune a spaCy NER model.
- Adding labeling functions from gazetteers to our aggregation model, we revise the updated noisy annotation with Rubrix, and retrain the spaCy model.
- Instead of a spaCy model, we fine-tune a transformers model with the help of the _simpletransformers_ library.


![Visualization of the skweak aggregation model in Rubrix](../_static/tutorials/skweak/skweak_1.png)

## Introduction

Our goal is to show you how you can incorporate Rubrix into data programming workflows to programatically build training data with a human-in-the-loop approach. We will use the [skweak](https://github.com/NorskRegnesentral/skweak) library.

### What is weak supervision? and skweak?
Weak supervision is a branch of machine learning based on getting lower quality labels more efficiently. We can achieve this by using [skweak](https://github.com/NorskRegnesentral/skweak), a library for programmatically building and managing training datasets without manual labeling.

### This tutorial

In this tutorial, we bring content from the [Quick Start Named-Entity Recognition](https://github.com/NorskRegnesentral/skweak/blob/main/examples/quick_start.ipynb) and the [Step-by-step NER](https://github.com/NorskRegnesentral/skweak/tree/main/examples/ner) tutorials from skweak’s documentation and show you how to extend weak supervision workflows with Rubrix.

We will take records from the CoNLL 2003 dataset and build our own annotations with `skweak`. Then we are going to evaluate NER models trained on our annotations on the standard development set of CoNLL 2003.

## Setup
Rubrix, is a free and open-source tool to explore, annotate, and monitor data for NLP projects.

If you are new to Rubrix, check out the ⭐ [Github repository](https://github.com/recognai/rubrix).

If you have not installed and launched Rubrix yet, check the [Setup and Installation guide](https://docs.rubrix.ml/en/latest/getting_started/setup%26installation.html).

For this tutorial we also need some third party libraries that can be installed via pip:

In [1]:
%pip install skweak spacy simpletransformers -qqq
!python -m spacy download en_core_web_sm -qqq
!python -m spacy download en_core_web_md -qqq

Note: you may need to restart the kernel to use updated packages.
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_md')


## Named Entity Recognition with skweak and Rubrix

Rubrix allows you to log and track data for different NLP tasks (such as `Token Classification` or `Text Classification`). 

In this tutorial, we will use the English portion of the [CoNLL 2003](https://arxiv.org/abs/cs/0306050v1) dataset, a standard Named Entity Recognition benchmark.

### The dataset

In this tutorial we'll be using skweak's data programming methods for programatically building a training set with the help of Rubrix for analizing and reviewing data. We'll then train a model with this training set.

Although the gold labels for the training set of CoNLL 2003 are already known, we will purposefully ignore them, as our goal in this tutorial is to build our own annotations and see how well they perform on the development set.

We will load the CoNLL 2003 dataset with the help of the `datasets` library.  

### Preprocessing

In [2]:
from datasets import load_dataset
import numpy as np

def convert_ner_tags(record, tag_set=None):
    record['ner_tags'] = [ tag_set[x] for x in record['ner_tags'] ]
    return record

In [3]:
# Remove all annotations other than ORG from the dataset.
# In the future, users won't have to run this code as the dataset will be already saved to the Rubrix repository.
# This section will become an appendix. 

from collections import defaultdict

dataset = load_dataset("conll2003")

tag_set = defaultdict(lambda: 'O')
tag_set[3] = 'B-ORG'
tag_set[4] = 'I-ORG'

entity_set = {}

dataset = dataset\
    .map(convert_ner_tags, fn_kwargs={"tag_set": tag_set})

Reusing dataset conll2003 (/home/user/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/user/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6/cache-5e92d269428fa335.arrow


  0%|          | 0/3250 [00:00<?, ?ex/s]

  0%|          | 0/3453 [00:00<?, ?ex/s]

In [4]:
# List of regex patterns to match entities on the dataset.

import pandas as pd
import re

def compile_expressions(word_list, patterns, case_sensitive=True):
    if word_list:
        if case_sensitive:
            word_str = "|".join(word_list)
        else:
            word_str = "|".join([ "(?i:{0})".format(x) for x in word_list ])
    
        formatted_patterns = [ x.format(word_str) for x in patterns ]
    else:
        formatted_patterns = patterns
        
    expressions = [
        re.compile(x) for x in formatted_patterns
    ]
    
    return expressions

expression_dict = {}

location_list = [
    "Finland", "France", "Israel", 
    "Japan", "Chicago", "China", 
    "Cleveland", "Colorado", "Hong Kong", 
    "Istanbul", "Japan", "Kansas",
    "Kurdistan", "Palestine", "Pakistan",
    "London", "New York", "New Zealand",
    "Washington", "India", "Colombia",
    "San Francisco"
]
location_patterns = [ 
    "({0})(?:\s+([A-Z][a-z]*|and|of|the|from))*", # "New York Stock Exchange"
]

expression_dict["org_location"] = compile_expressions(location_list, location_patterns)

demonym_list = [
        "International", "Basque", "Iraqi",
        "Islamic", "Japanese", "Jordanian",
        "Lebanese", "Palestinian", "Welsh",
        "French", "British"
]
demonym_patterns = [
    "({0})(?:\s+([A-Z][a-z]*|and|of|the|from|on))*" # "Islamic Revolutionary Court"
]
expression_dict["org_demonym"] = compile_expressions(demonym_list, demonym_patterns)

numbers_list = [
    "First", "Second", "Third", "Fourth", "Fifth",
    "One", "Two", "Three", "Four", "Five"
]
number_patterns = [
    "({0})(?:\s+([A-Z][a-z]*|and|of|the|from|on))*" # "Fourth World Conference on Women"
]
expression_dict["org_number"] = compile_expressions(numbers_list, number_patterns)

general_orgs = [
    "Office", "Department", "Comission",
    "Association", "Corporation", "Army",
    "Party", "Exchange", "Council", 
    "University", "Federation", "Bank",
    "Government", "Journal", "Newsroom",
    "Newsdesk", "Bureau", "Organisation",
    "Organization", "Comission", "Council",
    "Group", "House", "Reuters"
]
general_orgs_patterns = [
    "({0})(?:\s+([A-Z][a-z]*|and|of|the|from|on))*" # U.S. Department of Health and Human Services
]
expression_dict["org_general"] = compile_expressions(general_orgs, general_orgs_patterns)

org_adjectives = [
    "Civil", "Democractic", "Federal",
    "Republican", "National", "Revolutionary",
    "New", "State", "United"
]
org_adjectives_patterns = [
    "({0})(?:\s+([A-Z][a-z]*|and|of|the|from|on))*"
]
expression_dict["org_adjectives"] = compile_expressions(org_adjectives, org_adjectives_patterns)


still working on this one 
org_abbreviation = [
    "U\.N\.", "U\.S\.", "St", 
    "St\."
]
org_abbreviation_patterns = [
    "({0})(?:\s+([A-Z][a-z]*|and|of|the|from|on))*"
]
expression_dict['org_abbreviation'] = compile_expressions(org_abbreviation, org_abbreviation_patterns, case_sensitive=False)

generic_patterns = [
    "[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\s+\(\s*[A-Z]*\s*\)",
    "[A-Z]+(?:\s+[A-Z][a-z]+)*",
    "[A-Z][a-z]*shire"
]
expression_dict['org_generic'] = compile_expressions([], generic_patterns)

In [5]:
import skweak
from functools import partial
from skweak.base import CombinedAnnotator

## Here we transform the regexes into annotators encapsulated by a skweak CombinedAnnotator object. 

def regex_detector_fun(doc, expression=None, label=None):
    for match in re.finditer(expression, doc.text):
        char_start, char_end = match.span()
        span = doc.char_span(char_start, char_end)
        if span:
            yield span.start, span.end, label

functions_dict = defaultdict(list)

for key, iterable in expression_dict.items():
    for item in iterable:
        functions_dict[key].append(
            partial(regex_detector_fun, expression=item, label=key)
        )

regex_annotators = CombinedAnnotator()

for key, iterable in functions_dict.items():
    for idx, func in enumerate(iterable):
        annotator = skweak.heuristics.FunctionAnnotator(
            "{0}_{1}".format(key, idx), 
            func
        )
        regex_annotators.add_annotator(annotator)

We will now convert the training and validation splits of our dataset into spaCy Doc objects. 

spaCy demands strings to be given as inputs to a tokenizer. However, as our dataset is already tokenized, we bypass this restriction by using our own tokenizer and encapsulating our tokens in a class that inherits from `str`.

In [6]:
import spacy
from spacy.tokens import Doc
from dataclasses import dataclass

@dataclass
class Record(str):
    tokens: list

def custom_tokenizer(text):
    return Doc(nlp.vocab, text.tokens)

nlp = spacy.load("en_core_web_sm", disable=["ner", "lemmatizer"])
nlp.tokenizer = custom_tokenizer

training_set = [ Record(x) for x in dataset["train"]["tokens"] ]
dev_set = [ Record(x) for x in dataset["validation"]["tokens"] ]

train_docs = list(nlp.pipe(training_set))
dev_docs = list(nlp.pipe(dev_set))

In [7]:
from spacy.tokens import Span
from itertools import groupby
from dataclasses import dataclass
import copy

@dataclass
class IndexedLabel:
    index: int
    label: str

def annotate_labels_to_doc(doc, labels, null_label="O"):
    labels = [ IndexedLabel(idx, item) for idx, item in enumerate(labels) ]
    grouped_labels = [ list(group[1]) for group in groupby(labels) ]
    span_objects = [ Span(doc, item[0].index, item[-1].index + 1, item[0].label) for item in grouped_labels ]
    span_objects = [ span for span in span_objects if span.label_ != null_label ]
    doc.set_ents(span_objects)
    return doc

# I'm stripping the BI tags here as they are not needed for creating spacy Doc objects. 
dev_labels = [ [y.lstrip('B-').lstrip('I-') for y in x] for x in dataset["validation"]["ner_tags"] ]

for idx, label_sequence in enumerate(dev_labels):
    dev_docs[idx] = annotate_labels_to_doc(dev_docs[idx], label_sequence)

In [8]:
from skweak.spacy import ModelAnnotator
from typing import Iterable, Tuple
from spacy.tokens import Doc

# This is the NER model trained on OntoNotes that annotates ORGs as COMPANY, ORG or OTHER_ORG.
# Notice that its output tags will be org_ontonotes_COMPANY, org_ontonotes_ORG, and org_ontonotes_OTHER_ORG. 

class FilteredModelAnnotator(ModelAnnotator):
    def __init__(self, allowed_labels, prefix, *args, **kwargs):
        self.allowed_labels = allowed_labels
        self.prefix = prefix
        super().__init__(*args, **kwargs)
        
    def find_spans(self, doc: Doc) -> Iterable[Tuple[int, int, str]]:
        """Annotates one single document using the Spacy NER model"""
        for span_start, span_end, span_label in super().find_spans(doc):
            if span_label in self.allowed_labels:
                yield span_start, span_end, span_label

    def pipe(self, docs: Iterable[Doc]) -> Iterable[Doc]:
        """Annotates the stream of documents based on the Spacy model"""
        for doc in super().pipe(docs):
            filtered_spans = []
            for span in doc.spans[self.name]:
                if span.label_ in self.allowed_labels:
                    if self.prefix:
                        output_label = self.prefix + span.label_
                    else:
                        output_label = span.label_
                    filtered_spans.append(Span(doc, span.start, span.end, output_label))
            doc.spans[self.name] = filtered_spans
            
            yield doc

allowed_labels = ["COMPANY", "ORG", "OTHER_ORG"]
ner_annotator = FilteredModelAnnotator(allowed_labels, "org_ontonotes_", "spacy", "en_core_web_sm")

In [9]:
# Run the annotators. 
train_docs = list(regex_annotators.pipe(train_docs))
train_docs = list(ner_annotator.pipe(train_docs))

In [10]:
# Let's make a list of all soft labels that were produced by our annotators. 

def get_soft_labels(docs):
    entity_set = set()
    for doc in docs:
        for span in doc.spans.values():
            for ent in span:
                entity_set.update([ent.label_])
    return list(entity_set)

soft_labels = get_soft_labels(train_docs)
soft_labels

['org_demonym',
 'org_generic',
 'org_ontonotes_ORG',
 'org_general',
 'org_location',
 'org_number',
 'org_adjectives']

In [11]:
# Here we aggregate the soft labels into the final ORG labels with a majority voter. 

from skweak.aggregation import MajorityVoter

voter = MajorityVoter("maj_voter", labels=["ORG"] + soft_labels, sequence_labelling=True)


voter.add_underspecified_label("ORG", soft_labels)

voter.add_underspecified_label("ENT", ["ORG"])

train_docs = list(voter.pipe(train_docs))

In [12]:
# Although we already have only one final annotation per token, we still need to replace the soft label names by ORG,
# and assign these labels to the "ents" attribute of the spacy Doc object.
replacement_dict = {label:'ORG' for label in soft_labels}

In [13]:
replacement_dict

{'org_demonym': 'ORG',
 'org_generic': 'ORG',
 'org_ontonotes_ORG': 'ORG',
 'org_general': 'ORG',
 'org_location': 'ORG',
 'org_number': 'ORG',
 'org_adjectives': 'ORG'}

In [14]:
import copy

for doc in train_docs:
    voted_ents = []
    for span in doc.spans.get('maj_voter', []):
        tmp_span = Span(doc, span.start, span.end, label=replacement_dict[span.label_])
        voted_ents.append(tmp_span)
    doc.set_ents(voted_ents)

In [15]:
# Now we just train the model. 

from skweak.utils import docbin_writer

docbin_writer(train_docs, "/tmp/train.spacy")
docbin_writer(dev_docs, "/tmp/dev.spacy")

Write to /tmp/train.spacy...done
Write to /tmp/dev.spacy...done


In [16]:
!spacy init config - --lang en --pipeline ner --optimize accuracy | \
spacy train - \
--training.max_steps 200 \
--paths.train /tmp/train.spacy \
--paths.dev /tmp/dev.spacy \
--initialize.vectors en_core_web_sm \
--output /tmp/model

[38;5;4mℹ Saving to output directory: /tmp/model[0m
[38;5;4mℹ Using CPU[0m
[1m
[2022-01-23 10:08:04,423] [INFO] Set up nlp object from config
[2022-01-23 10:08:04,432] [INFO] Pipeline: ['tok2vec', 'ner']
[2022-01-23 10:08:04,435] [INFO] Created vocabulary
[2022-01-23 10:08:04,779] [INFO] Added vectors: en_core_web_sm
[2022-01-23 10:08:04,796] [INFO] Finished initializing nlp object
[2022-01-23 10:09:07,903] [INFO] Initialized pipeline components: ['tok2vec', 'ner']
[38;5;2m✔ Initialized pipeline[0m
[1m
[38;5;4mℹ Pipeline: ['tok2vec', 'ner'][0m
[38;5;4mℹ Initial learn rate: 0.001[0m
E    #       LOSS TOK2VEC  LOSS NER  ENTS_F  ENTS_P  ENTS_R  SCORE 
---  ------  ------------  --------  ------  ------  ------  ------
  0       0          0.00     30.50    1.00    2.60    0.62    0.01
  0     200          9.93   1443.22   17.92   18.20   17.64    0.18
[38;5;2m✔ Saved pipeline to output directory[0m
/tmp/model/model-last


In [18]:
# The cells below have the code for visualizing the dataset with Rubrix.
# They will be integrated with the other sections in future versions of this draft.

# Here we just halt the notebook as we don't wish to run the cells below.
raise Exception

Exception: 

### Visualization with Rubrix

We can use Rubrix to visualize the outputs of our aggregation model. 

First we define a `doc_logger` function that will log the predictions produced by all our annotators to Rubrix.

In [None]:
from tqdm import tqdm
import rubrix as rb
from functools import partial

def doc_logger(
    texts, 
    docs, 
    rubrix_dataset="conll_2003", 
    log_spans=True, 
    log_ents=False, 
    log_blank=False, 
    annotation_agent="gold_standard"):
    
    def unroll_entities(entity_list):
        return [ (ent.label_, ent.start_char, ent.end_char) for ent in entity_list ]
    
    records = []
    
    for idx, doc in enumerate(tqdm(docs, total=len(docs))):
        
        tokens = [token.text for token in doc]
        
        record_kwargs = {
            "tokens": tokens,
            "text": texts[idx],
            "metadata": {
                "doc_index": idx
            },
        }
        
        if not tokens:
            continue
        
        entity_dict = dict()
        
        if log_ents and doc.ents:
            entity_dict[annotation_agent] = unroll_entities(doc.ents)
        if log_spans and doc.spans:
            for labelling_function, span_list in doc.spans.items():
                entity_dict[labelling_function] = unroll_entities(span_list)
        
        if any(entity_dict.keys()):
            for source, entities in entity_dict.items():
                
                if source == annotation_agent:
                    record_kwargs.update({
                        "annotation": entities,
                        "annotation_agent": source
                    })
                else:
                    record_kwargs.update({
                        "prediction": entities,
                        "prediction_agent": source
                    })
                    
                record = rb.TokenClassificationRecord(**record_kwargs)
                records.append(record)
        
        elif log_blank:
            entities = unroll_entities(doc.ents)
            record = rb.TokenClassificationRecord(**record_kwargs)
            records.append(record)
            
    if records:
        rb.log(records=records, name=rubrix_dataset)

In [None]:
def log_to_rubrix(tokens, docs, limit=None, **kwargs):
    texts = [ " ".join(x) for x in tokens ]
    if limit:
        text_sample = texts[:limit]
        doc_sample = docs[:limit]
    else:
        text_sample = texts
        doc_sample = docs
    doc_logger(
        text_sample,
        doc_sample,
        **kwargs
    )

log_to_rubrix(dataset["validation"]["tokens"], dev_docs, limit=3000, rubrix_dataset="conll_2003_dev", log_ents=True)
log_to_rubrix(dataset["train"]["tokens"], train_docs, limit=None, rubrix_dataset="conll_2003_train", log_blank=True)