Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement caching mechanism for the pipelines #370

Merged
merged 37 commits into from
Mar 18, 2024
Merged

Conversation

plaguss
Copy link
Contributor

@plaguss plaguss commented Mar 1, 2024

Description

This PR implements the first version of caching for the Pipeline objects. It works by serializing and saving the pipeline content after each Step (runs as a callback after each process) finishes to a folder under ~/.cache/distilabel/pipelines:

  • batch_manager.json: Contains the serialized _BatchManager, the object in charge of managing the batches internally in the Pipeline.
  • pipeline.yaml: Contains the serialized Pipeline in YAML format, which can be reused within the CLI (CLI with run command #403).
  • data.jsonl: Generations saved as a jsonl file (this file may be modified or removed).

The following example, written under the tests/integrations folder can be run as an example:

from typing import Any, Dict, Generator, List

from distilabel.pipeline.local import Pipeline
from distilabel.steps.base import RuntimeParameter, Step, StepInput
from distilabel.steps.generators.huggingface import LoadHubDataset


class RenameColumns(Step):
    rename_mappings: RuntimeParameter[Dict[str, str]]

    @property
    def inputs(self) -> List[str]:
        return []

    @property
    def outputs(self) -> List[str]:
        return list(self.rename_mappings.values())  # type: ignore

    def process(self, inputs: StepInput) -> Generator[List[Dict[str, Any]], None, None]:
        outputs = []
        for input in inputs:
            outputs.append(
                {self.rename_mappings.get(k, k): v for k, v in input.items()}  # type: ignore
            )
        yield outputs


class GenerateResponse(Step):
    @property
    def inputs(self) -> List[str]:
        return ["instruction"]

    def process(self, inputs: StepInput) -> Generator[List[Dict[str, Any]], None, None]:
        import time

        time.sleep(0.8)

        print("***** NOT CACHED ******", len(inputs))
        for input in inputs:
            input["response"] = "I don't know"

        # NOTE: Caching here to save the evolution of the _BatchManager
        self.pipeline._cache()
        yield inputs

    @property
    def outputs(self) -> List[str]:
        return ["response"]


def test_pipeline_cached():
    def run_pipeline():
        with Pipeline() as pipeline:
            load_hub_dataset = LoadHubDataset(name="load_dataset", batch_size=8)
            rename_columns = RenameColumns(name="rename_columns", input_batch_size=12)
            generate_response = GenerateResponse(
                name="generate_response", input_batch_size=16
            )

            load_hub_dataset.connect(rename_columns)
            rename_columns.connect(generate_response)

            pipeline.run(
                parameters={
                    "load_dataset": {
                        "repo_id": "plaguss/test",
                        "split": "train",
                    },
                    "rename_columns": {
                        "rename_mappings": {
                            "prompt": "instruction",
                        },
                    },
                }
            )

    run_pipeline()
    print()
    print("----- RUNNING PIPELINE AGAIN -----")
    print()
    run_pipeline()

if __name__ == "__main__":
    test_pipeline_cached()

The script runs the pipeline twice, taking a look at the logs should show the effect of passing through the cached batches. Currently, we cached only after a step finishes, but it can managed by the user. See for example the previous GenerateResponse step, which calls the _pipeline._cache method before yielding it's results.

This PR would close #389

@plaguss plaguss added the enhancement New feature or request label Mar 1, 2024
@plaguss plaguss self-assigned this Mar 1, 2024
@plaguss plaguss marked this pull request as ready for review March 11, 2024 08:52
@plaguss plaguss merged commit eecdab4 into core-refactor Mar 18, 2024
4 checks passed
@plaguss plaguss deleted the caching branch March 18, 2024 09:02
@alvarobartt alvarobartt linked an issue Mar 18, 2024 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Pipeline caching
1 participant