Skip to content

Commit

Permalink
LIT blank demo.
Browse files Browse the repository at this point in the history
This allows launching a server using the binary `examples:blank_slate_demo` with no model and dataset pre-loaded. Then you can load the model and dataset through the UI configuration. The supported model/dataset include GLUE, Penguin, T5, LM, Image, and CoRef.

Also added the ability to configure the number of maximum examples when loading the dataset for different types of datasets.

PiperOrigin-RevId: 550916451
  • Loading branch information
bdu91 authored and LIT team committed Jul 25, 2023
1 parent 8f6a93a commit 22b0dea
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 17 deletions.
175 changes: 175 additions & 0 deletions lit_nlp/examples/blank_slate_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
r"""An blank demo ready to load models and datasets.
The currently supported models and datasets are:
- classification model on SST-2, with the Stanford Sentiment Treebank dataset.
- regression model on STS-B, with Semantic Textual Similarit Benchmark dataset.
- classification model on MultiNLI, with the MultiNLI dataset.
- TensorFlow Keras model for penguin classification, with the Penguin tabular
dataset from TFDS.
- T5 models using HuggingFace Transformers and Keras, with the English CNNDM
summarization dataset and the WMT '14 machine-translation dataset.
- BERT (bert-*) as a masked language model and GPT-2 (gpt2* or distilgpt2) as a
left-to-right language model, with the Stanford Sentiment Treebank dataset,
the IMDB reviews dataset, Billion Word Benchmark (lm1b) dataset and the option
to load sentences from a flat text file.
- MobileNet model, with the Imagenette TFDS dataset.
To run:
python -m lit_nlp.examples.blank_slate_demo --port=5432
Then navigate to localhost:5432 to access the demo UI.
"""
import sys
from typing import Optional, Sequence

from absl import app
from absl import flags
from absl import logging
from lit_nlp import app as lit_app
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.examples.datasets import classification
from lit_nlp.examples.datasets import glue
from lit_nlp.examples.datasets import imagenette
from lit_nlp.examples.datasets import lm
from lit_nlp.examples.datasets import mt
from lit_nlp.examples.datasets import penguin_data
from lit_nlp.examples.datasets import summarization
from lit_nlp.examples.models import glue_models
from lit_nlp.examples.models import mobilenet
from lit_nlp.examples.models import penguin_model
from lit_nlp.examples.models import pretrained_lms
from lit_nlp.examples.models import t5

# NOTE: additional flags defined in server_flags.py

FLAGS = flags.FLAGS

FLAGS.set_default("development_demo", True)


def get_wsgi_app() -> Optional[dev_server.LitServerType]:
"""Return WSGI app for container-hosted demos."""
FLAGS.set_default("server_type", "external")
FLAGS.set_default("demo_mode", True)
# Parse flags without calling app.run(main), to avoid conflict with
# gunicorn command line flags.
unused = flags.FLAGS(sys.argv, known_only=True)
if unused:
logging.info(
"blank_slate_demo:get_wsgi_app() called with unused args: %s", unused
)
return main([])


def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")

models = {}
model_loaders: lit_app.ModelLoadersMap = {}

# glue demo model loaders.
model_loaders["sst2"] = (
glue_models.SST2Model,
glue_models.GlueModelConfig.init_spec(),
)
model_loaders["stsb"] = (
glue_models.STSBModel,
glue_models.GlueModelConfig.init_spec(),
)
model_loaders["mnli"] = (
glue_models.MNLIModel,
glue_models.GlueModelConfig.init_spec(),
)

# penguin demo model loaders.
model_loaders["penguin"] = (
penguin_model.PenguinModel,
penguin_model.PenguinModel.init_spec(),
)

# t5 demo model loaders.
model_loaders["T5 summarization"] = (
t5.T5Summarization,
t5.T5Summarization.init_spec(),
)
model_loaders["T5 translation"] = (
t5.T5Translation,
t5.T5Translation.init_spec(),
)

# lm demo model loaders.
model_loaders["bert"] = (
pretrained_lms.BertMLM,
pretrained_lms.BertMLM.init_spec(),
)
model_loaders["gpt2"] = (
pretrained_lms.GPT2LanguageModel,
pretrained_lms.GPT2LanguageModel.init_spec(),
)

# image model loaders.
model_loaders["image"] = (
mobilenet.MobileNet,
mobilenet.MobileNet.init_spec(),
)

datasets = {}
dataset_loaders: lit_app.DatasetLoadersMap = {}

# glue demo dataset loaders.
dataset_loaders["sst2"] = (glue.SST2Data, glue.SST2Data.init_spec())
dataset_loaders["stsb"] = (glue.STSBData, glue.STSBData.init_spec())
dataset_loaders["mnli"] = (glue.MNLIData, glue.MNLIData.init_spec())

# penguin demo dataset loaders.
dataset_loaders["penguin"] = (
penguin_data.PenguinDataset,
penguin_data.PenguinDataset.init_spec(),
)

# t5 demo dataset loaders.
dataset_loaders["CNN DailyMail (t5)"] = (
summarization.CNNDMData,
summarization.CNNDMData.init_spec(),
)
dataset_loaders["WMT 14 (t5)"] = (mt.WMT14Data, mt.WMT14Data.init_spec())

# lm demo dataset loaders.
dataset_loaders["sst (lm)"] = (
glue.SST2DataForLM,
glue.SST2DataForLM.init_spec(),
)
dataset_loaders["imdb (lm)"] = (
classification.IMDBData,
classification.IMDBData.init_spec(),
)
dataset_loaders["plain text sentences (lm)"] = (
lm.PlaintextSents,
lm.PlaintextSents.init_spec(),
)
dataset_loaders["bwb (lm)"] = (
lm.BillionWordBenchmark,
lm.BillionWordBenchmark.init_spec(),
)

# image demo dataset loaders.
dataset_loaders["image"] = (
imagenette.ImagenetteDataset,
imagenette.ImagenetteDataset.init_spec(),
)

# Start the LIT server. See server_flags.py for server options.
lit_demo = dev_server.Server(
models,
datasets,
model_loaders=model_loaders,
dataset_loaders=dataset_loaders,
**server_flags.get_flags(),
)
return lit_demo.serve()


if __name__ == "__main__":
app.run(main)
21 changes: 21 additions & 0 deletions lit_nlp/examples/datasets/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types as lit_types
from lit_nlp.lib import utils
import pandas as pd
import tensorflow_datasets as tfds

Expand Down Expand Up @@ -107,6 +108,26 @@ def spec(self):
}


class SST2DataForLM(SST2Data):
"""Stanford Sentiment Treebank, binary version (SST-2).
See https://www.tensorflow.org/datasets/catalog/glue#gluesst2.
This data is reformatted to serve the language models.
"""

def __init__(self, path_or_splitname: str, max_examples: int = -1):
super().__init__(path_or_splitname, max_examples)
self._examples = [
utils.remap_dict(ex, {'sentence': 'text'}) for ex in self._examples
]

def spec(self):
return {
'text': lit_types.TextSegment(),
'label': lit_types.CategoryLabel(vocab=self.LABELS),
}


class MRPCData(lit_dataset.Dataset):
"""Microsoft Research Paraphrase Corpus.
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/examples/lm_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
}

dataset_loaders: lit_app.DatasetLoadersMap = {
"sst_dev": (glue.SST2Data, glue.SST2Data.init_spec()),
"sst_dev": (glue.SST2DataForLM, glue.SST2DataForLM.init_spec()),
"imdb_train": (
classification.IMDBData,
classification.IMDBData.init_spec(),
Expand Down
7 changes: 5 additions & 2 deletions lit_nlp/examples/models/mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class MobileNet(model.Model):
"""MobileNet model trained on ImageNet dataset."""

def __init__(self) -> None:
def __init__(self, name='mobilenet_v2') -> None:
# Initialize imagenet labels.
self.labels = [''] * len(imagenet_labels.IMAGENET_2012_LABELS)
self.label_to_idx = {}
Expand All @@ -25,7 +25,10 @@ def __init__(self) -> None:
self.labels[i] = l
self.label_to_idx[l] = i

self.model = tf.keras.applications.mobilenet_v2.MobileNetV2()
if name == 'mobilenet_v2':
self.model = tf.keras.applications.mobilenet_v2.MobileNetV2()
elif name == 'mobilenet':
self.model = tf.keras.applications.mobilenet.MobileNet()

def predict_minibatch(
self, input_batch: List[lit_types.JsonDict]) -> List[lit_types.JsonDict]:
Expand Down
62 changes: 50 additions & 12 deletions lit_nlp/examples/models/mobilenet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,45 +10,83 @@

class MobileNetTest(parameterized.TestCase):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model = mobilenet.MobileNet()

@parameterized.named_parameters(
dict(
testcase_name='compatible',
testcase_name='compatible_spec_model_v2',
model_name='mobilenet_v2',
dataset_spec={'image': types.ImageBytes()},
expected=True,
),
dict(
testcase_name='empty',
testcase_name='empty_spec_model_v2',
model_name='mobilenet_v2',
dataset_spec={},
expected=False,
),
dict(
testcase_name='no_images',
testcase_name='no_images_spec_model_v2',
model_name='mobilenet_v2',
dataset_spec={'text': types.TextSegment()},
expected=False,
),
dict(
testcase_name='wrong_keys',
testcase_name='wrong_keys_spec_model_v2',
model_name='mobilenet_v2',
dataset_spec={'wrong_image_key': types.ImageBytes()},
expected=False,
),
dict(
testcase_name='compatible_spec_model_v1',
model_name='mobilenet',
dataset_spec={'image': types.ImageBytes()},
expected=True,
),
dict(
testcase_name='empty_spec_model_v1',
model_name='mobilenet',
dataset_spec={},
expected=False,
),
dict(
testcase_name='no_images_spec_model_v1',
model_name='mobilenet',
dataset_spec={'text': types.TextSegment()},
expected=False,
),
dict(
testcase_name='wrong_keys_spec_model_v1',
model_name='mobilenet',
dataset_spec={'wrong_image_key': types.ImageBytes()},
expected=False,
),
)
def test_compatibility(self, dataset_spec: types.Spec, expected: bool):
def test_compatibility(
self, model_name: str, dataset_spec: types.Spec, expected: bool
):
dataset = lit_dataset.Dataset(spec=dataset_spec)
self.assertEqual(self.model.is_compatible_with_dataset(dataset), expected)
model = mobilenet.MobileNet(model_name)
self.assertEqual(model.is_compatible_with_dataset(dataset), expected)

def test_model(self):
@parameterized.named_parameters(
dict(
testcase_name='model_v1',
model_name='mobilenet',
),
dict(
testcase_name='model_v2',
model_name='mobilenet_v2',
),
)
def test_model(self, model_name: str):
# Create an input with base64 encoded image.
input_1 = {
'image': np.zeros(shape=(mobilenet.IMAGE_SHAPE), dtype=np.float32)
}
# Create an input with image data in Numpy array.
pil_image = PILImage.new(mode='RGB', size=(300, 200))
input_2 = {'image': image_utils.convert_pil_to_image_str(pil_image)}
model_out = self.model.predict([input_1, input_2])
model = mobilenet.MobileNet(model_name)
model_out = model.predict([input_1, input_2])
model_out = list(model_out)
# Check first output.
self.assertIn('preds', model_out[0])
Expand Down
44 changes: 44 additions & 0 deletions lit_nlp/examples/models/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,3 +516,47 @@ def output_spec(self):
spec = lit_types.remap_spec(self.wrapped.output_spec(), self.FIELD_RENAMES)
spec["rougeL"] = lit_types.Scalar()
return spec


class T5Translation(TranslationWrapper):
"""T5 translation model.
TranslationWrapper class has input_specs compatible with the corresponding
dataset, but its init args are not supported by the front-end system, thus
we set up a layer of init args on top to work with the front-end.
"""

def __init__(
self, model_name="t5-small", model=None, tokenizer=None, **config_kw
):
model = T5HFModel(model_name, model, tokenizer, **config_kw)
super().__init__(model)

@classmethod
def init_spec(cls) -> lit_types.Spec:
return {
"model_name": lit_types.String(default="t5-small", required=False),
**T5ModelConfig.init_spec(),
}


class T5Summarization(SummarizationWrapper):
"""T5 summarization model.
SummarizationWrapper class has input_specs compatible with the corresponding
dataset, but its init args are not supported by the front-end system, thus
we set up a layer of init args on top to work with the front-end.
"""

def __init__(
self, model_name="t5-small", model=None, tokenizer=None, **config_kw
):
model = T5HFModel(model_name, model, tokenizer, **config_kw)
super().__init__(model)

@classmethod
def init_spec(cls) -> lit_types.Spec:
return {
"model_name": lit_types.String(default="t5-small", required=False),
**T5ModelConfig.init_spec(),
}
4 changes: 2 additions & 2 deletions lit_nlp/examples/t5_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
)

model_loaders: lit_app.ModelLoadersMap = {
"T5 Saved Model": (t5.T5SavedModel, t5.T5SavedModel.init_spec()),
"T5 HF Model": (t5.T5HFModel, t5.T5HFModel.init_spec()),
"T5 summarization": (t5.T5Summarization, t5.T5Summarization.init_spec()),
"T5 translation": (t5.T5Translation, t5.T5Translation.init_spec()),
}

##
Expand Down

0 comments on commit 22b0dea

Please sign in to comment.