Skip to content

Commit

Permalink
Adds init spec support to GLUE demo
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 506287107
  • Loading branch information
RyanMullins authored and LIT team committed Feb 1, 2023
1 parent f3b0d6e commit 0f133cf
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 6 deletions.
40 changes: 40 additions & 0 deletions lit_nlp/examples/datasets/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,26 @@ class SST2Data(lit_dataset.Dataset):
"""

LABELS = ['0', '1']
AVAILABLE_SPLITS = ['test', 'train', 'validation']

def __init__(self, split: str):
if split not in self.AVAILABLE_SPLITS:
raise ValueError(
f"Unsupported split '{split}'. Allowed values: "
f'{self.AVAILABLE_SPLITS}'
)

self._examples = []
for ex in load_tfds('glue/sst2', split=split):
self._examples.append({
'sentence': ex['sentence'].decode('utf-8'),
'label': self.LABELS[ex['label']],
})

@classmethod
def init_spec(cls) -> lit_types.Spec:
return {'split': lit_types.CategoryLabel(vocab=cls.AVAILABLE_SPLITS)}

def spec(self):
return {
'sentence': lit_types.TextSegment(),
Expand Down Expand Up @@ -129,7 +140,15 @@ class STSBData(lit_dataset.Dataset):
See https://www.tensorflow.org/datasets/catalog/glue#gluestsb.
"""

AVAILABLE_SPLITS = ['test', 'train', 'validation']

def __init__(self, split: str):
if split not in self.AVAILABLE_SPLITS:
raise ValueError(
f"Unsupported split '{split}'. Allowed values: "
f'{self.AVAILABLE_SPLITS}'
)

self._examples = []
for ex in load_tfds('glue/stsb', split=split):
self._examples.append({
Expand All @@ -138,6 +157,10 @@ def __init__(self, split: str):
'label': ex['label'],
})

@classmethod
def init_spec(cls) -> lit_types.Spec:
return {'split': lit_types.CategoryLabel(vocab=cls.AVAILABLE_SPLITS)}

def spec(self):
return {
'sentence1': lit_types.TextSegment(),
Expand All @@ -153,8 +176,21 @@ class MNLIData(lit_dataset.Dataset):
"""

LABELS = ['entailment', 'neutral', 'contradiction']
AVAILABLE_SPLITS = [
'test_matched',
'test_mismatched',
'train',
'validation_matched',
'validation_mismatched',
]

def __init__(self, split: str):
if split not in self.AVAILABLE_SPLITS:
raise ValueError(
f"Unsupported split '{split}'. Allowed values: "
f'{self.AVAILABLE_SPLITS}'
)

self._examples = []
for ex in load_tfds('glue/mnli', split=split):
self._examples.append({
Expand All @@ -163,6 +199,10 @@ def __init__(self, split: str):
'label': self.LABELS[ex['label']],
})

@classmethod
def init_spec(cls) -> lit_types.Spec:
return {'split': lit_types.CategoryLabel(vocab=cls.AVAILABLE_SPLITS)}

def spec(self):
return {
'premise': lit_types.TextSegment(),
Expand Down
33 changes: 28 additions & 5 deletions lit_nlp/examples/glue_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from absl import app
from absl import flags
from absl import logging
from lit_nlp import app as lit_app
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.examples.datasets import glue
Expand Down Expand Up @@ -87,7 +88,10 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
logging.info("Quick-start mode; overriding --models and --max_examples.")

models = {}
model_loaders: lit_app.ModelLoadersMap = {}

datasets = {}
dataset_loaders: lit_app.DatasetLoadersMap = {}

tasks_to_load = set()
for model_string in _MODELS.value:
Expand All @@ -103,30 +107,49 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
# Load the model from disk.
models[name] = MODELS_BY_TASK[task](path)
tasks_to_load.add(task)
if task not in model_loaders:
# Adds the model loader info. Since task-specific GLUE models set specific
# __init__() values, we use the GlueModelConfig.init_spec() here because
# it is limited to only those paramaters that will not override or
# interfere with the parameters set by task-specific model subclasses.
model_loaders[task] = (
MODELS_BY_TASK[task],
glue_models.GlueModelConfig.init_spec(),
)

##
# Load datasets for each task that we have a model for
if "sst2" in tasks_to_load:
logging.info("Loading data for SST-2 task.")
datasets["sst_dev"] = glue.SST2Data("validation")
dataset_loaders["sst2"] = (glue.SST2Data, glue.SST2Data.init_spec())

if "stsb" in tasks_to_load:
logging.info("Loading data for STS-B task.")
datasets["stsb_dev"] = glue.STSBData("validation")
dataset_loaders["stsb"] = (glue.STSBData, glue.STSBData.init_spec())

if "mnli" in tasks_to_load:
logging.info("Loading data for MultiNLI task.")
datasets["mnli_dev"] = glue.MNLIData("validation_matched")
datasets["mnli_dev_mm"] = glue.MNLIData("validation_mismatched")
dataset_loaders["mnli"] = (glue.MNLIData, glue.MNLIData.init_spec())

# Truncate datasets if --max_examples is set.
for name in datasets:
logging.info("Dataset: '%s' with %d examples", name, len(datasets[name]))
datasets[name] = datasets[name].slice[:_MAX_EXAMPLES.value]
logging.info(" truncated to %d examples", len(datasets[name]))
if _MAX_EXAMPLES.value is not None:
for name in datasets:
logging.info("Dataset: '%s' with %d examples", name, len(datasets[name]))
datasets[name] = datasets[name].slice[: _MAX_EXAMPLES.value]
logging.info(" truncated to %d examples", len(datasets[name]))

# Start the LIT server. See server_flags.py for server options.
lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
lit_demo = dev_server.Server(
models,
datasets,
model_loaders=model_loaders,
dataset_loaders=dataset_loaders,
**server_flags.get_flags(),
)
return lit_demo.serve()


Expand Down
17 changes: 16 additions & 1 deletion lit_nlp/examples/models/glue_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

JsonDict = lit_types.JsonDict
Spec = lit_types.Spec
TFSequenceClassifierOutput = transformers.modeling_tf_outputs.TFSequenceClassifierOutput
TFSequenceClassifierOutput = (
transformers.modeling_tf_outputs.TFSequenceClassifierOutput
)


@attr.s(auto_attribs=True, kw_only=True)
Expand All @@ -38,6 +40,19 @@ class GlueModelConfig(object):
output_attention: bool = True
output_embeddings: bool = True

@classmethod
def init_spec(cls) -> lit_types.Spec:
return {
"model_name_or_path": lit_types.String(
default="bert-base-uncased", required=False
),
"max_seq_length": lit_types.Integer(default=128, required=False),
"inference_batch_size": lit_types.Integer(default=32, required=False),
"compute_grads": lit_types.Boolean(default=True, required=False),
"output_attention": lit_types.Boolean(default=True, required=False),
"output_embeddings": lit_types.Boolean(default=True, required=False),
}


class GlueModel(lit_model.Model):
"""GLUE benchmark model, using Keras/TF2 and Huggingface Transformers.
Expand Down

0 comments on commit 0f133cf

Please sign in to comment.