Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add step decorator #387

Merged
merged 8 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/distilabel/pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import Optional, Set

from distilabel.steps.typing import StepInput
from distilabel.steps.base import StepInput


def combine_dicts(
Expand Down
61 changes: 38 additions & 23 deletions src/distilabel/steps/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

from distilabel.pipeline.base import BasePipeline, _GlobalPipelineManager
from distilabel.pipeline.logging import get_logger
from distilabel.steps.typing import StepInput
from distilabel.utils.serialization import TYPE_INFO_KEY, _Serializable
from distilabel.utils.typing import is_parameter_annotated_with

if TYPE_CHECKING:
from pydantic.fields import FieldInfo
Expand All @@ -35,12 +35,17 @@

_T = TypeVar("_T")
_RUNTIME_PARAMETER_ANNOTATION = "distilabel_step_runtime_parameter"

RuntimeParameter = Annotated[
Union[_T, None], Field(default=None), _RUNTIME_PARAMETER_ANNOTATION
]
"""Used to mark the attributes of a `Step` as a runtime parameter."""

_STEP_INPUT_ANNOTATION = "distilabel_step_input"
StepInput = Annotated[List[Dict[str, Any]], _STEP_INPUT_ANNOTATION]
"""StepInput is just an `Annotated` alias of the typing `List[Dict[str, Any]]` with
extra metadata that allows `distilabel` to perform validations over the `process` step
method defined in each `Step`"""


class _Step(BaseModel, _Serializable, ABC):
"""Base class for the steps that can be included in a `Pipeline`.
Expand Down Expand Up @@ -101,6 +106,7 @@ def process(self, inputs: *StepInput) -> StepOutput:

_runtime_parameters: Dict[str, Any] = PrivateAttr(default_factory=dict)
_values: Dict[str, Any] = PrivateAttr(default_factory=dict)
_built_from_decorator: bool = PrivateAttr(default=False)
_logger: logging.Logger = PrivateAttr(get_logger("step"))

def model_post_init(self, _: Any) -> None:
Expand Down Expand Up @@ -146,9 +152,10 @@ def _set_runtime_parameters(self, runtime_parameters: Dict[str, Any]) -> None:
Args:
runtime_parameters: A dictionary with the runtime parameters for the step.
"""
self._runtime_parameters = runtime_parameters
for name, value in runtime_parameters.items():
setattr(self, name, value)
if name in self.runtime_parameters_names:
setattr(self, name, value)
self._runtime_parameters[name] = value

@property
def is_generator(self) -> bool:
Expand Down Expand Up @@ -241,7 +248,10 @@ def get_process_step_input(self) -> Union[inspect.Parameter, None]:
"""
step_input_parameter = None
for parameter in self.process_parameters:
if _is_step_input(parameter) and step_input_parameter is not None:
if (
is_parameter_annotated_with(parameter, _STEP_INPUT_ANNOTATION)
and step_input_parameter is not None
):
raise TypeError(
f"Step '{self.name}' should have only one parameter with type hint `StepInput`."
)
Expand Down Expand Up @@ -362,6 +372,7 @@ def _get_runtime_parameters_info(self) -> List[Dict[str, Any]]:


class Step(_Step, ABC):
# TODO: this should be a `RuntimeParameter`
input_batch_size: PositiveInt = DEFAULT_INPUT_BATCH_SIZE

@abstractmethod
Expand All @@ -379,7 +390,16 @@ def process_applying_mappings(self, *args: List[Dict[str, Any]]) -> "StepOutput"

inputs = self._apply_input_mappings(args) if self.input_mappings else args

for output_rows in self.process(*inputs):
# If the `Step` was built using the `@step` decorator, then we need to pass
# the runtime parameters as kwargs, so they can be used within the processing
# function
generator = (
self.process(*inputs)
if not self._built_from_decorator
else self.process(*inputs, **self._runtime_parameters)
)

for output_rows in generator:
yield [
{
# Apply output mapping and revert input mapping
Expand Down Expand Up @@ -427,6 +447,7 @@ class GeneratorStep(_Step, ABC):
any input from the previous steps.
"""

# TODO: this should be a `RuntimeParameter` and maybe be called `output_batch_size`
batch_size: int = 50

@abstractmethod
Expand All @@ -435,7 +456,7 @@ def process(self) -> "GeneratorStepOutput":
output rows and a boolean indicating if it's the last batch or not."""
pass

def process_applying_mappings(self, *args: "StepInput") -> "GeneratorStepOutput":
def process_applying_mappings(self) -> "GeneratorStepOutput":
"""Runs the `process` method of the step applying the `outputs_mappings` to the
output rows. This is the function that should be used to run the generation logic
of the step.
Expand All @@ -444,7 +465,16 @@ def process_applying_mappings(self, *args: "StepInput") -> "GeneratorStepOutput"
The output rows and a boolean indicating if it's the last batch or not.
"""

for output_rows, last_batch in self.process(*args):
# If the `Step` was built using the `@step` decorator, then we need to pass
# the runtime parameters as `kwargs`, so they can be used within the processing
# function
generator = (
self.process()
if not self._built_from_decorator
else self.process(**self._runtime_parameters)
)

for output_rows, last_batch in generator:
yield (
[
{self.output_mappings.get(k, k): v for k, v in row.items()}
Expand All @@ -470,21 +500,6 @@ def outputs(self) -> List[str]:
return []


def _is_step_input(parameter: inspect.Parameter) -> bool:
"""Check if the parameter has type hint `StepInput`.

Args:
parameter: The parameter to check.

Returns:
`True` if the parameter has type hint `StepInput`, `False` otherwise.
"""
return (
get_origin(parameter.annotation) is Annotated
and get_args(parameter.annotation)[-1] == "StepInput"
)


def _is_runtime_parameter(field: "FieldInfo") -> Tuple[bool, bool]:
"""Check if a `pydantic.BaseModel` field is a `RuntimeParameter` and if it's optional
i.e. providing a value for the field in `Pipeline.run` is optional.
Expand Down
10 changes: 6 additions & 4 deletions src/distilabel/steps/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional

from distilabel.pipeline.utils import combine_dicts
from distilabel.steps.base import Step
from distilabel.steps.typing import StepInput, StepOutput
from distilabel.steps.base import Step, StepInput

if TYPE_CHECKING:
from distilabel.steps.typing import StepOutput


class CombineColumns(Step):
Expand All @@ -41,7 +43,7 @@ def outputs(self) -> List[str]:
else [f"merged_{column}" for column in self.merge_columns]
)

def process(self, *args: StepInput) -> StepOutput:
def process(self, *args: StepInput) -> "StepOutput":
yield combine_dicts(
*args,
merge_keys=set(self.inputs),
Expand Down
152 changes: 152 additions & 0 deletions src/distilabel/steps/decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import TYPE_CHECKING, Any, Callable, List, Literal, Type, TypeVar, Union

from pydantic import create_model

from distilabel.steps.base import (
_RUNTIME_PARAMETER_ANNOTATION,
GeneratorStep,
GlobalStep,
Step,
)
from distilabel.utils.typing import is_parameter_annotated_with

if TYPE_CHECKING:
from distilabel.steps.base import _Step
from distilabel.steps.typing import GeneratorStepOutput, StepOutput

_step_mapping = {
"normal": Step,
"global": GlobalStep,
"generator": GeneratorStep,
}

ProcessingFunc = TypeVar(
"ProcessingFunc", bound=Callable[..., Union["StepOutput", "GeneratorStepOutput"]]
)


def step(
inputs: Union[List[str], None] = None,
outputs: Union[List[str], None] = None,
step_type: Literal["normal", "global", "generator"] = "normal",
) -> Callable[..., Type["_Step"]]:
"""Creates an `Step` from a processing function.

Args:
inputs: a list containing the name of the inputs columns/keys expected by this step.
If not provided the default will be an empty list `[]` and it will be assumed
that the step doesn't need any spefic columns. Defaults to `None`.
outputs: a list containing the name of the outputs columns/keys that the step
will generate. If not provided the default will be an empty list `[]` and it
will be assumed that the step doesn't need any spefic columns. Defaults to
`None`.
step_type: the kind of step to create. Valid choices are: "normal" (`Step`),
"global" (`GlobalStep`) or "generator" (`GeneratorStep`). Defaults to
`"normal"`.

Returns:
A callable that will generate the type given the processing function.

Example:

```python
# Normal step
@step(inputs=["instruction"], outputs=["generation"])
def GenerationStep(inputs: StepInput, dummy_generation: RuntimeParameter[str]) -> StepOutput:
for input in inputs:
input["generation"] = dummy_generation
yield inputs

# Global step
@step(inputs=["instruction"], step_type="global")
def FilteringStep(inputs: StepInput, max_length: RuntimeParameter[int] = 256) -> StepOutput:
yield [
input
for input in inputs
if len(input["instruction"]) <= max_length
]

# Generator step
@step(outputs=["num"], step_type="generator")
def RowGenerator(num_rows: RuntimeParameter[int] = 500) -> GeneratorStepOutput:
data = list(range(num_rows))
for i in range(0, len(data), 100):
last_batch = i + 100 >= len(data)
yield [{"num": num} for num in data[i : i + 100]], last_batch
```
"""

inputs = inputs or []
outputs = outputs or []

def decorator(
func: Callable[..., Union["StepOutput", "GeneratorStepOutput"]],
) -> Type["_Step"]:
if step_type not in _step_mapping:
raise ValueError(
f"Invalid step type '{step_type}'. Please, review the '{func.__name__}'"
" function decorated with the `@step` decorator and provide a valid"
" `step_type`. Valid choices are: 'normal', 'global' or 'generator'."
)

BaseClass = _step_mapping[step_type]

signature = inspect.signature(func)

runtime_parameters = {
name: (
param.annotation,
param.default if param.default != param.empty else None,
)
for name, param in signature.parameters.items()
if is_parameter_annotated_with(param, _RUNTIME_PARAMETER_ANNOTATION)
}

RuntimeParametersModel = create_model( # type: ignore
"RuntimeParametersModel",
**runtime_parameters, # type: ignore
)

def inputs_property(self) -> List[str]:
return inputs

def outputs_property(self) -> List[str]:
return outputs

def process(
self, *args: Any, **kwargs: Any
) -> Union["StepOutput", "GeneratorStepOutput"]:
return func(*args, **kwargs)

return type( # type: ignore
func.__name__,
(
BaseClass,
RuntimeParametersModel,
),
{
"process": process,
"inputs": property(inputs_property),
"outputs": property(outputs_property),
"__module__": func.__module__,
"__doc__": func.__doc__,
"_built_from_decorator": True,
},
)

return decorator
18 changes: 10 additions & 8 deletions src/distilabel/steps/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional

from pydantic import Field

from distilabel.llm.base import LLM
from distilabel.steps.base import RuntimeParameter, Step
from distilabel.steps.base import RuntimeParameter, Step, StepInput
from distilabel.steps.task.typing import ChatType
from distilabel.steps.typing import StepInput, StepOutput

if TYPE_CHECKING:
from distilabel.steps.typing import StepOutput


class Task(Step, ABC):
Expand Down Expand Up @@ -64,7 +66,7 @@ def format_output(self, output: str) -> Dict[str, Any]:
as a string, and generates a Python dictionary with the outputs of the task."""
pass

def process(self, inputs: StepInput) -> StepOutput:
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
"""Processes the inputs of the task and generates the outputs using the LLM.

Args:
Expand All @@ -77,10 +79,10 @@ def process(self, inputs: StepInput) -> StepOutput:
outputs = self.llm.generate(formatted_inputs, **self.generation_kwargs) # type: ignore
formatted_outputs = [self.format_output(output) for output in outputs] # type: ignore

outputs: StepOutput = [] # type: ignore
outputs = []
for input, formatted_output in zip(inputs, formatted_outputs):
output = {k: v for k, v in input.items() if k in self.inputs}
output.update(formatted_output)
output["model_name"] = self.llm.model_name # type: ignore
outputs.append(output) # type: ignore
yield outputs # type: ignore
output["model_name"] = self.llm.model_name
outputs.append(output)
yield outputs
5 changes: 0 additions & 5 deletions src/distilabel/steps/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@

from typing_extensions import Annotated

StepInput = Annotated[List[Dict[str, Any]], "StepInput"]
"""StepInput is just an `Annotated` alias of the typing `List[Dict[str, Any]]` with
extra metadata that allows `distilabel` to perform validations over the `process` step
method defined in each `Step`"""

StepOutput = Annotated[Iterator[List[Dict[str, Any]]], "StepOutput"]
"""StepOutput is just an `Annotated` alias of the typing `Iterator[List[Dict[str, Any]]]`"""

Expand Down
Loading
Loading