Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
db89cab
lay down foundation of transform capability in dataloader.
jperez999 Oct 8, 2022
736c8e8
working transforms for greater that host memory and greater
jperez999 Oct 13, 2022
c1572fa
add in memory versions of indexing
jperez999 Oct 14, 2022
0eb8d81
added docstrings to operators and made change to tf embedding test
jperez999 Oct 14, 2022
6334d91
remove base loader logic for gpu only
jperez999 Oct 14, 2022
f978c57
Merge branch 'main' into dl-transforms
jperez999 Oct 25, 2022
d4d4687
working id lookup and non lookup nmap torch
jperez999 Nov 3, 2022
91fe81b
all torch embeddings with id and without working
jperez999 Nov 3, 2022
0d59358
tf with and without lookups all green
jperez999 Nov 3, 2022
4cd2270
Merge branch 'dl-transforms' of https://github.com/jperez999/dataload…
jperez999 Nov 3, 2022
8b6fed0
add the npy_append_array package to allow for utility build embedding…
jperez999 Nov 3, 2022
da2c9dc
retry adding base.txt
jperez999 Nov 3, 2022
fdcbca4
Merge branch 'main' into dl-transforms
jperez999 Nov 3, 2022
c1dbe71
fix various errors in dataloader to allow dictarray
jperez999 Nov 3, 2022
a499d88
Merge branch 'dl-transforms' of https://github.com/jperez999/dataload…
jperez999 Nov 3, 2022
1867c4c
remove rmm in testing
jperez999 Nov 3, 2022
3f400f2
remove cudf based parameter row_group_size_bytes
jperez999 Nov 3, 2022
2378873
Merge branch 'main' into dl-transforms
jperez999 Nov 3, 2022
44cdfad
move npy-append to dev requirements for conda package mitigation
jperez999 Nov 3, 2022
6bb0cc0
tox update for requirements
jperez999 Nov 3, 2022
01ea009
adding support for still using CPU, overriding default behavior when …
jperez999 Nov 3, 2022
2213a24
refactor code to use base embedding operators
jperez999 Nov 8, 2022
9c247cc
remove comments from loader base
jperez999 Nov 8, 2022
2e47f87
replace torch double dtype with float64 to keep current convention
jperez999 Nov 8, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion merlin/loader/loader_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import threading
import warnings
from collections import OrderedDict
from typing import List

import numpy as np

Expand All @@ -36,6 +37,8 @@
make_df,
pull_apart_list,
)
from merlin.dag import BaseOperator, ColumnSelector, DictArray, Graph, Node
from merlin.dag.executors import LocalExecutor
from merlin.io import shuffle_df
from merlin.schema import Tags

Expand All @@ -59,6 +62,8 @@ def __init__(
global_size=None,
global_rank=None,
drop_last=False,
transforms=None,
device=None,
):
self.dataset = dataset
self.batch_size = batch_size
Expand All @@ -70,7 +75,10 @@ def __init__(
self.drop_last = drop_last

self.indices = cp.arange(self.dataset.npartitions)
self.device = "cpu" if not HAS_GPU or dataset.cpu else 0
if device:
self.device = device
else:
self.device = "cpu" if not HAS_GPU or dataset.cpu else 0

if not dataset.schema:
warnings.warn(
Expand All @@ -79,6 +87,7 @@ def __init__(
)
dataset.schema = dataset.infer_schema()

self.schema = dataset.schema
self.sparse_names = []
self.sparse_max = {}
self.sparse_as_dense = set()
Expand Down Expand Up @@ -126,6 +135,28 @@ def __init__(
self._batch_itr = None
self._workers = None

if transforms is not None:

if isinstance(transforms, List):
carry_node = Node(ColumnSelector("*"))
for transform in transforms:
if not isinstance(transform, BaseOperator):
raise TypeError(
f"Detected invalid transform, {type(transform)},"
"we only support operators based on the merlin core"
"`BaseOperator`"
)
carry_node = carry_node >> transform
transform_graph = Graph(carry_node)
elif type(transforms, Graph):
transform_graph = transforms
self.transforms = transform_graph.construct_schema(self.schema).output_node
self.schema = self.transforms.output_schema
self.executor = LocalExecutor()
else:
self.transforms = None
self.executor = None

@property
def _buff(self):
if self.__buff is None:
Expand Down Expand Up @@ -553,8 +584,25 @@ def _handle_tensors(self, tensors, tensor_names):
labels = None
if len(self.label_names) > 0:
labels = X.pop(self.label_names[0])

if self.transforms:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should think about creating a comprehensive "column" class that can be sub-classed to ScalarColumn and ListColumn. This will hide the tuple format behind a df series type interface that will be more friendly to the other parts of merlin, i.e. the graph. The use case is what if I want to do some after dataloader inbatch processing to a list column. It will be easier to abstract that tuple representation (values, nnz) and allow the user to not have to worry about keeping track of all that.

X = self.executor.transform(DictArray(X), [self.transforms])

return X, labels

def _pack(self, gdf):
if isinstance(gdf, np.ndarray):
return gdf
# if self.device has value ('cpu') gdf should not be transferred to dlpack
elif hasattr(gdf, "to_dlpack") and callable(getattr(gdf, "to_dlpack")) and not self.device:
return gdf.to_dlpack()
elif hasattr(gdf, "to_numpy") and callable(getattr(gdf, "to_numpy")):
gdf = gdf.to_numpy()
if isinstance(gdf[0], list):
gdf = np.stack(gdf)
return gdf
return gdf.toDlpack()


class ChunkQueue:
"""This class takes partitions (parts) from an merlin.io.Dataset
Expand Down
15 changes: 15 additions & 0 deletions merlin/loader/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
27 changes: 27 additions & 0 deletions merlin/loader/ops/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# flake8: noqa
from merlin.loader.ops.embeddings.tf_embedding_op import (
TF_MmapNumpyTorchEmbedding,
TF_NumpyEmbeddingOperator,
TFEmbeddingOperator,
)
from merlin.loader.ops.embeddings.torch_embedding_op import (
Torch_MmapNumpyTorchEmbedding,
Torch_NumpyEmbeddingOperator,
TorchEmbeddingOperator,
)
235 changes: 235 additions & 0 deletions merlin/loader/ops/embeddings/embedding_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


import numpy as np

from merlin.core.protocols import Transformable
from merlin.dag import BaseOperator
from merlin.dag.selector import ColumnSelector
from merlin.schema import ColumnSchema, Schema, Tags


class EmbeddingOperator(BaseOperator):
"""Create an operator that will apply a torch embedding table to supplied indices.
This operator allows the user to supply an id lookup table if the indices supplied
via the id_lookup_table.

Parameters
----------
embeddings : np.ndarray
numpy ndarray representing embedding values
lookup_key : str, optional
the name of the column that will be used as indices, by default "id"
embedding_name : str, optional
name of new column of embeddings, added to output, by default "embeddings"
id_lookup_table : np.array, optional
numpy array of values that represent embedding indices, by default None
"""

def __init__(
self,
embeddings: np.ndarray,
lookup_key: str = "id",
embedding_name: str = "embeddings",
id_lookup_table=None,
):
self.embeddings = self._load_embeddings(embeddings)
self.lookup_key = lookup_key
self.embedding_name = embedding_name
self.id_lookup_table = id_lookup_table

def transform(
self, col_selector: ColumnSelector, transformable: Transformable
) -> Transformable:
keys = transformable[self.lookup_key]
indices = keys.cpu()
if self.id_lookup_table is not None:
indices = self._create_tensor(np.nonzero(np.in1d(self.id_lookup_table, indices)))
embeddings = self._embeddings_lookup(indices)
transformable[self.embedding_name] = self._format_embeddings(embeddings, keys)
return transformable

def _load_embeddings(self, embeddings):
raise NotImplementedError("No logic supplied to load embeddings.")

def _create_tensor(self, values):
raise NotImplementedError("No logic supplied to create tensor.")

def _embeddings_lookup(self, indices):
raise NotImplementedError("No logic to look up embeddings with indices.")

def _format_embeddings(self, embeddings, keys):
raise NotImplementedError("No logic to format embeddings.")

def _get_dtype(self, embeddings):
raise NotImplementedError("No logic to retrieve dtype from embeddings.")

def compute_output_schema(
self,
input_schema: Schema,
col_selector: ColumnSelector,
prev_output_schema: Schema = None,
) -> Schema:
"""Creates the output schema for this operator.

Parameters
----------
input_schema : Schema
schema coming from ancestor nodes
col_selector : ColumnSelector
subselection of columns to apply to this operator
prev_output_schema : Schema, optional
the output schema of the previously executed operators, by default None

Returns
-------
Schema
Schema representing the correct output for this operator.
"""
col_schemas = []
for _, col_schema in input_schema.column_schemas.items():
col_schemas.append(col_schema)
col_schemas.append(
ColumnSchema(
name=self.embedding_name,
tags=[Tags.CONTINUOUS],
dtype=self._get_dtype(self.embeddings),
is_list=True,
is_ragged=False,
)
)

return Schema(col_schemas)


class NumpyEmbeddingOperator(BaseOperator):
"""Create an embedding table from supplied embeddings to add embedding entry
to records based on supplied indices. Support for indices lookup table is available.
Embedding table is stored in host memory.

Parameters
----------
embeddings : np.ndarray
numpy ndarray representing embedding values
lookup_key : str, optional
the name of the column that will be used as indices, by default "id"
embedding_name : str, optional
name of new column of embeddings, added to output, by default "embeddings"
id_lookup_table : np.array, optional
numpy array of values that represent embedding indices, by default None
"""

def __init__(
self,
embeddings: np.ndarray,
lookup_key: str = "id",
embedding_name: str = "embeddings",
id_lookup_table=None,
):
self.embeddings = embeddings
self.lookup_key = lookup_key
self.embedding_name = embedding_name
self.id_lookup_table = id_lookup_table

def transform(
self, col_selector: ColumnSelector, transformable: Transformable
) -> Transformable:
keys = transformable[self.lookup_key]
indices = keys.cpu()
if self.id_lookup_table is not None:
indices = np.in1d(self.id_lookup_table, indices)
embeddings = self.embeddings[indices]
# numpy_to_tensor
transformable[self.embedding_name] = self._format_embeddings(embeddings, keys)
return transformable

def _format_embeddings(self, embeddings, keys):
raise NotImplementedError("No logic to format embeddings.")

def compute_output_schema(
self,
input_schema: Schema,
col_selector: ColumnSelector,
prev_output_schema: Schema = None,
) -> Schema:
"""Creates the output schema for this operator.

Parameters
----------
input_schema : Schema
schema coming from ancestor nodes
col_selector : ColumnSelector
subselection of columns to apply to this operator
prev_output_schema : Schema, optional
the output schema of the previously executed operators, by default None

Returns
-------
Schema
Schema representing the correct output for this operator.
"""
col_schemas = []
for _, col_schema in input_schema.column_schemas.items():
col_schemas.append(col_schema)
col_schemas.append(
ColumnSchema(
name=self.embedding_name,
tags=[Tags.CONTINUOUS],
dtype=self.embeddings.dtype,
is_list=True,
is_ragged=False,
)
)

return Schema(col_schemas)


class MmapNumpyTorchEmbedding(NumpyEmbeddingOperator):
"""Operator loads numpy embedding table from file using memory map to be used to create
torch embedding representations. This allows for larger than host memory embedding
tables to be used for embedding lookups. The only limit to the size is what fits in
storage, preferred storage device is SSD for faster lookups.

Parameters
----------
embedding_npz : numpy ndarray file
file holding numpy ndarray representing embedding table
ids_lookup_npz : numpy array file, optional
file holding numpy array of values that represent embedding indices, by default None
lookup_key : str, optional
the name of the column that will be used as indices, by default "id"
embedding_name : str, optional
name of new column of embeddings, added to output, by default "embeddings"
transform_function : _type_, optional
function that will transform embedding from numpy to torch, by default None
"""

def __init__(
self,
embedding_npz,
ids_lookup_npz=None,
lookup_key="id",
embedding_name="embeddings",
):
embeddings = np.load(embedding_npz, mmap_mode="r")
id_lookup = np.load(ids_lookup_npz) if ids_lookup_npz else None
super().__init__(
embeddings,
lookup_key=lookup_key,
embedding_name=embedding_name,
id_lookup_table=id_lookup,
)
Loading