Skip to content

Commit

Permalink
First version of the spaCy usas rule based tagger #4
Browse files Browse the repository at this point in the history
  • Loading branch information
apmoore1 committed Nov 30, 2021
1 parent 53c0076 commit 75770b1
Show file tree
Hide file tree
Showing 6 changed files with 439 additions and 166 deletions.
23 changes: 23 additions & 0 deletions docs/docs/api/file_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,29 @@

---

<a id="pymusas.file_utils.ensure_path"></a>

### ensure\_path

```python
def ensure_path(path: Union[str, Path]) -> Path
```

Ensure string is converted to a Path.

This is a more restrictive version of spaCy's [ensure_path](https://github.com/explosion/spaCy/blob/ac05de2c6c708e33ebad6c901e674e1e8bdc0688/spacy/util.py#L358)

<h4 id="ensure_path.parameters">Parameters<a className="headerlink" href="#ensure_path.parameters" title="Permanent link">&para;</a></h4>


- __path__ : `Union[str, Path]` <br/>
If string, it's converted to Path.

<h4 id="ensure_path.returns">Returns<a className="headerlink" href="#ensure_path.returns" title="Permanent link">&para;</a></h4>


- `Path` <br/>

<a id="pymusas.file_utils.download_url_file"></a>

### download\_url\_file
Expand Down
21 changes: 21 additions & 0 deletions pymusas/file_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from hashlib import sha256
import os
from pathlib import Path
from typing import Union

import requests
from requests.adapters import HTTPAdapter
Expand All @@ -10,6 +11,26 @@
from . import config


def ensure_path(path: Union[str, Path]) -> Path:
"""
Ensure string is converted to a Path.
This is a more restrictive version of spaCy's [ensure_path](https://github.com/explosion/spaCy/blob/ac05de2c6c708e33ebad6c901e674e1e8bdc0688/spacy/util.py#L358)
# Parameters
path : `Union[str, Path]`
If string, it's converted to Path.
# Returns
`Path`
"""
if isinstance(path, str):
return Path(path)
return path


def _session_with_backoff() -> requests.Session:
"""
We ran into an issue where http requests to s3 were timing out,
Expand Down
169 changes: 138 additions & 31 deletions pymusas/spacy_api/taggers/rule_based.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import copy
import logging
from typing import Callable, Dict, Iterable, List, Optional
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional, Union, cast

from spacy.language import Language
from spacy.pipe_analysis import validate_attrs
from spacy.tokens import Doc, Token
from spacy.training import Example
from spacy.util import SimpleFrozenList
import srsly

from ...config import LANG_LEXICON_RESOUCRE_MAPPER
from ...lexicon_collection import LexiconCollection
from ...file_utils import ensure_path
from ...taggers.rule_based import _tag_token


Expand All @@ -32,10 +35,10 @@ def __init__(self,
if lemma_lexicon_lookup is not None:
self.lemma_lexicon_lookup = lemma_lexicon_lookup

self.usas_tags_token_attr = usas_tags_token_attr
self.pos_mapper = pos_mapper
self.pos_attribute = pos_attribute
self.lemma_attribute = lemma_attribute
self._usas_tags_token_attr = usas_tags_token_attr
self._pos_attribute = pos_attribute
self._lemma_attribute = lemma_attribute

if Token.has_extension(self.usas_tags_token_attr):
old_extension = Token.get_extension(self.usas_tags_token_attr)
Expand Down Expand Up @@ -66,6 +69,49 @@ def __init__(self,
usas_factory_meta.requires = [f'token.{lemma_required}',
f'token.{pos_required}']

@property
def usas_tags_token_attr(self) -> str:
return self._usas_tags_token_attr

@usas_tags_token_attr.setter
def usas_tags_token_attr(self, value: str) -> None:
usas_factory_meta = Language.get_factory_meta('usas_tagger')
usas_factory_meta.assigns = validate_attrs([f'token._.{value}'])
self._usas_tags_token_attr = value

@property
def pos_attribute(self) -> str:
return self._pos_attribute

@pos_attribute.setter
def pos_attribute(self, value: str) -> None:
value_required = value
if value_required in ['pos_', 'tag_']:
value_required = value_required[:-1]
value_required_to_remove = f'token.{self._pos_attribute}'
if value_required_to_remove in ['token.pos_', 'token.tag_']:
value_required_to_remove = value_required_to_remove[:-1]

self._update_factory_attributes(value_required, value_required_to_remove)

self._pos_attribute = value

@property
def lemma_attribute(self) -> str:
return self._lemma_attribute

@lemma_attribute.setter
def lemma_attribute(self, value: str) -> None:
value_required = value
if value_required == 'lemma_':
value_required = value_required[:-1]
value_required_to_remove = f'token.{self._lemma_attribute}'
if value_required_to_remove == 'token.lemma_':
value_required_to_remove = value_required_to_remove[:-1]
self._update_factory_attributes(value_required, value_required_to_remove)

self._lemma_attribute = value

def __call__(self, doc: Doc) -> Doc:
for token in doc:
text = token.text
Expand All @@ -82,36 +128,97 @@ def __call__(self, doc: Doc) -> Doc:

def initialize(self, get_examples: Optional[Callable[[], Iterable[Example]]] = None,
nlp: Optional[Language] = None,
lexicon_lookup_data: Optional[Dict[str, List[str]]] = None,
lexicon_lookup: Optional[Dict[str, List[str]]] = None,
lemma_lexicon_lookup: Optional[Dict[str, List[str]]] = None,
pos_mapper: Optional[Dict[str, List[str]]] = None,
usas_tags_token_attr: str = 'usas_tags',
pos_attribute: str = 'pos_',
lemma_attribute: str = 'lemma_'
) -> None:

def any_data(lexicon_data: List[Optional[Dict[str, List[str]]]]) -> bool:
return any(lexicon_data)

all_lexicon_data = [lexicon_lookup_data, lemma_lexicon_lookup]
if not any_data(all_lexicon_data) and nlp is not None:
nlp_language = nlp.lang
if nlp_language in LANG_LEXICON_RESOUCRE_MAPPER:
lang_lexicon_info = LANG_LEXICON_RESOUCRE_MAPPER[nlp_language]
lexicon_lookup_data = LexiconCollection.from_tsv(lang_lexicon_info['lexicon'], include_pos=True)
lemma_lexicon_lookup = LexiconCollection.from_tsv(lang_lexicon_info['lexicon_lemma'], include_pos=False)

all_lexicon_data = [lexicon_lookup_data, lemma_lexicon_lookup]
if lexicon_lookup_data is not None:
self.lexicon_lookup = lexicon_lookup_data
if lexicon_lookup is not None:
self.lexicon_lookup = lexicon_lookup
if lemma_lexicon_lookup is not None:
self.lexicon_lemma_lookup = lemma_lexicon_lookup
self.lemma_lexicon_lookup = lemma_lexicon_lookup
self.pos_mapper = pos_mapper
self.usas_tags_token_attr = usas_tags_token_attr
self.pos_attribute = pos_attribute
self.lemma_attribute = lemma_attribute

def from_bytes(self, bytes_data: bytes, *,
exclude: Iterable[str] = SimpleFrozenList()
) -> "USASRuleBasedTagger":
serialise_data = srsly.msgpack_loads(bytes_data)
self.lexicon_lookup = srsly.msgpack_loads(serialise_data['lexicon_lookup'])
self.lemma_lexicon_lookup = srsly.msgpack_loads(serialise_data['lemma_lexicon_lookup'])
self.pos_mapper = srsly.msgpack_loads(serialise_data['pos_mapper'])
self.usas_tags_token_attr = srsly.msgpack_loads(serialise_data['usas_tags_token_attr'])
self.pos_attribute = srsly.msgpack_loads(serialise_data['pos_attribute'])
self.lemma_attribute = srsly.msgpack_loads(serialise_data['lemma_attribute'])
return self

def to_bytes(self, *, exclude: Iterable[str] = SimpleFrozenList()) -> bytes:
serialise = {}
serialise["lexicon_lookup"] = srsly.msgpack_dumps(self.lexicon_lookup)
serialise["lemma_lexicon_lookup"] = srsly.msgpack_dumps(self.lemma_lexicon_lookup)
serialise["pos_mapper"] = srsly.msgpack_dumps(self.pos_mapper)
serialise["usas_tags_token_attr"] = srsly.msgpack_dumps(self.usas_tags_token_attr)
serialise["pos_attribute"] = srsly.msgpack_dumps(self.pos_attribute)
serialise["lemma_attribute"] = srsly.msgpack_dumps(self.lemma_attribute)
return cast(bytes, srsly.msgpack_dumps(serialise))

def from_disk(self, path: Union[str, Path], *,
exclude: Iterable[str] = SimpleFrozenList()
) -> "USASRuleBasedTagger":
component_folder = ensure_path(path)
lexicon_file = Path(component_folder, 'lexicon_lookup.json')
if lexicon_file.exists():
with lexicon_file.open('r', encoding='utf-8') as lexicon_data:
self.lexicon_lookup = srsly.json_loads(lexicon_data.read())
lemma_lexicon_file = Path(component_folder, 'lemma_lexicon_lookup.json')
if lemma_lexicon_file.exists():
with lemma_lexicon_file.open('r', encoding='utf-8') as lemma_lexicon_data:
self.lemma_lexicon_lookup = srsly.json_loads(lemma_lexicon_data.read())
pos_mapper_file = Path(component_folder, 'pos_mapper.json')
if pos_mapper_file.exists():
with pos_mapper_file.open('r', encoding='utf-8') as pos_mapper_data:
self.pos_mapper = srsly.json_loads(pos_mapper_data.read())
with Path(component_folder, 'attribute_data.json').open('r', encoding='utf-8') as attribute_file:
attribute_data = srsly.json_loads(attribute_file.read())
self.usas_tags_token_attr = attribute_data['usas_tags_token_attr']
self.pos_attribute = attribute_data['pos_attribute']
self.lemma_attribute = attribute_data['lemma_attribute']
return self

def to_disk(self, path: Union[str, Path], *,
exclude: Iterable[str] = SimpleFrozenList()
) -> None:
component_folder = ensure_path(path)
component_folder.mkdir(exist_ok=True)
if self.lexicon_lookup:
with Path(component_folder, 'lexicon_lookup.json').open('w', encoding='utf-8') as lexicon_file:
lexicon_file.write(srsly.json_dumps(self.lexicon_lookup))
if self.lemma_lexicon_lookup:
with Path(component_folder, 'lemma_lexicon_lookup.json').open('w', encoding='utf-8') as lexicon_file:
lexicon_file.write(srsly.json_dumps(self.lemma_lexicon_lookup))
if self.pos_mapper is not None:
with Path(component_folder, 'pos_mapper.json').open('w', encoding='utf-8') as pos_mapper_file:
pos_mapper_file.write(srsly.json_dumps(self.pos_mapper))
attribute_data = {'usas_tags_token_attr': self.usas_tags_token_attr,
'pos_attribute': self.pos_attribute,
'lemma_attribute': self.lemma_attribute}
with Path(component_folder, 'attribute_data.json').open('w', encoding='utf-8') as attribute_data_file:
attribute_data_file.write(srsly.json_dumps(attribute_data))

@classmethod
def _update_factory_attributes(cls, new_attribute_name: str, old_attribute_name: str) -> None:
usas_factory_meta = Language.get_factory_meta('usas_tagger')
required_attributes = copy.deepcopy(usas_factory_meta.requires)
updated_attributes = [attribute for attribute in required_attributes
if attribute != old_attribute_name]
updated_attributes.append(f'token.{new_attribute_name}')

if not any_data(all_lexicon_data):
error_msg = ('Missing data for initialisation. No data has '
'been explicitly passed.')
if nlp is not None:
supported_languages = '\n'.join(LANG_LEXICON_RESOUCRE_MAPPER.keys())
error_msg += (' In addition the Spacy language you are using '
'is not supported by our list of pre-complied '
f'lexicons:\n{supported_languages}')
raise ValueError(error_msg)
usas_factory_meta.requires = updated_attributes


@Language.factory("usas_tagger")
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ module = [
'urllib3.*',
'tqdm.*',
'pydoc_markdown.*',
'spacy.vocab'
'spacy.vocab',
'srsly.*'
]
ignore_missing_imports = true

Expand Down

0 comments on commit 75770b1

Please sign in to comment.