From 1cef1078f36c029d58bd9f5f0b98c611bca3395b Mon Sep 17 00:00:00 2001 From: Elron Bandel Date: Wed, 26 Jun 2024 16:02:14 +0300 Subject: [PATCH] Add infer() function for end to end inference pipeline (#952) * Inference and Functions Signed-off-by: elronbandel * Update Signed-off-by: elronbandel * delete function from PR Signed-off-by: elronbandel * Update docs Signed-off-by: elronbandel * Update Signed-off-by: elronbandel * Update docs Signed-off-by: elronbandel * FIx Signed-off-by: elronbandel --------- Signed-off-by: elronbandel --- docs/docs/production.rst | 85 ++++++++++++++++--- prepare/engines/model/flan.py | 8 ++ src/unitxt/__init__.py | 2 +- src/unitxt/api.py | 13 ++- src/unitxt/artifact.py | 4 +- .../engines/model/flan/t5_small/hf.json | 5 ++ src/unitxt/metric_utils.py | 78 ++++++++++++----- tests/library/test_api.py | 25 +++++- 8 files changed, 181 insertions(+), 39 deletions(-) create mode 100644 prepare/engines/model/flan.py create mode 100644 src/unitxt/catalog/engines/model/flan/t5_small/hf.json diff --git a/docs/docs/production.rst b/docs/docs/production.rst index 38e662f56..a5d5e1aa6 100644 --- a/docs/docs/production.rst +++ b/docs/docs/production.rst @@ -4,21 +4,44 @@ To use this tutorial, you need to :ref:`install unitxt `. -===================================== -Dynamic Data Processing For Inference -===================================== +======================== +Inference and Production +======================== -Unitxt can be used to process data dynamically and generate model-ready inputs on the fly, based on a given task recipe. +In this guide you will learn how to use unitxt data recipes in production. -First define a recipe: +For instance, you learn how to make end-to-end functions like `paraphrase()`: .. code-block:: python - recipe = "card=cards.wnli,template=templates.classification.multi_class.relation.default,demos_pool_size=5,num_demos=2" + def paraphrase(text): + return unitxt.infer( + [{"input_text": text, "output_text": ""}], + recipe="card=cards.coedit.paraphrase,template=templates.rewriting.paraphrase.default", + engine="engines.model.flan.t5_small.hf" + ) + +Which then can be used like: + +.. code-block:: python + + paraphrase("So simple to paraphrase!") + +In general, Unitxt is capable of: + - Producing processed data according to a given recipe. + - Post-processing predictions based on a recipe. + - Performing end-to-end inference using a recipe and a specified inference engine. + +Produce Data +------------ +First, define a recipe: -Second, prepare an python dictionary object in the exact schema of the task used in that recipe: +.. code-block:: python + + recipe = "card=cards.wnli,template=templates.classification.multi_class.relation.default,demos_pool_size=5,num_demos=2" +Next, prepare a Python dictionary that matches the schema required by the recipe: .. code-block:: python @@ -32,19 +55,25 @@ Second, prepare an python dictionary object in the exact schema of the task used "text_b_type": "hypothesis", } -Then you can produce the model-ready input data with the `produce` function: +Then, produce the model-ready input data with the `produce` function: .. code-block:: python from unitxt import produce - result = produce([instance], recipe) + result = produce(instance, recipe) + +To view the formatted instance, print the result: + +.. code-block:: + + print(result["source"]) -Then you have the formatted instance in the result. If you `print(result[0]["source"])` you will get: +This will output instances like: .. code-block:: - Given a premise and hypothesis classify the entailment of the hypothesis to one of entailment, not entailment. + Given a premise and a hypothesis, classify the entailment of the hypothesis as either 'entailment' or 'not entailment'. premise: When Tatyana reached the cabin, her mother was sleeping. She was careful not to disturb her, undressing and climbing back into her berth., hypothesis: mother was careful not to disturb her, undressing and climbing back into her berth. The entailment class is entailment @@ -55,6 +84,40 @@ Then you have the formatted instance in the result. If you `print(result[0]["sou premise: It works perfectly, hypothesis: It works! The entailment class is +Post Process Data +----------------- +After obtaining predictions, they can be post-processed: +.. code-block:: python + + from unitxt import post_process + + prediction = model.generate(result["source"]) + processed_result = post_process(predictions=[prediction], data=[result])[0] + +End to End Inference Pipeline +----------------------------- + +You can also implement an end-to-end inference pipeline using your preferred data and an inference engine: + +.. code-block:: python + + from unitxt import infer + from unitxt.inference import HFPipelineBasedInferenceEngine + + engine = HFPipelineBasedInferenceEngine( + model_name="google/flan-t5-small", max_new_tokens=32 + ) + + infer(instance, recipe, engine) + +Alternatively, you can specify any inference engine from the catalog: + +.. code-block:: python + infer( + instance, + recipe="card=cards.wnli,template=templates.classification.multi_class.relation.default,demos_pool_size=5,num_demos=2", + engine="engines.model.flan.t5_small.hf" + ) diff --git a/prepare/engines/model/flan.py b/prepare/engines/model/flan.py new file mode 100644 index 000000000..b5e348e85 --- /dev/null +++ b/prepare/engines/model/flan.py @@ -0,0 +1,8 @@ +from unitxt.catalog import add_to_catalog +from unitxt.inference import HFPipelineBasedInferenceEngine + +engine = HFPipelineBasedInferenceEngine( + model_name="google/flan-t5-small", max_new_tokens=32 +) + +add_to_catalog(engine, "engines.model.flan.t5_small.hf", overwrite=True) diff --git a/src/unitxt/__init__.py b/src/unitxt/__init__.py index e80573be9..70c4e999b 100644 --- a/src/unitxt/__init__.py +++ b/src/unitxt/__init__.py @@ -1,6 +1,6 @@ import random -from .api import evaluate, load, load_dataset, produce +from .api import evaluate, infer, load, load_dataset, post_process, produce from .catalog import add_to_catalog, get_from_catalog from .logging_utils import get_logger from .register import register_all_artifacts, register_local_catalog diff --git a/src/unitxt/api.py b/src/unitxt/api.py index bae78b32b..7fbc92f41 100644 --- a/src/unitxt/api.py +++ b/src/unitxt/api.py @@ -6,7 +6,7 @@ from .artifact import fetch_artifact from .dataset_utils import get_dataset_artifact from .logging_utils import get_logger -from .metric_utils import _compute +from .metric_utils import _compute, _post_process from .operator import SourceOperator from .standard import StandardRecipe @@ -91,6 +91,10 @@ def evaluate(predictions, data) -> List[Dict[str, Any]]: return _compute(predictions=predictions, references=data) +def post_process(predictions, data) -> List[Dict[str, Any]]: + return _post_process(predictions=predictions, references=data) + + @lru_cache def _get_produce_with_cache(recipe_query): return get_dataset_artifact(recipe_query).produce @@ -104,3 +108,10 @@ def produce(instance_or_instances, recipe_query): if not is_list: result = result[0] return result + + +def infer(instance_or_instances, recipe, engine): + dataset = produce(instance_or_instances, recipe) + engine, _ = fetch_artifact(engine) + predictions = engine.infer(dataset) + return post_process(predictions, dataset) diff --git a/src/unitxt/artifact.py b/src/unitxt/artifact.py index 423d7ab93..2c4d0222c 100644 --- a/src/unitxt/artifact.py +++ b/src/unitxt/artifact.py @@ -6,7 +6,7 @@ import re from abc import abstractmethod from copy import deepcopy -from typing import Any, Dict, List, Optional, Union, final +from typing import Any, Dict, List, Optional, Tuple, Union, final from .dataclass import ( AbstractField, @@ -429,7 +429,7 @@ def __str__(self): return f"Artifact {self.name} does not exist, in artifactories:{self.artifactories}" -def fetch_artifact(artifact_rep): +def fetch_artifact(artifact_rep) -> Tuple[Artifact, Union[Artifactory, None]]: if isinstance(artifact_rep, Artifact): return artifact_rep, None if Artifact.is_artifact_file(artifact_rep): diff --git a/src/unitxt/catalog/engines/model/flan/t5_small/hf.json b/src/unitxt/catalog/engines/model/flan/t5_small/hf.json new file mode 100644 index 000000000..0a177f126 --- /dev/null +++ b/src/unitxt/catalog/engines/model/flan/t5_small/hf.json @@ -0,0 +1,5 @@ +{ + "__type__": "hf_pipeline_based_inference_engine", + "model_name": "google/flan-t5-small", + "max_new_tokens": 32 +} diff --git a/src/unitxt/metric_utils.py b/src/unitxt/metric_utils.py index 5bee6c8a8..a41cc3b2c 100644 --- a/src/unitxt/metric_utils.py +++ b/src/unitxt/metric_utils.py @@ -9,6 +9,7 @@ from .dict_utils import dict_set from .operator import ( MultiStreamOperator, + SequentialOperator, SequentialOperatorInitializer, StreamInitializerOperator, ) @@ -146,6 +147,59 @@ def process( # When receiving instances from this scheme, the keys and values are returned as two separate # lists, and are converted to a dictionary. +_post_process_steps = SequentialOperator( + steps=[ + Copy( + field="prediction", + to_field="raw_prediction", + ), + Copy( + field="references", + to_field="raw_references", + ), + Copy( + field="source", + to_field="task_data/source", + ), + ApplyOperatorsField( + operators_field="postprocessors", + ), + Copy( + field="prediction", + to_field="processed_prediction", + ), + Copy( + field="references", + to_field="processed_references", + ), + ] +) + + +class PostProcessRecipe(SequentialOperatorInitializer): + def prepare(self): + register_all_artifacts() + self.steps = [ + FromPredictionsAndOriginalData(), + _post_process_steps, + ] + + +def _post_process( + predictions: List[str], + references: Iterable, + split_name: str = "all", +): + _reset_env_local_catalogs() + register_all_artifacts() + recipe = PostProcessRecipe() + + multi_stream = recipe( + predictions=predictions, references=references, split_name=split_name + ) + + return [instance["processed_prediction"] for instance in multi_stream[split_name]] + class MetricRecipe(SequentialOperatorInitializer): calc_confidence_intervals: bool = True @@ -156,29 +210,7 @@ def prepare(self): self.steps = [ FromPredictionsAndOriginalData(), LoadJson(field="task_data"), - Copy( - field="prediction", - to_field="raw_prediction", - ), - Copy( - field="references", - to_field="raw_references", - ), - Copy( - field="source", - to_field="task_data/source", - ), - ApplyOperatorsField( - operators_field="postprocessors", - ), - Copy( - field="prediction", - to_field="processed_prediction", - ), - Copy( - field="references", - to_field="processed_references", - ), + _post_process_steps, SplitByNestedGroup( field_name_of_group="group", number_of_fusion_generations=self.number_of_fusion_generations, diff --git a/tests/library/test_api.py b/tests/library/test_api.py index 437bdd992..5904d5cda 100644 --- a/tests/library/test_api.py +++ b/tests/library/test_api.py @@ -1,5 +1,5 @@ import numpy as np -from unitxt.api import evaluate, load_dataset, produce +from unitxt.api import evaluate, infer, load_dataset, post_process, produce from unitxt.card import TaskCard from unitxt.loaders import LoadHF from unitxt.task import Task @@ -88,6 +88,15 @@ def test_evaluate(self): del instance_with_results["postprocessors"] self.assertDictEqual(results[0], instance_with_results) + def test_post_process(self): + dataset = load_dataset( + "card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5,max_validation_instances=5,max_test_instances=5" + ) + predictions = ["2.5", "2.5", "2.2", "3", "4"] + targets = [2.5, 2.5, 2.2, 3.0, 4.0] + results = post_process(predictions, dataset["train"]) + self.assertListEqual(results, targets) + def test_evaluate_with_metrics_external_setup(self): dataset = load_dataset( "card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5,max_validation_instances=5,max_test_instances=5,metrics=[metrics.accuracy],postprocessors=[processors.first_character]" @@ -208,3 +217,17 @@ def test_load_dataset_from_dict(self): "When I pulled the pin out, it had a hole.", ) self.assertEqual(dataset["train"]["metrics"][0], ["metrics.accuracy"]) + + def test_infer(self): + engine = "engines.model.flan.t5_small.hf" + recipe = "card=cards.almost_evil,template=templates.qa.open.simple,demos_pool_size=0,num_demos=0" + instances = [ + {"question": "How many days there are in a week", "answers": ["7"]}, + { + "question": "If a ate an apple in the morning, and one in the evening, how many apples did I eat?", + "answers": ["2"], + }, + ] + predictions = infer(instances, recipe, engine) + targets = ["365", "1"] + self.assertListEqual(predictions, targets)