Skip to content

Commit

Permalink
Try #1691:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] committed Jan 21, 2022
2 parents b2cb6ce + a6a2972 commit 99b9fcc
Show file tree
Hide file tree
Showing 17 changed files with 89 additions and 91 deletions.
3 changes: 0 additions & 3 deletions CMakeLists.txt
Expand Up @@ -395,9 +395,6 @@ set(ARB_MODCC_FLAGS)
if(ARB_VECTORIZE)
list(APPEND ARB_MODCC_FLAGS "--simd")
endif()
if(ARB_WITH_PROFILING)
list(APPEND ARB_MODCC_FLAGS "--profile")
endif()

#----------------------------------------------------------
# Set up install paths, permissions.
Expand Down
2 changes: 1 addition & 1 deletion arbor/benchmark_cell_group.cpp
Expand Up @@ -61,7 +61,7 @@ void benchmark_cell_group::advance(epoch ep,
using std::chrono::high_resolution_clock;
using duration_type = std::chrono::duration<double, std::micro>;

PE(advance_bench_cell);
PE(advance:bench:cell);
// Micro-seconds to advance in this epoch.
auto us = 1e3*(ep.duration());
for (auto i: util::make_span(0, gids_.size())) {
Expand Down
4 changes: 2 additions & 2 deletions arbor/communication/communicator.cpp
Expand Up @@ -139,12 +139,12 @@ time_type communicator::min_delay() {
}

gathered_vector<spike> communicator::exchange(std::vector<spike> local_spikes) {
PE(communication_exchange_sort);
PE(communication:exchange:sort);
// sort the spikes in ascending order of source gid
util::sort_by(local_spikes, [](spike s){return s.source;});
PL();

PE(communication_exchange_gather);
PE(communication:exchange:gather);
// global all-to-all to gather a local copy of the global spike list on each node.
auto global_spikes = distributed_->gather_spikes(local_spikes);
num_spikes_ += global_spikes.size();
Expand Down
26 changes: 13 additions & 13 deletions arbor/fvm_lowered_cell_impl.hpp
Expand Up @@ -198,7 +198,7 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate(
set_gpu();

// Integration setup
PE(advance_integrate_setup);
PE(advance:integrate:setup);
threshold_watcher_.clear_crossings();

auto n_samples = staged_samples.size();
Expand Down Expand Up @@ -227,11 +227,11 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate(

// Deliver events and accumulate mechanism current contributions.

PE(advance_integrate_events);
PE(advance:integrate:events);
state_->deliverable_events.mark_until_after(state_->time);
PL();

PE(advance_integrate_current_zero);
PE(advance:integrate:current:zero);
state_->zero_currents();
PL();
for (auto& m: mechanisms_) {
Expand All @@ -245,7 +245,7 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate(
m->update_current();
}

PE(advance_integrate_events);
PE(advance:integrate:events);
state_->deliverable_events.drop_marked_events();

// Update event list and integration step times.
Expand All @@ -260,24 +260,24 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate(
// want to use mean current contributions as opposed to point
// sample.)

PE(advance_integrate_stimuli)
PE(advance:integrate:stimuli)
state_->add_stimulus_current();
PL();

// Take samples at cell time if sample time in this step interval.

PE(advance_integrate_samples);
PE(advance:integrate:samples);
sample_events_.mark_until(state_->time_to);
state_->take_samples(sample_events_.marked_events(), sample_time_, sample_value_);
sample_events_.drop_marked_events();
PL();

// Integrate voltage by matrix solve.

PE(advance_integrate_matrix_build);
PE(advance:integrate:matrix:build);
matrix_.assemble(state_->dt_intdom, state_->voltage, state_->current_density, state_->conductivity);
PL();
PE(advance_integrate_matrix_solve);
PE(advance:integrate:matrix:solve);
matrix_.solve(state_->voltage);
PL();

Expand All @@ -289,17 +289,17 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate(

// Update ion concentrations.

PE(advance_integrate_ionupdate);
PE(advance:integrate:ionupdate);
update_ion_state();
PL();

// Update time and test for spike threshold crossings.

PE(advance_integrate_threshold);
PE(advance:integrate:threshold);
threshold_watcher_.test(&state_->time_since_spike);
PL();

PE(advance_integrate_post)
PE(advance:integrate:post)
if (post_events_) {
for (auto& m: mechanisms_) {
m->post_event();
Expand All @@ -313,14 +313,14 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate(
// Check for non-physical solutions:

if (check_voltage_mV_>0) {
PE(advance_integrate_physicalcheck);
PE(advance:integrate:physicalcheck);
assert_voltage_bounded(check_voltage_mV_);
PL();
}

// Check for end of integration.

PE(advance_integrate_stepsupdate);
PE(advance:integrate:stepsupdate);
if (!--remaining_steps) {
tmin_ = state_->time_bounds().first;
remaining_steps = dt_steps(tmin_, tfinal, dt_max);
Expand Down
23 changes: 21 additions & 2 deletions arbor/include/arbor/mechanism.hpp
Expand Up @@ -9,6 +9,8 @@
#include <arbor/fvm_types.hpp>
#include <arbor/mechanism_abi.h>
#include <arbor/mechinfo.hpp>
#include <arbor/profile/profiler.hpp>
#include <arbor/version.hpp>

namespace arb {

Expand All @@ -32,6 +34,8 @@ class mechanism {
mechanism(const arb_mechanism_type m,
const arb_mechanism_interface& i): mech_{m}, iface_{i}, ppack_{} {
if (mech_.abi_version != ARB_MECH_ABI_VERSION) throw unsupported_abi_error{mech_.abi_version};
state_prof_id = profile::profiler_region_id("advance:integrate:state:"+internal_name());
current_prof_id = profile::profiler_region_id("advance:integrate:current:"+internal_name());
}
mechanism() = default;
mechanism(const mechanism&) = delete;
Expand All @@ -55,8 +59,8 @@ class mechanism {

// Forward to interface methods
void initialize() { ppack_.vec_t = *time_ptr_ptr; iface_.init_mechanism(&ppack_); }
void update_current() { ppack_.vec_t = *time_ptr_ptr; iface_.compute_currents(&ppack_); }
void update_state() { ppack_.vec_t = *time_ptr_ptr; iface_.advance_state(&ppack_); }
void update_current() { prof_enter(current_prof_id); ppack_.vec_t = *time_ptr_ptr; iface_.compute_currents(&ppack_); prof_exit(); }
void update_state() { prof_enter(state_prof_id); ppack_.vec_t = *time_ptr_ptr; iface_.advance_state(&ppack_); prof_exit(); }
void update_ions() { ppack_.vec_t = *time_ptr_ptr; iface_.write_ions(&ppack_); }
void post_event() { ppack_.vec_t = *time_ptr_ptr; iface_.post_event(&ppack_); }
void deliver_events(arb_deliverable_event_stream& stream) { ppack_.vec_t = *time_ptr_ptr; iface_.apply_events(&ppack_, &stream); }
Expand All @@ -68,6 +72,21 @@ class mechanism {
arb_mechanism_interface iface_;
arb_mechanism_ppack ppack_;
arb_value_type** time_ptr_ptr = nullptr;

private:
#ifdef ARB_PROFILE_ENABLED
void prof_enter(profile::region_id_type id) {
profile::profiler_enter(id);
}
void prof_exit() {
profile::profiler_leave();
}
#else
void prof_enter(profile::region_id_type) {}
void prof_exit() {}
#endif
profile::region_id_type state_prof_id;
profile::region_id_type current_prof_id;
};

struct mechanism_layout {
Expand Down
2 changes: 1 addition & 1 deletion arbor/include/arbor/profile/profiler.hpp
Expand Up @@ -39,7 +39,7 @@ void profiler_enter(std::size_t region_id);
void profiler_leave();

profile profiler_summary();
std::size_t profiler_region_id(const char* name);
std::size_t profiler_region_id(const std::string& name);

std::ostream& operator<<(std::ostream&, const profile&);

Expand Down
2 changes: 1 addition & 1 deletion arbor/lif_cell_group.cpp
Expand Up @@ -40,7 +40,7 @@ cell_kind lif_cell_group::get_cell_kind() const {
}

void lif_cell_group::advance(epoch ep, time_type dt, const event_lane_subrange& event_lanes) {
PE(advance_lif);
PE(advance:lif);
if (event_lanes.size() > 0) {
for (auto lid: util::make_span(gids_.size())) {
// Advance each cell independently.
Expand Down
6 changes: 3 additions & 3 deletions arbor/mc_cell_group.cpp
Expand Up @@ -394,7 +394,7 @@ void mc_cell_group::advance(epoch ep, time_type dt, const event_lane_subrange& e

// Bin and collate deliverable events from event lanes.

PE(advance_eventsetup);
PE(advance:eventsetup);
staged_events_.clear();

// Skip event handling if nothing to deliver.
Expand Down Expand Up @@ -452,7 +452,7 @@ void mc_cell_group::advance(epoch ep, time_type dt, const event_lane_subrange& e
// value as defined below, grouping together all the samples of the
// same probe for this callback in this association.

PE(advance_samplesetup);
PE(advance:samplesetup);
std::vector<sampler_call_info> call_info;

std::vector<sample_event> sample_events;
Expand Down Expand Up @@ -533,7 +533,7 @@ void mc_cell_group::advance(epoch ep, time_type dt, const event_lane_subrange& e
// vector of sample entries from the lowered cell sample times and values
// and then call the callback.

PE(advance_sampledeliver);
PE(advance:sampledeliver);
std::vector<sample_record> sample_records;
sample_records.reserve(max_samples_per_call);

Expand Down
36 changes: 25 additions & 11 deletions arbor/profile/profiler.cpp
Expand Up @@ -20,24 +20,24 @@ using util::make_span;
namespace {
// Check whether a string describes a valid profiler region name.
bool is_valid_region_string(const std::string& s) {
if (s.size()==0u || s.front()=='_' || s.back()=='_') return false;
if (s.size()==0u || s.front()==':' || s.back()==':') return false;
return s.find("__") == s.npos;
}

//
// Return a list of the words in the string, using '_' as the delimiter
// Return a list of the words in the string, using ':' as the delimiter
// string, e.g.:
// "communicator" -> {"communicator"}
// "communicator_events" -> {"communicator", "events"}
// "communicator_events_sort" -> {"communicator", "events", "sort"}
std::vector<std::string> split(const std::string& str) {
std::vector<std::string> cont;
std::size_t first = 0;
std::size_t last = str.find('_');
std::size_t last = str.find(':');
while (last != std::string::npos) {
cont.push_back(str.substr(first, last - first));
first = last + 1;
last = str.find('_', first);
last = str.find(':', first);
}
cont.push_back(str.substr(first, last - first));
return cont;
Expand Down Expand Up @@ -91,7 +91,7 @@ class profiler {
// The regions are assigned consecutive indexes in the order that they are
// added to the profiler with calls to `region_index()`, with the first
// region numbered zero.
std::unordered_map<const char*, region_id_type> name_index_;
std::unordered_map<std::string, region_id_type> name_index_;

// The name of each region being recorded, with index stored in name_index_
// is used to index into region_names_.
Expand All @@ -108,10 +108,10 @@ class profiler {

void initialize(task_system_handle& ts);
void enter(region_id_type index);
void enter(const char* name);
void enter(const std::string& name);
void leave();
const std::vector<std::string>& regions() const;
region_id_type region_index(const char* name);
region_id_type region_index(const std::string& name);
profile results() const;

static profiler& get_global_profiler() {
Expand Down Expand Up @@ -186,7 +186,7 @@ void profiler::enter(region_id_type index) {
recorders_[thread_ids_.at(std::this_thread::get_id())].enter(index);
}

void profiler::enter(const char* name) {
void profiler::enter(const std::string& name) {
if (!init_) return;
const auto index = region_index(name);
recorders_[thread_ids_.at(std::this_thread::get_id())].enter(index);
Expand All @@ -197,7 +197,7 @@ void profiler::leave() {
recorders_[thread_ids_.at(std::this_thread::get_id())].leave();
}

region_id_type profiler::region_index(const char* name) {
region_id_type profiler::region_index(const std::string& name) {
// The name_index_ hash table is shared by all threads, so all access
// has to be protected by a mutex.
std::lock_guard<std::mutex> guard(mutex_);
Expand Down Expand Up @@ -249,6 +249,20 @@ profile profiler::results() const {

p.num_threads = recorders_.size();

// Remove elements with count == 0
for(unsigned i=0; i<p.counts.size();) {
if (p.counts[i] != 0) {
++i;
continue;
}
std::swap(p.counts[i], p.counts.back());
std::swap(p.times[i], p.times.back());
std::swap(p.names[i], p.names.back());
p.counts.pop_back();
p.times.pop_back();
p.names.pop_back();
}

return p;
}

Expand Down Expand Up @@ -332,7 +346,7 @@ void profiler_leave() {
profiler::get_global_profiler().leave();
}

region_id_type profiler_region_id(const char* name) {
region_id_type profiler_region_id(const std::string& name) {
if (!is_valid_region_string(name)) {
throw std::runtime_error(std::string("'")+name+"' is not a valid profiler region name.");
}
Expand Down Expand Up @@ -370,7 +384,7 @@ void profiler_enter(region_id_type) {}
profile profiler_summary();
void profiler_print(const profile& prof, float threshold) {};
profile profiler_summary() {return profile();}
region_id_type profiler_region_id(const char*) {return 0;}
region_id_type profiler_region_id(const std::string&) {return 0;}
std::ostream& operator<<(std::ostream& o, const profile&) {return o;}

#endif // ARB_HAVE_PROFILING
Expand Down

0 comments on commit 99b9fcc

Please sign in to comment.