The Python 3.8.12 environment is used to run this notebook.

### Install Watchful

In [1]:
import sys
!{sys.executable} -m pip install watchful

### Create your preferred enricher

In [2]:
"""
Your enricher should inherit from Watchful's `Enricher` interface and
implement the methods `__init__` and `enrich_row` with the same 
signatures.
"""

import os
import pprint
import sys
from typing import Iterable, List

from watchful import attributes
from watchful.enricher import Enricher

pprint._sorted = lambda x: x
pprint = pprint.PrettyPrinter(indent=4).pprint


class NEREnricher(Enricher):
    """
    This is an example of a customized enricher class that inherits from the
    `Enricher` interface, with subsequent implementation of the methods
    `__init__` and `enrich_row` with the same signatures.
    """

    def __init__(
        self,
    ) -> None:
        """
        In this function, we create variables that we will later use in
        `enrich_row` to enrich our data row-wise. Refer to Watchful
        documentation on creating attribute spans.
        """

        global Sentence
        from flair.data import Sentence
        from flair.models import SequenceTagger
        import logging
        import warnings

        logging.getLogger("flair").setLevel(logging.ERROR)
        warnings.filterwarnings("ignore", module="huggingface_hub")

        tagger = SequenceTagger.load("ner")

        def predict(sent: Sentence) -> None:
            tagger.predict(sent)

        self.enrichment_args = (predict,)

    def enrich_row(
        self,
        row: Iterable[str],
    ) -> List[attributes.EnrichedCell]:
        """
        In this function, we use our variables from `self.enrichment_args` to
        enrich every row of your data. The return value is our enriched row.
        """

        predict, = self.enrichment_args

        enriched_row = []

        for raw_cell in row:
            sent = Sentence(raw_cell)
            predict(sent)

            enriched_cell = []

            ent_spans = []
            ent_values = []
            ent_scores = []
            for ent in sent.get_spans("ner"):
                ent_spans.append((ent.start_position, ent.end_position))
                ent_values.append(ent.get_label("ner").value)
                ent_scores.append(
                    str(int(round(ent.get_label("ner").score, 2) * 100))
                )
            enriched_cell.append(
                (ent_spans, {"entity": ent_values, "score": ent_scores}, "ENTS")
            )

            attributes.adjust_span_offsets_from_char_to_byte(
                raw_cell,
                enriched_cell
            )

            enriched_row.append(enriched_cell)

        """
        Prints your enriched row so you can see the intermediate output in
        this notebook later on. Comment this if you are enriching a large
        dataset.
        """
        print("Enriched row:")
        pprint(enriched_row)
        print("*" * 80)

        return enriched_row


### Enrich your dataset

In [3]:
from watchful.enrich import enrich_dataset

enrich_dataset(NEREnricher)


Using your custom enricher...
Enriching:  /Users/runyan/watchful/datasets/raw/0e9987ff5118ceac918153270be44b1ccf132f69083dbc38935aa9d200296c42 ...
Enriched row:
[   [   (   [   (111, 121),
                (168, 176),
                (197, 207),
                (254, 267),
                (334, 339),
                (342, 353),
                (386, 394),
                (487, 492),
                (495, 504),
                (521, 534),
                (629, 638),
                (697, 704),
                (723, 736),
                (738, 745),
                (803, 811),
                (847, 855)],
            {   'entity': [   'MISC',
                              'LOC',
                              'MISC',
                              'PER',
                              'MISC',
                              'PER',
                              'PER',
                              'MISC',
                              'PER',
                              'PER',
                

### Query your enriched dataset in Watchful 

In [4]:
from watchful.client import query

def query_to_fields(query_str):
    query_res = query(query_str)
    return [m["fields"] for m in query_res["candidates"]]

for entity_value in ["LOC", "ORG", "PER", "MISC"]:
    print(entity_value)
    print(query_to_fields(f"TOKS: [entity {entity_value}]"))
    print("*" * 80)

LOC
[['Scotland', ''], ['USA', ''], ['', 'Haddonfield'], ['', 'Illinois']]
********************************************************************************
ORG
[['Montrose', '']]
********************************************************************************
PER
[['Jessica', ''], ['Lange', ''], ['Liam', ''], ['Neeson', ''], ['Tim', ''], ['Roth', ''], ['John', ''], ['Hurt', '']]
********************************************************************************
MISC
[['Braveheart', ''], ['Braveheart', ''], ['Oscar', ''], ['Oscar', ''], ['', 'Lady'], ['', 'and'], ['', 'the'], ['', 'Tramp']]
********************************************************************************
