Skip to content

Commit

Permalink
Ensuring _id and _meta are present in IndexedDataset.examples
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 553525287
  • Loading branch information
RyanMullins authored and LIT team committed Aug 3, 2023
1 parent 0146d5f commit 50fc3a4
Show file tree
Hide file tree
Showing 3 changed files with 468 additions and 122 deletions.
101 changes: 62 additions & 39 deletions lit_nlp/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,44 @@
# limitations under the License.
# ==============================================================================
"""Base classes for LIT models."""
from collections.abc import Callable, Mapping, Sequence
import hashlib
import glob
import inspect
import os
import random
from types import MappingProxyType # pylint: disable=g-importing-member
from typing import Callable, Mapping, Optional, Sequence, Union, cast
import types
from typing import Optional, Union, cast

from absl import logging
from lit_nlp.api import types
from lit_nlp.api import types as lit_types
from lit_nlp.lib import serialize
from lit_nlp.lib import utils

JsonDict = types.JsonDict
IndexedInput = types.IndexedInput
ExampleId = types.ExampleId
Spec = types.Spec
ExampleId = lit_types.ExampleId
IdFnType = Callable[[lit_types.JsonDict], lit_types.ExampleId]
IndexedInput = lit_types.IndexedInput
JsonDict = lit_types.JsonDict
Spec = lit_types.Spec

LIT_FILE_EXTENSION = '.lit.jsonl'
LIT_SPEC_EXTENSION = '.spec'

INPUT_ID_FIELD = '_id'
INPUT_META_FIELD = '_meta'
INPUT_INTERNAL_FIELDS = (INPUT_ID_FIELD, INPUT_META_FIELD)


# This is used here and in caching.py, but we define here to avoid a circular
# dependency of dataset -> caching -> model -> dataset
def input_hash(example: types.Input) -> types.ExampleId:
def input_hash(example: lit_types.JsonDict) -> lit_types.ExampleId:
"""Create stable hash of an input example."""
raw_example = {k: v for k, v in example.items() if k not in ('_id', '_meta')}
json_str = serialize.to_json(raw_example, simple=True, sort_keys=True).encode(
'utf-8'
)
return types.ExampleId(hashlib.md5(json_str).hexdigest())
raw_example = {
k: v for k, v in example.items()
if k not in INPUT_INTERNAL_FIELDS
}
json_str = serialize.to_json(raw_example, simple=True, sort_keys=True)
return lit_types.ExampleId(hashlib.md5(json_str.encode('utf-8')).hexdigest())


def write_examples(examples: Sequence[JsonDict], path: str):
Expand Down Expand Up @@ -123,7 +130,7 @@ def description(self) -> str:
return self._description or inspect.getdoc(self) or '' # pytype: disable=bad-return-type

@classmethod
def init_spec(cls) -> Optional[types.Spec]:
def init_spec(cls) -> Optional[lit_types.Spec]:
"""Attempts to infer a Spec describing a Dataset's constructor parameters.
The Dataset base class attempts to infer a Spec for the constructor using
Expand All @@ -141,7 +148,7 @@ def init_spec(cls) -> Optional[types.Spec]:
could not be inferred.
"""
try:
spec = types.infer_spec_for_func(cls.__init__)
spec = lit_types.infer_spec_for_func(cls.__init__)
except TypeError as e:
spec = None
logging.warning(
Expand Down Expand Up @@ -242,25 +249,32 @@ def bytes_from_lit_example(lit_example: JsonDict) -> bytes:
return serialize.to_json(lit_example).encode('utf-8')


IdFnType = Callable[[types.Input], ExampleId]


class IndexedDataset(Dataset):
"""Dataset with additional indexing information."""

_index: dict[ExampleId, IndexedInput] = {}

def _normalize_example(
self, data: JsonDict, ex_id: ExampleId, meta: lit_types.InputMetadata
):
return types.MappingProxyType(dict(data, _id=ex_id, _meta=meta))

def index_inputs(
self, examples: list[types.Input]
self, examples: list[lit_types.JsonDict]
) -> list[IndexedInput]:
"""Create indexed versions of inputs."""
indexed = []
for example in examples:
ex_id = self.id_fn(example)
ex_meta = types.InputMetadata(added=None, parentId=None, source=None)
ex_id = example.get(INPUT_ID_FIELD, self.id_fn(example))
ex_meta = example.get(
INPUT_META_FIELD,
lit_types.InputMetadata(added=None, parentId=None, source=None),
)
indexed.append(
IndexedInput(
data=MappingProxyType(example | {'_id': ex_id, '_meta': ex_meta}),
data=types.MappingProxyType(
example | {INPUT_ID_FIELD: ex_id, INPUT_META_FIELD: ex_meta}
),
id=ex_id,
meta=ex_meta,
)
Expand All @@ -270,19 +284,28 @@ def index_inputs(
def __init__(
self,
*args,
id_fn: IdFnType = input_hash,
id_fn: Optional[IdFnType] = None,
indexed_examples: Optional[list[IndexedInput]] = None,
**kw,
):
# The base Dataset class will initialize self._examples in this call to
# super().__init__(), which may or may not include the _id and _meta fields.
super().__init__(*args, **kw)
self.id_fn = id_fn
self.id_fn = id_fn if id_fn is not None else input_hash

if indexed_examples:
self._indexed_examples = indexed_examples
self._examples = [MappingProxyType(
{k: v for k, v in ex['data'].items() if k not in ('_id', '_meta')}
) for ex in indexed_examples]
# Ensure that all indexed exampls provide a readonly view of their data.
for ie in self._indexed_examples:
if not isinstance((ie_data := ie['data']), types.MappingProxyType):
ie['data'] = self._normalize_example(ie_data, ie['id'], ie['meta'])
else:
self._indexed_examples = self.index_inputs(self._examples)

self._examples = [
self._normalize_example(ex['data'], ex['id'], ex.get('meta', {}))
for ex in self._indexed_examples
]
self._index = {ex['id']: ex for ex in self._indexed_examples}

@property
Expand All @@ -293,7 +316,8 @@ def _slicer(slice_obj):
return IndexedDataset(
indexed_examples=self.indexed_examples[slice_obj],
id_fn=self.id_fn,
base=self)
base=self
)

return SliceWrapper(_slicer)

Expand All @@ -309,7 +333,7 @@ def indexed_examples(self) -> Sequence[IndexedInput]:
@property
def index(self) -> Mapping[ExampleId, IndexedInput]:
"""Return a read-only view of the index."""
return MappingProxyType(self._index)
return types.MappingProxyType(self._index)

def save(self, examples: list[IndexedInput], path: str):
"""Save newly-created datapoints to disk.
Expand All @@ -325,7 +349,8 @@ def save(self, examples: list[IndexedInput], path: str):
# datasets can override. Then also save in the lit json format and save
# the spec as well.
if not path.endswith(LIT_FILE_EXTENSION):
self._base.save(examples, path)
if (base_dataset := self._base) is not None:
base_dataset.save(examples, path)
path += LIT_FILE_EXTENSION

write_examples(examples, path)
Expand All @@ -347,7 +372,9 @@ def load(self, path: str):
# Try to load data using the base load method. If any data is
# returned, then use that. Otherwise try loading the lit json extension
# data format.
new_dataset = self._base.load(path)
base_dataset = self._base
new_dataset = base_dataset.load(path) if base_dataset else None

if new_dataset is not None:
description = (f'{len(new_dataset)} examples from '
f'{path}\n{self._base.description()}')
Expand Down Expand Up @@ -400,16 +427,12 @@ def load_lit_format(
with open(path, 'r') as fd:
examples = [serialize.from_json(line) for line in fd.readlines()]

first_example_keys = set(examples[0].keys())
# TODO(b/171513556, b/204318513): remove this once input representations are
# consolidated.
if first_example_keys == {'id', 'data', 'meta'} or first_example_keys == {
'id',
'data',
}:
first_example_keys = set(ex.keys() if (ex := examples[0]) else [])
# TODO(b/294233896): remove this once input representations are consolidated.
if first_example_keys.issuperset({'id', 'data'}):
return IndexedDataset(
spec=spec,
indexed_examples=cast(list[types.IndexedInput], examples),
indexed_examples=cast(list[lit_types.IndexedInput], examples),
id_fn=id_fn,
*args,
**kw,
Expand Down

0 comments on commit 50fc3a4

Please sign in to comment.