Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support tuple return type from model pre and update test to use this #890

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
b698224
Support `tuple` return typee from `pre` arg to `evaluate`, `predict`
oliverholworthy Nov 16, 2022
94b2855
Update CLM transformer test to use `pre` instead of Loader `transform`
oliverholworthy Nov 16, 2022
03a063d
Merge branch 'main' into clm-test-use-model-pre
oliverholworthy Nov 18, 2022
66fb21c
Merge branch 'main' into clm-test-use-model-pre
oliverholworthy Nov 21, 2022
7cf0f5f
Merge branch 'main' into clm-test-use-model-pre
rnyak Nov 23, 2022
89869a3
Merge branch 'main' into clm-test-use-model-pre
rnyak Nov 23, 2022
5e3ea7a
Merge branch 'main' into clm-test-use-model-pre
oliverholworthy Nov 25, 2022
4ee01cf
Update youtube dnn tests to use transform as model fit pre
oliverholworthy Dec 1, 2022
1abd8d5
Merge branch 'main' into clm-test-use-model-pre
oliverholworthy Dec 1, 2022
1eef7b8
Add `pre` to ModelBlock fit/evaluate
oliverholworthy Dec 2, 2022
f58e075
Revert "Add `pre` to ModelBlock fit/evaluate"
oliverholworthy Dec 2, 2022
c499616
Raise exception if ragged/sparse tensors are passed at training time.
oliverholworthy Dec 2, 2022
b95d000
Update model_test helper to avoid passing ragged tensors to `fit`
oliverholworthy Dec 2, 2022
f617a50
Handle x and y in model_test
oliverholworthy Dec 2, 2022
19eefb4
Change process_lists param to False by default
oliverholworthy Dec 2, 2022
a4b26b3
Convert to tuple in test loader
oliverholworthy Dec 2, 2022
1f42cfe
Move order of ragged tensor assertion to before train_pre call
oliverholworthy Dec 2, 2022
a332e25
expand dims in test_classification
oliverholworthy Dec 2, 2022
37cc84d
pass transform as pre in test in batch negatives
oliverholworthy Dec 2, 2022
17ef3d5
Update continuous and retrieval tests
oliverholworthy Dec 2, 2022
03f6885
Remove test of sequence predict functions with loader
oliverholworthy Dec 7, 2022
50921b8
Update error message about ragged tensors for clarity
oliverholworthy Dec 7, 2022
deef749
Add explanation about why the input types are restricted
oliverholworthy Dec 7, 2022
d68f218
Merge branch 'main' into clm-test-use-model-pre
oliverholworthy Dec 12, 2022
72e96da
Rename dataset to dataloader in model_test
oliverholworthy Dec 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion merlin/models/tf/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def sample_batch(
include_targets: bool = True,
to_ragged: bool = False,
to_dense: bool = False,
process_lists=True,
process_lists=False,
):
"""Util function to generate a batch of input tensors from a merlin.io.Dataset instance

Expand Down
26 changes: 26 additions & 0 deletions merlin/models/tf/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,22 @@ def train_step(self, data):
with tf.GradientTape() as tape:
x, y, sample_weight = unpack_x_y_sample_weight(data)

# Ensure that we don't have any ragged or sparse tensors passed at training time.
if isinstance(x, dict):
for k in x:
if isinstance(x[k], (tf.RaggedTensor, tf.SparseTensor)):
raise ValueError(
"Training with RaggedTensor or SparseTensor input features is "
"not supported. Please update your dataloader to pass a tuple "
"of dense tensors instead, (corresponding to the values and "
"row lengths of the ragged input feature). This will ensure that "
"the model can be saved with the correct input signature, "
"and served correctly. "
"This is because when ragged or sparse tensors are fed as inputs "
"the input feature names are currently lost in the saved model "
"input signature."
)

if getattr(self, "train_pre", None):
out = call_layer(self.train_pre, x, targets=y, features=x, training=True)
if isinstance(out, Prediction):
Expand Down Expand Up @@ -775,6 +791,11 @@ def test_step(self, data):
out = call_layer(self.test_pre, x, targets=y, features=x, testing=True)
if isinstance(out, Prediction):
x, y = out.outputs, out.targets
elif isinstance(out, tuple):
assert (
len(out) == 2
), "output of `pre` must be a 2-tuple of x, y or `Prediction` tuple"
x, y = out
else:
x = out

Expand Down Expand Up @@ -804,6 +825,11 @@ def predict_step(self, data):
out = call_layer(self.predict_pre, x, features=x, training=False)
if isinstance(out, Prediction):
x = out.outputs
elif isinstance(out, tuple):
assert (
len(out) == 2
), "output of `pre` must be a 2-tuple of x, y or `Prediction` tuple"
x, y = out
else:
x = out

Expand Down
9 changes: 6 additions & 3 deletions merlin/models/tf/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,16 @@ def model_test(

assert isinstance(loaded_model, type(model))

x, y = sample_batch(dataloader, batch_size=50, to_ragged=False, process_lists=False)
batch = [(x, y)]

np.testing.assert_array_almost_equal(
model.predict(batch[0]),
loaded_model.predict(batch[0]),
model.predict(iter(batch)),
loaded_model.predict(iter(batch)),
)

loaded_model.compile(run_eagerly=run_eagerly, optimizer=optimizer, **kwargs)
loaded_model.train_step(batch)
loaded_model.fit(iter(batch))

return loaded_model, losses

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/tf/inputs/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_continuous_features_ragged(sequence_testing_data: Dataset):
inputs = ml.ContinuousFeatures.from_schema(
schema, post=ml.BroadcastToSequence(context_schema, seq_schema), aggregation="concat"
)
features, _ = ml.sample_batch(sequence_testing_data, batch_size=100)
features, _ = ml.sample_batch(sequence_testing_data, batch_size=100, process_lists=True)
outputs = inputs(features)

assert outputs.to_tensor().shape == (100, 4, 6)
41 changes: 25 additions & 16 deletions tests/unit/tf/models/test_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,19 +819,20 @@ def test_youtube_dnn_retrieval(sequence_testing_data: Dataset):

as_ragged = mm.ListToRagged()

def last_interaction_as_target(inputs, targets):
inputs = as_ragged(inputs)
items = inputs["item_id_seq"]
_items = items[:, :-1]
targets = items[:, -1:].flat_values
class LastInteractionAsTarget(tf.keras.layers.Layer):
def call(self, inputs, **kwargs):
inputs = as_ragged(inputs)
items = inputs["item_id_seq"]
_items = items[:, :-1]
targets = items[:, -1:].flat_values

inputs["item_id_seq"] = _items
inputs["item_id_seq"] = _items

return inputs, targets
return inputs, targets

dataloader = mm.Loader(sequence_testing_data, batch_size=50).map(last_interaction_as_target)
dataloader = mm.Loader(sequence_testing_data, batch_size=50)

losses = model.fit(dataloader, epochs=1)
losses = model.fit(dataloader, epochs=1, pre=LastInteractionAsTarget())

assert losses is not None

Expand All @@ -856,10 +857,14 @@ def test_youtube_dnn_retrieval_v2(sequence_testing_data: Dataset, run_eagerly, t
schema=sequence_testing_data.schema, top_block=mm.MLPBlock([32]), num_sampled=1000
)

dataloader = mm.Loader(sequence_testing_data, batch_size=50).map(target_augmentation)
dataloader = mm.Loader(sequence_testing_data, batch_size=50)

_, losses = testing_utils.model_test(
model, dataloader, reload_model=True, run_eagerly=run_eagerly
model,
dataloader,
reload_model=True,
run_eagerly=run_eagerly,
fit_kwargs=dict(pre=target_augmentation),
)

assert losses is not None
Expand Down Expand Up @@ -936,8 +941,10 @@ def test_youtube_dnn_v2_export_embeddings(sequence_testing_data: Dataset):
schema=sequence_testing_data.schema, top_block=mm.MLPBlock([32]), num_sampled=1000
)

dataloader = mm.Loader(sequence_testing_data, batch_size=50).map(predict_next)
model, _ = testing_utils.model_test(model, dataloader, reload_model=False)
dataloader = mm.Loader(sequence_testing_data, batch_size=50)
model, _ = testing_utils.model_test(
model, dataloader, reload_model=False, fit_kwargs=dict(pre=predict_next)
)

candidates = model.candidate_embeddings().compute()
assert list(candidates.columns) == [str(i) for i in range(32)]
Expand Down Expand Up @@ -969,13 +976,15 @@ def test_youtube_dnn_topk_evaluation(sequence_testing_data: Dataset, run_eagerly
schema=sequence_testing_data.schema, top_block=mm.MLPBlock([32]), num_sampled=1000
)

dataloader = mm.Loader(sequence_testing_data, batch_size=50).map(predict_next)
dataloader = mm.Loader(sequence_testing_data, batch_size=50)

model, _ = testing_utils.model_test(model, dataloader, reload_model=False)
model, _ = testing_utils.model_test(
model, dataloader, reload_model=False, fit_kwargs=dict(pre=predict_next)
)

# Top-K evaluation
topk_model = model.to_top_k_encoder(k=20)
topk_model.compile(run_eagerly=run_eagerly)

metrics = topk_model.evaluate(dataloader, return_dict=True)
metrics = topk_model.evaluate(dataloader, return_dict=True, pre=predict_next)
assert all([metric >= 0 for metric in metrics.values()])
6 changes: 6 additions & 0 deletions tests/unit/tf/outputs/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ def _last_interaction_as_target(inputs, targets):
_items = items[:, :-1]
targets = tf.one_hot(items[:, -1:].flat_values, 51997)
inputs["item_id_seq"] = _items
for k in inputs:
if isinstance(inputs[k], tf.RaggedTensor):
inputs[k] = (
tf.expand_dims(inputs[k].values, 1),
tf.expand_dims(inputs[k].row_lengths(), 1),
)
return inputs, targets

schema = sequence_testing_data.schema.select_by_tag(Tags.CATEGORICAL)
Expand Down
16 changes: 7 additions & 9 deletions tests/unit/tf/transformers/test_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,7 @@ def test_transformer_with_causal_language_modeling(sequence_testing_data: Datase
target = sequence_testing_data.schema.select_by_tag(Tags.ITEM_ID).column_names[0]
predict_next = mm.SequencePredictNext(schema=seq_schema, target=target)

loader = Loader(
sequence_testing_data,
batch_size=8,
shuffle=False,
).map(predict_next)
loader = Loader(sequence_testing_data, batch_size=8, shuffle=False)

model = mm.Model(
mm.InputBlockV2(
Expand All @@ -260,14 +256,16 @@ def test_transformer_with_causal_language_modeling(sequence_testing_data: Datase

batch = next(iter(loader))[0]
outputs = model(batch)
assert list(outputs.shape) == [8, 3, 51997]
testing_utils.model_test(model, loader, run_eagerly=run_eagerly, reload_model=True)
assert list(outputs.shape) == [8, 4, 51997]
testing_utils.model_test(
model, loader, run_eagerly=run_eagerly, reload_model=True, fit_kwargs={"pre": predict_next}
)

metrics = model.evaluate(loader, batch_size=8, steps=1, return_dict=True)
metrics = model.evaluate(loader, batch_size=8, steps=1, return_dict=True, pre=predict_next)
assert len(metrics) > 0

predictions = model.predict(loader, batch_size=8, steps=1)
assert predictions.shape == (8, 3, 51997)
assert predictions.shape == (8, 4, 51997)


@pytest.mark.parametrize("run_eagerly", [True, False])
Expand Down
7 changes: 4 additions & 3 deletions tests/unit/tf/transforms/test_negative_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,10 @@ def test_model_with_dataloader(self, music_streaming_data: Dataset, tf_random_se
add_negatives = InBatchNegatives(schema, 5, seed=tf_random_seed)

batch_size, n_per_positive = 10, 5
loader = mm.Loader(dataset, batch_size=batch_size).map(add_negatives)
loader = mm.Loader(dataset, batch_size=batch_size)

features, targets = next(iter(loader))
features, targets = next(loader)
features, targets = add_negatives(features, targets)

expected_batch_size = batch_size + batch_size * n_per_positive

Expand All @@ -226,4 +227,4 @@ def test_model_with_dataloader(self, music_streaming_data: Dataset, tf_random_se
assert model(features).shape[0] > batch_size
assert model(features).shape[0] <= expected_batch_size

testing_utils.model_test(model, loader)
testing_utils.model_test(model, loader, fit_kwargs=dict(pre=add_negatives))
34 changes: 6 additions & 28 deletions tests/unit/tf/transforms/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,17 @@

import merlin.models.tf as mm
from merlin.io import Dataset
from merlin.models.tf.loader import Loader
from merlin.models.tf.utils.testing_utils import assert_output_shape
from merlin.schema import Tags


@pytest.mark.parametrize("use_loader", [False, True])
def test_seq_predict_next(sequence_testing_data: Dataset, use_loader: bool):
def test_seq_predict_next(sequence_testing_data: Dataset):
seq_schema = sequence_testing_data.schema.select_by_tag(Tags.SEQUENCE)
target = sequence_testing_data.schema.select_by_tag(Tags.ITEM_ID).column_names[0]
predict_next = mm.SequencePredictNext(schema=seq_schema, target=target, pre=mm.ListToRagged())

batch, _ = mm.sample_batch(sequence_testing_data, batch_size=8, process_lists=False)
if use_loader:
dataset_transformed = Loader(sequence_testing_data, batch_size=8, shuffle=False).map(
predict_next
)
output = next(iter(dataset_transformed))
else:
output = predict_next(batch)
output = predict_next(batch)
output_x, output_y = output
output_y = output_y[target]

Expand All @@ -58,20 +50,13 @@ def test_seq_predict_next(sequence_testing_data: Dataset, use_loader: bool):
)


@pytest.mark.parametrize("use_loader", [False, True])
def test_seq_predict_last(sequence_testing_data: Dataset, use_loader: bool):
def test_seq_predict_last(sequence_testing_data: Dataset):
seq_schema = sequence_testing_data.schema.select_by_tag(Tags.SEQUENCE)
target = sequence_testing_data.schema.select_by_tag(Tags.ITEM_ID).column_names[0]
predict_last = mm.SequencePredictLast(schema=seq_schema, target=target)

batch, _ = mm.sample_batch(sequence_testing_data, batch_size=8, process_lists=False)
if use_loader:
dataset_transformed = Loader(sequence_testing_data, batch_size=8, shuffle=False).map(
predict_last
)
output = next(iter(dataset_transformed))
else:
output = predict_last(batch)
output = predict_last(batch)
output_x, output_y = output
output_y = output_y[target]

Expand All @@ -93,20 +78,13 @@ def test_seq_predict_last(sequence_testing_data: Dataset, use_loader: bool):
)


@pytest.mark.parametrize("use_loader", [False, True])
def test_seq_predict_random(sequence_testing_data: Dataset, use_loader: bool):
def test_seq_predict_random(sequence_testing_data: Dataset):
seq_schema = sequence_testing_data.schema.select_by_tag(Tags.SEQUENCE)
target = sequence_testing_data.schema.select_by_tag(Tags.ITEM_ID).column_names[0]
predict_random = mm.SequencePredictRandom(schema=seq_schema, target=target)

batch, _ = mm.sample_batch(sequence_testing_data, batch_size=8, process_lists=False)
if use_loader:
dataset_transformed = Loader(sequence_testing_data, batch_size=8, shuffle=False).map(
predict_random
)
output = next(iter(dataset_transformed))
else:
output = predict_random(batch)
output = predict_random(batch)
output_x, output_y = output
output_y = output_y[target]

Expand Down