Skip to content

Commit

Permalink
fixed bernoulli variable to automatically clip to bounds
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobaustin123 committed May 1, 2020
1 parent c637492 commit 2227e4e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
23 changes: 23 additions & 0 deletions docs/ref/ridge-pymc3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pymc3 as pm
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm

basic_model = pm.Model()

xdata = np.arange(500)
ydata = 2 * xdata + np.random.randn(500) / 4 + 3

with basic_model:
w = pm.Normal('w', mu=0, sigma=5)
b = pm.Normal('b', mu=0, sigma=5)
y = pm.Normal('y', mu=w * xdata + b, sigma=1.0, observed=ydata)

with basic_model:
data = pm.sample(500)

pm.traceplot(data)
plt.show()

print("w: ", data['w'].mean())
print("b: ", data['b'].mean())
4 changes: 2 additions & 2 deletions include/autoppl/expression/distribution/bernoulli.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ struct Bernoulli : util::DistExpr<Bernoulli<p_type>>
}

dist_value_t pdf(value_t x, size_t index=0) const
{
{
if (x == 1) return p(index);
else if (x == 0) return 1. - p();
else return 0.0;
Expand All @@ -43,7 +43,7 @@ struct Bernoulli : util::DistExpr<Bernoulli<p_type>>
else return std::numeric_limits<dist_value_t>::lowest();
}

param_value_t p(size_t index=0) const { return p_.get_value(index); }
param_value_t p(size_t index=0) const { return std::max(std::min(p_.get_value(index), static_cast<param_value_t>(max())), static_cast<param_value_t>(min())); }
value_t min() const { return 0; }
value_t max() const { return 1; }

Expand Down

0 comments on commit 2227e4e

Please sign in to comment.