Skip to content

Fix issue 3146 for multiply_log #3147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
18af903
Fix issue 3146 by using if constexpr for path deduction. Also remove …
Feb 13, 2025
7afd1a9
fix cpplint
Feb 13, 2025
dbf99e8
update multiply log to fix #2494
SteveBronder Feb 18, 2025
00c3f8d
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Feb 18, 2025
8e0e207
update multiply_log
SteveBronder Feb 18, 2025
5a9b392
update multiply_log
SteveBronder Feb 18, 2025
df6a972
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Feb 18, 2025
4444ee9
fix accidental change to elt_multiply prim
SteveBronder Feb 18, 2025
f37df6f
update lmultiply to just call multiply_log
SteveBronder Mar 6, 2025
3267107
update to develop
SteveBronder Mar 6, 2025
8198639
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Mar 6, 2025
b064d8a
update reverse mode multiply_log to have one signature to accept comb…
SteveBronder Mar 6, 2025
511b06a
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Mar 6, 2025
026583c
update reverse mode multiply_log to have one signature to accept comb…
SteveBronder Mar 6, 2025
1be4234
fix opencl template issue for lmultiply
SteveBronder Mar 7, 2025
90c4235
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Mar 7, 2025
0b4a667
fix opencl lmultiply
SteveBronder Mar 14, 2025
b065359
Merge commit '8b8057ae220ba325e1ae38fb1adc48ceead52b0a' into HEAD
yashikno Mar 14, 2025
eb7db19
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Mar 14, 2025
9a5d2bb
Merge remote-tracking branch 'origin/develop' into fix/issue-3146
SteveBronder Mar 19, 2025
d55f7a6
remove lmultiply from opencl and instead use opencl multiply_log
SteveBronder Mar 19, 2025
f1c9094
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Mar 19, 2025
c6ad6b0
conditional sum for opencl multiply_log
SteveBronder Mar 20, 2025
cbc3af5
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Mar 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 1 addition & 20 deletions stan/math/fwd/fun/lmultiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,8 @@
#include <stan/math/fwd/meta.hpp>
#include <stan/math/fwd/core.hpp>
#include <stan/math/fwd/fun/log.hpp>
#include <stan/math/fwd/fun/multiply_log.hpp>
#include <stan/math/prim/fun/lmultiply.hpp>
#include <cmath>

namespace stan {
namespace math {

template <typename T>
inline fvar<T> lmultiply(const fvar<T>& x1, const fvar<T>& x2) {
return fvar<T>(lmultiply(x1.val_, x2.val_),
x1.d_ * log(x2.val_) + x1.val_ * x2.d_ / x2.val_);
}

template <typename T>
inline fvar<T> lmultiply(double x1, const fvar<T>& x2) {
return fvar<T>(lmultiply(x1, x2.val_), x1 * x2.d_ / x2.val_);
}

template <typename T>
inline fvar<T> lmultiply(const fvar<T>& x1, double x2) {
return fvar<T>(lmultiply(x1.val_, x2), x1.d_ * log(x2));
}
} // namespace math
} // namespace stan
#endif
11 changes: 11 additions & 0 deletions stan/math/fwd/fun/multiply_log.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <stan/math/fwd/meta.hpp>
#include <stan/math/fwd/core.hpp>
#include <stan/math/fwd/fun/log.hpp>
#include <stan/math/fwd/fun/value_of_rec.hpp>
#include <stan/math/prim/fun/multiply_log.hpp>
#include <cmath>

Expand All @@ -12,19 +13,29 @@ namespace math {

template <typename T>
inline fvar<T> multiply_log(const fvar<T>& x1, const fvar<T>& x2) {
if (value_of_rec(x1) == 0.0 && value_of_rec(x2) == 0.0) {
return fvar<T>(0.0);
}
return fvar<T>(multiply_log(x1.val_, x2.val_),
x1.d_ * log(x2.val_) + x1.val_ * x2.d_ / x2.val_);
}

template <typename T>
inline fvar<T> multiply_log(double x1, const fvar<T>& x2) {
if (x1 == 0.0 && value_of_rec(x2) == 0.0) {
return fvar<T>(0.0);
}
return fvar<T>(multiply_log(x1, x2.val_), x1 * x2.d_ / x2.val_);
}

template <typename T>
inline fvar<T> multiply_log(const fvar<T>& x1, double x2) {
if (value_of_rec(x1) == 0.0 && x2 == 0.0) {
return fvar<T>(0.0);
}
return fvar<T>(multiply_log(x1.val_, x2), x1.d_ * log(x2));
}

} // namespace math
} // namespace stan
#endif
4 changes: 2 additions & 2 deletions stan/math/opencl/kernel_generator/elt_function_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,8 @@ ADD_BINARY_FUNCTION_WITH_INCLUDES(log_diff_exp,
opencl_kernels::log_diff_exp_device_function)
ADD_BINARY_FUNCTION_WITH_INCLUDES(
multiply_log, stan::math::opencl_kernels::multiply_log_device_function)
ADD_BINARY_FUNCTION_WITH_INCLUDES(
lmultiply, stan::math::opencl_kernels::lmultiply_device_function)
// ADD_BINARY_FUNCTION_WITH_INCLUDES(
// lmultiply, stan::math::opencl_kernels::lmultiply_device_function)

#undef ADD_BINARY_FUNCTION_WITH_INCLUDES
#undef ADD_UNARY_FUNCTION_WITH_INCLUDES
Expand Down
46 changes: 0 additions & 46 deletions stan/math/opencl/rev/lmultiply.hpp
Original file line number Diff line number Diff line change
@@ -1,46 +0,0 @@
#ifndef STAN_MATH_OPENCL_REV_LMULTIPLY_HPP
#define STAN_MATH_OPENCL_REV_LMULTIPLY_HPP
#ifdef STAN_OPENCL

#include <stan/math/opencl/rev/adjoint_results.hpp>
#include <stan/math/opencl/kernel_generator.hpp>
#include <stan/math/rev/core.hpp>
#include <stan/math/rev/fun/adjoint_of.hpp>
#include <stan/math/rev/fun/value_of.hpp>

namespace stan {
namespace math {

/**
* Returns the elementwise `lmultiply()` of the input.
*
* @tparam T_a type of first expression
* @tparam T_b type of second expression
* @param a first expression
* @param b second expression
*
* @return Elementwise `lmultiply()` of the input.
*/
template <typename T_a, typename T_b,
require_all_prim_or_rev_kernel_expression_t<T_a, T_b>* = nullptr,
require_any_var_t<T_a, T_b>* = nullptr,
require_any_not_stan_scalar_t<T_a, T_b>* = nullptr>
inline var_value<matrix_cl<double>> lmultiply(T_a&& a, T_b&& b) {
arena_t<T_a> a_arena = std::forward<T_a>(a);
arena_t<T_b> b_arena = std::forward<T_b>(b);

return make_callback_var(
lmultiply(value_of(a_arena), value_of(b_arena)),
[a_arena, b_arena](const vari_value<matrix_cl<double>>& res) mutable {
adjoint_results(a_arena, b_arena) += expressions(
elt_multiply(res.adj(), log(value_of(b_arena))),
elt_multiply(res.adj(),
elt_divide(value_of(a_arena), value_of(b_arena))));
});
}

} // namespace math
} // namespace stan

#endif
#endif
34 changes: 30 additions & 4 deletions stan/math/opencl/rev/multiply_log.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@

namespace stan {
namespace math {
namespace internalcl {
template <bool Cond, typename T>
inline decltype(auto) conditional_sum(T&& x) {
if constexpr (Cond) {
return sum(std::forward<T>(x));
} else {
return std::forward<T>(x);
}
}
} // namespace internalcl

/**
* Returns the elementwise `multiply_log()` of the input.
Expand All @@ -32,10 +42,26 @@ inline var_value<matrix_cl<double>> multiply_log(T_a&& a, T_b&& b) {
return make_callback_var(
multiply_log(value_of(a_arena), value_of(b_arena)),
[a_arena, b_arena](const vari_value<matrix_cl<double>>& res) mutable {
adjoint_results(a_arena, b_arena) += expressions(
elt_multiply(res.adj(), log(value_of(b_arena))),
elt_multiply(res.adj(),
elt_divide(value_of(a_arena), value_of(b_arena))));
constexpr bool is_scalar_a = !is_matrix_v<T_a>;
constexpr bool is_scalar_b = !is_matrix_v<T_b>;
using internalcl::conditional_sum;
auto is_zero = value_of(a_arena) == 0.0 && value_of(b_arena) == 0.0;
if constexpr (is_var<T_a>::value && is_var<T_b>::value) {
a_arena.adj() += conditional_sum<is_scalar_a>(select(
is_zero, 0.0, elt_multiply(res.adj(), log(value_of(b_arena)))));
b_arena.adj() += conditional_sum<is_scalar_b>(select(
is_zero, 0.0,
elt_multiply(res.adj(),
elt_divide(value_of(a_arena), value_of(b_arena)))));
} else if constexpr (is_var<T_a>::value) {
a_arena.adj() += conditional_sum<is_scalar_a>(select(
is_zero, 0.0, elt_multiply(res.adj(), log(value_of(b_arena)))));
} else if constexpr (is_var<T_b>::value) {
b_arena.adj() += conditional_sum<is_scalar_b>(select(
is_zero, 0.0,
elt_multiply(res.adj(),
elt_divide(value_of(a_arena), value_of(b_arena)))));
}
});
}

Expand Down
Loading