Skip to content

Commit 52d5635

Browse files
author
Bob Carpenter
authored
Merge pull request stan-dev#510 from stan-dev/feature/issue-38-univariate-normal-distribution-on-sufficient-statistics
Feature/issue 38 univariate normal distribution on sufficient statistics
2 parents c21bf4a + 1578d9c commit 52d5635

6 files changed

Lines changed: 352 additions & 4 deletions

File tree

stan/math/prim/scal.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,8 @@
333333
#include <stan/math/prim/scal/prob/normal_log.hpp>
334334
#include <stan/math/prim/scal/prob/normal_lpdf.hpp>
335335
#include <stan/math/prim/scal/prob/normal_rng.hpp>
336+
#include <stan/math/prim/scal/prob/normal_sufficient_log.hpp>
337+
#include <stan/math/prim/scal/prob/normal_sufficient_lpdf.hpp>
336338
#include <stan/math/prim/scal/prob/pareto_ccdf_log.hpp>
337339
#include <stan/math/prim/scal/prob/pareto_cdf.hpp>
338340
#include <stan/math/prim/scal/prob/pareto_cdf_log.hpp>

stan/math/prim/scal/meta/max_size.hpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,26 @@ namespace stan {
2121
}
2222

2323
template <typename T1, typename T2, typename T3, typename T4>
24-
size_t max_size(const T1& x1, const T2& x2, const T3& x3, const T4& x4) {
24+
size_t max_size(const T1& x1, const T2& x2, const T3& x3,
25+
const T4& x4) {
2526
size_t result = length(x1);
2627
result = result > length(x2) ? result : length(x2);
2728
result = result > length(x3) ? result : length(x3);
2829
result = result > length(x4) ? result : length(x4);
2930
return result;
3031
}
3132

33+
template <typename T1, typename T2, typename T3, typename T4,
34+
typename T5>
35+
size_t max_size(const T1& x1, const T2& x2, const T3& x3,
36+
const T4& x4, const T5& x5) {
37+
size_t result = length(x1);
38+
result = result > length(x2) ? result : length(x2);
39+
result = result > length(x3) ? result : length(x3);
40+
result = result > length(x4) ? result : length(x4);
41+
result = result > length(x5) ? result : length(x5);
42+
return result;
43+
}
44+
3245
}
3346
#endif
34-

stan/math/prim/scal/prob/normal_lpdf.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,16 @@ namespace stan {
2929
*
3030
* <p>The result log probability is defined to be the sum of the
3131
* log probabilities for each observation/mean/deviation triple.
32+
* @tparam T_y Underlying type of scalar in sequence.
33+
* @tparam T_loc Type of location parameter.
34+
* @tparam T_scale Type of scale parameter.
3235
* @param y (Sequence of) scalar(s).
3336
* @param mu (Sequence of) location parameter(s)
3437
* for the normal distribution.
3538
* @param sigma (Sequence of) scale parameters for the normal
3639
* distribution.
3740
* @return The log of the product of the densities.
3841
* @throw std::domain_error if the scale is not positive.
39-
* @tparam T_y Underlying type of scalar in sequence.
40-
* @tparam T_loc Type of location parameter.
4142
*/
4243
template <bool propto,
4344
typename T_y, typename T_loc, typename T_scale>
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#ifndef STAN_MATH_PRIM_SCAL_PROB_NORMAL_SUFFICIENT_LOG_HPP
2+
#define STAN_MATH_PRIM_SCAL_PROB_NORMAL_SUFFICIENT_LOG_HPP
3+
4+
#include <stan/math/prim/scal/meta/return_type.hpp>
5+
#include <stan/math/prim/scal/prob/normal_sufficient_lpdf.hpp>
6+
7+
namespace stan {
8+
namespace math {
9+
10+
/**
11+
* @deprecated use <code>normal_lpdf</code>
12+
*/
13+
template <bool propto,
14+
typename T_y, typename T_s, typename T_n,
15+
typename T_loc, typename T_scale>
16+
inline
17+
typename return_type<T_y, T_s, T_loc, T_scale>::type
18+
normal_sufficient_log(const T_y& y_bar, const T_s& s_squared,
19+
const T_n& n_obs, const T_loc& mu,
20+
const T_scale& sigma) {
21+
return normal_sufficient_lpdf<propto, T_y, T_s, T_n,
22+
T_loc, T_scale>(y_bar, s_squared,
23+
n_obs, mu, sigma);
24+
}
25+
26+
/**
27+
* @deprecated use <code>normal_lpdf</code>
28+
*/
29+
template <typename T_y, typename T_s, typename T_n,
30+
typename T_loc, typename T_scale>
31+
inline
32+
typename return_type<T_y, T_s, T_loc, T_scale>::type
33+
normal_sufficient_log(const T_y& y_bar, const T_s& s_squared,
34+
const T_n& n_obs, const T_loc& mu,
35+
const T_scale& sigma) {
36+
return normal_sufficient_lpdf<T_y, T_s, T_n,
37+
T_loc, T_scale>(y_bar, s_squared,
38+
n_obs, mu, sigma);
39+
}
40+
41+
}
42+
}
43+
#endif
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
#ifndef STAN_MATH_PRIM_SCAL_PROB_NORMAL_SUFFICIENT_LPDF_HPP
2+
#define STAN_MATH_PRIM_SCAL_PROB_NORMAL_SUFFICIENT_LPDF_HPP
3+
4+
#include <stan/math/prim/scal/meta/return_type.hpp>
5+
#include <stan/math/prim/scal/prob/normal_lpdf.hpp>
6+
7+
#include <stan/math/prim/scal/meta/OperandsAndPartials.hpp>
8+
#include <stan/math/prim/scal/meta/scalar_seq_view.hpp>
9+
#include <stan/math/prim/scal/err/check_consistent_sizes.hpp>
10+
#include <stan/math/prim/scal/err/check_finite.hpp>
11+
#include <stan/math/prim/scal/err/check_positive.hpp>
12+
#include <stan/math/prim/scal/err/check_nonnegative.hpp>
13+
#include <stan/math/prim/scal/fun/constants.hpp>
14+
#include <stan/math/prim/scal/fun/value_of.hpp>
15+
#include <stan/math/prim/scal/meta/include_summand.hpp>
16+
#include <stan/math/prim/scal/meta/VectorBuilder.hpp>
17+
#include <stan/math/prim/scal/meta/max_size.hpp>
18+
19+
namespace stan {
20+
21+
namespace math {
22+
23+
/**
24+
* The log of the normal density for the specified scalar(s) given
25+
* the specified mean(s) and deviation(s).
26+
* y, s_quared, mu, or sigma can each be either
27+
* a scalar, a std vector or Eigen vector.
28+
* n can be either a single int or an std vector of ints.
29+
* Any vector inputs must be the same length.
30+
*
31+
* <p>The result log probability is defined to be the sum of the
32+
* log probabilities for each observation/mean/deviation triple.
33+
*
34+
* @tparam T_y Type of sample average parameter.
35+
* @tparam T_s Type of sample squared errors parameter.
36+
* @tparam T_n Type of sample size parameter.
37+
* @tparam T_loc Type of location parameter.
38+
* @tparam T_scale Type of scale parameter.
39+
* @param y_bar (Sequence of) scalar(s) (sample average(s)).
40+
* @param s_squared (Sequence of) sum(s) of sample squared errors
41+
* @param n_obs (Sequence of) sample size(s)
42+
* @param mu (Sequence of) location parameter(s)
43+
* for the normal distribution.
44+
* @param sigma (Sequence of) scale parameters for the normal
45+
* distribution.
46+
* @return The log of the product of the densities.
47+
* @throw std::domain_error if either n or sigma are not positive,
48+
* if s_squared is negative or if any parameter is not finite.
49+
*/
50+
template <bool propto,
51+
typename T_y, typename T_s, typename T_n, typename T_loc,
52+
typename T_scale>
53+
typename return_type<T_y, T_s, T_loc, T_scale>::type
54+
normal_sufficient_lpdf(const T_y& y_bar, const T_s& s_squared,
55+
const T_n& n_obs, const T_loc& mu,
56+
const T_scale& sigma) {
57+
static const char*
58+
function = "stan::math::normal_sufficient_lpdf(%1%)";
59+
typedef typename
60+
stan::partials_return_type<T_y, T_s, T_n, T_loc, T_scale>::type
61+
T_partials_return;
62+
63+
using std::log;
64+
using stan::is_constant_struct;
65+
using stan::math::check_positive;
66+
using stan::math::check_finite;
67+
using stan::math::check_not_nan;
68+
using stan::math::check_consistent_sizes;
69+
using stan::math::value_of;
70+
using stan::math::include_summand;
71+
72+
// check if any vectors are zero length
73+
if (!(stan::length(y_bar)
74+
&& stan::length(s_squared)
75+
&& stan::length(n_obs)
76+
&& stan::length(mu)
77+
&& stan::length(sigma)))
78+
return 0.0;
79+
80+
// set up return value accumulator
81+
T_partials_return logp(0.0);
82+
83+
// validate args (here done over var, which should be OK)
84+
check_finite(function,
85+
"Location parameter sufficient statistic", y_bar);
86+
check_finite(function,
87+
"Scale parameter sufficient statistic", s_squared);
88+
check_nonnegative(function,
89+
"Scale parameter sufficient statistic", s_squared);
90+
check_finite(function,
91+
"Number of observations", n_obs);
92+
check_positive(function,
93+
"Number of observations", n_obs);
94+
check_finite(function,
95+
"Location parameter", mu);
96+
check_finite(function, "Scale parameter", sigma);
97+
check_positive(function, "Scale parameter", sigma);
98+
check_consistent_sizes(function,
99+
"Location parameter sufficient statistic",
100+
y_bar,
101+
"Scale parameter sufficient statistic",
102+
s_squared,
103+
"Number of observations", n_obs,
104+
"Location parameter", mu,
105+
"Scale parameter", sigma);
106+
// check if no variables are involved and prop-to
107+
if (!include_summand<propto, T_y, T_s, T_loc, T_scale>::value)
108+
return 0.0;
109+
110+
// set up template expressions wrapping scalars into vector views
111+
OperandsAndPartials<T_y, T_s, T_loc, T_scale>
112+
operands_and_partials(y_bar, s_squared, mu, sigma);
113+
114+
scalar_seq_view<const T_y> y_bar_vec(y_bar);
115+
scalar_seq_view<const T_s> s_squared_vec(s_squared);
116+
scalar_seq_view<const T_n> n_obs_vec(n_obs);
117+
scalar_seq_view<const T_loc> mu_vec(mu);
118+
scalar_seq_view<const T_scale> sigma_vec(sigma);
119+
size_t N = max_size(y_bar, s_squared, n_obs, mu, sigma);
120+
121+
for (size_t i = 0; i < N; i++) {
122+
const T_partials_return y_bar_dbl = value_of(y_bar_vec[i]);
123+
const T_partials_return s_squared_dbl =
124+
value_of(s_squared_vec[i]);
125+
const T_partials_return n_obs_dbl = n_obs_vec[i];
126+
const T_partials_return mu_dbl = value_of(mu_vec[i]);
127+
const T_partials_return sigma_dbl = value_of(sigma_vec[i]);
128+
const T_partials_return sigma_squared = pow(sigma_dbl, 2);
129+
130+
if (include_summand<propto>::value)
131+
logp += NEG_LOG_SQRT_TWO_PI * n_obs_dbl;
132+
133+
if (include_summand<propto, T_scale>::value)
134+
logp -= n_obs_dbl * log(sigma_dbl);
135+
136+
const T_partials_return cons_expr =
137+
(s_squared_dbl
138+
+ n_obs_dbl * pow(y_bar_dbl - mu_dbl, 2));
139+
140+
logp -= cons_expr / (2 * sigma_squared);
141+
142+
// gradients
143+
if (!is_constant_struct<T_y>::value ||
144+
!is_constant_struct<T_loc>::value) {
145+
const T_partials_return common_derivative =
146+
n_obs_dbl * (mu_dbl - y_bar_dbl) / sigma_squared;
147+
if (!is_constant_struct<T_y>::value)
148+
operands_and_partials.d_x1[i] += common_derivative;
149+
if (!is_constant_struct<T_loc>::value)
150+
operands_and_partials.d_x3[i] -= common_derivative;
151+
}
152+
if (!is_constant_struct<T_s>::value)
153+
operands_and_partials.d_x2[i] -=
154+
0.5 / sigma_squared;
155+
if (!is_constant_struct<T_scale>::value)
156+
operands_and_partials.d_x4[i]
157+
+= cons_expr / pow(sigma_dbl, 3) - n_obs_dbl / sigma_dbl;
158+
}
159+
return operands_and_partials.value(logp);
160+
}
161+
162+
template <typename T_y, typename T_s, typename T_n,
163+
typename T_loc, typename T_scale>
164+
inline
165+
typename return_type<T_y, T_s, T_loc, T_scale>::type
166+
normal_sufficient_lpdf(const T_y& y_bar, const T_s& s_squared,
167+
const T_n& n_obs, const T_loc& mu,
168+
const T_scale& sigma) {
169+
return
170+
normal_sufficient_lpdf<false>(y_bar, s_squared,
171+
n_obs, mu, sigma);
172+
}
173+
174+
}
175+
}
176+
#endif
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
// Arguments: Doubles, Doubles, Ints, Doubles, Doubles
2+
#include <stan/math/prim/scal.hpp>
3+
4+
using std::vector;
5+
using std::numeric_limits;
6+
using stan::math::var;
7+
8+
class AgradDistributionNormalSufficient : public AgradDistributionTest {
9+
public:
10+
void valid_values(vector<vector<double> >& parameters,
11+
vector<double>& log_prob) {
12+
vector<double> param(5);
13+
14+
//observed values: -0.1, 0.1
15+
param[0] = 0; // y_bar
16+
param[1] = 0.02; // s_squared
17+
param[2] = 2; // n_obs
18+
param[3] = -2.3; // mu
19+
param[4] = 2.1; // sigma
20+
parameters.push_back(param);
21+
log_prob.push_back(-4.52356581482502); // expected log_prob
22+
23+
//observed values: 0, 1
24+
param[0] = 0.5; // y_bar
25+
param[1] = 0.5; // s_squared
26+
param[2] = 2; // n_obs
27+
param[3] = 0; // mu
28+
param[4] = 1; // sigma
29+
parameters.push_back(param);
30+
log_prob.push_back(-2.33787706640935); // expected log_prob
31+
32+
33+
//observed values: 0, 2
34+
param[0] = 1; // y_bar
35+
param[1] = 2; // s_squared
36+
param[2] = 2; // n_obs
37+
param[3] = 1; // mu
38+
param[4] = 1; // sigma
39+
parameters.push_back(param);
40+
log_prob.push_back(-2.83787706640935); // expected log_prob
41+
42+
//observed values: 1, 2
43+
param[0] = 1.5; // y_bar
44+
param[1] = 0.5; // s_squared
45+
param[2] = 2; // n_obs
46+
param[3] = -1; // mu
47+
param[4] = 3; // sigma
48+
parameters.push_back(param);
49+
log_prob.push_back(-4.75732386596779); // expected log_prob
50+
}
51+
52+
void invalid_values(vector<size_t>& index,
53+
vector<double>& value) {
54+
// y
55+
56+
// mu
57+
index.push_back(3U);
58+
value.push_back(numeric_limits<double>::infinity());
59+
60+
index.push_back(3U);
61+
value.push_back(-numeric_limits<double>::infinity());
62+
63+
// sigma
64+
index.push_back(4U);
65+
value.push_back(0.0);
66+
67+
index.push_back(4U);
68+
value.push_back(-1.0);
69+
70+
index.push_back(4U);
71+
value.push_back(-numeric_limits<double>::infinity());
72+
}
73+
74+
template <typename T_y, typename T_s, typename T_n,
75+
typename T_loc, typename T_scale, typename T5>
76+
typename stan::return_type<T_y, T_s, T_n, T_loc, T_scale>::type
77+
log_prob(const T_y& y_bar, const T_s& s_squared, const T_n& n_obs,
78+
const T_loc& mu, const T_scale& sigma,
79+
const T5&) {
80+
return stan::math::normal_sufficient_lpdf(y_bar, s_squared, n_obs, mu, sigma);
81+
}
82+
83+
template <bool propto,
84+
typename T_y, typename T_s, typename T_n,
85+
typename T_loc, typename T_scale, typename T5>
86+
typename stan::return_type<T_y, T_s, T_n, T_loc, T_scale>::type
87+
log_prob(const T_y& y_bar, const T_s& s_squared, const T_n& n_obs,
88+
const T_loc& mu, const T_scale& sigma,
89+
const T5&) {
90+
return stan::math::normal_sufficient_lpdf<propto>(y_bar, s_squared, n_obs, mu, sigma);
91+
}
92+
93+
94+
template <typename T_y, typename T_s, typename T_n,
95+
typename T_loc, typename T_scale, typename T5>
96+
typename stan::return_type<T_y, T_s, T_n, T_loc, T_scale>::type
97+
log_prob_function(const T_y& y_bar, const T_s& s_squared, const T_n& n_obs,
98+
const T_loc& mu, const T_scale& sigma,
99+
const T5&) {
100+
using stan::math::include_summand;
101+
using stan::math::pi;
102+
using stan::math::square;
103+
typename stan::return_type<T_y, T_s, T_n, T_loc, T_scale>::type lp(0.0);
104+
if (include_summand<true,T_scale>::value)
105+
lp -= n_obs * log(sigma);
106+
107+
lp -= (s_squared + n_obs * pow(y_bar - mu, 2)) / (2 * pow(sigma, 2));
108+
109+
if (include_summand<true>::value)
110+
lp -= log(sqrt(2.0 * pi()));
111+
return lp;
112+
}
113+
};
114+

0 commit comments

Comments
 (0)