In [None]:
import os
import shutil
import time

os.environ["LOGURU_LEVEL"] = "ERROR"

## Modify Resources(gpus) in EmbeddingGeneration to speedup

Please note this is dependent on model used and GPU sku. 

Assuming < 50% is being utilized we can expect speedup 

In [2]:
model_name = "sentence-transformers/all-MiniLM-L6-v2"

In [3]:
from nemo_curator.stages.resources import Resources
from nemo_curator.stages.text.embedders import EmbeddingCreatorStage
from nemo_curator.stages.text.models.utils import format_name_with_suffix

embedding_creator = EmbeddingCreatorStage(model_identifier=model_name, model_inference_batch_size=256)

embedding_creator_half_gpu = EmbeddingCreatorStage(model_identifier=model_name, model_inference_batch_size=256).with_(
    {format_name_with_suffix(model_name, "_model"): {"resources": Resources(gpus=0.5)}}
)

In [None]:
import uuid

from nemo_curator.backends.experimental.ray_data import RayDataExecutor
from nemo_curator.core.client import RayClient
from nemo_curator.pipeline import Pipeline
from nemo_curator.stages.text.io.reader.parquet import ParquetReader
from nemo_curator.stages.text.io.writer.parquet import ParquetWriter

input_path = "./tinystories_train_parquet"
output_path = "./output_path/"
ray_temp_dir = "/tmp"  # noqa: S108

shutil.rmtree(output_path, ignore_errors=True)

time_taken = {}
for _embedding_creator, _name in [(embedding_creator, "full_gpu"), (embedding_creator_half_gpu, "half_gpu")]:
    # Ensure RAY_ADDRESS is cleared before starting a new cluster
    os.environ["RAY_ADDRESS"] = ""
    ray_temp_dir_path = os.path.join(ray_temp_dir, str(uuid.uuid4())[:6])
    pipeline = Pipeline(
        name=_name,
        stages=[
            ParquetReader(
                file_paths=input_path,
                fields=["text"],
            ),
            _embedding_creator,
            ParquetWriter(os.path.join(output_path, f"{_name}_tinystories")),
        ],
    )
    with RayClient(
        num_cpus=64,
        num_gpus=4,
        ray_temp_dir=ray_temp_dir_path,
        ray_dashboard_host="0.0.0.0",  # noqa: S104
    ) as client:
        t0 = time.time()
        pipeline.run(RayDataExecutor())
        t1 = time.time()
        time_taken[_name] = t1 - t0
        print(f"Time taken for {_name}: {t1 - t0} seconds")
    # Small delay to ensure Ray cluster is fully cleaned up before next iteration
    time.sleep(2)

2025-11-20 23:44:39,539	INFO worker.py:1691 -- Using address 127.0.1.1:6380 set in the environment variable RAY_ADDRESS
2025-11-20 23:44:39,548	INFO worker.py:1832 -- Connecting to existing Ray cluster at address: 127.0.1.1:6380...


2025-11-20 23:44:40,950	INFO usage_lib.py:447 -- Usage stats collection is disabled.
2025-11-20 23:44:40,950	INFO scripts.py:914 -- [37mLocal node IP[39m: [1m127.0.1.1[22m


[2025-11-20 23:44:45,799 W 216844 216844] global_state_accessor.cc:505: Some processes that the driver needs to connect to have not registered with GCS, so retrying. Have you run 'ray start' on this node?
[2025-11-20 23:44:46,801 W 216844 216844] global_state_accessor.cc:505: Some processes that the driver needs to connect to have not registered with GCS, so retrying. Have you run 'ray start' on this node?


2025-11-20 23:44:47,452	SUCC scripts.py:950 -- [32m--------------------[39m
2025-11-20 23:44:47,452	SUCC scripts.py:951 -- [32mRay runtime started.[39m
2025-11-20 23:44:47,452	SUCC scripts.py:952 -- [32m--------------------[39m
2025-11-20 23:44:47,452	INFO scripts.py:954 -- [36mNext steps[39m
2025-11-20 23:44:47,452	INFO scripts.py:957 -- To add another node to this Ray cluster, run
2025-11-20 23:44:47,452	INFO scripts.py:960 -- [1m  ray start --address='127.0.1.1:6380'[22m
2025-11-20 23:44:47,452	INFO scripts.py:969 -- To connect to this Ray cluster:
2025-11-20 23:44:47,452	INFO scripts.py:971 -- [35mimport[39m[26m ray
2025-11-20 23:44:47,453	INFO scripts.py:972 -- ray[35m.[39m[26minit(_node_ip_address[35m=[39m[26m[33m'127.0.1.1'[39m[26m)
2025-11-20 23:44:47,453	INFO scripts.py:984 -- To submit a Ray job using the Ray Jobs CLI:
2025-11-20 23:44:47,453	INFO scripts.py:985 -- [1m  RAY_API_SERVER_ADDRESS='http://127.0.1.1:8266' ray job submit --working-dir . -- pyt

[2025-11-20 23:44:47,802 I 216844 216844] global_state_accessor.cc:487: This node has an IP address of 127.0.1.1, but we cannot find a local Raylet with the same address. This can happen when you connect to the Ray cluster with a different IP address or when connecting to a container.
2025-11-20 23:44:47,805	INFO worker.py:2012 -- Connected to Ray cluster.
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 9398.65it/s]
2025-11-20 23:45:00,094	INFO streaming_executor.py:85 -- A new progress UI is available. To enable, set `ray.data.DataContext.get_current().enable_rich_progress_bars = True`.
2025-11-20 23:45:00,095	INFO logging.py:397 -- Registered dataset logger for dataset dataset_6_0
2025-11-20 23:45:00,105	INFO streaming_executor.py:170 -- Starting execution of Dataset dataset_6_0. Full logs are in /raid/praateekm/ray_temp/b49a36/session_2025-11-20_23-44-40_951218_217438/logs/ray-data
2025-11-20 23:45:00,106	INFO streaming_executor.py:171 -- Execution plan of Dataset dataset_6_

Running 0: 0.00 row [00:00, ? row/s]

- MapBatches(FilePartitioningStageTask) 1: 0.00 row [00:00, ? row/s]

- StreamingRepartition 2: 0.00 row [00:00, ? row/s]

- MapBatches(ParquetReaderStageTask)->MapBatches(TokenizerStageActor) 3: 0.00 row [00:00, ? row/s]

- MapBatches(EmbeddingModelStageActor) 4: 0.00 row [00:00, ? row/s]

- MapBatches(ParquetWriterTask) 5: 0.00 row [00:00, ? row/s]

Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 9583.33it/s][32m [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 10450.06it/s][32m [repeated 2x across cluster][0m
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 11375.93it/s][32m [repeated 2x across cluster][0m
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 11546.07it/s] 
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 12031.85it/s]zerStageActor)) pid=230200)[0m 
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 138731.11it/s]
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 12957.38it/s]zerStageActor)) pid=230988)[0m 
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 14770.41it/s]zerStageActor)) pid=231927)[0m 
Fetching 30 file

Time taken for full_gpu: 213.12328267097473 seconds


2025-11-20 23:48:14,802	INFO worker.py:1691 -- Using address 127.0.1.1:6380 set in the environment variable RAY_ADDRESS
2025-11-20 23:48:14,810	INFO worker.py:1832 -- Connecting to existing Ray cluster at address: 127.0.1.1:6380...


2025-11-20 23:48:16,283	INFO usage_lib.py:447 -- Usage stats collection is disabled.
2025-11-20 23:48:16,284	INFO scripts.py:914 -- [37mLocal node IP[39m: [1m127.0.1.1[22m
2025-11-20 23:48:22,744	SUCC scripts.py:950 -- [32m--------------------[39m
2025-11-20 23:48:22,744	SUCC scripts.py:951 -- [32mRay runtime started.[39m
2025-11-20 23:48:22,744	SUCC scripts.py:952 -- [32m--------------------[39m
2025-11-20 23:48:22,744	INFO scripts.py:954 -- [36mNext steps[39m
2025-11-20 23:48:22,744	INFO scripts.py:957 -- To add another node to this Ray cluster, run
2025-11-20 23:48:22,744	INFO scripts.py:960 -- [1m  ray start --address='127.0.1.1:6380'[22m
2025-11-20 23:48:22,744	INFO scripts.py:969 -- To connect to this Ray cluster:
2025-11-20 23:48:22,745	INFO scripts.py:971 -- [35mimport[39m[26m ray
2025-11-20 23:48:22,745	INFO scripts.py:972 -- ray[35m.[39m[26minit(_node_ip_address[35m=[39m[26m[33m'127.0.1.1'[39m[26m)
2025-11-20 23:48:22,745	INFO scripts.py:984 -- To su

2025-11-20 23:48:23,124	INFO worker.py:2012 -- Connected to Ray cluster.
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 13438.97it/s]
2025-11-20 23:48:34,866	INFO logging.py:397 -- Registered dataset logger for dataset dataset_6_0
2025-11-20 23:48:34,869	INFO streaming_executor.py:170 -- Starting execution of Dataset dataset_6_0. Full logs are in /raid/praateekm/ray_temp/a0b8be/session_2025-11-20_23-48-16_284531_239854/logs/ray-data
2025-11-20 23:48:34,869	INFO streaming_executor.py:171 -- Execution plan of Dataset dataset_6_0: InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(FilePartitioningStageTask)] -> TaskPoolMapOperator[StreamingRepartition] -> ActorPoolMapOperator[MapBatches(ParquetReaderStageTask)->MapBatches(TokenizerStageActor)] -> ActorPoolMapOperator[MapBatches(EmbeddingModelStageActor)] -> TaskPoolMapOperator[MapBatches(ParquetWriterTask)]


Running 0: 0.00 row [00:00, ? row/s]

- MapBatches(FilePartitioningStageTask) 1: 0.00 row [00:00, ? row/s]

- StreamingRepartition 2: 0.00 row [00:00, ? row/s]

- MapBatches(ParquetReaderStageTask)->MapBatches(TokenizerStageActor) 3: 0.00 row [00:00, ? row/s]

- MapBatches(EmbeddingModelStageActor) 4: 0.00 row [00:00, ? row/s]

- MapBatches(ParquetWriterTask) 5: 0.00 row [00:00, ? row/s]

Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 10266.74it/s][32m [repeated 2x across cluster][0m
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 10565.00it/s][32m [repeated 2x across cluster][0m
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 26915.32it/s][32m [repeated 2x across cluster][0m
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 8056.16it/s][32m [repeated 2x across cluster][0m
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 12475.62it/s] 
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 12846.26it/s]zerStageActor)) pid=252947)[0m 
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 121691.61it/s]
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 13665.20it/s]zerStageActor)) pid=253845)[0m 
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 138425.87it/s]
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 12222.35it/s]zerStageActor)) pid=254674)[0m 
Fetching 30 files: 100%|██████████| 30/30 [00:00<0

Time taken for half_gpu: 153.33278346061707 seconds


In [5]:
for run_name, run_time in time_taken.items():
    print(f"Time taken for {run_name}: {run_time:.2f} seconds")

Time taken for full_gpu: 213.12 seconds
Time taken for half_gpu: 153.33 seconds
