diff --git a/.make.defaults b/.make.defaults index 9b4c7e1ab..8b14c533f 100644 --- a/.make.defaults +++ b/.make.defaults @@ -254,7 +254,7 @@ __check_defined = \ .PHONY: .defaults.check.installed .defaults.check.installed:: - if [ ! command -v $(CHECK_RUNNABLE) &>/dev/null ]; then \ + @if [ ! command -v $(CHECK_RUNNABLE) &>/dev/null ]; then \ echo $(CHECK_RUNNABLE) must be installed; \ exit 1; \ fi diff --git a/data-processing-lib/src/data_processing/launch/ray/transform_configuration.py b/data-processing-lib/src/data_processing/launch/ray/transform_configuration.py new file mode 100644 index 000000000..2fb68dd67 --- /dev/null +++ b/data-processing-lib/src/data_processing/launch/ray/transform_configuration.py @@ -0,0 +1,33 @@ +from data_processing.launch.ray import DefaultTableTransformRuntimeRay +from data_processing.transform.transform_configuration import ( + TransformConfiguration, + TransformConfigurationProxy, +) + + +class RayTransformConfiguration(TransformConfigurationProxy): + def __init__( + self, + transform_config: TransformConfiguration, + runtime_class: type[DefaultTableTransformRuntimeRay] = DefaultTableTransformRuntimeRay, + ): + """ + Initialization + :param transform_class: implementation of the transform + :param runtime_class: implementation of the transform runtime + :param base: base transform configuration class + :param name: transform name + :param remove_from_metadata: list of parameters to remove from metadata + :return: + """ + super().__init__( + proxied_transform_config=transform_config, + ) + self.runtime_class = runtime_class + + def create_transform_runtime(self) -> DefaultTableTransformRuntimeRay: + """ + Create transform runtime with the parameters captured during apply_input_params() + :return: transform runtime object + """ + return self.runtime_class(self.params) diff --git a/data-processing-lib/src/data_processing/launch/ray/transform_launcher.py b/data-processing-lib/src/data_processing/launch/ray/transform_launcher.py index 0cb86bca3..fcbaf5d40 100644 --- a/data-processing-lib/src/data_processing/launch/ray/transform_launcher.py +++ b/data-processing-lib/src/data_processing/launch/ray/transform_launcher.py @@ -16,13 +16,15 @@ import ray from data_processing.data_access import DataAccessFactory, DataAccessFactoryBase -from data_processing.transform import TransformConfiguration from data_processing.launch.ray import ( + DefaultTableTransformRuntimeRay, RayLauncherConfiguration, TransformOrchestratorConfiguration, - orchestrate, DefaultTableTransformRuntimeRay, + orchestrate, ) +from data_processing.launch.ray.transform_configuration import RayTransformConfiguration from data_processing.launch.transform_launcher import AbstractTransformLauncher +from data_processing.transform import TransformConfiguration from data_processing.utils import get_logger, str2bool @@ -36,8 +38,8 @@ class RayTransformLauncher(AbstractTransformLauncher): def __init__( self, - transform_config: TransformConfiguration, - runtime_class: type[DefaultTableTransformRuntimeRay]=DefaultTableTransformRuntimeRay, + transform_config: RayTransformConfiguration, + # runtime_class: type[DefaultTableTransformRuntimeRay]=DefaultTableTransformRuntimeRay, data_access_factory: DataAccessFactoryBase = DataAccessFactory(), ): """ @@ -46,7 +48,7 @@ def __init__( :param data_access_factory: the factory to create DataAccess instances. """ super().__init__(transform_config, data_access_factory) - self.transform_runtime_config = RayLauncherConfiguration(transform_config, runtime_class) + self.transform_runtime_config = RayLauncherConfiguration(transform_config, transform_config.runtime_class) self.ray_orchestrator = TransformOrchestratorConfiguration(name=self.name) def __get_parameters(self) -> bool: diff --git a/data-processing-lib/src/data_processing/test_support/transform/noop_transform.py b/data-processing-lib/src/data_processing/test_support/transform/noop_transform.py index 88f36b04f..6c5e57329 100644 --- a/data-processing-lib/src/data_processing/test_support/transform/noop_transform.py +++ b/data-processing-lib/src/data_processing/test_support/transform/noop_transform.py @@ -15,8 +15,8 @@ from typing import Any import pyarrow as pa - from data_processing.launch.pure_python import PythonTransformLauncher +from data_processing.launch.ray.transform_configuration import RayTransformConfiguration from data_processing.transform import AbstractTableTransform, TransformConfiguration from data_processing.utils import CLIArgumentProvider, get_logger @@ -65,6 +65,7 @@ def transform(self, table: pa.Table) -> tuple[list[pa.Table], dict[str, Any]]: metadata = {"nfiles": 1, "nrows": len(table)} return [table], metadata + class NOOPTransformConfiguration(TransformConfiguration): """ @@ -118,6 +119,10 @@ def apply_input_params(self, args: Namespace) -> bool: return True +class NOOPRayTransformConfiguration(RayTransformConfiguration): + def __init__(self): + super().__init__(NOOPTransformConfiguration()) + # # class NOOPTransformConfigurationRayLauncherConfiguration(RayLauncherConfiguration): @@ -141,7 +146,7 @@ def apply_input_params(self, args: Namespace) -> bool: # if __name__ == "__main__": - #launcher = PythonTransformLauncher(transform_runtime_config=NOOPPythonLauncherConfigurationPython()) + # launcher = PythonTransformLauncher(transform_runtime_config=NOOPPythonLauncherConfigurationPython()) launcher = PythonTransformLauncher(transform_runtime_config=NOOPTransformConfiguration()) logger.info("Launching noop transform") launcher.launch() diff --git a/data-processing-lib/src/data_processing/transform/transform_configuration.py b/data-processing-lib/src/data_processing/transform/transform_configuration.py index 6e95ab262..80f4a72ab 100644 --- a/data-processing-lib/src/data_processing/transform/transform_configuration.py +++ b/data-processing-lib/src/data_processing/transform/transform_configuration.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################ - +import argparse from argparse import ArgumentParser +from typing import Any from data_processing.transform import AbstractTableTransform from data_processing.utils import CLIArgumentProvider @@ -21,7 +22,7 @@ class TransformConfiguration(CLIArgumentProvider): This is a base transform configuration class defining transform's input/output parameter """ - def __init__(self, name:str, transform_class:AbstractTableTransform, remove_from_metadata: list[str] = []): + def __init__(self, name: str, transform_class: AbstractTableTransform, remove_from_metadata: list[str] = []): """ Initialization """ @@ -30,6 +31,29 @@ def __init__(self, name:str, transform_class:AbstractTableTransform, remove_from self.remove_from_metadata = remove_from_metadata self.params = {} + +class TransformConfigurationProxy(TransformConfiguration): + def __init__(self, proxied_transform_config: TransformConfiguration): + self.proxied_transform_config = proxied_transform_config + # Python probably has a better way of doing this using the proxied transform config + self.name = proxied_transform_config.name + self.transform_class = proxied_transform_config.transform_class + self.remove_from_metadata = proxied_transform_config.remove_from_metadata + self.params = {} + + def add_input_params(self, parser: argparse.ArgumentParser) -> None: + self.proxied_transform_config.add_input_params(parser) + + def apply_input_params(self, args: argparse.Namespace) -> bool: + is_valid = self.proxied_transform_config.apply_input_params(args) + if is_valid: + self.params = self.proxied_transform_config.params + return is_valid + + def get_input_params(self) -> dict[str, Any]: + return self.params + + def get_transform_config( transform_configuration: TransformConfiguration, argv: list[str], parser: ArgumentParser = None ): diff --git a/data-processing-lib/test/data_processing_tests/launch/ray/launcher_test.py b/data-processing-lib/test/data_processing_tests/launch/ray/launcher_test.py index 01c790f05..addfbf936 100644 --- a/data-processing-lib/test/data_processing_tests/launch/ray/launcher_test.py +++ b/data-processing-lib/test/data_processing_tests/launch/ray/launcher_test.py @@ -13,11 +13,13 @@ import os import sys -from data_processing.transform import TransformConfiguration -from data_processing.launch.ray import ( - RayTransformLauncher, +from data_processing.launch.ray import RayTransformLauncher +from data_processing.launch.ray.transform_configuration import RayTransformConfiguration +from data_processing.test_support.transform import NOOPTransformConfiguration +from data_processing.test_support.transform.noop_transform import ( + NOOPRayTransformConfiguration, ) -from data_processing.transform import AbstractTableTransform +from data_processing.transform import AbstractTableTransform, TransformConfiguration from data_processing.utils import ParamsUtils @@ -43,15 +45,13 @@ code_location = {"github": "github", "commit_hash": "12345", "path": "path"} -class TestingTransformConfiguration(TransformConfiguration): - def __init__(self): - super().__init__("test", transform_class=AbstractTableTransform) class TestLauncherTransformLauncher(RayTransformLauncher): """ Test driver for validation of the functionality """ + def __init__(self): - super().__init__( TestingTransformConfiguration()) + super().__init__(NOOPRayTransformConfiguration()) def _submit_for_execution(self) -> int: """ @@ -89,8 +89,6 @@ def test_launcher(): sys.argv = ParamsUtils.dict_to_req(d=params) res = TestLauncherTransformLauncher().launch() - - assert 0 == res # Add local config, should fail because now three different configs exist params["data_local_config"] = ParamsUtils.convert_to_ast(local_conf) diff --git a/data-processing-lib/test/data_processing_tests/launch/ray/test_noop_launch.py b/data-processing-lib/test/data_processing_tests/launch/ray/test_noop_launch.py index ae728c1a6..f92cee0eb 100644 --- a/data-processing-lib/test/data_processing_tests/launch/ray/test_noop_launch.py +++ b/data-processing-lib/test/data_processing_tests/launch/ray/test_noop_launch.py @@ -13,10 +13,15 @@ import os import pyarrow as pa - from data_processing.launch.ray import RayTransformLauncher -from data_processing.test_support.launch.transform_test import AbstractTransformLauncherTest +from data_processing.test_support.launch.transform_test import ( + AbstractTransformLauncherTest, +) from data_processing.test_support.transform import NOOPTransformConfiguration +from data_processing.test_support.transform.noop_transform import ( + NOOPRayTransformConfiguration, +) + table = pa.Table.from_pydict({"name": pa.array(["Tom"]), "age": pa.array([23])}) expected_table = table # We're a noop after all. @@ -32,6 +37,6 @@ class TestRayNOOPTransform(AbstractTransformLauncherTest): def get_test_transform_fixtures(self) -> list[tuple]: basedir = "../../../../test-data/data_processing/ray/noop/" basedir = os.path.abspath(os.path.join(os.path.dirname(__file__), basedir)) - launcher = RayTransformLauncher(NOOPTransformConfiguration()) + launcher = RayTransformLauncher(NOOPRayTransformConfiguration()) fixtures = [(launcher, {"noop_sleep_sec": 0}, basedir + "/input", basedir + "/expected")] return fixtures diff --git a/transforms/code/code_quality/src/code_quality_local_ray.py b/transforms/code/code_quality/src/code_quality_local_ray.py index f74b8e71b..48eed9ad8 100644 --- a/transforms/code/code_quality/src/code_quality_local_ray.py +++ b/transforms/code/code_quality/src/code_quality_local_ray.py @@ -14,7 +14,8 @@ import sys from pathlib import Path -from code_quality_transform import CodeQualityRayLauncher +from code_quality_transform import CodeQualityRayTransformConfiguration +from data_processing.launch.ray import RayTransformLauncher from data_processing.utils import ParamsUtils @@ -26,7 +27,7 @@ } # create launcher -launcher = CodeQualityRayLauncher() +launcher = RayTransformLauncher(CodeQualityRayTransformConfiguration()) worker_options = {"num_cpus": 0.8} diff --git a/transforms/code/code_quality/src/code_quality_s3_ray.py b/transforms/code/code_quality/src/code_quality_s3_ray.py index 1885daf9a..1954a37c0 100644 --- a/transforms/code/code_quality/src/code_quality_s3_ray.py +++ b/transforms/code/code_quality/src/code_quality_s3_ray.py @@ -12,7 +12,8 @@ import sys -from code_quality_transform import CodeQualityRayLauncher +from code_quality_transform import CodeQualityRayTransformConfiguration +from data_processing.launch.ray import RayTransformLauncher from data_processing.utils import ParamsUtils @@ -27,7 +28,7 @@ } # create launcher -launcher = CodeQualityRayLauncher() +launcher = RayTransformLauncher(CodeQualityRayTransformConfiguration()) worker_options = {"num_cpus": 0.8} diff --git a/transforms/code/code_quality/src/code_quality_transform.py b/transforms/code/code_quality/src/code_quality_transform.py index 064809cee..f7cc53e83 100644 --- a/transforms/code/code_quality/src/code_quality_transform.py +++ b/transforms/code/code_quality/src/code_quality_transform.py @@ -24,9 +24,9 @@ import numpy as np import pyarrow as pa from bs4 import BeautifulSoup -from data_processing.transform import TransformConfiguration from data_processing.launch.ray import RayTransformLauncher -from data_processing.transform import AbstractTableTransform +from data_processing.launch.ray.transform_configuration import RayTransformConfiguration +from data_processing.transform import AbstractTableTransform, TransformConfiguration from data_processing.utils import TransformUtils from transformers import AutoTokenizer @@ -285,10 +285,7 @@ def transform(self, table: pa.Table) -> tuple[list[pa.Table], dict]: class CodeQualityTransformConfiguration(TransformConfiguration): def __init__(self): - super().__init__( - name="code_quality", - transform_class=CodeQualityTransform - ) + super().__init__(name="code_quality", transform_class=CodeQualityTransform) def add_input_params(self, parser: ArgumentParser) -> None: parser.add_argument( @@ -339,28 +336,11 @@ def apply_input_params(self, args: Namespace) -> bool: return True -# class CodeQualityRayLauncherConfiguration(RayLauncherConfiguration): -# def __init__(self): -# super().__init__( -# name="code_quality", -# transform_class=CodeQualityTransform, -# launcher_configuration=CodeQualityTransformConfiguration(), -# ) -# -# -# class CodeQualityPythonLauncherConfiguration(PythonLauncherConfiguration): -# def __init__(self): -# super().__init__( -# name="code_quality", -# transform_class=CodeQualityTransform, -# launcher_configuration=CodeQualityTransformConfiguration(), -# ) -# self.base = CodeQualityTransformConfiguration() - -class CodeQualityRayLauncher(RayTransformLauncher): +class CodeQualityRayTransformConfiguration(RayTransformConfiguration): def __init__(self): super().__init__(transform_config=CodeQualityTransformConfiguration()) + if __name__ == "__main__": - launcher = CodeQualityRayLauncher() + launcher = RayTransformLauncher(CodeQualityRayTransformConfiguration()) launcher.launch() diff --git a/transforms/code/code_quality/test/test_code_quality_launcher.py b/transforms/code/code_quality/test/test_code_quality_launcher.py index 793bb4e4d..5a9e16361 100644 --- a/transforms/code/code_quality/test/test_code_quality_launcher.py +++ b/transforms/code/code_quality/test/test_code_quality_launcher.py @@ -12,8 +12,11 @@ import os -from code_quality_transform import CodeQualityRayLauncher -from data_processing.test_support.launch.transform_test import AbstractTransformLauncherTest +from code_quality_transform import CodeQualityRayTransformConfiguration +from data_processing.launch.ray import RayTransformLauncher +from data_processing.test_support.launch.transform_test import ( + AbstractTransformLauncherTest, +) class TestCodeQualityTransform(AbstractTransformLauncherTest): @@ -30,6 +33,6 @@ def get_test_transform_fixtures(self) -> list[tuple]: } basedir = "../test-data" basedir = os.path.abspath(os.path.join(os.path.dirname(__file__), basedir)) - launcher = CodeQualityRayLauncher() + launcher = RayTransformLauncher(CodeQualityRayTransformConfiguration()) fixtures = [(launcher, cli, basedir + "/input", basedir + "/expected")] return fixtures diff --git a/transforms/code/malware/src/malware_local_ray.py b/transforms/code/malware/src/malware_local_ray.py index 9149ef726..c2a16cd0b 100644 --- a/transforms/code/malware/src/malware_local_ray.py +++ b/transforms/code/malware/src/malware_local_ray.py @@ -14,8 +14,10 @@ import sys from pathlib import Path +from data_processing.launch.ray import RayTransformLauncher from data_processing.utils import ParamsUtils -from malware_transform import check_clamd, MalwareRayLauncher +from malware_transform import MalwareRayTransformConfiguration, check_clamd + TEST_SOCKET = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".tmp", "clamd.ctl")) # create parameters @@ -51,6 +53,6 @@ # Set the simulated command line args sys.argv = ParamsUtils.dict_to_req(d=params | malware_params) # create launcher - launcher = MalwareRayLauncher() + launcher = RayTransformLauncher(MalwareRayTransformConfiguration()) # Launch the ray actor(s) to process the input launcher.launch() diff --git a/transforms/code/malware/src/malware_transform.py b/transforms/code/malware/src/malware_transform.py index 5d80c23f2..8802ffe42 100644 --- a/transforms/code/malware/src/malware_transform.py +++ b/transforms/code/malware/src/malware_transform.py @@ -18,10 +18,13 @@ import clamd import pyarrow as pa -from data_processing.transform import TransformConfiguration -from data_processing.launch.pure_python import PythonTransformLauncher, PythonLauncherConfiguration +from data_processing.launch.pure_python import ( + PythonLauncherConfiguration, + PythonTransformLauncher, +) from data_processing.launch.ray import RayTransformLauncher -from data_processing.transform import AbstractTableTransform +from data_processing.launch.ray.transform_configuration import RayTransformConfiguration +from data_processing.transform import AbstractTableTransform, TransformConfiguration from data_processing.utils import get_logger from data_processing.utils.transform_utils import TransformUtils @@ -202,11 +205,12 @@ def apply_input_params(self, args: Namespace) -> bool: # name="Malware", transform_class=MalwareTransform, launcher_configuration=MalwareTransformConfiguration() # ) # -class MalwareRayLauncher(RayTransformLauncher): +class MalwareRayTransformConfiguration(RayTransformConfiguration): def __init__(self): super().__init__(transform_config=MalwareTransformConfiguration()) + if __name__ == "__main__": - launcher = MalwareRayLauncher() + launcher = RayTransformLauncher(MalwareRayTransformConfiguration()) logger.info("Launching malware transform") launcher.launch() diff --git a/transforms/code/malware/test/test_malware_ray.py b/transforms/code/malware/test/test_malware_ray.py index f03c99797..903f2999f 100644 --- a/transforms/code/malware/test/test_malware_ray.py +++ b/transforms/code/malware/test/test_malware_ray.py @@ -15,11 +15,15 @@ import os -from data_processing.test_support.launch.transform_test import AbstractTransformLauncherTest +from data_processing.launch.ray import RayTransformLauncher +from data_processing.launch.ray.transform_configuration import RayTransformConfiguration +from data_processing.test_support.launch.transform_test import ( + AbstractTransformLauncherTest, +) from malware_transform import ( INPUT_COLUMN_KEY, OUTPUT_COLUMN_KEY, - MalwareRayLauncher, + MalwareRayTransformConfiguration, ) @@ -34,7 +38,7 @@ def get_test_transform_fixtures(self) -> list[tuple]: basedir = os.path.abspath(os.path.join(os.path.dirname(__file__), basedir)) fixtures = [ ( - MalwareRayLauncher(), + RayTransformLauncher(MalwareRayTransformConfiguration()), {INPUT_COLUMN_KEY: "contents", OUTPUT_COLUMN_KEY: "virus_detection"}, basedir + "/input", basedir + "/expected", diff --git a/transforms/code/proglang_select/src/proglang_select_local_ray.py b/transforms/code/proglang_select/src/proglang_select_local_ray.py index 9de0415de..37fc65a8f 100644 --- a/transforms/code/proglang_select/src/proglang_select_local_ray.py +++ b/transforms/code/proglang_select/src/proglang_select_local_ray.py @@ -13,15 +13,16 @@ import os import sys +from data_processing.launch.ray import RayTransformLauncher from data_processing.utils import ParamsUtils from proglang_select_transform import ( + ProgLangSelectRayConfiguration, lang_allowed_langs_file_key, lang_lang_column_key, - lang_output_column_key, ProgLangSelectRayLauncher, + lang_output_column_key, ) - # create parameters language_column_name = "language" annotated_column_name = "lang_selected" @@ -61,6 +62,6 @@ if __name__ == "__main__": sys.argv = ParamsUtils.dict_to_req(d=params) # create launcher - launcher = ProgLangSelectRayLauncher() + launcher = RayTransformLauncher(ProgLangSelectRayConfiguration()) # launch launcher.launch() diff --git a/transforms/code/proglang_select/src/proglang_select_transform.py b/transforms/code/proglang_select/src/proglang_select_transform.py index f03aece94..b5dc565ca 100644 --- a/transforms/code/proglang_select/src/proglang_select_transform.py +++ b/transforms/code/proglang_select/src/proglang_select_transform.py @@ -20,13 +20,16 @@ DataAccessFactory, DataAccessFactoryBase, ) -from data_processing.transform import TransformConfiguration -from data_processing.launch.pure_python import PythonTransformLauncher, PythonLauncherConfiguration +from data_processing.launch.pure_python import ( + PythonLauncherConfiguration, + PythonTransformLauncher, +) from data_processing.launch.ray import ( DefaultTableTransformRuntimeRay, RayTransformLauncher, ) -from data_processing.transform import AbstractTableTransform +from data_processing.launch.ray.transform_configuration import RayTransformConfiguration +from data_processing.transform import AbstractTableTransform, TransformConfiguration from data_processing.utils import TransformUtils, get_logger from ray.actor import ActorHandle @@ -237,10 +240,12 @@ def apply_input_params(self, args: Namespace) -> bool: # ) # -class ProgLangSelectRayLauncher(RayTransformLauncher): + +class ProgLangSelectRayConfiguration(RayTransformConfiguration): def __init__(self): super().__init__(ProgLangSelectTransformConfiguration(), ProgLangSelectRuntime) + if __name__ == "__main__": - launcher = ProgLangSelectRayLauncher() + launcher = RayTransformLauncher(ProgLangSelectRayConfiguration()) launcher.launch() diff --git a/transforms/code/proglang_select/test/test_proglang_select_ray.py b/transforms/code/proglang_select/test/test_proglang_select_ray.py index 3d0716cfc..5f71aee90 100644 --- a/transforms/code/proglang_select/test/test_proglang_select_ray.py +++ b/transforms/code/proglang_select/test/test_proglang_select_ray.py @@ -12,12 +12,15 @@ import os - -from data_processing.test_support.launch.transform_test import AbstractTransformLauncherTest +from data_processing.launch.ray import RayTransformLauncher +from data_processing.test_support.launch.transform_test import ( + AbstractTransformLauncherTest, +) from proglang_select_transform import ( + ProgLangSelectRayConfiguration, lang_allowed_langs_file_key, lang_lang_column_key, - lang_output_column_key, ProgLangSelectRayLauncher, + lang_output_column_key, ) @@ -44,7 +47,7 @@ def get_test_transform_fixtures(self) -> list[tuple]: } fixtures = [ ( - ProgLangSelectRayLauncher(), + RayTransformLauncher(ProgLangSelectRayConfiguration()), config, basedir + "/input", basedir + "/expected", diff --git a/transforms/universal/doc_id/src/doc_id_local_ray.py b/transforms/universal/doc_id/src/doc_id_local_ray.py index 1d301821d..7bd66457c 100644 --- a/transforms/universal/doc_id/src/doc_id_local_ray.py +++ b/transforms/universal/doc_id/src/doc_id_local_ray.py @@ -13,8 +13,10 @@ import os import sys +from data_processing.launch.ray import RayTransformLauncher from data_processing.utils import ParamsUtils -from doc_id_transform import DocIDRayLauncher +from doc_id_transform import DocIDRayTransformConfiguration + # create parameters input_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../test-data/input")) @@ -45,6 +47,6 @@ sys.argv = ParamsUtils.dict_to_req(d=params) # create launcher -launcher = DocIDRayLauncher() +launcher = RayTransformLauncher(DocIDRayTransformConfiguration()) # launch launcher.launch() diff --git a/transforms/universal/doc_id/src/doc_id_s3_ray.py b/transforms/universal/doc_id/src/doc_id_s3_ray.py index b017a8b6d..b3158fde7 100644 --- a/transforms/universal/doc_id/src/doc_id_s3_ray.py +++ b/transforms/universal/doc_id/src/doc_id_s3_ray.py @@ -12,13 +12,15 @@ import sys +from data_processing.launch.ray import RayTransformLauncher from data_processing.utils import ParamsUtils -from doc_id_transform import DocIDRayLauncher +from doc_id_transform import DocIDRayTransformConfiguration + # create launcher -launcher = DocIDRayLauncher() +launcher = RayTransformLauncher(DocIDRayTransformConfiguration()) # create parameters s3_cred = { "access_key": "localminioaccesskey", diff --git a/transforms/universal/doc_id/src/doc_id_transform.py b/transforms/universal/doc_id/src/doc_id_transform.py index ce0416b10..fb22d3aad 100644 --- a/transforms/universal/doc_id/src/doc_id_transform.py +++ b/transforms/universal/doc_id/src/doc_id_transform.py @@ -15,15 +15,17 @@ import pyarrow as pa import ray -from data_processing.transform import TransformConfiguration from data_processing.data_access import DataAccessFactoryBase -from data_processing.launch.pure_python import PythonTransformLauncher, PythonLauncherConfiguration +from data_processing.launch.pure_python import ( + PythonLauncherConfiguration, + PythonTransformLauncher, +) from data_processing.launch.ray import ( DefaultTableTransformRuntimeRay, RayTransformLauncher, ) -from data_processing.transform import AbstractTableTransform - +from data_processing.launch.ray.transform_configuration import RayTransformConfiguration +from data_processing.transform import AbstractTableTransform, TransformConfiguration from data_processing.utils import CLIArgumentProvider, TransformUtils, get_logger from ray.actor import ActorHandle @@ -196,30 +198,11 @@ def apply_input_params(self, args: Namespace) -> bool: return True -# class DocIDRayLauncherConfiguration(RayLauncherConfiguration): -# -# """ -# Provides support for configuring and using the associated Transform class include -# configuration with CLI args and combining of metadata. -# """ -# -# def __init__(self): -# super().__init__( -# transform_config = DocIDTransformConfiguration(), -# runtime_class=DocIDRuntime, -# ) -# - -# class DocIDPythonLauncherConfiguration(PythonLauncherConfiguration): -# def __init__(self): -# super().__init__( -# name=short_name, transform_class=DocIDTransform, launcher_configuration=DocIDTransformConfiguration() -# ) -# -class DocIDRayLauncher(RayTransformLauncher): +class DocIDRayTransformConfiguration(RayTransformConfiguration): def __init__(self): super().__init__(transform_config=DocIDTransformConfiguration(), runtime_class=DocIDRuntime) + if __name__ == "__main__": - launcher = DocIDRayLauncher() + launcher = RayTransformLauncher(DocIDRayTransformConfiguration()) launcher.launch() diff --git a/transforms/universal/doc_id/test/test_doc_id_ray.py b/transforms/universal/doc_id/test/test_doc_id_ray.py index 1a3e125a5..ed6cace6a 100644 --- a/transforms/universal/doc_id/test/test_doc_id_ray.py +++ b/transforms/universal/doc_id/test/test_doc_id_ray.py @@ -12,11 +12,15 @@ import os -from data_processing.test_support.launch.transform_test import AbstractTransformLauncherTest +from data_processing.launch.ray import RayTransformLauncher +from data_processing.test_support.launch.transform_test import ( + AbstractTransformLauncherTest, +) from doc_id_transform import ( + DocIDRayTransformConfiguration, doc_column_name_cli_param, hash_column_name_cli_param, - int_column_name_cli_param, DocIDRayLauncher, + int_column_name_cli_param, ) @@ -34,6 +38,6 @@ def get_test_transform_fixtures(self) -> list[tuple]: hash_column_name_cli_param: "doc_hash", int_column_name_cli_param: "doc_int", } - launcher = DocIDRayLauncher() + launcher = RayTransformLauncher(DocIDRayTransformConfiguration()) fixtures.append((launcher, transform_config, basedir + "/input", basedir + "/expected")) return fixtures diff --git a/transforms/universal/ededup/src/ededup_local_ray.py b/transforms/universal/ededup/src/ededup_local_ray.py index 95a40e8a7..a9b78b960 100644 --- a/transforms/universal/ededup/src/ededup_local_ray.py +++ b/transforms/universal/ededup/src/ededup_local_ray.py @@ -13,11 +13,13 @@ import os import sys +from data_processing.launch.ray import RayTransformLauncher from data_processing.utils import ParamsUtils -from ededup_transform import EdedupRayLauncher +from ededup_transform import EdedupRayTransformConfiguration + # create launcher -launcher = EdedupRayLauncher() +launcher = RayTransformLauncher(EdedupRayTransformConfiguration()) # create parameters input_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../test-data/input")) output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../output")) diff --git a/transforms/universal/ededup/src/ededup_s3_ray.py b/transforms/universal/ededup/src/ededup_s3_ray.py index 55519580b..a57f89c90 100644 --- a/transforms/universal/ededup/src/ededup_s3_ray.py +++ b/transforms/universal/ededup/src/ededup_s3_ray.py @@ -12,11 +12,13 @@ import sys +from data_processing.launch.ray import RayTransformLauncher from data_processing.utils import ParamsUtils -from ededup_transform import EdedupRayLauncher +from ededup_transform import EdedupRayTransformConfiguration + # create launcher -launcher = EdedupRayLauncher() +launcher = RayTransformLauncher(EdedupRayTransformConfiguration()) # create parameters s3_cred = { "access_key": "localminioaccesskey", diff --git a/transforms/universal/ededup/src/ededup_transform.py b/transforms/universal/ededup/src/ededup_transform.py index 5faebc81e..eded8151f 100644 --- a/transforms/universal/ededup/src/ededup_transform.py +++ b/transforms/universal/ededup/src/ededup_transform.py @@ -16,13 +16,13 @@ import pyarrow as pa import ray from data_processing.data_access import DataAccessFactoryBase -from data_processing.transform import TransformConfiguration from data_processing.launch.ray import ( DefaultTableTransformRuntimeRay, - RayUtils, RayTransformLauncher, + RayUtils, ) -from data_processing.transform import AbstractTableTransform +from data_processing.launch.ray.transform_configuration import RayTransformConfiguration +from data_processing.transform import AbstractTableTransform, TransformConfiguration from data_processing.utils import GB, CLIArgumentProvider, TransformUtils, get_logger from ray.actor import ActorHandle @@ -223,6 +223,7 @@ class EdedupTableTransformConfiguration(TransformConfiguration): Provides support for configuring and using the associated Transform class include configuration with CLI args and combining of metadata. """ + def __init__(self): super().__init__( name=short_name, @@ -251,10 +252,12 @@ def apply_input_params(self, args: Namespace) -> bool: logger.info(f"exact dedup params are {self.params}") return True -class EdedupRayLauncher(RayTransformLauncher): + +class EdedupRayTransformConfiguration(RayTransformConfiguration): def __init__(self): super().__init__(transform_config=EdedupTableTransformConfiguration(), runtime_class=EdedupRuntime) + if __name__ == "__main__": - launcher = EdedupRayLauncher() + launcher = RayTransformLauncher(EdedupRayTransformConfiguration()) launcher.launch() diff --git a/transforms/universal/ededup/test/test_ededup_ray.py b/transforms/universal/ededup/test/test_ededup_ray.py index 1958a55fc..cf4f2f46c 100644 --- a/transforms/universal/ededup/test/test_ededup_ray.py +++ b/transforms/universal/ededup/test/test_ededup_ray.py @@ -12,8 +12,11 @@ import os -from data_processing.test_support.launch.transform_test import AbstractTransformLauncherTest -from ededup_transform import EdedupRayLauncher +from data_processing.launch.ray import RayTransformLauncher +from data_processing.test_support.launch.transform_test import ( + AbstractTransformLauncherTest, +) +from ededup_transform import EdedupRayTransformConfiguration class TestRayBlocklistTransform(AbstractTransformLauncherTest): @@ -31,6 +34,6 @@ def get_test_transform_fixtures(self) -> list[tuple]: "ededup_num_hashes": 2, "ededup_doc_column": "contents", } - launcher = EdedupRayLauncher() + launcher = RayTransformLauncher(EdedupRayTransformConfiguration()) fixtures = [(launcher, config, basedir + "/input", basedir + "/expected")] return fixtures diff --git a/transforms/universal/fdedup/src/fdedup_local_ray.py b/transforms/universal/fdedup/src/fdedup_local_ray.py index 5df5d6981..d40d1a8c1 100644 --- a/transforms/universal/fdedup/src/fdedup_local_ray.py +++ b/transforms/universal/fdedup/src/fdedup_local_ray.py @@ -13,11 +13,13 @@ import os import sys +from data_processing.launch.ray import RayTransformLauncher from data_processing.utils import ParamsUtils -from fdedup_transform import FdedupRayLauncher +from fdedup_transform import FdedupRayTransformConfiguration + # create launcher -launcher = FdedupRayLauncher() +launcher = RayTransformLauncher(FdedupRayTransformConfiguration()) # create parameters input_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../test-data/input")) output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../output")) diff --git a/transforms/universal/fdedup/src/fdedup_s3_ray.py b/transforms/universal/fdedup/src/fdedup_s3_ray.py index 7c4a25b63..01bd8f0bf 100644 --- a/transforms/universal/fdedup/src/fdedup_s3_ray.py +++ b/transforms/universal/fdedup/src/fdedup_s3_ray.py @@ -12,11 +12,13 @@ import sys +from data_processing.launch.ray import RayTransformLauncher from data_processing.utils import ParamsUtils -from fdedup_transform import FdedupRayLauncher +from fdedup_transform import FdedupRayTransformConfiguration + # create launcher -launcher = FdedupRayLauncher() +launcher = RayTransformLauncher(FdedupRayTransformConfiguration()) # create parameters s3_cred = { "access_key": "localminioaccesskey", diff --git a/transforms/universal/fdedup/src/fdedup_transform.py b/transforms/universal/fdedup/src/fdedup_transform.py index 4b219c7ff..c0297ed36 100644 --- a/transforms/universal/fdedup/src/fdedup_transform.py +++ b/transforms/universal/fdedup/src/fdedup_transform.py @@ -20,14 +20,14 @@ import pyarrow as pa import ray from data_processing.data_access import DataAccessFactoryBase -from data_processing.transform import TransformConfiguration from data_processing.launch.ray import ( DefaultTableTransformRuntimeRay, - RayUtils, RayTransformLauncher, + RayUtils, TransformTableProcessorRay, ) -from data_processing.transform import AbstractTableTransform +from data_processing.launch.ray.transform_configuration import RayTransformConfiguration +from data_processing.transform import AbstractTableTransform, TransformConfiguration from data_processing.utils import ( RANDOM_SEED, CLIArgumentProvider, @@ -794,10 +794,12 @@ def apply_input_params(self, args: Namespace) -> bool: logger.info(f"fuzzy dedup params are {self.params}") return True -class FdedupRayLauncher(RayTransformLauncher): + +class FdedupRayTransformConfiguration(RayTransformConfiguration): def __init__(self): super().__init__(transform_config=FdedupTableTransformConfiguration(), runtime_class=FdedupRuntime) + if __name__ == "__main__": - launcher = FdedupRayLauncher() + launcher = RayTransformLauncher(FdedupRayTransformConfiguration()) launcher.launch() diff --git a/transforms/universal/fdedup/test/test_fdedup_ray.py b/transforms/universal/fdedup/test/test_fdedup_ray.py index 6c5574df7..3d910cc8b 100644 --- a/transforms/universal/fdedup/test/test_fdedup_ray.py +++ b/transforms/universal/fdedup/test/test_fdedup_ray.py @@ -12,9 +12,11 @@ import os - -from data_processing.test_support.launch.transform_test import AbstractTransformLauncherTest -from fdedup_transform import FdedupRayLauncher +from data_processing.launch.ray import RayTransformLauncher +from data_processing.test_support.launch.transform_test import ( + AbstractTransformLauncherTest, +) +from fdedup_transform import FdedupRayTransformConfiguration class TestRayBlocklistTransform(AbstractTransformLauncherTest): @@ -52,6 +54,6 @@ def get_test_transform_fixtures(self) -> list[tuple]: "fdedup_use_doc_snapshot": False, "fdedup_use_bucket_snapshot": False, } - launcher = FdedupRayLauncher() + launcher = RayTransformLauncher(FdedupRayTransformConfiguration()) fixtures = [(launcher, config, basedir + "/input", basedir + "/expected")] return fixtures diff --git a/transforms/universal/filter/src/filter_local_ray.py b/transforms/universal/filter/src/filter_local_ray.py index f775547ec..1c69801fe 100644 --- a/transforms/universal/filter/src/filter_local_ray.py +++ b/transforms/universal/filter/src/filter_local_ray.py @@ -13,11 +13,13 @@ import os import sys +from data_processing.launch.ray import RayTransformLauncher from data_processing.utils import ParamsUtils from filter_transform import ( + FilterRayTransformConfiguration, filter_columns_to_drop_cli_param, filter_criteria_cli_param, - filter_logical_operator_cli_param, FilterRayLauncher, + filter_logical_operator_cli_param, ) @@ -64,6 +66,6 @@ # Create the CLI args as will be parsed by the launcher sys.argv = ParamsUtils.dict_to_req(launcher_params | filter_params) # Create the longer to launch with the blocklist transform. - launcher = FilterRayLauncher() + launcher = RayTransformLauncher(FilterRayTransformConfiguration()) # Launch the ray actor(s) to process the input launcher.launch() diff --git a/transforms/universal/filter/src/filter_s3_ray.py b/transforms/universal/filter/src/filter_s3_ray.py index d6afd7abf..29bd72b6a 100644 --- a/transforms/universal/filter/src/filter_s3_ray.py +++ b/transforms/universal/filter/src/filter_s3_ray.py @@ -12,11 +12,13 @@ import sys +from data_processing.launch.ray import RayTransformLauncher from data_processing.utils import ParamsUtils from filter_transform import ( + FilterRayTransformConfiguration, filter_columns_to_drop_cli_param, filter_criteria_cli_param, - filter_logical_operator_cli_param, FilterRayLauncher, + filter_logical_operator_cli_param, ) @@ -67,6 +69,6 @@ # Create the CLI args as will be parsed by the launcher sys.argv = ParamsUtils.dict_to_req(launcher_params | filter_params) # Create the longer to launch with the blocklist transform. - launcher = FilterRayLauncher() + launcher = RayTransformLauncher(FilterRayTransformConfiguration()) # Launch the ray actor(s) to process the input launcher.launch() diff --git a/transforms/universal/filter/src/filter_transform.py b/transforms/universal/filter/src/filter_transform.py index a377979ed..efdd8ecc4 100644 --- a/transforms/universal/filter/src/filter_transform.py +++ b/transforms/universal/filter/src/filter_transform.py @@ -16,10 +16,13 @@ import duckdb import pyarrow as pa -from data_processing.transform import TransformConfiguration -from data_processing.launch.pure_python import PythonTransformLauncher, PythonLauncherConfiguration +from data_processing.launch.pure_python import ( + PythonLauncherConfiguration, + PythonTransformLauncher, +) from data_processing.launch.ray import RayTransformLauncher -from data_processing.transform import AbstractTableTransform +from data_processing.launch.ray.transform_configuration import RayTransformConfiguration +from data_processing.transform import AbstractTableTransform, TransformConfiguration from data_processing.utils import CLIArgumentProvider, get_logger @@ -198,23 +201,12 @@ def apply_input_params(self, args: argparse.Namespace) -> bool: return True - -# class FilterPythonLauncherConfiguration(PythonLauncherConfiguration): -# """ -# Provides support for configuring and using the associated Transform class include -# configuration with CLI args and combining of metadata. -# """ -# -# def __init__(self): -# super().__init__( -# name=short_name, transform_class=FilterTransform, launcher_configuration=FilterTransformConfiguration() -# ) -# -class FilterRayLauncher(RayTransformLauncher): +class FilterRayTransformConfiguration(RayTransformConfiguration): def __init__(self): super().__init__(transform_config=FilterTransformConfiguration()) + if __name__ == "__main__": - launcher = FilterRayLauncher() + launcher = RayTransformLauncher(FilterRayTransformConfiguration()) logger.info("Launching filtering") launcher.launch() diff --git a/transforms/universal/filter/test/test_filter_ray.py b/transforms/universal/filter/test/test_filter_ray.py index f4fa29613..a176d1193 100644 --- a/transforms/universal/filter/test/test_filter_ray.py +++ b/transforms/universal/filter/test/test_filter_ray.py @@ -12,13 +12,16 @@ import os - -from data_processing.test_support.launch.transform_test import AbstractTransformLauncherTest +from data_processing.launch.ray import RayTransformLauncher +from data_processing.test_support.launch.transform_test import ( + AbstractTransformLauncherTest, +) from filter_transform import ( + FilterRayTransformConfiguration, filter_columns_to_drop_cli_param, filter_criteria_cli_param, filter_logical_operator_cli_param, - filter_logical_operator_default, FilterRayLauncher, + filter_logical_operator_default, ) @@ -34,7 +37,7 @@ def get_test_transform_fixtures(self) -> list[tuple]: fixtures.append( ( - FilterRayLauncher(), + RayTransformLauncher(FilterRayTransformConfiguration()), { filter_criteria_cli_param: [ "docq_total_words > 100 AND docq_total_words < 200", @@ -50,7 +53,7 @@ def get_test_transform_fixtures(self) -> list[tuple]: fixtures.append( ( - FilterRayLauncher(), + RayTransformLauncher(FilterRayTransformConfiguration()), { filter_criteria_cli_param: [ "docq_total_words > 100 AND docq_total_words < 200", @@ -66,7 +69,7 @@ def get_test_transform_fixtures(self) -> list[tuple]: fixtures.append( ( - FilterRayLauncher(), + RayTransformLauncher(FilterRayTransformConfiguration()), { filter_criteria_cli_param: [], filter_logical_operator_cli_param: filter_logical_operator_default, @@ -79,7 +82,7 @@ def get_test_transform_fixtures(self) -> list[tuple]: fixtures.append( ( - FilterRayLauncher(), + RayTransformLauncher(FilterRayTransformConfiguration()), { filter_criteria_cli_param: [ "date_acquired BETWEEN '2023-07-04' AND '2023-07-08'", @@ -95,7 +98,7 @@ def get_test_transform_fixtures(self) -> list[tuple]: fixtures.append( ( - FilterRayLauncher(), + RayTransformLauncher(FilterRayTransformConfiguration()), { filter_criteria_cli_param: [ "document IN ('CC-MAIN-20190221132217-20190221154217-00305.warc.gz', 'CC-MAIN-20200528232803-20200529022803-00154.warc.gz', 'CC-MAIN-20190617103006-20190617125006-00025.warc.gz')", diff --git a/transforms/universal/noop/src/noop_local_ray.py b/transforms/universal/noop/src/noop_local_ray.py index 10e69fa9a..bf2e024b7 100644 --- a/transforms/universal/noop/src/noop_local_ray.py +++ b/transforms/universal/noop/src/noop_local_ray.py @@ -13,8 +13,10 @@ import os import sys +from data_processing.launch.ray import RayTransformLauncher from data_processing.utils import ParamsUtils -from noop_transform import NOOPRayLauncher +from noop_transform import NOOPRayTransformConfiguration + # create parameters input_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "test-data", "input")) @@ -44,6 +46,6 @@ # Set the simulated command line args sys.argv = ParamsUtils.dict_to_req(d=params) # create launcher - launcher = NOOPRayLauncher() + launcher = RayTransformLauncher(NOOPRayTransformConfiguration()) # Launch the ray actor(s) to process the input launcher.launch() diff --git a/transforms/universal/noop/src/noop_s3_ray.py b/transforms/universal/noop/src/noop_s3_ray.py index c797b326e..31196f25a 100644 --- a/transforms/universal/noop/src/noop_s3_ray.py +++ b/transforms/universal/noop/src/noop_s3_ray.py @@ -13,12 +13,14 @@ import os import sys +from data_processing.launch.ray import RayTransformLauncher from data_processing.utils import ParamsUtils -from noop_transform import NOOPRayLauncher +from noop_transform import NOOPRayTransformConfiguration + print(os.environ) # create launcher -launcher = NOOPRayLauncher() +launcher = RayTransformLauncher(NOOPRayTransformConfiguration()) # create parameters s3_cred = { "access_key": "localminioaccesskey", diff --git a/transforms/universal/noop/src/noop_transform.py b/transforms/universal/noop/src/noop_transform.py index 6ce6b80a3..720d02cb1 100644 --- a/transforms/universal/noop/src/noop_transform.py +++ b/transforms/universal/noop/src/noop_transform.py @@ -15,12 +15,13 @@ from typing import Any import pyarrow as pa - -from data_processing.transform import TransformConfiguration -from data_processing.launch.pure_python import PythonTransformLauncher, PythonLauncherConfiguration +from data_processing.launch.pure_python import ( + PythonLauncherConfiguration, + PythonTransformLauncher, +) from data_processing.launch.ray import RayTransformLauncher -from data_processing.transform import AbstractTableTransform - +from data_processing.launch.ray.transform_configuration import RayTransformConfiguration +from data_processing.transform import AbstractTableTransform, TransformConfiguration from data_processing.utils import CLIArgumentProvider, get_logger @@ -120,11 +121,21 @@ def apply_input_params(self, args: Namespace) -> bool: self.params = self.params | captured logger.info(f"noop parameters are : {self.params}") return True -class NOOPRayLauncher(RayTransformLauncher): + + +class NOOPRayTransformConfiguration(RayTransformConfiguration): + """ + Implements the RayTransformConfiguration for NOOP as required by the RayTransformLauncher. + NOOP does not use a RayRuntime class so the superclass only needs the base + python-only configuration. + """ + def __init__(self): - super().__init__(transform_config=NOOPTransformConfiguration()) + super().__init__(NOOPTransformConfiguration()) + if __name__ == "__main__": - launcher = NOOPRayLauncher() + # launcher = NOOPRayLauncher() + launcher = RayTransformLauncher(NOOPRayTransformConfiguration()) logger.info("Launching noop transform") launcher.launch() diff --git a/transforms/universal/noop/test/test_noop_ray.py b/transforms/universal/noop/test/test_noop_ray.py index 85d17dd1b..ee490bd14 100644 --- a/transforms/universal/noop/test/test_noop_ray.py +++ b/transforms/universal/noop/test/test_noop_ray.py @@ -13,8 +13,15 @@ import os from data_processing.launch.pure_python import PythonTransformLauncher -from data_processing.test_support.launch.transform_test import AbstractTransformLauncherTest -from noop_transform import sleep_cli_param, NOOPRayLauncher, NOOPTransformConfiguration +from data_processing.launch.ray import RayTransformLauncher +from data_processing.test_support.launch.transform_test import ( + AbstractTransformLauncherTest, +) +from noop_transform import ( + NOOPRayTransformConfiguration, + NOOPTransformConfiguration, + sleep_cli_param, +) class TestRayNOOPTransform(AbstractTransformLauncherTest): @@ -29,6 +36,7 @@ def get_test_transform_fixtures(self) -> list[tuple]: fixtures = [] launcher = PythonTransformLauncher(NOOPTransformConfiguration()) fixtures.append((launcher, {sleep_cli_param: 0}, basedir + "/input", basedir + "/expected")) - launcher = NOOPRayLauncher() + # launcher = NOOPRayLauncher() + launcher = RayTransformLauncher(NOOPRayTransformConfiguration()) fixtures.append((launcher, {sleep_cli_param: 0}, basedir + "/input", basedir + "/expected")) return fixtures diff --git a/transforms/universal/tokenization/src/tokenization_local_ray.py b/transforms/universal/tokenization/src/tokenization_local_ray.py index efe10503b..7343a4567 100644 --- a/transforms/universal/tokenization/src/tokenization_local_ray.py +++ b/transforms/universal/tokenization/src/tokenization_local_ray.py @@ -13,8 +13,10 @@ import os import sys +from data_processing.launch.ray import RayTransformLauncher from data_processing.utils import ParamsUtils -from tokenization_transform import TokenizationRayLauncher +from tokenization_transform import TokenizationRayConfiguration + # create parameters input_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "test-data", "ds01", "input")) @@ -42,6 +44,6 @@ sys.argv = ParamsUtils.dict_to_req(d=params) # create launcher - launcher = TokenizationRayLauncher() + launcher = RayTransformLauncher(TokenizationRayConfiguration()) # Launch the ray actor(s) to process the input launcher.launch() diff --git a/transforms/universal/tokenization/src/tokenization_local_ray_long_doc.py b/transforms/universal/tokenization/src/tokenization_local_ray_long_doc.py index 317666ff5..7456cbc36 100644 --- a/transforms/universal/tokenization/src/tokenization_local_ray_long_doc.py +++ b/transforms/universal/tokenization/src/tokenization_local_ray_long_doc.py @@ -13,8 +13,10 @@ import os import sys +from data_processing.launch.ray import RayTransformLauncher from data_processing.utils import ParamsUtils -from tokenization_transform import TokenizationRayLauncher +from tokenization_transform import TokenizationRayConfiguration + # create parameters input_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "test-data", "ds02", "input")) @@ -50,6 +52,6 @@ sys.argv = ParamsUtils.dict_to_req(d=params) # create launcher - launcher = TokenizationRayLauncher() + launcher = RayTransformLauncher(TokenizationRayConfiguration()) # Launch the ray actor(s) to process the input launcher.launch() diff --git a/transforms/universal/tokenization/src/tokenization_s3_long_doc.py b/transforms/universal/tokenization/src/tokenization_s3_long_doc.py index 4b8a8014f..81f851812 100644 --- a/transforms/universal/tokenization/src/tokenization_s3_long_doc.py +++ b/transforms/universal/tokenization/src/tokenization_s3_long_doc.py @@ -12,8 +12,10 @@ import sys +from data_processing.launch.ray import RayTransformLauncher from data_processing.utils import ParamsUtils -from tokenization_transform import TokenizationRayLauncher +from tokenization_transform import TokenizationRayConfiguration + # create parameters s3_cred = { @@ -53,6 +55,6 @@ sys.argv = ParamsUtils.dict_to_req(d=params) # create launcher - launcher = TokenizationRayLauncher() + launcher = RayTransformLauncher(TokenizationRayConfiguration()) # Launch the ray actor(s) to process the input launcher.launch() diff --git a/transforms/universal/tokenization/src/tokenization_s3_ray.py b/transforms/universal/tokenization/src/tokenization_s3_ray.py index f773175db..ee7b9059b 100644 --- a/transforms/universal/tokenization/src/tokenization_s3_ray.py +++ b/transforms/universal/tokenization/src/tokenization_s3_ray.py @@ -13,12 +13,14 @@ import os import sys +from data_processing.launch.ray import RayTransformLauncher from data_processing.utils import ParamsUtils -from tokenization_transform import TokenizationRayLauncher +from tokenization_transform import TokenizationRayConfiguration + print(os.environ) # create launcher -launcher = TokenizationRayLauncher() +launcher = RayTransformLauncher(TokenizationRayConfiguration()) # create parameters s3_cred = { "access_key": "localminioaccesskey", diff --git a/transforms/universal/tokenization/src/tokenization_transform.py b/transforms/universal/tokenization/src/tokenization_transform.py index 2cfb79100..e6ceb6654 100644 --- a/transforms/universal/tokenization/src/tokenization_transform.py +++ b/transforms/universal/tokenization/src/tokenization_transform.py @@ -20,9 +20,9 @@ from typing import Any import pyarrow as pa -from data_processing.transform import TransformConfiguration from data_processing.launch.ray import RayTransformLauncher -from data_processing.transform import AbstractTableTransform +from data_processing.launch.ray.transform_configuration import RayTransformConfiguration +from data_processing.transform import AbstractTableTransform, TransformConfiguration from data_processing.utils import get_logger from tokenization_utils import is_valid_argument_string, load_tokenizer, split_text @@ -285,11 +285,12 @@ def apply_input_params(self, args: Namespace) -> bool: # launcher_configuration=TokenizationTransformConfiguration(), # ) # -class TokenizationRayLauncher(RayTransformLauncher): +class TokenizationRayConfiguration(RayTransformConfiguration): def __init__(self): super().__init__(transform_config=TokenizationTransformConfiguration()) + if __name__ == "__main__": - launcher = TokenizationRayLauncher() + launcher = RayTransformLauncher(TokenizationRayConfiguration()) logger.info("Launching Tokenization transform") launcher.launch() diff --git a/transforms/universal/tokenization/test/test_tokenization_launch_long_doc.py b/transforms/universal/tokenization/test/test_tokenization_launch_long_doc.py index cac018e4e..2cf61e895 100644 --- a/transforms/universal/tokenization/test/test_tokenization_launch_long_doc.py +++ b/transforms/universal/tokenization/test/test_tokenization_launch_long_doc.py @@ -12,9 +12,12 @@ import os +from data_processing.launch.ray import RayTransformLauncher +from data_processing.test_support.launch.transform_test import ( + AbstractTransformLauncherTest, +) +from tokenization_transform import TokenizationRayConfiguration -from data_processing.test_support.launch.transform_test import AbstractTransformLauncherTest -from tokenization_transform import TokenizationRayLauncher tkn_params = { "tkn_tokenizer": "hf-internal-testing/llama-tokenizer", @@ -36,8 +39,6 @@ class TestRayTokenizationTransform(AbstractTransformLauncherTest): def get_test_transform_fixtures(self) -> list[tuple]: basedir = "../test-data" basedir = os.path.abspath(os.path.join(os.path.dirname(__file__), basedir)) - launcher =TokenizationRayLauncher() - fixtures = [ - (launcher, tkn_params, basedir + "/ds02/input", basedir + "/ds02/expected") - ] + launcher = RayTransformLauncher(TokenizationRayConfiguration()) + fixtures = [(launcher, tkn_params, basedir + "/ds02/input", basedir + "/ds02/expected")] return fixtures diff --git a/transforms/universal/tokenization/test/test_tokenization_ray.py b/transforms/universal/tokenization/test/test_tokenization_ray.py index 2363e706b..03383a836 100644 --- a/transforms/universal/tokenization/test/test_tokenization_ray.py +++ b/transforms/universal/tokenization/test/test_tokenization_ray.py @@ -12,9 +12,12 @@ import os +from data_processing.launch.ray import RayTransformLauncher +from data_processing.test_support.launch.transform_test import ( + AbstractTransformLauncherTest, +) +from tokenization_transform import TokenizationRayConfiguration -from data_processing.test_support.launch.transform_test import AbstractTransformLauncherTest -from tokenization_transform import TokenizationRayLauncher tkn_params = { "tkn_tokenizer": "hf-internal-testing/llama-tokenizer", @@ -34,8 +37,6 @@ class TestRayTokenizationTransform(AbstractTransformLauncherTest): def get_test_transform_fixtures(self) -> list[tuple]: basedir = "../test-data" basedir = os.path.abspath(os.path.join(os.path.dirname(__file__), basedir)) - launcher = TokenizationRayLauncher() - fixtures = [ - (launcher, tkn_params, basedir + "/ds01/input", basedir + "/ds01/expected") - ] + launcher = RayTransformLauncher(TokenizationRayConfiguration()) + fixtures = [(launcher, tkn_params, basedir + "/ds01/input", basedir + "/ds01/expected")] return fixtures