Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GenerateEmbeddings task #427

Merged
merged 8 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/distilabel/pipeline/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def _all_steps_loaded(self) -> bool:
`True` if all the steps have been loaded correctly, `False` otherwise.
"""
self._logger.info("⏳ Waiting for all the steps to load...")
previous_message = None
while True:
with self.shared_info[_STEPS_LOADED_LOCK_KEY]:
steps_loaded = self.shared_info[_STEPS_LOADED_KEY]
Expand All @@ -155,9 +156,12 @@ def _all_steps_loaded(self) -> bool:
self._logger.error("❌ Failed to load all the steps")
return False

self._logger.info(f"⏳ Steps loaded: {steps_loaded}/{len(self.dag)}")
message = f"⏳ Steps loaded: {steps_loaded}/{len(self.dag)}"
if message != previous_message:
self._logger.info(message)
previous_message = message

time.sleep(5)
time.sleep(2.5)

def _request_initial_batches(self) -> None:
"""Requests the initial batches to the generator steps."""
Expand Down
5 changes: 3 additions & 2 deletions src/distilabel/steps/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,12 @@ def process(self, inputs: *StepInput) -> StepOutput:
output_mappings: Dict[str, str] = {}

_runtime_parameters: Dict[str, Any] = PrivateAttr(default_factory=dict)
_values: Dict[str, Any] = PrivateAttr(default_factory=dict)
_built_from_decorator: bool = PrivateAttr(default=False)
_logger: logging.Logger = PrivateAttr(get_logger("steps"))

def model_post_init(self, _: Any) -> None:
def model_post_init(self, __context: Any) -> None:
super().model_post_init(__context)

if self.pipeline is None:
self.pipeline = _GlobalPipelineManager.get_pipeline()

Expand Down
16 changes: 10 additions & 6 deletions src/distilabel/steps/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import requests
from datasets import load_dataset
from pydantic import Field
from datasets import IterableDataset, load_dataset
from pydantic import Field, PrivateAttr

from distilabel.steps.base import GeneratorStep, RuntimeParameter

Expand Down Expand Up @@ -80,17 +80,20 @@ class LoadHubDataset(GeneratorStep):
description="The Hugging Face Hub repository ID of the dataset to load.",
)
split: RuntimeParameter[str] = Field(
default=None, description="The split of the dataset to load."
default="train",
description="The split of the dataset to load. Defaults to 'train'.",
)
config: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The configuration of the dataset to load. This is optional and only"
" needed if the dataset has multiple configurations.",
)

_dataset: Union[IterableDataset, None] = PrivateAttr(...)

def load(self) -> None:
"""Load the dataset from the Hugging Face Hub"""
self._values["dataset"] = load_dataset(
self._dataset = load_dataset(
self.repo_id, # type: ignore
self.config,
split=self.split,
Expand All @@ -108,10 +111,11 @@ def process(self, offset: int = 0) -> "GeneratorStepOutput":
A tuple containing a batch of rows and a boolean indicating if the batch is
the last one.
"""
dataset = self._values["dataset"]
num_examples = self._get_dataset_num_examples()
num_returned_rows = 0
for batch_num, batch in enumerate(dataset.iter(batch_size=self.batch_size)):
for batch_num, batch in enumerate(
self._dataset.iter(batch_size=self.batch_size)
):
if batch_num * self.batch_size < offset:
continue
transformed_batch = self._transform_batch(batch)
Expand Down
4 changes: 2 additions & 2 deletions src/distilabel/steps/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ class _Task(ABC):
)

def load(self) -> None:
"""Loads the LLM via the `LLM.load()` method (done for safer serialization)."""
self.llm.load() # type: ignore
"""Loads the LLM via the `LLM.load()` method"""
self.llm.load()

@abstractmethod
def format_input(self, input: Dict[str, Any]) -> "ChatType":
Expand Down
98 changes: 98 additions & 0 deletions src/distilabel/steps/task/generate_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Any, Dict, List

from distilabel.llm.base import LLM
from distilabel.steps.base import Step
from distilabel.utils.chat import is_openai_format

if TYPE_CHECKING:
from distilabel.steps.base import StepInput
from distilabel.steps.task.typing import ChatType
from distilabel.steps.typing import StepOutput


class GenerateEmbeddings(Step):
"""Generate embeddings for a text input using the last hidden state of an `LLM`, as
described in the paper 'What Makes Good Data for Alignment? A Comprehensive Study of
Automatic Data Selection in Instruction Tuning'.

Input columns:
text (`str`, `List[Dict[str, str]]`): The input text or conversation to generate
embeddings for.

Output columns:
embedding (`List[float]`): The embedding of the input text or conversation.

Reference:
- [What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning](https://arxiv.org/abs/2312.15685)
"""

llm: LLM

def load(self) -> None:
self.llm.load()
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved

@property
def inputs(self) -> List[str]:
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved
"""The inputs for the task is a `text` column containing either a string or a
list of dictionaries in OpenAI chat-like format."""
return ["text"]

@property
def outputs(self) -> List[str]:
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved
"""The outputs for the task is an `embedding` column containing the embedding of
the `text` input."""
return ["embedding"]

def format_input(self, input: Dict[str, Any]) -> "ChatType":
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved
"""Formats the input to be used by the LLM to generate the embeddings. The input
can be in `ChatType` format or a string. If a string, it will be converted to a
list of dictionaries in OpenAI chat-like format.

Args:
input: The input to format.

Returns:
The OpenAI chat-like format of the input.
"""
text = input["text"] = input["text"]

# input is in `ChatType` format
if isinstance(text, str):
return [{"role": "user", "content": text}]

if is_openai_format(text):
return text

raise ValueError(
f"Couldn't format input for step {self.name}. The `text` input column has to"
" be a string or a list of dictionaries in OpenAI chat-like format."
)

def process(self, inputs: "StepInput") -> "StepOutput": # type: ignore
"""Generates an embedding for each input using the last hidden state of the `LLM`.

Args:
inputs: A list of Python dictionaries with the inputs of the task.

Returns:
A list of Python dictionaries with the outputs of the task.
"""
formatted_inputs = [self.format_input(input) for input in inputs]
last_hidden_states = self.llm.get_last_hidden_states(formatted_inputs)
for input, hidden_state in zip(inputs, last_hidden_states):
input["embedding"] = hidden_state[-1].tolist()
yield inputs
39 changes: 39 additions & 0 deletions src/distilabel/utils/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any


def is_openai_format(input: Any) -> bool:
"""Checks if the input is in OpenAI chat-like format:

```python
[
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi! How can I help you?"},
]
```

Args:
input: The input to check.

Returns:
A boolean indicating if the input is in OpenAI chat-like format.
"""
if not isinstance(input, list):
return False
return all(
isinstance(x, dict) and "role" in x.keys() and "content" in x.keys()
for x in input
)
42 changes: 42 additions & 0 deletions tests/unit/steps/task/test_generate_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Generator

import pytest
from distilabel.llm.huggingface.transformers import TransformersLLM
from distilabel.pipeline.local import Pipeline
from distilabel.steps.task.generate_embeddings import GenerateEmbeddings


@pytest.fixture(scope="module")
def transformers_llm() -> Generator[TransformersLLM, None, None]:
llm = TransformersLLM(
model="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={"is_decoder": True},
)
llm.load()

yield llm


class TestGenerateEmbeddings:
def test_process(self, transformers_llm: TransformersLLM) -> None:
task = GenerateEmbeddings(
name="task", llm=transformers_llm, pipeline=Pipeline()
)
result = next(task.process([{"text": "Hello, how are you?"}]))

assert "embedding" in result[0]
assert len(result[0]["embedding"]) == 384
4 changes: 2 additions & 2 deletions tests/unit/steps/task/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ def test_process(self) -> None:
pipeline = Pipeline()
llm = DummyLLM()
task = TextGeneration(name="task", llm=llm, pipeline=pipeline)
assert list(task.process([{"instruction": "test"}])) == [
[{"instruction": "test", "generation": "output", "model_name": "test"}]
assert next(task.process([{"instruction": "test"}])) == [
{"instruction": "test", "generation": "output", "model_name": "test"}
]
43 changes: 43 additions & 0 deletions tests/unit/utils/test_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import pytest
from distilabel.utils.chat import is_openai_format


@pytest.mark.parametrize(
"input, expected",
[
(None, False),
(1, False),
("Hello", False),
(
[
{"role": "user", "content": "Hello!"},
],
True,
),
(
[
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi! How can I help you?"},
],
True,
),
],
)
def test_is_openai_format(input: Any, expected: bool) -> None:
assert is_openai_format(input) == expected
Loading