Skip to content

Commit

Permalink
Dataset save/load helpers and unit tests.
Browse files Browse the repository at this point in the history
Move input_hash definition to dataset.py and make it the default for IndexedDataset.

PiperOrigin-RevId: 538835316
  • Loading branch information
iftenney authored and LIT team committed Jun 8, 2023
1 parent 5f4f7ee commit b7ce560
Show file tree
Hide file tree
Showing 7 changed files with 373 additions and 46 deletions.
83 changes: 66 additions & 17 deletions lit_nlp/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
# limitations under the License.
# ==============================================================================
"""Base classes for LIT models."""
import hashlib
import glob
import inspect
import os
import random
from types import MappingProxyType # pylint: disable=g-importing-member
from typing import cast, Optional, Callable, Mapping, Sequence
from typing import Callable, Mapping, Optional, Sequence, cast

from absl import logging

from lit_nlp.api import types
from lit_nlp.lib import serialize
from lit_nlp.lib import utils
Expand All @@ -35,6 +35,29 @@
LIT_SPEC_EXTENSION = '.spec'


# This is used here and in caching.py, but we define here to avoid a circular
# dependency of dataset -> caching -> model -> dataset
def input_hash(example: types.Input) -> types.ExampleId:
"""Create stable hash of an input example."""
json_str = serialize.to_json(example, simple=True, sort_keys=True).encode(
'utf-8'
)
return types.ExampleId(hashlib.md5(json_str).hexdigest())


def write_examples(examples: Sequence[JsonDict], path: str):
"""Write examples to disk as LIT JSONL format."""
with open(path, 'w') as fd:
for ex in examples:
fd.write(serialize.to_json(ex) + '\n')


def write_spec(spec: Spec, path: str):
"""Write spec to disk as LIT JSON format."""
with open(path, 'w') as fd:
fd.write(serialize.to_json(spec, indent=2))


class SliceWrapper(object):
"""Shim object to implement custom slicing via foo[a:b:c] rather than constructing a slice object explicitly."""

Expand Down Expand Up @@ -235,13 +258,14 @@ def index_inputs(self, examples: list[types.Input]) -> list[IndexedInput]:
]
# pylint: enable=g-complex-comprehension

def __init__(self,
*args,
id_fn: Optional[IdFnType] = None,
indexed_examples: Optional[list[IndexedInput]] = None,
**kw):
def __init__(
self,
*args,
id_fn: IdFnType = input_hash,
indexed_examples: Optional[list[IndexedInput]] = None,
**kw,
):
super().__init__(*args, **kw)
assert id_fn is not None, 'id_fn must be specified.'
self.id_fn = id_fn
if indexed_examples:
self._indexed_examples = indexed_examples
Expand Down Expand Up @@ -293,13 +317,8 @@ def save(self, examples: list[IndexedInput], path: str):
self._base.save(examples, path)
path += LIT_FILE_EXTENSION

with open(path, 'w') as fd:
for ex in examples:
fd.write(serialize.to_json(ex) + '\n')

spec_path = path + LIT_SPEC_EXTENSION
with open(spec_path, 'w') as fd:
fd.write(serialize.to_json(self.spec()))
write_examples(examples, path)
write_spec(self.spec(), path + LIT_SPEC_EXTENSION)

return path

Expand Down Expand Up @@ -338,8 +357,10 @@ def load(self, path: str):
with open(spec_path, 'r') as fd:
spec = serialize.from_json(fd.read())

description = (f'{len(examples)} examples from '
f'{path}\n{self._base.description()}')
description = f'{len(examples)} examples from {path}'
if self._base is not None:
description += '\n' + self._base.description()

return IndexedDataset(
base=self._base,
indexed_examples=examples,
Expand All @@ -356,6 +377,34 @@ def __eq__(self, other):
return self_ids == other_ids


def load_lit_format(
path: str, *args, id_fn=input_hash, **kw
) -> Dataset | IndexedDataset:
"""Load data from LIT jsonl format."""
with open(path + LIT_SPEC_EXTENSION, 'r') as fd:
spec = serialize.from_json(fd.read())

with open(path, 'r') as fd:
examples = [serialize.from_json(line) for line in fd.readlines()]

first_example_keys = set(examples[0].keys())
# TODO(b/171513556, b/204318513): remove this once input representations are
# consolidated.
if first_example_keys == {'id', 'data', 'meta'} or first_example_keys == {
'id',
'data',
}:
return IndexedDataset(
spec=spec,
indexed_examples=cast(list[types.IndexedInput], examples),
id_fn=id_fn,
*args,
**kw,
)
else:
return Dataset(spec=spec, examples=examples, *args, **kw)


class NoneDataset(Dataset):
"""Empty dataset, with fields as the union of model specs."""

Expand Down
201 changes: 179 additions & 22 deletions lit_nlp/api/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@
# ==============================================================================
"""Tests for lit_nlp.lib.model."""

import os

from absl.testing import absltest
from absl.testing import parameterized

from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types
import numpy as np


def get_testdata_path(fname):
return os.path.join(os.path.dirname(__file__), 'testdata')


class TestDatasetEmptyInit(lit_dataset.Dataset):
Expand All @@ -42,8 +48,8 @@ def __init__(self, path: str, max_examples: int = 200, max_qps: float = 1.0):
class DatasetTest(parameterized.TestCase):

@parameterized.named_parameters(
("empty_init", TestDatasetEmptyInit),
("pass_thru_init", TestDatasetPassThroughInit),
('empty_init', TestDatasetEmptyInit),
('pass_thru_init', TestDatasetPassThroughInit),
)
def test_init_spec_empty(self, dataset: lit_dataset.Dataset):
self.assertEmpty(dataset.init_spec())
Expand All @@ -52,45 +58,196 @@ def test_init_spec_populated(self):
self.assertEqual(
TestDatasetInitWithArgs.init_spec(),
{
"path": types.String(),
"max_examples": types.Integer(default=200, required=False),
"max_qps": types.Scalar(default=1.0, required=False),
'path': types.String(),
'max_examples': types.Integer(default=200, required=False),
'max_qps': types.Scalar(default=1.0, required=False),
},
)

@parameterized.named_parameters(
# All base Dataset classes are incompatible with automated spec inference
# due to the complexity of their arguments, thus return None.
("dataset", lit_dataset.Dataset),
("indexed_dataset", lit_dataset.IndexedDataset),
("none_dataset", lit_dataset.NoneDataset),
('dataset', lit_dataset.Dataset),
('indexed_dataset', lit_dataset.IndexedDataset),
('none_dataset', lit_dataset.NoneDataset),
)
def test_init_spec_none(self, dataset: lit_dataset.Dataset):
self.assertIsNone(dataset.init_spec())

def test_remap(self):
"""Test remap method."""
spec = {
"score": types.Scalar(),
"text": types.TextSegment(),
'score': types.Scalar(),
'text': types.TextSegment(),
}
datapoints = [
{'score': 0, 'text': 'a'},
{'score': 0, 'text': 'b'},
]
dset = lit_dataset.Dataset(spec, datapoints)
remap_dict = {'score': 'val', 'nothing': 'nada'}
remapped_dset = dset.remap(remap_dict)
self.assertIn('val', remapped_dset.spec())
self.assertNotIn('score', remapped_dset.spec())
self.assertEqual({'val': 0, 'text': 'a'}, remapped_dset.examples[0])


class DatasetLoadingTest(absltest.TestCase):
"""Test to read data from LIT JSONL format."""

def setUp(self):
super().setUp()

self.data_spec = {
'parity': types.CategoryLabel(vocab=['odd', 'even']),
'text': types.TextSegment(),
'value': types.Integer(),
'other_divisors': types.SparseMultilabel(),
'in_spanish': types.TextSegment(),
'embedding': types.Embeddings(),
}

self.sample_examples = [
{
'parity': 'odd',
'text': 'One',
'value': 1,
'other_divisors': [],
'in_spanish': 'Uno',
},
{
'parity': 'even',
'text': 'Two',
'value': 2,
'other_divisors': [],
'in_spanish': 'Dos',
},
{
'parity': 'odd',
'text': 'Three',
'value': 3,
'other_divisors': [],
'in_spanish': 'Tres',
},
{
'parity': 'even',
'text': 'Four',
'value': 4,
'other_divisors': ['Two'],
'in_spanish': 'Cuatro',
},
{
'parity': 'odd',
'text': 'Five',
'value': 5,
'other_divisors': [],
'in_spanish': 'Cinco',
},
{
'parity': 'even',
'text': 'Six',
'value': 6,
'other_divisors': ['Two', 'Three'],
'in_spanish': 'Seis',
},
{
"score": 0,
"text": "a"
'parity': 'odd',
'text': 'Seven',
'value': 7,
'other_divisors': [],
'in_spanish': 'Siete',
},
{
"score": 0,
"text": "b"
'parity': 'even',
'text': 'Eight',
'value': 8,
'other_divisors': ['Two', 'Four'],
'in_spanish': 'Ocho',
},
{
'parity': 'odd',
'text': 'Nine',
'value': 9,
'other_divisors': ['Three'],
'in_spanish': 'Nueve',
},
{
'parity': 'even',
'text': 'Ten',
'value': 10,
'other_divisors': ['Two', 'Five'],
'in_spanish': 'Diez',
},
]
dset = lit_dataset.Dataset(spec, datapoints)
remap_dict = {"score": "val", "nothing": "nada"}
remapped_dset = dset.remap(remap_dict)
self.assertIn("val", remapped_dset.spec())
self.assertNotIn("score", remapped_dset.spec())
self.assertEqual({"val": 0, "text": "a"}, remapped_dset.examples[0])
# Add embeddings
rand = np.random.RandomState(42)
for ex in self.sample_examples:
vec = rand.normal(0, 1, size=16)
# Scale such that norm = value, for testing
vec = ex['value'] * vec / np.linalg.norm(vec)
# Convert to regular list to avoid issues with assertEqual not correctly
# handling NumPy array equality.
ex['embedding'] = vec.tolist()

# Index data
self.indexed_dataset = lit_dataset.IndexedDataset(
spec=self.data_spec,
examples=self.sample_examples,
)

def test_load_lit_format_unindexed(self):
ds = lit_dataset.load_lit_format(
get_testdata_path('count_examples.lit.jsonl')
)
self.assertEqual(self.data_spec, ds.spec())
self.assertEqual(self.sample_examples, ds.examples)

def test_load_lit_format_indexed(self):
ds = lit_dataset.load_lit_format(
get_testdata_path('count_examples.indexed.lit.jsonl'),
)
self.assertIsInstance(ds, lit_dataset.IndexedDataset)
self.assertEqual(self.data_spec, ds.spec())
self.assertEqual(self.sample_examples, ds.examples)
self.assertEqual(self.indexed_dataset.indexed_examples, ds.indexed_examples)

def test_indexed_dataset_load(self):
ds = self.indexed_dataset.load(
get_testdata_path('count_examples.indexed.lit.jsonl')
)
self.assertIsInstance(ds, lit_dataset.IndexedDataset)
self.assertEqual(self.data_spec, ds.spec())
self.assertEqual(self.sample_examples, ds.examples)
self.assertEqual(self.indexed_dataset.indexed_examples, ds.indexed_examples)

def test_write_roundtrip(self):
tempdir = self.create_tempdir()
output_base = os.path.join(tempdir.full_path, 'test_dataset.lit.jsonl')
lit_dataset.write_examples(self.sample_examples, output_base)
lit_dataset.write_spec(self.data_spec, output_base + '.spec')

# Read back and compare contents
ds = lit_dataset.load_lit_format(output_base)
self.assertEqual(self.data_spec, ds.spec())
self.assertEqual(self.sample_examples, ds.examples)

def test_write_roundtrip_indexed(self):
tempdir = self.create_tempdir()
output_base = os.path.join(
tempdir.full_path, 'test_dataset.indexed.lit.jsonl'
)
lit_dataset.write_examples(
self.indexed_dataset.indexed_examples, output_base
)
lit_dataset.write_spec(self.data_spec, output_base + '.spec')

# Read back and compare contents
ds = lit_dataset.load_lit_format(output_base)
self.assertIsInstance(ds, lit_dataset.IndexedDataset)
self.assertEqual(self.data_spec, ds.spec())
self.assertEqual(self.sample_examples, ds.examples)
self.assertEqual(self.indexed_dataset.indexed_examples, ds.indexed_examples)


if __name__ == "__main__":
if __name__ == '__main__':
absltest.main()
10 changes: 10 additions & 0 deletions lit_nlp/api/testdata/count_examples.indexed.lit.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{"data": {"parity": "odd", "text": "One", "value": 1, "other_divisors": [], "in_spanish": "Uno", "embedding": [0.13202541280516308, -0.03675031469846249, 0.17215403687114747, 0.40481762858549075, -0.06223739704246233, -0.062233033238054764, 0.4197509223576569, 0.20398228297541593, -0.12478514902133428, 0.14421113892521678, -0.12317529473213541, -0.1237898348537082, 0.06431298281540655, -0.5085452318633588, -0.45847896121327225, -0.14945465659311677]}, "id": "557d1c5676d57adc05858fc860e90c0d", "meta": {"added": null, "parentId": null, "source": null}}
{"data": {"parity": "even", "text": "Two", "value": 2, "other_divisors": [], "in_spanish": "Dos", "embedding": [-0.5377313370378919, 0.16683989554078127, -0.4820872802875518, -0.7498189405671427, 0.7781408532153896, -0.11986893914950662, 0.03585201033575663, -0.7564259546664667, -0.28902315938179374, 0.05889091604205179, -0.6110844176124447, 0.19946523529095225, -0.3188905231309265, -0.15486576876720004, -0.31945750343040386, 0.9834097755539289]}, "id": "c92361a5f1cbecf7843c84e061e1dd6d", "meta": {"added": null, "parentId": null, "source": null}}
{"data": {"parity": "odd", "text": "Three", "value": 3, "other_divisors": [], "in_spanish": "Tres", "embedding": [-0.010945545265540037, -0.857748394616669, 0.6670410208894, -0.9900405225786026, 0.16937748150720353, -1.5891902567324316, -1.0770896092748838, 0.1596441942749694, 0.5988578789616861, 0.1389707377080942, -0.09378472495940747, -0.24417939196279864, -1.199004216487564, -0.5837561069111461, -0.37355401695525714, 0.8572709874029102]}, "id": "c60f8279359ca0a4624c3285b664fecd", "meta": {"added": null, "parentId": null, "source": null}}
{"data": {"parity": "even", "text": "Four", "value": 4, "other_divisors": ["Two"], "in_spanish": "Cuatro", "embedding": [0.41399411719030604, -2.124125155291056, 0.39045900895911695, -0.4639502714677827, -0.815561145698349, 0.7369525805012918, 1.242156631637323, 1.1220138813175107, -1.0110961150918898, -0.37254159179247054, 0.3991088833297577, 1.1753447239072536, -0.5773130290190065, -0.22368344897741224, -1.332921397913926, -1.4411994947515943]}, "id": "542cd86e4bb692de854432a32959ce03", "meta": {"added": null, "parentId": null, "source": null}}
{"data": {"parity": "odd", "text": "Five", "value": 5, "other_divisors": [], "in_spanish": "Cinco", "embedding": [0.8951305045189742, 1.4941209098391575, -0.07933096362493794, 1.1055561367201296, 0.3984014152304394, -0.7107052545275087, 0.39813655366234274, 1.6943996236976022, -0.03946826005961582, 1.723711698040927, -2.886079053863323, 0.9054604581537493, 0.09589662748112579, -0.32940565468889504, 0.10108955054898887, -2.189633259962756]}, "id": "bb164ff77c3ac345ac3ac3ba80fb2c56", "meta": {"added": null, "parentId": null, "source": null}}
{"data": {"parity": "even", "text": "Six", "value": 6, "other_divisors": ["Two", "Three"], "in_spanish": "Seis", "embedding": [-0.44260967832717035, 0.7195344017816173, 2.977760214035824, -1.0442456558953457, -1.629007229959361, -1.0109738020898464, 1.8444136882723055, 0.662389822969991, -1.0673964506250544, 1.0341657082409859, 0.19559836860649446, 1.9516910055879046, -1.4145437411613457, -0.6601957070305794, -0.7900458508175529, -2.948788244123915]}, "id": "a65ff9cc3a631bf34842e788b1764207", "meta": {"added": null, "parentId": null, "source": null}}
{"data": {"parity": "odd", "text": "Seven", "value": 7, "other_divisors": [], "in_spanish": "Siete", "embedding": [0.634163008209849, 0.55906876204407, 0.010950837541011992, -0.5023853268787337, -3.0311188967149296, -0.900842407411736, -0.7339479447751722, -1.7181348454466505, -0.34540502631326997, 0.8653041573758785, 4.03940364034519, 0.37387102311805503, 0.5515628047003093, -0.15943131743360547, -4.109187449176784, -0.05678138350015767]}, "id": "00fc1466e4e6c07b2420fd375b02279f", "meta": {"added": null, "parentId": null, "source": null}}
{"data": {"parity": "even", "text": "Eight", "value": 8, "other_divisors": ["Two", "Four"], "in_spanish": "Ocho", "embedding": [0.10448270187299902, 4.273041577172203, -0.3336928985455562, 0.5231009671141299, -0.06021528879192428, -2.0273321163861473, 1.9824804784766397, 1.3043951691285334, 1.372220936567237, -1.5775349018274538, 2.43345888918842, -2.43182261538663, 1.0180342195988585, 3.7998327142081547, -1.718305675081331, -0.9823694274248165]}, "id": "5e67dc379b6d591777e2c017aa5fc8f4", "meta": {"added": null, "parentId": null, "source": null}}
{"data": {"parity": "odd", "text": "Nine", "value": 9, "other_divisors": ["Three"], "in_spanish": "Nueve", "embedding": [0.2351581687891107, -1.1881062818129788, -3.659269218610787, 0.16179551166147463, -2.5068336575018373, 1.117587587913186, -2.169665403706205, 3.6575488565220597, -1.848328015996989, -0.7600036017746825, 1.9197450927723436, -2.9046044519902594, 0.5367619565206706, 3.0846069812758867, -3.793353091211151, 0.4357006051120228]}, "id": "8ca0ab896678a1d41faa13c8e374acf8", "meta": {"added": null, "parentId": null, "source": null}}
{"data": {"parity": "even", "text": "Ten", "value": 10, "other_divisors": ["Two", "Five"], "in_spanish": "Diez", "embedding": [0.7777890241097059, 2.3398749818160023, -3.7020022393864433, -3.951922494292855, 1.5620904113236476, 0.8888292117898868, 0.7496863737379172, 1.0368659285370994, -2.035208856733374, 0.6950994081126011, 0.8771206022300255, -2.137943352136647, 5.583974654531584, 1.4181086755390935, -3.565387186125148, 1.9649634444320656]}, "id": "03160c8ecd35e78125dce3758121a6d5", "meta": {"added": null, "parentId": null, "source": null}}

0 comments on commit b7ce560

Please sign in to comment.