Skip to content

Commit

Permalink
[REVIEW] Update shuffle options and reorder partitions (#201)
Browse files Browse the repository at this point in the history
* use new shuffle options

* add check for full shuffle

* fix criteo-example

* code review

* use Bens enum suggestion

* isort

* isort
  • Loading branch information
rjzamora committed Aug 7, 2020
1 parent 69bf27a commit 2826d0a
Show file tree
Hide file tree
Showing 11 changed files with 105 additions and 50 deletions.
4 changes: 2 additions & 2 deletions examples/criteo-example.ipynb
Expand Up @@ -204,7 +204,7 @@
"outputs": [],
"source": [
"%%time\n",
"proc.apply(train_dataset, shuffle=True, output_path=output_train_dir, out_files_per_proc=5)"
"proc.apply(train_dataset, shuffle=nvt.io.Shuffle.PER_PARTITION, output_path=output_train_dir, out_files_per_proc=5)"
]
},
{
Expand All @@ -214,7 +214,7 @@
"outputs": [],
"source": [
"%%time\n",
"proc.apply(valid_dataset, record_stats=False, shuffle=True, output_path=output_valid_dir, out_files_per_proc=5)"
"proc.apply(valid_dataset, record_stats=False, shuffle=nvt.io.Shuffle.PER_PARTITION, output_path=output_valid_dir, out_files_per_proc=5)"
]
},
{
Expand Down
14 changes: 9 additions & 5 deletions examples/dask-nvtabular-criteo-benchmark.py
Expand Up @@ -8,9 +8,9 @@
from dask.distributed import Client, performance_report
from dask_cuda import LocalCUDACluster

import nvtabular.ops as ops
from nvtabular import Dataset, Workflow
from nvtabular.io import device_mem_size
from nvtabular import io as nvt_io
from nvtabular import ops as ops


def setup_rmm_pool(client, pool_size):
Expand Down Expand Up @@ -79,7 +79,7 @@ def main(args):
cat_cache["C10"] = "host"

# Use total device size to calculate args.device_limit_frac
device_size = device_mem_size(kind="total")
device_size = nvt_io.device_mem_size(kind="total")
device_limit = int(args.device_limit_frac * device_size)
device_pool_size = int(args.device_pool_frac * device_size)
part_size = int(args.part_mem_frac * device_size)
Expand Down Expand Up @@ -134,14 +134,18 @@ def main(args):
with performance_report(filename=args.profile):
processor.apply(
dataset,
shuffle="full" if args.worker_shuffle else "partial",
shuffle=nvt_io.Shuffle.PER_WORKER
if args.worker_shuffle
else nvt_io.Shuffle.PER_PARTITION,
out_files_per_proc=out_files_per_proc,
output_path=out_path,
)
else:
processor.apply(
dataset,
shuffle="full" if args.worker_shuffle else "partial",
shuffle=nvt_io.Shuffle.PER_WORKER
if args.worker_shuffle
else nvt_io.Shuffle.PER_PARTITION,
out_files_per_proc=out_files_per_proc,
output_path=out_path,
)
Expand Down
4 changes: 2 additions & 2 deletions examples/gpu_benchmark.ipynb
Expand Up @@ -168,7 +168,7 @@
"outputs": [],
"source": [
"%%time\n",
"proc.apply(trains_itrs, apply_offline=True, record_stats=True, output_path=output_path_train, shuffle=False)"
"proc.apply(trains_itrs, apply_offline=True, record_stats=True, output_path=output_path_train, shuffle=None)"
]
},
{
Expand All @@ -178,7 +178,7 @@
"outputs": [],
"source": [
"%%time\n",
"proc.apply(valids_itrs, apply_offline=True, record_stats=False, output_path=output_path_valid, shuffle=False)"
"proc.apply(valids_itrs, apply_offline=True, record_stats=False, output_path=output_path_valid, shuffle=None)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions examples/rossmann-store-sales-example.ipynb
Expand Up @@ -222,8 +222,8 @@
"metadata": {},
"outputs": [],
"source": [
"proc.apply(train_dataset, record_stats=True, output_path=PREPROCESS_DIR_TRAIN, shuffle=\"full\", out_files_per_proc=2)\n",
"proc.apply(valid_dataset, record_stats=False, output_path=PREPROCESS_DIR_VALID, shuffle=False)"
"proc.apply(train_dataset, record_stats=True, output_path=PREPROCESS_DIR_TRAIN, shuffle=nvt.io.Shuffle.PER_WORKER, out_files_per_proc=2)\n",
"proc.apply(valid_dataset, record_stats=False, output_path=PREPROCESS_DIR_VALID, shuffle=None)"
]
},
{
Expand Down
30 changes: 27 additions & 3 deletions nvtabular/io.py
Expand Up @@ -15,6 +15,7 @@
#

import collections
import enum
import functools
import json
import logging
Expand Down Expand Up @@ -53,11 +54,34 @@
LOG = logging.getLogger("nvtabular")


class Shuffle(enum.Enum):
PER_PARTITION = 0
PER_WORKER = 1
FULL = 2


#
# Helper Function definitions
#


def _check_shuffle_arg(shuffle):
if shuffle is None:
return shuffle

if isinstance(shuffle, Shuffle):
if shuffle == Shuffle.FULL:
raise ValueError('`shuffle="full"` is not yet supported.')
elif shuffle is True:
shuffle = Shuffle.PER_WORKER
warnings.warn("`shuffle=True` is deprecated. Using `PER_WORKER`.", DeprecationWarning)
elif shuffle is False:
shuffle = None
else:
raise ValueError(f"`shuffle={shuffle}` not recognized.")
return shuffle


def _allowable_batch_size(gpu_memory_frac, row_size):
free_mem = device_mem_size(kind="free")
gpu_memory = free_mem * gpu_memory_frac
Expand Down Expand Up @@ -448,7 +472,7 @@ def _bytesio_to_disk(self):
for bio, path in zip(self.data_bios, self.data_paths):
gdf = cudf.io.read_parquet(bio, index=False)
bio.close()
if self.shuffle == "full":
if self.shuffle == Shuffle.PER_WORKER:
gdf = _shuffle_gdf(gdf)
gdf.to_parquet(path, compression=None, index=False)
return
Expand Down Expand Up @@ -509,7 +533,7 @@ def _close_writers(self):
return None

def _bytesio_to_disk(self):
raise ValueError("hugectr binary format doesn't support shuffle=full yet")
raise ValueError("hugectr binary format doesn't support PER_WORKER shuffle yet")


#
Expand Down Expand Up @@ -543,7 +567,7 @@ def _write_output_partition(
out_files_per_proc,
shuffle,
use_guid=True,
bytes_io=(shuffle == "full"),
bytes_io=(shuffle == Shuffle.PER_WORKER),
num_threads=num_threads,
)
writer.set_col_names(labels=label_names, cats=cat_names, conts=cont_names)
Expand Down
66 changes: 44 additions & 22 deletions nvtabular/workflow.py
Expand Up @@ -558,10 +558,13 @@ def __init__(self, client=None, **kwargs):
super().__init__(**kwargs)
self.ddf = None
self.client = client
self._shuffle_parts = False

def set_ddf(self, ddf):
def set_ddf(self, ddf, shuffle=None):
if isinstance(ddf, (dask_cudf.DataFrame, nvt_io.Dataset)):
self.ddf = ddf
if shuffle is not None:
self._shuffle_parts = shuffle
else:
raise TypeError("ddf type not supported.")

Expand All @@ -570,7 +573,7 @@ def get_ddf(self):
raise ValueError("No dask_cudf frame available.")
elif isinstance(self.ddf, nvt_io.Dataset):
columns = self.columns_ctx["all"]["base"]
return self.ddf.to_ddf(columns=columns)
return self.ddf.to_ddf(columns=columns, shuffle=self._shuffle_parts)
return self.ddf

@staticmethod
Expand Down Expand Up @@ -669,38 +672,47 @@ def apply(
):
"""
Runs all the preprocessing and feature engineering operators.
Also, shuffles the data if shuffle is set to True.
Also, shuffles the data if a `shuffle` option is specified.
Parameters
-----------
dataset : object
apply_offline : boolean
runs operators in offline mode or not
Runs operators in offline mode or not
record_stats : boolean
record the stats in file or not. Only available
Record the stats in file or not. Only available
for apply_offline=True
shuffle : {"full", "partial", None}
Whether to shuffle output dataset. "partial" means
each worker will randomly shuffle data into a number
(`out_files_per_proc`) of different output files as the data
is processed. The output files are distinctly mapped to
each worker process. "full" means the workers will perform the
"partial" shuffle into BytesIO files, and then perform
a full shuffle of each in-memory file before writing to
disk. A "true" full shuffle is not yet implemented.
shuffle : nvt.io.Shuffle enum
How to shuffle the output dataset. Shuffling is only
performed if the data is written to disk. For all options,
other than `None` (which means no shuffling), the partitions
of the underlying dataset/ddf will be randomly ordered. If
`PER_PARTITION` is specified, each worker/process will also
shuffle the rows within each partition before splitting and
appending the data to a number (`out_files_per_proc`) of output
files. Output files are distinctly mapped to each worker process.
If `PER_WORKER` is specified, each worker will follow the same
procedure as `PER_PARTITION`, but will re-shuffle each file after
all data is persisted. This results in a full shuffle of the
data processed by each worker. To improve performace, this option
currently uses host-memory `BytesIO` objects for the intermediate
persist stage. The `FULL` option is not yet implemented.
output_path : string
path to output data
Path to write processed/shuffled output data
output_format : {"parquet", "hugectr", None}
Output format for processed/shuffled dataset. If None,
no output dataset will be written.
Output format to write processed/shuffled data. If None,
no output dataset will be written (and shuffling skipped).
out_files_per_proc : integer
number of files to create (per process) after
Number of files to create (per process) after
shuffling the data
num_io_threads : integer
Number of IO threads to use for writing the output dataset.
For `0` (default), no dedicated IO threads will be used.
"""

# Check shuffle argument
shuffle = nvt_io._check_shuffle_arg(shuffle)

# If no tasks have been loaded then we need to load internal config
if not self.phases:
self.finalize()
Expand Down Expand Up @@ -740,6 +752,9 @@ def iterate_online(
):
""" Iterate through dataset and (optionally) apply/shuffle/write.
"""
# Check shuffle argument
shuffle = nvt_io._check_shuffle_arg(shuffle)

# Check if we have a (supported) writer
output_path = output_path or "./"
output_path = str(output_path)
Expand All @@ -748,13 +763,13 @@ def iterate_online(
output_path,
out_files_per_proc,
shuffle,
bytes_io=(shuffle == "full"),
bytes_io=(shuffle == nvt_io.shuffle.per_worker),
num_threads=num_io_threads,
)

# Iterate through dataset, apply ops, and write out processed data
if apply_ops:
for gdf in dataset.to_iter():
for gdf in dataset.to_iter(shuffle=(shuffle is not None)):
self.apply_ops(gdf, output_path=output_path, writer=writer)

# Close writer and write general/specialized metadata
Expand Down Expand Up @@ -791,6 +806,9 @@ def build_and_process_graph(
Full graph is only executed if `output_format` is specified.
"""
# Check shuffle argument
shuffle = nvt_io._check_shuffle_arg(shuffle)

end = end_phase if end_phase else len(self.phases)

if output_format not in ("parquet", "hugectr", None):
Expand All @@ -806,7 +824,7 @@ def build_and_process_graph(
else:
clean_worker_cache()

self.set_ddf(dataset)
self.set_ddf(dataset, shuffle=(shuffle is not None))
if apply_ops:
for idx, _ in enumerate(self.phases[:end]):
self.exec_phase(idx, record_stats=record_stats)
Expand Down Expand Up @@ -837,6 +855,9 @@ def write_to_dataset(
Assumes statistics are already gathered.
"""
# Check shuffle argument
shuffle = nvt_io._check_shuffle_arg(shuffle)

if nfiles:
warnings.warn("nfiles is deprecated. Use out_files_per_proc")
if out_files_per_proc is None:
Expand Down Expand Up @@ -905,7 +926,8 @@ def ddf_to_dataset(
)
return

# Default (shuffle=False): Just use dask_cudf.to_parquet
# Default (shuffle=None and out_files_per_proc=None)
# Just use `dask_cudf.to_parquet`
fut = ddf.to_parquet(output_path, compression=None, write_index=False, compute=False)
if self.client is None:
fut.compute(scheduler="synchronous")
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_dask_nvt.py
Expand Up @@ -26,6 +26,7 @@

import nvtabular.ops as ops
from nvtabular import Dataset, Workflow
from nvtabular.io import Shuffle
from tests.conftest import allcols_csv, mycols_csv, mycols_pq


Expand All @@ -47,7 +48,7 @@ def _dummy_op_logic(gdf, target_columns, _id="dummy", **kwargs):
@pytest.mark.parametrize("freq_threshold", [0, 150])
@pytest.mark.parametrize("cat_cache", ["device", None])
@pytest.mark.parametrize("on_host", [True, False])
@pytest.mark.parametrize("shuffle", ["full", None])
@pytest.mark.parametrize("shuffle", [Shuffle.PER_WORKER, None])
def test_dask_workflow_api_dlrm(
client, tmpdir, datasets, freq_threshold, part_mem_fraction, engine, cat_cache, on_host, shuffle
):
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_io.py
Expand Up @@ -40,7 +40,7 @@ def test_shuffle_gpu(tmpdir, datasets, engine):
df1 = cudf.read_parquet(paths[0])[mycols_pq]
else:
df1 = cudf.read_csv(paths[0], header=False, names=allcols_csv)[mycols_csv]
shuf = ParquetWriter(tmpdir, num_out_files=num_files, shuffle="partial")
shuf = ParquetWriter(tmpdir, num_out_files=num_files, shuffle=nvt.io.Shuffle.PER_PARTITION)
shuf.add_data(df1)
writer_files = shuf.data_paths
shuf.close()
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_hugectr(
output_path=outdir,
out_files_per_proc=nfiles,
output_format=output_format,
shuffle=False,
shuffle=None,
num_io_threads=num_io_threads,
)

Expand Down
10 changes: 5 additions & 5 deletions tests/unit/test_ops.py
Expand Up @@ -399,7 +399,7 @@ def test_lambdaop(tmpdir, df, dataset, gpu_memory_frac, engine, client):
processor.update_stats(dataset)
outdir = tmpdir.mkdir("out1")
processor.write_to_dataset(
outdir, dataset, out_files_per_proc=10, shuffle="partial", apply_ops=True
outdir, dataset, out_files_per_proc=10, shuffle=nvt.io.Shuffle.PER_PARTITION, apply_ops=True
)

dataset_2 = nvtabular.io.Dataset(
Expand All @@ -422,7 +422,7 @@ def test_lambdaop(tmpdir, df, dataset, gpu_memory_frac, engine, client):
processor.update_stats(dataset)
outdir = tmpdir.mkdir("out2")
processor.write_to_dataset(
outdir, dataset, out_files_per_proc=10, shuffle="partial", apply_ops=True
outdir, dataset, out_files_per_proc=10, shuffle=nvt.io.Shuffle.PER_PARTITION, apply_ops=True
)

dataset_2 = nvtabular.io.Dataset(
Expand Down Expand Up @@ -452,7 +452,7 @@ def test_lambdaop(tmpdir, df, dataset, gpu_memory_frac, engine, client):
processor.update_stats(dataset)
outdir = tmpdir.mkdir("out3")
processor.write_to_dataset(
outdir, dataset, out_files_per_proc=10, shuffle="partial", apply_ops=True
outdir, dataset, out_files_per_proc=10, shuffle=nvt.io.Shuffle.PER_PARTITION, apply_ops=True
)
dataset_2 = nvtabular.io.Dataset(
glob.glob(str(outdir) + "/*.parquet"), part_mem_fraction=gpu_memory_frac
Expand All @@ -478,7 +478,7 @@ def test_lambdaop(tmpdir, df, dataset, gpu_memory_frac, engine, client):
processor.update_stats(dataset)
outdir = tmpdir.mkdir("out4")
processor.write_to_dataset(
outdir, dataset, out_files_per_proc=10, shuffle="partial", apply_ops=True
outdir, dataset, out_files_per_proc=10, shuffle=nvt.io.Shuffle.PER_PARTITION, apply_ops=True
)

dataset_2 = nvtabular.io.Dataset(
Expand Down Expand Up @@ -508,7 +508,7 @@ def test_lambdaop(tmpdir, df, dataset, gpu_memory_frac, engine, client):
processor.update_stats(dataset)
outdir = tmpdir.mkdir("out5")
processor.write_to_dataset(
outdir, dataset, out_files_per_proc=10, shuffle="partial", apply_ops=True
outdir, dataset, out_files_per_proc=10, shuffle=nvt.io.Shuffle.PER_PARTITION, apply_ops=True
)

dataset_2 = nvtabular.io.Dataset(
Expand Down

0 comments on commit 2826d0a

Please sign in to comment.