Skip to content

Commit

Permalink
[Bugfix][Relay][Keras] Fix UpSampling2D about the wrong assertion abo…
Browse files Browse the repository at this point in the history
…ut size (#15082)

* fix wrong assertion about unsample in keras.py

* Update test_forward.py

* Update test_forward.py
  • Loading branch information
jikechao committed Jun 14, 2023
1 parent bd24133 commit 081cc2e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
4 changes: 1 addition & 3 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,10 +767,8 @@ def _convert_upsample(
params["scale_h"] = h
elif upsample_type == "UpSampling2D":
h, w = keras_layer.size
if h != w:
raise tvm.error.OpAttributeInvalid("Height must equal width for operator Upsample.")
params["scale_h"] = h
params["scale_w"] = h
params["scale_w"] = w

if hasattr(keras_layer, "interpolation"):
interpolation = keras_layer.interpolation
Expand Down
5 changes: 5 additions & 0 deletions tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,11 @@ def test_forward_upsample(self, keras_mod, interpolation="nearest"):
x = keras_mod.layers.UpSampling2D(size=(3, 3), interpolation=interpolation)(data)
keras_model = keras_mod.models.Model(data, x)
verify_keras_frontend(keras_model)
# Height and width are not equal for the attribute size
data = keras_mod.layers.Input(shape=(2, 1, 3))
x = keras_mod.layers.UpSampling2D(size=(1, 2), interpolation=interpolation)(data)
keras_model = keras_mod.models.Model(data, x)
verify_keras_frontend(keras_model)

def test_forward_reshape(self, keras_mod):
"""test_forward_reshape"""
Expand Down

0 comments on commit 081cc2e

Please sign in to comment.