Skip to content

Commit

Permalink
feat: prepare_for_training supports spacy (#1635)
Browse files Browse the repository at this point in the history
Refs #420
  • Loading branch information
frascuchon committed Jul 22, 2022
1 parent 08af717 commit 00f9197
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 4 deletions.
95 changes: 91 additions & 4 deletions src/rubrix/client/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
import functools
import logging
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import pandas as pd
Expand All @@ -38,7 +39,7 @@ def check_if_datasets_installed(*args, **kwargs):
import datasets
except ModuleNotFoundError:
raise ModuleNotFoundError(
"'datasets' must be installed to use `to_datasets`! "
f"'datasets' must be installed to use `{func.__name__}`! "
"You can install 'datasets' with the command: `pip install datasets>1.17.0`"
)
if not (parse_version(datasets.__version__) > parse_version("1.17.0")):
Expand All @@ -51,6 +52,21 @@ def check_if_datasets_installed(*args, **kwargs):
return check_if_datasets_installed


def _requires_spacy(func):
@functools.wraps(func)
def check_if_spacy_installed(*args, **kwargs):
try:
import spacy
except ModuleNotFoundError:
raise ModuleNotFoundError(
f"'spacy' must be installed to use `{func.__name__}`"
"You can install 'spacy' with the command: `pip install spacy`"
)
return func(*args, **kwargs)

return check_if_spacy_installed


class DatasetBase:
"""The Dataset classes are containers for Rubrix records.
Expand Down Expand Up @@ -640,6 +656,17 @@ def prepare_for_training(self) -> "datasets.Dataset":
)


class Framework(Enum):
TRANSFORMERS = "transformers"
SPACY = "spacy"

@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
)


@_prepend_docstring(TokenClassificationRecord)
class DatasetForTokenClassification(DatasetBase):
"""
Expand Down Expand Up @@ -762,8 +789,11 @@ def from_pandas(
) -> "DatasetForTokenClassification":
return super().from_pandas(dataframe)

@_requires_datasets
def prepare_for_training(self) -> "datasets.Dataset":
def prepare_for_training(
self,
framework: Union[Framework, str] = "transformers",
lang: Optional["spacy.Language"] = None,
) -> Union["datasets.Dataset", "spacy.tokens.DocBin"]:
"""Prepares the dataset for training.
This will return a ``datasets.Dataset`` with all columns returned by ``to_datasets`` method
Expand All @@ -772,8 +802,15 @@ def prepare_for_training(self) -> "datasets.Dataset":
- The *ner_tags* column corresponds to the iob tags sequences for annotations of the records
- The iob tags are transformed to integers.
Args:
framework: A string|enum specifying the framework for the training.
"transformers" and "spacy" are currently supported. Default: `transformers`
lang: The spacy nlp Language pipeline used to process the dataset. (Only for spacy framework)
Returns:
A datasets Dataset with a *ner_tags* column and all columns returned by ``to_datasets``.
A datasets Dataset with a *ner_tags* column and all columns returned by ``to_datasets`` for "transformers"
framework.
A spacy DocBin ready to use for training a spacy NER model for "spacy" framework.
Examples:
>>> import rubrix as rb
Expand Down Expand Up @@ -802,6 +839,22 @@ def prepare_for_training(self) -> "datasets.Dataset":
"""

# turn the string into a Framework instance and trigger error if str is not valid
if isinstance(framework, str):
framework = Framework(framework)

if framework is Framework.TRANSFORMERS:
return self._prepare_for_training_with_transformers()
# else: must be spacy for sure
if lang is None:
raise ValueError(
"Please provide a spacy language model to prepare the dataset for training with the spacy framework."
)
return self._prepare_for_training_with_spacy(nlp=lang)

@_requires_datasets
def _prepare_for_training_with_transformers(self):
import datasets

has_annotations = False
Expand Down Expand Up @@ -841,6 +894,40 @@ def spans2iob(example):

return ds.cast(new_features)

@_requires_spacy
def _prepare_for_training_with_spacy(
self, nlp: "spacy.Language"
) -> "spacy.tokens.DocBin":

from spacy.tokens import DocBin

db = DocBin()

# Creating the DocBin object as in https://spacy.io/usage/training#training-data
for record in self._records:
if record.annotation is None:
continue

doc = nlp(record.text)
entities = []

for anno in record.annotation:
span = doc.char_span(anno[1], anno[2], label=anno[0])
# There is a misalignment between record tokenization and spaCy tokenization
if span is None:
# TODO(@dcfidalgo): Do we want to warn and continue or should we stop the training set generation?
raise ValueError(
"The following annotation does not align with the tokens produced "
f"by the provided spacy language model: {(anno[0], record.text[anno[1]:anno[2]])}, {list(doc)}"
)
else:
entities.append(span)

doc.ents = entities
db.add(doc)

return db

def __all_labels__(self):
all_labels = set()
for record in self._records:
Expand Down
28 changes: 28 additions & 0 deletions tests/client/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import datasets
import pandas as pd
import pytest
import spacy

import rubrix as rb
from rubrix.client.datasets import (
Expand Down Expand Up @@ -557,6 +558,33 @@ def test_push_to_hub(self, tokenclassification_records):

assert isinstance(dataset_ds, datasets.Dataset)

@pytest.mark.skipif(
_HF_HUB_ACCESS_TOKEN is None,
reason="You need a HF Hub access token to test the push_to_hub feature",
)
def test_prepare_for_training_with_spacy(self):
ner_dataset = datasets.load_dataset(
"rubrix/gutenberg_spacy-ner",
use_auth_token=_HF_HUB_ACCESS_TOKEN,
split="train",
)
rb_dataset: DatasetForTokenClassification = rb.read_datasets(
ner_dataset, task="TokenClassification"
)
for r in rb_dataset:
r.annotation = [
(label, start, end) for label, start, end, _ in r.prediction
]

with pytest.raises(ValueError):
train = rb_dataset.prepare_for_training(framework="spacy")

train = rb_dataset.prepare_for_training(
framework="spacy", lang=spacy.blank("en")
)
assert isinstance(train, spacy.tokens.DocBin)
assert len(train) == 100

@pytest.mark.skipif(
_HF_HUB_ACCESS_TOKEN is None,
reason="You need a HF Hub access token to test the push_to_hub feature",
Expand Down

0 comments on commit 00f9197

Please sign in to comment.