diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 8afeb2c1187e..0f4c5f648352 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1220,19 +1220,21 @@ def get_timestep_embedding( [N x dim] Tensor of positional embeddings. """ dtype = get_default_dtype() - timesteps = _op.astype(x._expr, dtype) + + # Arithmetic should be done in float for precision. + timesteps = _op.astype(x._expr, "float32") half_dim = embedding_dim // 2 - exponent = rx.const(-math.log(max_period), dtype) * _op.arange( - start=0, end=half_dim, dtype=dtype + exponent = rx.const(-math.log(max_period), "float32") * _op.arange( + start=0, end=half_dim, dtype="float32" ) - exponent = exponent / (rx.const(half_dim - downscale_freq_shift, dtype)) + exponent = exponent / (rx.const(half_dim - downscale_freq_shift, "float32")) emb = _op.exp(exponent) emb = _op.expand_dims(timesteps, 1) * _op.expand_dims(emb, 0) # Scale embeddings if scale != 1: - emb = rx.const(scale, dtype) * emb + emb = rx.const(scale, "float32") * emb # Concat sine and cosine embeddings. if flip_sin_to_cos: @@ -1243,6 +1245,9 @@ def get_timestep_embedding( # Zero pad if embedding_dim % 2 == 1: emb = _op.nn.pad(emb, (0, 1, 0, 0)) + + # Cast to proper output type + emb = _op.astype(emb, dtype) return _wrap_nested(emb, name) diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py index f3d248eab41b..cb207954e834 100644 --- a/tests/python/relax/test_frontend_nn_modules.py +++ b/tests/python/relax/test_frontend_nn_modules.py @@ -402,8 +402,9 @@ def forward( lv8: R.Tensor((3, 5), dtype="float32") = R.multiply(lv2, lv7) lv9: R.Tensor((3, 5), dtype="float32") = R.sin(lv8) lv10: R.Tensor((3, 5), dtype="float32") = R.cos(lv8) - get_timestep_embedding: R.Tensor((3, 10), dtype="float32") = R.concat( - (lv9, lv10), axis=-1 + lv11: R.Tensor((3, 10), dtype="float32") = R.concat((lv9, lv10), axis=-1) + get_timestep_embedding: R.Tensor((3, 10), dtype="float32") = R.astype( + lv11, dtype="float32" ) gv1: R.Tuple( R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object) diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index f6cb29a87b1d..c7bef231243d 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -386,8 +386,9 @@ def test( lv8: R.Tensor((3, 5), dtype="float32") = R.multiply(lv2, lv7) lv9: R.Tensor((3, 5), dtype="float32") = R.sin(lv8) lv10: R.Tensor((3, 5), dtype="float32") = R.cos(lv8) - get_timestep_embedding: R.Tensor((3, 10), dtype="float32") = R.concat( - (lv9, lv10), axis=-1 + lv11: R.Tensor((3, 10), dtype="float32") = R.concat((lv9, lv10), axis=-1) + get_timestep_embedding: R.Tensor((3, 10), dtype="float32") = R.astype( + lv11, dtype="float32" ) gv1: R.Tuple( R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)