Skip to content

Commit

Permalink
Consolidate validation test code (issue #41)
Browse files Browse the repository at this point in the history
* Simplify trace analysis and reporting code in
  `trace_analysis.hpp`
* Consolidate convergence test run procedures into
  new class `convergence_test_runner`.
  • Loading branch information
halfflat committed Oct 28, 2016
1 parent 5ade8d0 commit 1b929ff
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 265 deletions.
120 changes: 120 additions & 0 deletions tests/validation/convergence_test.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#pragma once

#include <util/filter.hpp>
#include <util/rangeutil.hpp>

#include <json/json.hpp>

#include "gtest.h"

#include "trace_analysis.hpp"
#include "validation_data.hpp"

namespace nest {
namespace mc {

struct sampler_info {
const char* label;
cell_member_type probe;
simple_sampler sampler;
};

/* Common functionality for testing convergence towards
* a reference data set as some parameter of the model
* is changed.
*
* Type parameter Param is the type of the parameter that
* is changed between runs.
*/

template <typename Param>
class convergence_test_runner {
private:
std::string param_name_;
bool run_validation_;
nlohmann::json meta_;
std::vector<sampler_info> cell_samplers_;
std::map<std::string, trace_data> ref_data_;
std::vector<conv_entry<Param>> conv_tbl_;

public:
template <typename SamplerInfoSeq>
convergence_test_runner(
const std::string& param_name,
const SamplerInfoSeq& samplers,
const nlohmann::json meta
):
param_name_(param_name),
run_validation_(false),
meta_(meta)
{
util::assign(cell_samplers_, samplers);
}

// allow free access to JSON meta data attached to saved traces
nlohmann::json& metadata() { return meta_; }

void load_reference_data(const util::path& ref_path) {
run_validation_ = false;
try {
ref_data_ = g_trace_io.load_traces(ref_path);

run_validation_ = util::all_of(cell_samplers_,
[&](const sampler_info& se) { return ref_data_.count(se.label)>0; });

EXPECT_TRUE(run_validation_);
}
catch (std::runtime_error&) {
ADD_FAILURE() << "failure loading reference data: " << ref_path;
}
}

template <typename Model>
void run(Model& m, Param p, float t_end, float dt) {
// reset samplers and attach to probe locations
for (auto& se: cell_samplers_) {
se.sampler.reset();
m.attach_sampler(se.probe, se.sampler.template sampler<>());
}

m.run(t_end, dt);

for (auto& se: cell_samplers_) {
std::string label = se.label;
const auto& trace = se.sampler.trace;

// save trace
nlohmann::json trace_meta{meta_};
trace_meta[param_name_] = p;

g_trace_io.save_trace(label, trace, trace_meta);

// compute metrics
if (run_validation_) {
double linf = linf_distance(trace, ref_data_[label]);
auto pd = peak_delta(trace, ref_data_[label]);

conv_tbl_.push_back({label, p, linf, pd});
}
}
}

void report() {
if (run_validation_ && g_trace_io.verbose()) {
// reorder to group by id
util::stable_sort_by(conv_tbl_, [](const conv_entry<Param>& e) { return e.id; });
report_conv_table(std::cout, conv_tbl_, param_name_);
}
}

void assert_all_convergence() const {
for (const sampler_info& se: cell_samplers_) {
SCOPED_TRACE(se.label);
assert_convergence(util::filter(conv_tbl_,
[&](const conv_entry<Param>& e) { return e.id==se.label; }));
}
}
};

} // namespace mc
} // namespace nest
60 changes: 30 additions & 30 deletions tests/validation/trace_analysis.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "gtest.h"

#include <simple_sampler.hpp>
#include <math.hpp>
#include <util/optional.hpp>
#include <util/path.hpp>
#include <util/rangeutil.hpp>
Expand Down Expand Up @@ -42,62 +43,61 @@ std::vector<trace_peak> local_maxima(const trace_data& u);
util::optional<trace_peak> peak_delta(const trace_data& a, const trace_data& b);

// Record for error data for convergence testing.
// Only linf and peak_delta are used for convergence testing below;
// if and param are for record keeping in the validation test itself.

template <typename Param>
struct conv_entry {
std::string id;
Param param;
double linf;
util::optional<trace_peak> pd;
util::optional<trace_peak> peak_delta;
};

template <typename Param>
using conv_data = std::vector<conv_entry<Param>>;

// Assert error convergence (gtest).

template <typename Param>
void assert_convergence(const conv_data<Param>& cs) {
if (cs.size()<2) return;
template <typename ConvEntrySeq>
void assert_convergence(const ConvEntrySeq& cs) {
if (util::empty(cs)) return;

auto tbound = [](trace_peak p) { return std::abs(p.t)+p.t_err; };
auto smallest_pd = cs[0].pd;
float peak_dt_bound = math::infinity<>();

for (unsigned i = 1; i<cs.size(); ++i) {
const auto& p = cs[i-1];
const auto& c = cs[i];
for (auto pi = std::begin(cs); std::next(pi)!=std::end(cs); ++pi) {
const auto& p = *pi;
const auto& c = *std::next(pi);

EXPECT_LE(c.linf, p.linf) << "L∞ error increase";
EXPECT_TRUE(c.pd || (!c.pd && !p.pd)) << "divergence in peak count";

if (c.pd && smallest_pd) {
double t = std::abs(c.pd->t);
EXPECT_LE(t, c.pd->t_err+tbound(*smallest_pd)) << "divergence in max peak displacement";
if (!c.peak_delta) {
EXPECT_FALSE(p.peak_delta) << "divergence in peak count";
}
else {
double t = std::abs(c.peak_delta->t);
double t_limit = c.peak_delta->t_err+peak_dt_bound;

EXPECT_LE(t, t_limit) << "divergence in max peak displacement";

if (c.pd && (!smallest_pd || tbound(*c.pd)<tbound(*smallest_pd))) {
smallest_pd = c.pd;
peak_dt_bound = std::min(peak_dt_bound, tbound(*c.peak_delta));
}
}
}

// Report table of convergence results.
// (Takes collection with pair<string, conv_data>
// entries.)

template <typename Map>
void report_conv_table(std::ostream& out, const Map& tbl, const std::string& param_name) {
out << "location," << param_name << ",linf,peak_dt,peak_dt_err\n";
for (const auto& p: tbl) {
const auto& location = p.first;
for (const auto& c: p.second) {
out << location << "," << c.param << "," << c.linf << ",";
if (c.pd) {
out << c.pd->t << "," << c.pd->t_err << "\n";
}
else {
out << "NA,NA\n";
}

template <typename ConvEntrySeq>
void report_conv_table(std::ostream& out, const ConvEntrySeq& tbl, const std::string& param_name) {
out << "id," << param_name << ",linf,peak_dt,peak_dt_err\n";
for (const auto& c: tbl) {
out << c.id << "," << c.param << "," << c.linf << ",";
if (c.peak_delta) {
out << c.peak_delta->t << "," << c.peak_delta->t_err << "\n";
}
else {
out << "NA,NA\n";
}
}
}
Expand Down
Loading

0 comments on commit 1b929ff

Please sign in to comment.