diff --git a/merlin/loader/loader_base.py b/merlin/loader/loader_base.py index f7a0e33c..8dc69690 100644 --- a/merlin/loader/loader_base.py +++ b/merlin/loader/loader_base.py @@ -19,6 +19,7 @@ import threading import warnings from collections import OrderedDict +from typing import List import numpy as np @@ -36,6 +37,8 @@ make_df, pull_apart_list, ) +from merlin.dag import BaseOperator, ColumnSelector, DictArray, Graph, Node +from merlin.dag.executors import LocalExecutor from merlin.io import shuffle_df from merlin.schema import Tags @@ -59,6 +62,8 @@ def __init__( global_size=None, global_rank=None, drop_last=False, + transforms=None, + device=None, ): self.dataset = dataset self.batch_size = batch_size @@ -70,7 +75,10 @@ def __init__( self.drop_last = drop_last self.indices = cp.arange(self.dataset.npartitions) - self.device = "cpu" if not HAS_GPU or dataset.cpu else 0 + if device: + self.device = device + else: + self.device = "cpu" if not HAS_GPU or dataset.cpu else 0 if not dataset.schema: warnings.warn( @@ -79,6 +87,7 @@ def __init__( ) dataset.schema = dataset.infer_schema() + self.schema = dataset.schema self.sparse_names = [] self.sparse_max = {} self.sparse_as_dense = set() @@ -126,6 +135,28 @@ def __init__( self._batch_itr = None self._workers = None + if transforms is not None: + + if isinstance(transforms, List): + carry_node = Node(ColumnSelector("*")) + for transform in transforms: + if not isinstance(transform, BaseOperator): + raise TypeError( + f"Detected invalid transform, {type(transform)}," + "we only support operators based on the merlin core" + "`BaseOperator`" + ) + carry_node = carry_node >> transform + transform_graph = Graph(carry_node) + elif type(transforms, Graph): + transform_graph = transforms + self.transforms = transform_graph.construct_schema(self.schema).output_node + self.schema = self.transforms.output_schema + self.executor = LocalExecutor() + else: + self.transforms = None + self.executor = None + @property def _buff(self): if self.__buff is None: @@ -553,8 +584,25 @@ def _handle_tensors(self, tensors, tensor_names): labels = None if len(self.label_names) > 0: labels = X.pop(self.label_names[0]) + + if self.transforms: + X = self.executor.transform(DictArray(X), [self.transforms]) + return X, labels + def _pack(self, gdf): + if isinstance(gdf, np.ndarray): + return gdf + # if self.device has value ('cpu') gdf should not be transferred to dlpack + elif hasattr(gdf, "to_dlpack") and callable(getattr(gdf, "to_dlpack")) and not self.device: + return gdf.to_dlpack() + elif hasattr(gdf, "to_numpy") and callable(getattr(gdf, "to_numpy")): + gdf = gdf.to_numpy() + if isinstance(gdf[0], list): + gdf = np.stack(gdf) + return gdf + return gdf.toDlpack() + class ChunkQueue: """This class takes partitions (parts) from an merlin.io.Dataset diff --git a/merlin/loader/ops/__init__.py b/merlin/loader/ops/__init__.py new file mode 100644 index 00000000..0b8ff56d --- /dev/null +++ b/merlin/loader/ops/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/merlin/loader/ops/embeddings/__init__.py b/merlin/loader/ops/embeddings/__init__.py new file mode 100644 index 00000000..a788b281 --- /dev/null +++ b/merlin/loader/ops/embeddings/__init__.py @@ -0,0 +1,27 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# flake8: noqa +from merlin.loader.ops.embeddings.tf_embedding_op import ( + TF_MmapNumpyTorchEmbedding, + TF_NumpyEmbeddingOperator, + TFEmbeddingOperator, +) +from merlin.loader.ops.embeddings.torch_embedding_op import ( + Torch_MmapNumpyTorchEmbedding, + Torch_NumpyEmbeddingOperator, + TorchEmbeddingOperator, +) diff --git a/merlin/loader/ops/embeddings/embedding_op.py b/merlin/loader/ops/embeddings/embedding_op.py new file mode 100644 index 00000000..76561a6f --- /dev/null +++ b/merlin/loader/ops/embeddings/embedding_op.py @@ -0,0 +1,235 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import numpy as np + +from merlin.core.protocols import Transformable +from merlin.dag import BaseOperator +from merlin.dag.selector import ColumnSelector +from merlin.schema import ColumnSchema, Schema, Tags + + +class EmbeddingOperator(BaseOperator): + """Create an operator that will apply a torch embedding table to supplied indices. + This operator allows the user to supply an id lookup table if the indices supplied + via the id_lookup_table. + + Parameters + ---------- + embeddings : np.ndarray + numpy ndarray representing embedding values + lookup_key : str, optional + the name of the column that will be used as indices, by default "id" + embedding_name : str, optional + name of new column of embeddings, added to output, by default "embeddings" + id_lookup_table : np.array, optional + numpy array of values that represent embedding indices, by default None + """ + + def __init__( + self, + embeddings: np.ndarray, + lookup_key: str = "id", + embedding_name: str = "embeddings", + id_lookup_table=None, + ): + self.embeddings = self._load_embeddings(embeddings) + self.lookup_key = lookup_key + self.embedding_name = embedding_name + self.id_lookup_table = id_lookup_table + + def transform( + self, col_selector: ColumnSelector, transformable: Transformable + ) -> Transformable: + keys = transformable[self.lookup_key] + indices = keys.cpu() + if self.id_lookup_table is not None: + indices = self._create_tensor(np.nonzero(np.in1d(self.id_lookup_table, indices))) + embeddings = self._embeddings_lookup(indices) + transformable[self.embedding_name] = self._format_embeddings(embeddings, keys) + return transformable + + def _load_embeddings(self, embeddings): + raise NotImplementedError("No logic supplied to load embeddings.") + + def _create_tensor(self, values): + raise NotImplementedError("No logic supplied to create tensor.") + + def _embeddings_lookup(self, indices): + raise NotImplementedError("No logic to look up embeddings with indices.") + + def _format_embeddings(self, embeddings, keys): + raise NotImplementedError("No logic to format embeddings.") + + def _get_dtype(self, embeddings): + raise NotImplementedError("No logic to retrieve dtype from embeddings.") + + def compute_output_schema( + self, + input_schema: Schema, + col_selector: ColumnSelector, + prev_output_schema: Schema = None, + ) -> Schema: + """Creates the output schema for this operator. + + Parameters + ---------- + input_schema : Schema + schema coming from ancestor nodes + col_selector : ColumnSelector + subselection of columns to apply to this operator + prev_output_schema : Schema, optional + the output schema of the previously executed operators, by default None + + Returns + ------- + Schema + Schema representing the correct output for this operator. + """ + col_schemas = [] + for _, col_schema in input_schema.column_schemas.items(): + col_schemas.append(col_schema) + col_schemas.append( + ColumnSchema( + name=self.embedding_name, + tags=[Tags.CONTINUOUS], + dtype=self._get_dtype(self.embeddings), + is_list=True, + is_ragged=False, + ) + ) + + return Schema(col_schemas) + + +class NumpyEmbeddingOperator(BaseOperator): + """Create an embedding table from supplied embeddings to add embedding entry + to records based on supplied indices. Support for indices lookup table is available. + Embedding table is stored in host memory. + + Parameters + ---------- + embeddings : np.ndarray + numpy ndarray representing embedding values + lookup_key : str, optional + the name of the column that will be used as indices, by default "id" + embedding_name : str, optional + name of new column of embeddings, added to output, by default "embeddings" + id_lookup_table : np.array, optional + numpy array of values that represent embedding indices, by default None + """ + + def __init__( + self, + embeddings: np.ndarray, + lookup_key: str = "id", + embedding_name: str = "embeddings", + id_lookup_table=None, + ): + self.embeddings = embeddings + self.lookup_key = lookup_key + self.embedding_name = embedding_name + self.id_lookup_table = id_lookup_table + + def transform( + self, col_selector: ColumnSelector, transformable: Transformable + ) -> Transformable: + keys = transformable[self.lookup_key] + indices = keys.cpu() + if self.id_lookup_table is not None: + indices = np.in1d(self.id_lookup_table, indices) + embeddings = self.embeddings[indices] + # numpy_to_tensor + transformable[self.embedding_name] = self._format_embeddings(embeddings, keys) + return transformable + + def _format_embeddings(self, embeddings, keys): + raise NotImplementedError("No logic to format embeddings.") + + def compute_output_schema( + self, + input_schema: Schema, + col_selector: ColumnSelector, + prev_output_schema: Schema = None, + ) -> Schema: + """Creates the output schema for this operator. + + Parameters + ---------- + input_schema : Schema + schema coming from ancestor nodes + col_selector : ColumnSelector + subselection of columns to apply to this operator + prev_output_schema : Schema, optional + the output schema of the previously executed operators, by default None + + Returns + ------- + Schema + Schema representing the correct output for this operator. + """ + col_schemas = [] + for _, col_schema in input_schema.column_schemas.items(): + col_schemas.append(col_schema) + col_schemas.append( + ColumnSchema( + name=self.embedding_name, + tags=[Tags.CONTINUOUS], + dtype=self.embeddings.dtype, + is_list=True, + is_ragged=False, + ) + ) + + return Schema(col_schemas) + + +class MmapNumpyTorchEmbedding(NumpyEmbeddingOperator): + """Operator loads numpy embedding table from file using memory map to be used to create + torch embedding representations. This allows for larger than host memory embedding + tables to be used for embedding lookups. The only limit to the size is what fits in + storage, preferred storage device is SSD for faster lookups. + + Parameters + ---------- + embedding_npz : numpy ndarray file + file holding numpy ndarray representing embedding table + ids_lookup_npz : numpy array file, optional + file holding numpy array of values that represent embedding indices, by default None + lookup_key : str, optional + the name of the column that will be used as indices, by default "id" + embedding_name : str, optional + name of new column of embeddings, added to output, by default "embeddings" + transform_function : _type_, optional + function that will transform embedding from numpy to torch, by default None + """ + + def __init__( + self, + embedding_npz, + ids_lookup_npz=None, + lookup_key="id", + embedding_name="embeddings", + ): + embeddings = np.load(embedding_npz, mmap_mode="r") + id_lookup = np.load(ids_lookup_npz) if ids_lookup_npz else None + super().__init__( + embeddings, + lookup_key=lookup_key, + embedding_name=embedding_name, + id_lookup_table=id_lookup, + ) diff --git a/merlin/loader/ops/embeddings/tf_embedding_op.py b/merlin/loader/ops/embeddings/tf_embedding_op.py new file mode 100644 index 00000000..64609ba7 --- /dev/null +++ b/merlin/loader/ops/embeddings/tf_embedding_op.py @@ -0,0 +1,101 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorflow as tf + +from merlin.loader.ops.embeddings.embedding_op import ( + EmbeddingOperator, + MmapNumpyTorchEmbedding, + NumpyEmbeddingOperator, +) + + +class TFEmbeddingOperator(EmbeddingOperator): + """Create an operator that will apply a tf embedding table to supplied indices. + This operator allows the user to supply an id lookup table if the indices supplied + via the id_lookup_table. Embedding table is stored in host memory. + + Parameters + ---------- + embeddings : np.ndarray + numpy ndarray representing embedding values + lookup_key : str, optional + the name of the column that will be used as indices, by default "id" + embedding_name : str, optional + name of new column of embeddings, added to output, by default "embeddings" + id_lookup_table : np.array, optional + numpy array of values that represent embedding indices, by default None + """ + + def _load_embeddings(self, embeddings): + return embeddings if isinstance(embeddings, tf.Tensor) else tf.convert_to_tensor(embeddings) + + def _create_tensor(self, values): + return values + + def _embeddings_lookup(self, indices): + return tf.nn.embedding_lookup(self.embeddings, indices) + + def _format_embeddings(self, embeddings, keys): + return tf.squeeze(embeddings) + + def _get_dtype(self, embeddings): + return embeddings.dtype.as_numpy_dtype + + +class TF_NumpyEmbeddingOperator(NumpyEmbeddingOperator): + """Create an embedding table from supplied embeddings to add embedding entry + to records based on supplied indices. Support for indices lookup table is available. + Embedding table is stored in host memory. + + Parameters + ---------- + embeddings : np.ndarray + numpy ndarray representing embedding values + lookup_key : str, optional + the name of the column that will be used as indices, by default "id" + embedding_name : str, optional + name of new column of embeddings, added to output, by default "embeddings" + id_lookup_table : np.array, optional + numpy array of values that represent embedding indices, by default None + """ + + def _format_embeddings(self, embeddings, keys): + return tf.squeeze(tf.convert_to_tensor(embeddings)) + + +class TF_MmapNumpyTorchEmbedding(MmapNumpyTorchEmbedding): + """Operator loads numpy embedding table from file using memory map to be used to create + tensorflow embedding representations. This allows for larger than host memory embedding + tables to be used for embedding lookups. The only limit to the size is what fits in + storage, preferred storage device is SSD for faster lookups. + + Parameters + ---------- + embedding_npz : numpy ndarray file + file holding numpy ndarray representing embedding table + ids_lookup_npz : numpy array file, optional + file holding numpy array of values that represent embedding indices, by default None + lookup_key : str, optional + the name of the column that will be used as indices, by default "id" + embedding_name : str, optional + name of new column of embeddings, added to output, by default "embeddings" + transform_function : _type_, optional + function that will transform embedding from numpy to torch, by default None + """ + + def _format_embeddings(self, embeddings, keys): + return tf.squeeze(tf.convert_to_tensor(embeddings)) diff --git a/merlin/loader/ops/embeddings/torch_embedding_op.py b/merlin/loader/ops/embeddings/torch_embedding_op.py new file mode 100644 index 00000000..92b3849a --- /dev/null +++ b/merlin/loader/ops/embeddings/torch_embedding_op.py @@ -0,0 +1,106 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +from torch.nn import Embedding + +from merlin.loader.ops.embeddings.embedding_op import ( + EmbeddingOperator, + MmapNumpyTorchEmbedding, + NumpyEmbeddingOperator, +) + + +class TorchEmbeddingOperator(EmbeddingOperator): + """Create an operator that will apply a torch embedding table to supplied indices. + This operator allows the user to supply an id lookup table if the indices supplied + via the id_lookup_table. + + Parameters + ---------- + embeddings : np.ndarray + numpy ndarray representing embedding values + lookup_key : str, optional + the name of the column that will be used as indices, by default "id" + embedding_name : str, optional + name of new column of embeddings, added to output, by default "embeddings" + id_lookup_table : np.array, optional + numpy array of values that represent embedding indices, by default None + """ + + def _load_embeddings(self, embeddings): + return ( + embeddings + if isinstance(embeddings, Embedding) + else Embedding.from_pretrained(torch.FloatTensor(embeddings)) + ) + + def _create_tensor(self, values): + return torch.Tensor(values).to(torch.int32) + + def _embeddings_lookup(self, indices): + return self.embeddings(indices) + + def _format_embeddings(self, embeddings, keys): + return torch.squeeze(embeddings.to(keys.device)) + + def _get_dtype(self, embeddings): + return embeddings.weight.numpy().dtype + + +class Torch_NumpyEmbeddingOperator(NumpyEmbeddingOperator): + """Create an embedding table from supplied embeddings to add embedding entry + to records based on supplied indices. Support for indices lookup table is available. + Embedding table is stored in host memory. + + Parameters + ---------- + embeddings : np.ndarray + numpy ndarray representing embedding values + lookup_key : str, optional + the name of the column that will be used as indices, by default "id" + embedding_name : str, optional + name of new column of embeddings, added to output, by default "embeddings" + id_lookup_table : np.array, optional + numpy array of values that represent embedding indices, by default None + """ + + def _format_embeddings(self, embeddings, keys): + return torch.from_numpy(embeddings).to(keys.device) + + +class Torch_MmapNumpyTorchEmbedding(MmapNumpyTorchEmbedding): + """Operator loads numpy embedding table from file using memory map to be used to create + torch embedding representations. This allows for larger than host memory embedding + tables to be used for embedding lookups. The only limit to the size is what fits in + storage, preferred storage device is SSD for faster lookups. + + Parameters + ---------- + embedding_npz : numpy ndarray file + file holding numpy ndarray representing embedding table + ids_lookup_npz : numpy array file, optional + file holding numpy array of values that represent embedding indices, by default None + lookup_key : str, optional + the name of the column that will be used as indices, by default "id" + embedding_name : str, optional + name of new column of embeddings, added to output, by default "embeddings" + transform_function : _type_, optional + function that will transform embedding from numpy to torch, by default None + """ + + def _format_embeddings(self, embeddings, keys): + return torch.from_numpy(embeddings).to(keys.device) diff --git a/merlin/loader/tensorflow.py b/merlin/loader/tensorflow.py index 59a31da5..0a7b3d35 100644 --- a/merlin/loader/tensorflow.py +++ b/merlin/loader/tensorflow.py @@ -16,8 +16,6 @@ import contextlib import logging -import numpy as np - from merlin.loader.loader_base import LoaderBase from merlin.loader.tf_utils import configure_tensorflow @@ -112,6 +110,8 @@ def __init__( global_size=None, global_rank=None, drop_last=False, + transforms=None, + device=None, ): LoaderBase.__init__( self, @@ -123,6 +123,8 @@ def __init__( global_size=global_size, global_rank=global_rank, drop_last=drop_last, + transforms=transforms, + device=device, ) self._map_fns = [] @@ -179,17 +181,13 @@ def _tensor_split(self, tensor, idx, axis=0): """ return tf.split(tensor, idx, axis=axis) - def _pack(self, gdf): - if isinstance(gdf, np.ndarray): - return gdf - elif hasattr(gdf, "to_dlpack") and callable(getattr(gdf, "to_dlpack")): - return gdf.to_dlpack() - elif hasattr(gdf, "to_numpy") and callable(getattr(gdf, "to_numpy")): - gdf = gdf.to_numpy() - if isinstance(gdf[0], list): - gdf = np.stack(gdf) - return gdf - return gdf.toDlpack() + @property + def _LONG_DTYPE(self): + return tf.int64 + + @property + def _FLOAT32_DTYPE(self): + return tf.float32 def _unpack(self, gdf): if hasattr(gdf, "shape"): diff --git a/merlin/loader/torch.py b/merlin/loader/torch.py index 777d9c97..e5c3aecd 100644 --- a/merlin/loader/torch.py +++ b/merlin/loader/torch.py @@ -77,6 +77,8 @@ def __init__( global_size=None, global_rank=None, drop_last=False, + transforms=None, + device=None, ): LoaderBase.__init__( self, @@ -88,6 +90,8 @@ def __init__( global_size=global_size, global_rank=global_rank, drop_last=drop_last, + transforms=transforms, + device=device, ) def __iter__(self): @@ -96,22 +100,17 @@ def __iter__(self): def _get_device_ctx(self, dev): if dev == "cpu": return torch.device("cpu") - return torch.cuda.device("cuda:{}".format(dev)) - - def _pack(self, gdf): - if self.device == "cpu": - return gdf - return gdf.to_dlpack() + return torch.cuda.device(f"cuda:{dev}") def _unpack(self, dlpack): if self.device == "cpu": - values = dlpack.values + values = dlpack.values if hasattr(dlpack, "values") else dlpack dtype = values.dtype dtype = numpy_to_torch_dtype_dict[dtype.type] if hasattr(dtype, "type") else dtype if ( - len(dlpack.values.shape) == 2 - and dlpack.values.shape[1] == 1 - and isinstance(dlpack.values[0], np.ndarray) + len(values.shape) == 2 + and values.shape[1] == 1 + and isinstance(values[0], np.ndarray) ): return torch.squeeze(torch.Tensor(values)).type(dtype) return torch.Tensor(values).type(dtype) diff --git a/requirements/base.txt b/requirements/base.txt index 86519bc5..9e012e61 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1 +1 @@ -merlin-core \ No newline at end of file +merlin-core diff --git a/requirements/dev.txt b/requirements/dev.txt index 5e441cd9..869d9815 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,5 +1,6 @@ scipy scikit-learn>=0.20 +npy_append_array black==22.3.0 click<8.1.0 diff --git a/tests/conftest.py b/tests/conftest.py index f5e3bf5f..8cd22be9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,11 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import gc +import glob import random import dask import numpy as np import pandas as pd +from npy_append_array import NpyAppendArray try: import cudf @@ -44,7 +47,7 @@ def assert_eq(a, b, *args, **kwargs): import pytest -from merlin.core.dispatch import make_df +from merlin.core.dispatch import concat_columns, get_lib, make_df from merlin.io import Dataset from merlin.schema import Tags @@ -159,3 +162,83 @@ def multihot_dataset(multihot_data): ds = Dataset(make_df(multihot_data)) ds.schema["Post"] = ds.schema["Post"].with_tags(Tags.TARGET) return ds + + +@pytest.fixture(scope="session") +def num_embedding_ids(): + return 1 + + +@pytest.fixture(scope="session") +def embeddings_part_size(): + return 1e5 + + +@pytest.fixture(scope="session") +def embedding_ids(num_embedding_ids, embeddings_part_size): + df = make_df({"id": np.arange(num_embedding_ids * embeddings_part_size).astype("int32")}) + return df + + +@pytest.fixture(scope="session") +def rev_embedding_ids(embedding_ids, tmpdir_factory): + df_rev = embedding_ids["id"][::-1] + df_rev.reset_index(inplace=True, drop=True) + return make_df(df_rev) + + +@pytest.fixture(scope="session") +def embeddings_from_dataframe(embedding_ids, num_embedding_ids, tmpdir_factory): + embed_dir = tmpdir_factory.mktemp("embeds") + for idx, splt in enumerate(np.array_split(embedding_ids.to_numpy(), num_embedding_ids)): + vals = make_df(np.random.rand(splt.shape[0], 1024)) + ids = make_df({"id": np.squeeze(splt)}) + full = concat_columns([ids, vals]) + full.columns = [str(col) for col in full.columns] + full.to_parquet(f"{embed_dir}/{idx}.parquet") + return embed_dir + + +@pytest.fixture(scope="session") +def rev_embeddings_from_dataframe(rev_embedding_ids, num_embedding_ids, tmpdir_factory): + embed_dir = tmpdir_factory.mktemp("rev_embeds") + for idx, splt in enumerate(np.array_split(rev_embedding_ids.to_numpy(), num_embedding_ids)): + vals = make_df(np.random.rand(splt.shape[0], 1024)) + ids = make_df({"id": np.squeeze(splt)}) + full = concat_columns([ids, vals]) + full.columns = [str(col) for col in full.columns] + full.to_parquet(f"{embed_dir}/{idx}.parquet") + return embed_dir + + +def build_embeddings_from_pq( + df_paths, embedding_filename="embeddings.npy", lookup_filename="lookup_ids" +): + df_lib = get_lib() + with NpyAppendArray(embedding_filename) as nf: + with NpyAppendArray(lookup_filename) as lf: + for path in df_paths: + rows = df_lib.read_parquet(path) + numpy_rows = rows.to_numpy() + indices = np.ascontiguousarray(numpy_rows[:, 0]) + vectors = np.ascontiguousarray(numpy_rows[:, 1:]) + lf.append(indices) + nf.append(vectors) + del rows + del numpy_rows + del indices + del vectors + gc.collect() + return embedding_filename, lookup_filename + + +@pytest.fixture(scope="session") +def np_embeddings_from_pq(rev_embeddings_from_dataframe, tmpdir_factory): + paths = sorted(glob.glob(f"{rev_embeddings_from_dataframe}/*")) + embed_dir = tmpdir_factory.mktemp("np_embeds") + embeddings_file = f"{embed_dir}/embeddings.npy" + lookup_ids_file = f"{embed_dir}/ids_lookup.npy" + npy_filename, lookup_filename = build_embeddings_from_pq( + paths, embeddings_file, lookup_ids_file + ) + return npy_filename, lookup_filename diff --git a/tests/unit/loader/test_tf_embeddings.py b/tests/unit/loader/test_tf_embeddings.py new file mode 100644 index 00000000..1ecd5482 --- /dev/null +++ b/tests/unit/loader/test_tf_embeddings.py @@ -0,0 +1,259 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import glob + +import numpy as np +import pytest + +from merlin.core.dispatch import HAS_GPU +from merlin.io import Dataset +from merlin.loader.tensorflow import Loader +from merlin.schema import Tags + +tf = pytest.importorskip("tensorflow") + +from merlin.loader.ops.embeddings import ( # noqa + TF_MmapNumpyTorchEmbedding, + TF_NumpyEmbeddingOperator, + TFEmbeddingOperator, +) + + +@pytest.mark.parametrize("cpu", [None, "cpu"] if HAS_GPU else ["cpu"]) +def test_embedding_tf_np_mmap_dl_no_lookup(tmpdir, embedding_ids, np_embeddings_from_pq, cpu): + batch_size = 10000 + embeddings_file, _ = np_embeddings_from_pq + cat_names = ["id"] + embeddings = np.load(embeddings_file) + pq_path = tmpdir / "id.parquet" + embedding_ids.to_parquet(pq_path) + dataset = Dataset(str(pq_path)) + dataset = dataset.repartition(10) + schema = dataset.schema + for col_name in cat_names: + schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + dataset.schema = schema + + for col_name in cat_names: + schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + dataset.schema = schema + data_loader = Loader( + dataset, + batch_size=batch_size, + transforms=[TF_MmapNumpyTorchEmbedding(embeddings_file)], + shuffle=False, + device=cpu, + ) + full_len = 0 + for idx, batch in enumerate(data_loader): + assert "embeddings" in batch[0] + assert "id" in batch[0] + start = idx * batch_size + end = start + batch[0]["id"].shape[0] + assert (batch[0]["embeddings"].cpu().numpy() == embeddings[start:end]).all() + full_len += batch[0]["embeddings"].shape[0] + assert full_len == embedding_ids.shape[0] + + +@pytest.mark.parametrize("cpu", [None, "cpu"] if HAS_GPU else ["cpu"]) +def test_embedding_tf_np_mmap_dl_with_lookup(tmpdir, rev_embedding_ids, np_embeddings_from_pq, cpu): + batch_size = 10000 + embeddings_file, id_lookup_file = np_embeddings_from_pq + cat_names = ["id"] + embedding_ids = rev_embedding_ids + embeddings = np.load(embeddings_file) + pq_path = tmpdir / "id.parquet" + embedding_ids.to_parquet(pq_path) + dataset = Dataset(str(pq_path)) + dataset = dataset.repartition(10) + schema = dataset.schema + for col_name in cat_names: + schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + dataset.schema = schema + + for col_name in cat_names: + schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + dataset.schema = schema + data_loader = Loader( + dataset, + batch_size=batch_size, + transforms=[TF_MmapNumpyTorchEmbedding(embeddings_file, ids_lookup_npz=id_lookup_file)], + shuffle=False, + device=cpu, + ) + full_len = 0 + for idx, batch in enumerate(data_loader): + assert "embeddings" in batch[0] + assert "id" in batch[0] + start = idx * batch_size + end = start + batch[0]["id"].shape[0] + assert (batch[0]["embeddings"].cpu().numpy() == embeddings[start:end]).all() + full_len += batch[0]["embeddings"].shape[0] + assert full_len == embedding_ids.shape[0] + + +@pytest.mark.parametrize("cpu", [None, "cpu"] if HAS_GPU else ["cpu"]) +def test_embedding_tf_np_dl_no_lookup(tmpdir, embedding_ids, embeddings_from_dataframe, cpu): + cat_names = ["id"] + batch_size = 10000 + pq_path = tmpdir / "id.parquet" + embedding_ids.to_parquet(pq_path) + dataset = Dataset(str(pq_path)) + dataset = dataset.repartition(10) + schema = dataset.schema + for col_name in cat_names: + schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + dataset.schema = schema + + for col_name in cat_names: + schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + dataset.schema = schema + paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*")) + embeddings_ds = Dataset(paths) + embeddings_np = embeddings_ds.to_ddf().compute().to_numpy()[:, 1:] + data_loader = Loader( + dataset, + batch_size=batch_size, + transforms=[TF_NumpyEmbeddingOperator(embeddings_np)], + shuffle=False, + device=cpu, + ) + full_len = 0 + for idx, batch in enumerate(data_loader): + assert "embeddings" in batch[0] + assert "id" in batch[0] + start = idx * batch_size + end = start + batch[0]["id"].shape[0] + assert (batch[0]["embeddings"].cpu().numpy() == embeddings_np[start:end]).all() + full_len += batch[0]["embeddings"].shape[0] + assert full_len == embedding_ids.shape[0] + + +@pytest.mark.parametrize("cpu", [None, "cpu"] if HAS_GPU else ["cpu"]) +def test_embedding_tf_np_dl_with_lookup(tmpdir, rev_embedding_ids, embeddings_from_dataframe, cpu): + cat_names = ["id"] + batch_size = 10000 + pq_path = tmpdir / "id.parquet" + embedding_ids = rev_embedding_ids + embedding_ids.to_parquet(pq_path) + dataset = Dataset(str(pq_path)) + dataset = dataset.repartition(10) + schema = dataset.schema + for col_name in cat_names: + schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + dataset.schema = schema + + for col_name in cat_names: + schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + dataset.schema = schema + paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*")) + embeddings_ds = Dataset(paths) + embeddings_np = embeddings_ds.to_ddf().compute().to_numpy()[:, 1:] + data_loader = Loader( + dataset, + batch_size=batch_size, + transforms=[ + TF_NumpyEmbeddingOperator(embeddings_np, id_lookup_table=embedding_ids.to_numpy()) + ], + shuffle=False, + device=cpu, + ) + full_len = 0 + for idx, batch in enumerate(data_loader): + assert "embeddings" in batch[0] + assert "id" in batch[0] + start = idx * batch_size + end = start + batch[0]["id"].shape[0] + assert (batch[0]["embeddings"].cpu().numpy() == embeddings_np[start:end]).all() + full_len += batch[0]["embeddings"].shape[0] + assert full_len == embedding_ids.shape[0] + + +@pytest.mark.parametrize("cpu", [None, "cpu"] if HAS_GPU else ["cpu"]) +def test_embedding_tf_dl_no_lookup(tmpdir, embedding_ids, embeddings_from_dataframe, cpu): + cat_names = ["id"] + batch_size = 10000 + pq_path = tmpdir / "id.parquet" + embedding_ids.to_parquet(pq_path) + dataset = Dataset(str(pq_path)) + dataset = dataset.repartition(10) + schema = dataset.schema + for col_name in cat_names: + schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + dataset.schema = schema + + for col_name in cat_names: + schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + dataset.schema = schema + paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*")) + embeddings_ds = Dataset(paths) + np_tensor = embeddings_ds.to_ddf().compute().to_numpy()[:, 1:] + tf_tensor = tf.convert_to_tensor(np_tensor) + data_loader = Loader( + dataset, + batch_size=batch_size, + transforms=[TFEmbeddingOperator(tf_tensor)], + shuffle=False, + device=cpu, + ) + full_len = 0 + for idx, batch in enumerate(data_loader): + assert "embeddings" in batch[0] + assert "id" in batch[0] + start = idx * batch_size + end = start + batch[0]["id"].shape[0] + assert (batch[0]["embeddings"].cpu().numpy() == np_tensor[start:end]).all() + full_len += batch[0]["embeddings"].shape[0] + assert full_len == embedding_ids.shape[0] + + +@pytest.mark.parametrize("cpu", [None, "cpu"] if HAS_GPU else ["cpu"]) +def test_embedding_tf_dl_with_lookup(tmpdir, rev_embedding_ids, embeddings_from_dataframe, cpu): + cat_names = ["id"] + batch_size = 10000 + pq_path = tmpdir / "id.parquet" + embedding_ids = rev_embedding_ids + embedding_ids.to_parquet(pq_path) + dataset = Dataset(str(pq_path)) + dataset = dataset.repartition(10) + schema = dataset.schema + for col_name in cat_names: + schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + dataset.schema = schema + + for col_name in cat_names: + schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + dataset.schema = schema + paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*")) + embeddings_ds = Dataset(paths) + np_tensor = embeddings_ds.to_ddf().compute().to_numpy()[:, 1:] + tf_tensor = tf.convert_to_tensor(np_tensor) + data_loader = Loader( + dataset, + batch_size=batch_size, + transforms=[TFEmbeddingOperator(tf_tensor, id_lookup_table=embedding_ids.to_numpy())], + shuffle=False, + device=cpu, + ) + full_len = 0 + for idx, batch in enumerate(data_loader): + assert "embeddings" in batch[0] + assert "id" in batch[0] + start = idx * batch_size + end = start + batch[0]["id"].shape[0] + assert (batch[0]["embeddings"].cpu().numpy() == np_tensor[start:end]).all() + full_len += batch[0]["embeddings"].shape[0] + assert full_len == embedding_ids.shape[0] diff --git a/tests/unit/loader/test_torch_embeddings.py b/tests/unit/loader/test_torch_embeddings.py new file mode 100644 index 00000000..47346bb9 --- /dev/null +++ b/tests/unit/loader/test_torch_embeddings.py @@ -0,0 +1,240 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import glob + +import numpy as np +import pytest + +from merlin.core.dispatch import HAS_GPU +from merlin.io import Dataset +from merlin.loader.torch import Loader +from merlin.schema import Tags + +torch = pytest.importorskip("torch") + +from merlin.loader.ops.embeddings import ( # noqa + Torch_MmapNumpyTorchEmbedding, + Torch_NumpyEmbeddingOperator, + TorchEmbeddingOperator, +) + + +@pytest.mark.parametrize("cpu", [None, "cpu"] if HAS_GPU else ["cpu"]) +def test_embedding_torch_np_mmap_dl_with_lookup( + tmpdir, rev_embedding_ids, np_embeddings_from_pq, cpu +): + batch_size = 10000 + embeddings_file, lookup_file = np_embeddings_from_pq + cat_names = ["id"] + embeddings = np.load(embeddings_file) + pq_path = tmpdir / "id.parquet" + rev_embedding_ids.to_parquet(pq_path) + dataset = Dataset(str(pq_path)) + dataset = dataset.repartition(10) + schema = dataset.schema + for col_name in cat_names: + schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + dataset.schema = schema + + data_loader = Loader( + dataset, + batch_size=batch_size, + transforms=[Torch_MmapNumpyTorchEmbedding(embeddings_file, ids_lookup_npz=lookup_file)], + shuffle=False, + device=cpu, + ) + full_len = 0 + for idx, batch in enumerate(data_loader): + assert "embeddings" in batch[0] + assert "id" in batch[0] + start = idx * batch_size + end = start + batch[0]["id"].shape[0] + assert (batch[0]["embeddings"].cpu().numpy() == embeddings[start:end]).all() + full_len += batch[0]["embeddings"].shape[0] + assert full_len == rev_embedding_ids.shape[0] + + +@pytest.mark.parametrize("cpu", [None, "cpu"] if HAS_GPU else ["cpu"]) +def test_embedding_torch_np_mmap_dl_no_lookup(tmpdir, embedding_ids, np_embeddings_from_pq, cpu): + batch_size = 10000 + embeddings_file, lookup_file = np_embeddings_from_pq + cat_names = ["id"] + embeddings = np.load(embeddings_file) + pq_path = tmpdir / "id.parquet" + embedding_ids.to_parquet(pq_path) + dataset = Dataset(str(pq_path)) + dataset = dataset.repartition(10) + schema = dataset.schema + for col_name in cat_names: + schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + dataset.schema = schema + + data_loader = Loader( + dataset, + batch_size=batch_size, + transforms=[Torch_MmapNumpyTorchEmbedding(embeddings_file)], + shuffle=False, + device=cpu, + ) + full_len = 0 + for idx, batch in enumerate(data_loader): + assert "embeddings" in batch[0] + assert "id" in batch[0] + start = idx * batch_size + end = start + batch[0]["id"].shape[0] + assert (batch[0]["embeddings"].cpu().numpy() == embeddings[start:end]).all() + full_len += batch[0]["embeddings"].shape[0] + assert full_len == embedding_ids.shape[0] + + +@pytest.mark.parametrize("cpu", [None, "cpu"] if HAS_GPU else ["cpu"]) +def test_embedding_torch_np_dl_with_lookup( + tmpdir, rev_embedding_ids, embeddings_from_dataframe, cpu +): + cat_names = ["id"] + batch_size = 10000 + embedding_ids = rev_embedding_ids + pq_path = tmpdir / "id.parquet" + embedding_ids.to_parquet(pq_path) + dataset = Dataset(str(pq_path)) + dataset = dataset.repartition(10) + schema = dataset.schema + for col_name in cat_names: + schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + dataset.schema = schema + paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*")) + embeddings_ds = Dataset(paths) + embeddings_df = embeddings_ds.to_ddf().compute().to_numpy()[:, 1:] + data_loader = Loader( + dataset, + batch_size=batch_size, + transforms=[ + Torch_NumpyEmbeddingOperator(embeddings_df, id_lookup_table=embedding_ids.to_numpy()) + ], + shuffle=False, + device=cpu, + ) + full_len = 0 + for idx, batch in enumerate(data_loader): + assert "embeddings" in batch[0] + assert "id" in batch[0] + start = idx * batch_size + end = start + batch[0]["id"].shape[0] + assert (batch[0]["embeddings"].cpu().numpy() == embeddings_df[start:end]).all() + full_len += batch[0]["embeddings"].shape[0] + assert full_len == embedding_ids.shape[0] + + +@pytest.mark.parametrize("cpu", [None, "cpu"] if HAS_GPU else ["cpu"]) +def test_embedding_torch_np_dl_no_lookup(tmpdir, embedding_ids, embeddings_from_dataframe, cpu): + cat_names = ["id"] + batch_size = 10000 + pq_path = tmpdir / "id.parquet" + embedding_ids.to_parquet(pq_path) + dataset = Dataset(str(pq_path)) + dataset = dataset.repartition(10) + schema = dataset.schema + for col_name in cat_names: + schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + dataset.schema = schema + paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*")) + embeddings_ds = Dataset(paths) + embeddings_df = embeddings_ds.to_ddf().compute().to_numpy()[:, 1:] + data_loader = Loader( + dataset, + batch_size=batch_size, + transforms=[Torch_NumpyEmbeddingOperator(embeddings_df)], + shuffle=False, + device=cpu, + ) + full_len = 0 + for idx, batch in enumerate(data_loader): + assert "embeddings" in batch[0] + assert "id" in batch[0] + start = idx * batch_size + end = start + batch[0]["id"].shape[0] + assert (batch[0]["embeddings"].cpu().numpy() == embeddings_df[start:end]).all() + full_len += batch[0]["embeddings"].shape[0] + assert full_len == embedding_ids.shape[0] + + +@pytest.mark.parametrize("cpu", [None, "cpu"] if HAS_GPU else ["cpu"]) +def test_embedding_torch_dl_with_lookup(tmpdir, rev_embedding_ids, embeddings_from_dataframe, cpu): + cat_names = ["id"] + batch_size = 10000 + pq_path = tmpdir / "id.parquet" + embedding_ids = rev_embedding_ids + embedding_ids.to_parquet(pq_path) + dataset = Dataset(str(pq_path)) + dataset = dataset.repartition(10) + schema = dataset.schema + for col_name in cat_names: + schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + dataset.schema = schema + paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*")) + embeddings_ds = Dataset(paths) + np_tensor = embeddings_ds.to_ddf().compute().to_numpy()[:, 1:].astype("float32") + torch_tensor = torch.from_numpy(np_tensor) + data_loader = Loader( + dataset, + batch_size=batch_size, + transforms=[TorchEmbeddingOperator(torch_tensor, id_lookup_table=embedding_ids.to_numpy())], + shuffle=False, + device=cpu, + ) + full_len = 0 + for idx, batch in enumerate(data_loader): + assert "embeddings" in batch[0] + assert "id" in batch[0] + start = idx * batch_size + end = start + batch[0]["id"].shape[0] + assert (batch[0]["embeddings"].cpu().numpy() == np_tensor[start:end]).all() + full_len += batch[0]["embeddings"].shape[0] + assert full_len == embedding_ids.shape[0] + + +@pytest.mark.parametrize("cpu", [None, "cpu"] if HAS_GPU else ["cpu"]) +def test_embedding_torch_dl_no_lookup(tmpdir, embedding_ids, embeddings_from_dataframe, cpu): + cat_names = ["id"] + batch_size = 10000 + pq_path = tmpdir / "id.parquet" + embedding_ids.to_parquet(pq_path) + dataset = Dataset(str(pq_path)) + dataset = dataset.repartition(10) + schema = dataset.schema + for col_name in cat_names: + schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + dataset.schema = schema + paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*")) + embeddings_ds = Dataset(paths) + np_tensor = embeddings_ds.to_ddf().compute().to_numpy().astype("float32") + torch_tensor = torch.from_numpy(np_tensor) + data_loader = Loader( + dataset, + batch_size=batch_size, + transforms=[TorchEmbeddingOperator(torch_tensor)], + shuffle=False, + device=cpu, + ) + full_len = 0 + for idx, batch in enumerate(data_loader): + assert "embeddings" in batch[0] + assert "id" in batch[0] + start = idx * batch_size + end = start + batch[0]["id"].shape[0] + assert (batch[0]["embeddings"].cpu().numpy() == np_tensor[start:end]).all() + full_len += batch[0]["embeddings"].shape[0] + assert full_len == embedding_ids.shape[0] diff --git a/tox.ini b/tox.ini index 97d89ea6..47d5732f 100644 --- a/tox.ini +++ b/tox.ini @@ -12,7 +12,9 @@ commands = [testenv:test-cpu] ; Runs in: Github Actions ; Runs all CPU-based tests. -deps = -rrequirements/base.txt +deps = + -rrequirements/base.txt + -rrequirements/dev.txt commands = python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/core.git python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/nvtabular.git @@ -29,6 +31,7 @@ sitepackages=true ; to install requirements.txt yet. As we get better at python environment isolation, we will ; need to add some back. deps = + -rrequirements/dev.txt pytest pytest-cov commands =