Skip to content

Commit

Permalink
Interface to fetch entries in primitive types from DataPack (#900)
Browse files Browse the repository at this point in the history
* Allow raw output in get and multipack group bug fix

* type_fix

* modified interface

* Title Bug fix

* is_created_by bug fix

* docstring changes

* Update setup.py

* Update top.py

* Update top.py

* Update top.py

* Update top.py

Co-authored-by: mylibrar <54747962+mylibrar@users.noreply.github.com>
  • Loading branch information
Pushkar-Bhuse and mylibrar committed Jan 3, 2023
1 parent 72e8bce commit fd717ff
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 57 deletions.
8 changes: 5 additions & 3 deletions forte/data/base_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,13 +803,13 @@ def get_ids_by_creator(self, component: str) -> Set[int]:
return entry_set

def is_created_by(
self, entry: Entry, components: Union[str, Iterable[str]]
self, entry: Union[Entry, int], components: Union[str, Iterable[str]]
) -> bool:
"""
Check if the entry is created by any of the provided components.
Args:
entry: The entry to check.
entry: `tid` of the entry or the entry object to check
components: The list of component names.
Returns:
Expand All @@ -818,8 +818,10 @@ def is_created_by(
if isinstance(components, str):
components = [components]

entry_tid = entry.tid if isinstance(entry, Entry) else entry

for c in components:
if entry.tid in self._creation_records[c]:
if entry_tid in self._creation_records[c]:
break
else:
# The entry not created by any of these components.
Expand Down
101 changes: 77 additions & 24 deletions forte/data/data_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
ProcessExecutionException,
UnknownOntologyClassException,
)
from forte.common.constants import TID_INDEX
from forte.common.constants import TID_INDEX, BEGIN_ATTR_NAME, END_ATTR_NAME
from forte.data import data_utils_io
from forte.data.data_store import DataStore
from forte.data.entry_converter import EntryConverter
Expand All @@ -49,10 +49,11 @@
Annotation,
Link,
Group,
SinglePackEntries,
Generics,
AudioAnnotation,
Payload,
SinglePackEntries,
AnnotationLikeEntries,
)

from forte.data.modality import Modality
Expand Down Expand Up @@ -1503,9 +1504,12 @@ def covers(
def get( # type: ignore
self,
entry_type: Union[str, Type[EntryType]],
range_annotation: Optional[Union[Annotation, AudioAnnotation]] = None,
range_annotation: Optional[
Union[Annotation, AudioAnnotation, int]
] = None,
components: Optional[Union[str, Iterable[str]]] = None,
include_sub_type: bool = True,
get_raw: bool = False,
) -> Iterable[EntryType]:
r"""This function is used to get data from a data pack with various
methods.
Expand Down Expand Up @@ -1571,28 +1575,47 @@ def get( # type: ignore
Args:
entry_type: The type of entries requested.
range_annotation: The
range of entries requested. If `None`, will return valid
entries in the range of whole data pack.
range of entries requested. This value can be given by an
entry object or the ``tid`` of that entry. If `None`, will
return valid entries in the range of whole data pack.
components: The component (creator)
generating the entries requested. If `None`, will return valid
entries generated by any component.
include_sub_type: whether to consider the sub types of
the provided entry type. Default `True`.
get_raw: boolean to indicate if the entry should be returned in
its primitive form as opposed to an object. False by default
Yields:
Each `Entry` found using this method.
"""
entry_type_: Type[EntryType] = as_entry_type(entry_type)
# Convert entry_type to str
entry_type_ = (
get_full_module_name(entry_type)
if not isinstance(entry_type, str)
else entry_type
)

# pylint: disable=protected-access
# Check if entry_type_ represents a valid entry
if not self._data_store._is_subclass(entry_type_, Entry):
raise ValueError(
f"The specified entry type [{entry_type}] "
f"does not correspond to a "
f"`forte.data.ontology.core.Entry` class"
)

def require_annotations(entry_class=Annotation) -> bool:
if issubclass(entry_type_, entry_class):
if self._data_store._is_subclass(entry_type_, entry_class):
return True
if issubclass(entry_type_, Link):

curr_class: Type[EntryType] = as_entry_type(entry_type_)
if issubclass(curr_class, Link):
return issubclass(
entry_type_.ParentType, entry_class
) and issubclass(entry_type_.ChildType, entry_class)
if issubclass(entry_type_, Group):
return issubclass(entry_type_.MemberType, entry_class)
curr_class.ParentType, entry_class
) and issubclass(curr_class.ChildType, entry_class)
if issubclass(curr_class, Group):
return issubclass(curr_class.MemberType, entry_class)
return False

# If we don't have any annotations but the items to check requires them,
Expand Down Expand Up @@ -1631,28 +1654,58 @@ def require_annotations(entry_class=Annotation) -> bool:
yield from []
return

# If range_annotation is specified, we record its begin and
# end index
range_begin: int
range_end: int

if range_annotation is not None:
if isinstance(range_annotation, AnnotationLikeEntries):
range_begin = range_annotation.begin
range_end = range_annotation.end
else:
# range_annotation is given by the tid of the entry it
# represents
range_raw = self._data_store.transform_data_store_entry(
self.get_entry_raw(range_annotation)
)
range_begin = range_raw[BEGIN_ATTR_NAME]
range_end = range_raw[END_ATTR_NAME]

try:
for entry_data in self._data_store.get(
type_name=get_full_module_name(entry_type_),
type_name=entry_type_,
include_sub_type=include_sub_type,
range_span=range_annotation # type: ignore
and (range_annotation.begin, range_annotation.end),
and (range_begin, range_end),
):
entry: Entry = self.get_entry(tid=entry_data[TID_INDEX])

# Filter by components
if components is not None:
if not self.is_created_by(entry, components):
if not self.is_created_by(
entry_data[TID_INDEX], components
):
continue

# Filter out incompatible audio span comparison for Links and Groups
if (
issubclass(entry_type_, (Link, Group))
and isinstance(range_annotation, AudioAnnotation)
and not self._index.in_audio_span(
entry, range_annotation.span
entry: Union[Entry, Dict[str, Any]]
if get_raw:
entry = self._data_store.transform_data_store_entry(
entry_data
)
):
continue
else:
entry = self.get_entry(tid=entry_data[TID_INDEX])

# Filter out incompatible audio span comparison for Links and Groups
if (
self._data_store._is_subclass(
entry_type_, (Link, Group)
)
and isinstance(range_annotation, AudioAnnotation)
and not self._index.in_audio_span(
entry, range_annotation.span
)
):
continue

yield entry # type: ignore
except ValueError:
Expand Down
102 changes: 100 additions & 2 deletions forte/data/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,9 +813,22 @@ def fetch_entry_type_data(
# ie. NoneType.
if attr_class is None:
attr_class = type(None)
attr_args = get_args(attr_info.type)
if len(attr_args) == 0:
raw_attr_args = get_args(attr_info.type)
if len(raw_attr_args) == 0:
attr_args = tuple([attr_info.type])
else:
attr_args = ()
for args in raw_attr_args:
# This is the case when we have a multidimensional
# type attribute like List[Tuple[int, int]]. In this
# case get_args will return a tuple of tuples that
# looks like ((Tuple, int, int),). We thus convert
# this into a single dimensional tuple -
# (Tuple, int, int).
if isinstance(args, tuple):
attr_args += args
else:
attr_args += (args,)

# Prior to Python 3.7, fetching generic type
# aliases resulted in actual type objects whereas from
Expand Down Expand Up @@ -1321,6 +1334,91 @@ def _get_existing_ann_entry_tid(self, entry: List[Any]):
"getting entry id for annotation-like entry."
)

def get_attribute_positions(self, type_name: str) -> Dict[str, int]:
r"""This function returns a dictionary where the key represents
the attributes of the entry of type ``type_name`` and value
represents the index of the position where this attribute is
stored in the data store entry of this type.
For example:
.. code-block:: python
positions = data_store.get_attribute_positions(
"ft.onto.base_ontology.Document"
)
# positions = {
# "begin": 2,
# "end": 3,
# "payload_idx": 4,
# "document_class": 5,
# "sentiment": 6,
# "classifications": 7
# }
Args:
type_name (str): The fully qualified type name of a type.
Returns:
A dictionary indicating the attributes of an entry of type
``type_name`` and their respective positions in a data store
entry.
"""
type_data = self._get_type_info(type_name)

positions: Dict[str, int] = {}
for attr, val in type_data[constants.ATTR_INFO_KEY].items():
positions[attr] = val[constants.ATTR_INDEX_KEY]

return positions

def transform_data_store_entry(self, entry: List[Any]) -> Dict:
r"""
This method converts a raw data store entry into a format more easily
understandable to users. Data Store entries are stored as lists and
are not very easily understandable. This method converts ``DataStore``
entries from a list format to a dictionary based format where the key
is the names of the attributes of an entry and the value is the values
corresponding attributes in the data store entry.
For example:
.. code-block:: python
>>> data_store = DataStore()
>>> tid = data_store.add_entry_raw(
... type_name = 'ft.onto.base_ontology.Sentence',
... tid = 101, attribute_data = [0,10])
>>> entry = data_store.get_entry(tid)[0]
>>> transformed_entry = data_store.transform_data_store_entry(entry)
>>> transformed_entry == { 'begin': 0, 'end': 10, 'payload_idx': 0,
... 'speaker': None, 'part_id': None, 'sentiment': {},
... 'classification': {}, 'classifications': {}, 'tid': 101,
... 'type': 'ft.onto.base_ontology.Sentence'}
True
Args:
entry: A list representing a valid data store entry
Returns:
a dictionary representing the the input data store entry
"""

attribute_positions = self.get_attribute_positions(
entry[constants.ENTRY_TYPE_INDEX]
)

# We now convert the entry from data store format (list) to user
# representation format (dict) to make the contents of the entry more
# understandable.
user_rep: Dict[str, Any] = {}
for attr, pos in attribute_positions.items():
user_rep[attr] = entry[pos]

user_rep["tid"] = entry[constants.TID_INDEX]
user_rep["type"] = entry[constants.ENTRY_TYPE_INDEX]

return user_rep

def set_attribute(self, tid: int, attr_name: str, attr_value: Any):
r"""This function locates the entry data with ``tid`` and sets its
``attr_name`` with `attr_value`. It first finds ``attr_id`` according
Expand Down

0 comments on commit fd717ff

Please sign in to comment.