Skip to content

Commit

Permalink
Add infer() function for end to end inference pipeline (#952)
Browse files Browse the repository at this point in the history
* Inference and Functions

Signed-off-by: elronbandel <elron.bandel@ibm.com>

* Update

Signed-off-by: elronbandel <elron.bandel@ibm.com>

* delete function from PR

Signed-off-by: elronbandel <elron.bandel@ibm.com>

* Update docs

Signed-off-by: elronbandel <elron.bandel@ibm.com>

* Update

Signed-off-by: elronbandel <elron.bandel@ibm.com>

* Update docs

Signed-off-by: elronbandel <elron.bandel@ibm.com>

* FIx

Signed-off-by: elronbandel <elron.bandel@ibm.com>

---------

Signed-off-by: elronbandel <elron.bandel@ibm.com>
  • Loading branch information
elronbandel committed Jun 26, 2024
1 parent 029afd1 commit 1cef107
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 39 deletions.
85 changes: 74 additions & 11 deletions docs/docs/production.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,44 @@

To use this tutorial, you need to :ref:`install unitxt <install_unitxt>`.

=====================================
Dynamic Data Processing For Inference
=====================================
========================
Inference and Production
========================

Unitxt can be used to process data dynamically and generate model-ready inputs on the fly, based on a given task recipe.
In this guide you will learn how to use unitxt data recipes in production.

First define a recipe:
For instance, you learn how to make end-to-end functions like `paraphrase()`:

.. code-block:: python
recipe = "card=cards.wnli,template=templates.classification.multi_class.relation.default,demos_pool_size=5,num_demos=2"
def paraphrase(text):
return unitxt.infer(
[{"input_text": text, "output_text": ""}],
recipe="card=cards.coedit.paraphrase,template=templates.rewriting.paraphrase.default",
engine="engines.model.flan.t5_small.hf"
)
Which then can be used like:

.. code-block:: python
paraphrase("So simple to paraphrase!")
In general, Unitxt is capable of:
- Producing processed data according to a given recipe.
- Post-processing predictions based on a recipe.
- Performing end-to-end inference using a recipe and a specified inference engine.

Produce Data
------------

First, define a recipe:

Second, prepare an python dictionary object in the exact schema of the task used in that recipe:
.. code-block:: python
recipe = "card=cards.wnli,template=templates.classification.multi_class.relation.default,demos_pool_size=5,num_demos=2"
Next, prepare a Python dictionary that matches the schema required by the recipe:

.. code-block:: python
Expand All @@ -32,19 +55,25 @@ Second, prepare an python dictionary object in the exact schema of the task used
"text_b_type": "hypothesis",
}
Then you can produce the model-ready input data with the `produce` function:
Then, produce the model-ready input data with the `produce` function:

.. code-block:: python
from unitxt import produce
result = produce([instance], recipe)
result = produce(instance, recipe)
To view the formatted instance, print the result:

.. code-block::
print(result["source"])
Then you have the formatted instance in the result. If you `print(result[0]["source"])` you will get:
This will output instances like:

.. code-block::
Given a premise and hypothesis classify the entailment of the hypothesis to one of entailment, not entailment.
Given a premise and a hypothesis, classify the entailment of the hypothesis as either 'entailment' or 'not entailment'.
premise: When Tatyana reached the cabin, her mother was sleeping. She was careful not to disturb her, undressing and climbing back into her berth., hypothesis: mother was careful not to disturb her, undressing and climbing back into her berth.
The entailment class is entailment
Expand All @@ -55,6 +84,40 @@ Then you have the formatted instance in the result. If you `print(result[0]["sou
premise: It works perfectly, hypothesis: It works!
The entailment class is
Post Process Data
-----------------

After obtaining predictions, they can be post-processed:

.. code-block:: python
from unitxt import post_process
prediction = model.generate(result["source"])
processed_result = post_process(predictions=[prediction], data=[result])[0]
End to End Inference Pipeline
-----------------------------

You can also implement an end-to-end inference pipeline using your preferred data and an inference engine:

.. code-block:: python
from unitxt import infer
from unitxt.inference import HFPipelineBasedInferenceEngine
engine = HFPipelineBasedInferenceEngine(
model_name="google/flan-t5-small", max_new_tokens=32
)
infer(instance, recipe, engine)
Alternatively, you can specify any inference engine from the catalog:

.. code-block:: python
infer(
instance,
recipe="card=cards.wnli,template=templates.classification.multi_class.relation.default,demos_pool_size=5,num_demos=2",
engine="engines.model.flan.t5_small.hf"
)
8 changes: 8 additions & 0 deletions prepare/engines/model/flan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from unitxt.catalog import add_to_catalog
from unitxt.inference import HFPipelineBasedInferenceEngine

engine = HFPipelineBasedInferenceEngine(
model_name="google/flan-t5-small", max_new_tokens=32
)

add_to_catalog(engine, "engines.model.flan.t5_small.hf", overwrite=True)
2 changes: 1 addition & 1 deletion src/unitxt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random

from .api import evaluate, load, load_dataset, produce
from .api import evaluate, infer, load, load_dataset, post_process, produce
from .catalog import add_to_catalog, get_from_catalog
from .logging_utils import get_logger
from .register import register_all_artifacts, register_local_catalog
Expand Down
13 changes: 12 additions & 1 deletion src/unitxt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .artifact import fetch_artifact
from .dataset_utils import get_dataset_artifact
from .logging_utils import get_logger
from .metric_utils import _compute
from .metric_utils import _compute, _post_process
from .operator import SourceOperator
from .standard import StandardRecipe

Expand Down Expand Up @@ -91,6 +91,10 @@ def evaluate(predictions, data) -> List[Dict[str, Any]]:
return _compute(predictions=predictions, references=data)


def post_process(predictions, data) -> List[Dict[str, Any]]:
return _post_process(predictions=predictions, references=data)


@lru_cache
def _get_produce_with_cache(recipe_query):
return get_dataset_artifact(recipe_query).produce
Expand All @@ -104,3 +108,10 @@ def produce(instance_or_instances, recipe_query):
if not is_list:
result = result[0]
return result


def infer(instance_or_instances, recipe, engine):
dataset = produce(instance_or_instances, recipe)
engine, _ = fetch_artifact(engine)
predictions = engine.infer(dataset)
return post_process(predictions, dataset)
4 changes: 2 additions & 2 deletions src/unitxt/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
from abc import abstractmethod
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union, final
from typing import Any, Dict, List, Optional, Tuple, Union, final

from .dataclass import (
AbstractField,
Expand Down Expand Up @@ -429,7 +429,7 @@ def __str__(self):
return f"Artifact {self.name} does not exist, in artifactories:{self.artifactories}"


def fetch_artifact(artifact_rep):
def fetch_artifact(artifact_rep) -> Tuple[Artifact, Union[Artifactory, None]]:
if isinstance(artifact_rep, Artifact):
return artifact_rep, None
if Artifact.is_artifact_file(artifact_rep):
Expand Down
5 changes: 5 additions & 0 deletions src/unitxt/catalog/engines/model/flan/t5_small/hf.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"__type__": "hf_pipeline_based_inference_engine",
"model_name": "google/flan-t5-small",
"max_new_tokens": 32
}
78 changes: 55 additions & 23 deletions src/unitxt/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .dict_utils import dict_set
from .operator import (
MultiStreamOperator,
SequentialOperator,
SequentialOperatorInitializer,
StreamInitializerOperator,
)
Expand Down Expand Up @@ -146,6 +147,59 @@ def process(
# When receiving instances from this scheme, the keys and values are returned as two separate
# lists, and are converted to a dictionary.

_post_process_steps = SequentialOperator(
steps=[
Copy(
field="prediction",
to_field="raw_prediction",
),
Copy(
field="references",
to_field="raw_references",
),
Copy(
field="source",
to_field="task_data/source",
),
ApplyOperatorsField(
operators_field="postprocessors",
),
Copy(
field="prediction",
to_field="processed_prediction",
),
Copy(
field="references",
to_field="processed_references",
),
]
)


class PostProcessRecipe(SequentialOperatorInitializer):
def prepare(self):
register_all_artifacts()
self.steps = [
FromPredictionsAndOriginalData(),
_post_process_steps,
]


def _post_process(
predictions: List[str],
references: Iterable,
split_name: str = "all",
):
_reset_env_local_catalogs()
register_all_artifacts()
recipe = PostProcessRecipe()

multi_stream = recipe(
predictions=predictions, references=references, split_name=split_name
)

return [instance["processed_prediction"] for instance in multi_stream[split_name]]


class MetricRecipe(SequentialOperatorInitializer):
calc_confidence_intervals: bool = True
Expand All @@ -156,29 +210,7 @@ def prepare(self):
self.steps = [
FromPredictionsAndOriginalData(),
LoadJson(field="task_data"),
Copy(
field="prediction",
to_field="raw_prediction",
),
Copy(
field="references",
to_field="raw_references",
),
Copy(
field="source",
to_field="task_data/source",
),
ApplyOperatorsField(
operators_field="postprocessors",
),
Copy(
field="prediction",
to_field="processed_prediction",
),
Copy(
field="references",
to_field="processed_references",
),
_post_process_steps,
SplitByNestedGroup(
field_name_of_group="group",
number_of_fusion_generations=self.number_of_fusion_generations,
Expand Down
25 changes: 24 additions & 1 deletion tests/library/test_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from unitxt.api import evaluate, load_dataset, produce
from unitxt.api import evaluate, infer, load_dataset, post_process, produce
from unitxt.card import TaskCard
from unitxt.loaders import LoadHF
from unitxt.task import Task
Expand Down Expand Up @@ -88,6 +88,15 @@ def test_evaluate(self):
del instance_with_results["postprocessors"]
self.assertDictEqual(results[0], instance_with_results)

def test_post_process(self):
dataset = load_dataset(
"card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5,max_validation_instances=5,max_test_instances=5"
)
predictions = ["2.5", "2.5", "2.2", "3", "4"]
targets = [2.5, 2.5, 2.2, 3.0, 4.0]
results = post_process(predictions, dataset["train"])
self.assertListEqual(results, targets)

def test_evaluate_with_metrics_external_setup(self):
dataset = load_dataset(
"card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5,max_validation_instances=5,max_test_instances=5,metrics=[metrics.accuracy],postprocessors=[processors.first_character]"
Expand Down Expand Up @@ -208,3 +217,17 @@ def test_load_dataset_from_dict(self):
"When I pulled the pin out, it had a hole.",
)
self.assertEqual(dataset["train"]["metrics"][0], ["metrics.accuracy"])

def test_infer(self):
engine = "engines.model.flan.t5_small.hf"
recipe = "card=cards.almost_evil,template=templates.qa.open.simple,demos_pool_size=0,num_demos=0"
instances = [
{"question": "How many days there are in a week", "answers": ["7"]},
{
"question": "If a ate an apple in the morning, and one in the evening, how many apples did I eat?",
"answers": ["2"],
},
]
predictions = infer(instances, recipe, engine)
targets = ["365", "1"]
self.assertListEqual(predictions, targets)

0 comments on commit 1cef107

Please sign in to comment.