Skip to content

Commit

Permalink
Adds init spec support to LM demo
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 506287454
  • Loading branch information
RyanMullins authored and LIT team committed Feb 1, 2023
1 parent 0f133cf commit 9eebe57
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 12 deletions.
16 changes: 13 additions & 3 deletions lit_nlp/examples/datasets/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class IMDBData(lit_dataset.Dataset):
"""IMDB reviews dataset; see http://ai.stanford.edu/~amaas/data/sentiment/."""

LABELS = ["0", "1"]
AVAILABLE_SPLITS = ["test", "train", "unsupervised"]

def __init__(self, split="test", max_seq_len=500):
"""Dataset constructor, loads the data into memory."""
Expand All @@ -157,14 +158,23 @@ def __init__(self, split="test", max_seq_len=500):
for record in raw_examples:
# format and truncate from the end to max_seq_len tokens.
truncated_text = " ".join(
record["text"].decode("utf-8")\
.replace("<br />", "")\
.split()[-max_seq_len:])
record["text"]
.decode("utf-8")
.replace("<br />", "")
.split()[-max_seq_len:]
)
self._examples.append({
"text": truncated_text,
"label": self.LABELS[record["label"]],
})

@classmethod
def init_spec(cls) -> lit_types.Spec:
return {
"split": lit_types.CategoryLabel(vocab=cls.AVAILABLE_SPLITS),
"max_seq_len": lit_types.Integer(default=500),
}

def spec(self) -> lit_types.Spec:
"""Dataset spec, which should match the model"s input_spec()."""
return {
Expand Down
15 changes: 12 additions & 3 deletions lit_nlp/examples/datasets/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
class PlaintextSents(lit_dataset.Dataset):
"""Load sentences from a flat text file."""

def __init__(self, path_or_glob, skiplines=0):
def __init__(self, path_or_glob: str, skiplines: int = 0):
self._examples = self.load_datapoints(path_or_glob, skiplines=skiplines)

def load_datapoints(self, path_or_glob: str, skiplines=0):
def load_datapoints(self, path_or_glob: str, skiplines: int = 0):
examples = []
for path in glob.glob(path_or_glob):
with open(path) as fd:
Expand All @@ -36,7 +36,9 @@ def spec(self) -> lit_types.Spec:
class BillionWordBenchmark(lit_dataset.Dataset):
"""Billion Word Benchmark (lm1b); see http://www.statmt.org/lm-benchmark/."""

def __init__(self, split='train', max_examples=1000):
AVAILABLE_SPLITS = ['test', 'train']

def __init__(self, split: str = 'train', max_examples: int = 1000):
ds = tfds.load('lm1b', split=split)
if max_examples is not None:
# Normally we can just slice the resulting dataset, but lm1b is very large
Expand All @@ -47,5 +49,12 @@ def __init__(self, split='train', max_examples=1000):
'text': ex['text'].decode('utf-8')
} for ex in raw_examples]

@classmethod
def init_spec(cls) -> lit_types.Spec:
return {
'split': lit_types.CategoryLabel(vocab=cls.AVAILABLE_SPLITS),
'max_examples': lit_types.Integer(default=1000),
}

def spec(self) -> lit_types.Spec:
return {'text': lit_types.TextSegment()}
42 changes: 36 additions & 6 deletions lit_nlp/examples/lm_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from absl import flags
from absl import logging

from lit_nlp import app as lit_app
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.api import layout
Expand All @@ -41,13 +42,22 @@
"top_k", 10, "Rank to which the output distribution is pruned.")

_MAX_EXAMPLES = flags.DEFINE_integer(
"max_examples", 1000,
"Maximum number of examples to load from each evaluation set. Set to None to load the full set."
"max_examples",
1000,
(
"Maximum number of examples to load from each evaluation set. Set to"
" None to load the full set."
),
)

_LOAD_BWB = flags.DEFINE_bool(
"load_bwb", False,
"If true, will load examples from the Billion Word Benchmark dataset. This may download a lot of data the first time you run it, so disable by default for the quick-start example."
"load_bwb",
False,
(
"If true, will load examples from the Billion Word Benchmark dataset."
" This may download a lot of data the first time you run it, so disable"
" by default for the quick-start example."
),
)

# Custom frontend layout; see api/layout.py
Expand Down Expand Up @@ -106,7 +116,8 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
model_name_or_path, top_k=_TOP_K.value)
else:
raise ValueError(
f"Unsupported model name '{model_name}' from path '{model_name_or_path}'"
f"Unsupported model name '{model_name}' from path '"
f"{model_name_or_path}'"
)

datasets = {
Expand All @@ -117,12 +128,29 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
# Empty dataset, if you just want to type sentences into the UI.
"blank": lm.PlaintextSents(""),
}

dataset_loaders: lit_app.DatasetLoadersMap = {
"sst_dev": (glue.SST2Data, glue.SST2Data.init_spec()),
"imdb_train": (
classification.IMDBData,
classification.IMDBData.init_spec(),
),
"plain_text_sentences": (
lm.PlaintextSents,
lm.PlaintextSents.init_spec(),
),
}

# Guard this with a flag, because TFDS will download and process 1.67 GB
# of data if you haven't loaded `lm1b` before.
if _LOAD_BWB.value:
# A few sentences from the Billion Word Benchmark (Chelba et al. 2013).
datasets["bwb"] = lm.BillionWordBenchmark(
"train", max_examples=_MAX_EXAMPLES.value)
dataset_loaders["bwb"] = (
lm.BillionWordBenchmark,
lm.BillionWordBenchmark.init_spec(),
)

for name in datasets:
datasets[name] = datasets[name].slice[:_MAX_EXAMPLES.value]
Expand All @@ -135,7 +163,9 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
datasets,
generators=generators,
layouts=CUSTOM_LAYOUTS,
**server_flags.get_flags())
dataset_loaders=dataset_loaders,
**server_flags.get_flags(),
)
return lit_demo.serve()


Expand Down

0 comments on commit 9eebe57

Please sign in to comment.