-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c5c06a1
commit 00f055e
Showing
6 changed files
with
132 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
#include <chrono> | ||
#include <iostream> | ||
#include <fstream> | ||
#include <string> | ||
#include <vector> | ||
#include <array> | ||
#include <sstream> | ||
#include <unordered_map> | ||
|
||
#include <autoppl/variable.hpp> | ||
#include <autoppl/expr_builder.hpp> | ||
#include <autoppl/mcmc/hmc/nuts/nuts.hpp> | ||
|
||
#include "benchmark_utils.hpp" | ||
|
||
#include <benchmark/benchmark.h> | ||
|
||
namespace ppl { | ||
|
||
static void BM_Regression(benchmark::State& state) { | ||
constexpr size_t num_samples = 1000; | ||
constexpr size_t n_data = 10000; | ||
|
||
std::array<std::string, 4> headers = {"b", "x1", "x2", "x3"}; | ||
|
||
std::unordered_map<std::string, ppl::Data<double>> data; | ||
std::unordered_map<std::string, ppl::Param<double>> params; | ||
std::array<std::vector<double>, 4> storage; | ||
|
||
std::mt19937 gen; | ||
std::normal_distribution n1(-1.0, 1.4); | ||
std::normal_distribution n2(0.0, 1.4); | ||
std::normal_distribution n3(1.0, 1.4); | ||
std::normal_distribution eps(0.0, 1.0); | ||
|
||
for (size_t i = 0; i < n_data; ++i) { | ||
double x1 = n1(gen); | ||
double x2 = n2(gen); | ||
double x3 = n3(gen); | ||
data[headers[1]].observe(x1); | ||
data[headers[2]].observe(x2); | ||
data[headers[3]].observe(x3); | ||
data["y"].observe(x1 * 1.4 + x2 * 2. + x3 * 0.32 + eps(gen)); | ||
} | ||
|
||
// resize each storage and bind with param | ||
int i = 0; | ||
for (auto it = headers.begin(); it != headers.end(); ++it, ++i) { | ||
storage[i].resize(num_samples); | ||
params[*it].set_storage(storage[i].data()); | ||
} | ||
|
||
auto model = (params["b"] |= ppl::normal(0., 5.), | ||
params["x1"] |= ppl::normal(0., 5.), | ||
params["x2"] |= ppl::normal(0., 5.), | ||
params["x3"] |= ppl::normal(0., 5.), | ||
|
||
data["y"] |= ppl::normal( | ||
params["x1"] * data["x1"] + | ||
params["x2"] * data["x2"] + | ||
params["x3"] * data["x3"] + | ||
params["b"], 1.0)); | ||
|
||
for (auto _ : state) { | ||
ppl::nuts(model); | ||
} | ||
|
||
std::cout << "b: " << sample_average(storage[0]) << std::endl; | ||
std::cout << "w1: " << sample_average(storage[1]) << std::endl; | ||
std::cout << "w2: " << sample_average(storage[2]) << std::endl; | ||
std::cout << "w3: " << sample_average(storage[3]) << std::endl; | ||
} | ||
|
||
BENCHMARK(BM_Regression); | ||
|
||
} // namespace ppl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from cmdstanpy import CmdStanModel | ||
import pandas as pd | ||
import numpy as np | ||
|
||
N = 10000 | ||
X = np.random.normal(loc=[-1, 0, 1], scale=1.4, size=(N, 3)) | ||
w_true = np.array([1.4, 2., 0.32]) | ||
y = X.dot(w_true) + np.random.normal(loc=0., scale=1.0, size=N) | ||
|
||
cool_dat = { | ||
'N' : N, | ||
'x1' : list(X[:,0]), | ||
'x2' : list(X[:,1]), | ||
'x3' : list(X[:,2]), | ||
'y' : list(y) | ||
} | ||
|
||
stan_file = "regression_stan_2.stan" | ||
sm = CmdStanModel(stan_file=stan_file) | ||
fit = sm.sample(data=cool_dat, chains=1, cores=1, | ||
iter_warmup=1000, iter_sampling=1000, thin=1, | ||
max_treedepth=10, metric='diag', adapt_engaged=True, | ||
output_dir='.') | ||
print(fit.summary()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
data { | ||
int N; | ||
vector[N] x1; | ||
vector[N] x2; | ||
vector[N] x3; | ||
vector[N] y; | ||
} | ||
parameters { | ||
real w1; | ||
real w2; | ||
real w3; | ||
real b; | ||
} | ||
model { | ||
b ~ normal(0,5); | ||
w1 ~ normal(0,5); | ||
w2 ~ normal(0,5); | ||
w3 ~ normal(0,5); | ||
y ~ normal(b + w1 * x1 + w2 * x2 + w3 * x3, 1); | ||
} |