Skip to content

Commit

Permalink
Clean up plasticity (#1985)
Browse files Browse the repository at this point in the history
1. Fix Python bindings for `recipe::update`
   - *drop* the GIL before handing off to C++
   - tighten exception safety
2. Run plasticity examples with threads; both C++ and Python.
   - C++: Guard against I/O interleaving.
   - Py: Drop spikes from source, prettify reporting.
   - C++: use decor chaining.
3. Modernise PYBIND11_OVERLOAD -> *RIDE (advised since 2.6).
4. No longer do we initialise connectivity twice.
   - Simplify communicator construction.
   - Fix unit tests that needed to two-phase init communicator.
  • Loading branch information
thorstenhater committed Oct 5, 2022
1 parent a60fdc3 commit fc85765
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 50 deletions.
11 changes: 4 additions & 7 deletions arbor/communication/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,12 @@ namespace arb {

communicator::communicator(const recipe& rec,
const domain_decomposition& dom_dec,
const label_resolution_map& source_resolution_map,
const label_resolution_map& target_resolution_map,
execution_context& ctx): num_total_cells_{rec.num_cells()},
num_local_cells_{dom_dec.num_local_cells()},
num_local_groups_{dom_dec.num_groups()},
num_domains_{(cell_size_type) ctx.distributed->size()},
distributed_{ctx.distributed},
thread_pool_{ctx.thread_pool} {
update_connections(rec, dom_dec, source_resolution_map, target_resolution_map);
}
thread_pool_{ctx.thread_pool} {}

void communicator::update_connections(const connectivity& rec,
const domain_decomposition& dom_dec,
Expand Down Expand Up @@ -80,7 +76,6 @@ void communicator::update_connections(const connectivity& rec,
auto gid = gids[i];
gid_infos[i] = gid_info(gid, i, rec.connections_on(gid));
});

cell_local_size_type n_cons =
util::sum_by(gid_infos, [](const gid_info& g){ return g.conns.size(); });
std::vector<unsigned> src_domains;
Expand Down Expand Up @@ -129,7 +124,9 @@ void communicator::update_connections(const connectivity& rec,
// This is num_domains_ independent sorts, so it can be parallelized trivially.
const auto& cp = connection_part_;
threading::parallel_for::apply(0, num_domains_, thread_pool_.get(),
[&](cell_size_type i) { util::sort(util::subrange_view(connections_, cp[i], cp[i+1])); });
[&](cell_size_type i) {
util::sort(util::subrange_view(connections_, cp[i], cp[i+1]));
});
}

std::pair<cell_size_type, cell_size_type> communicator::group_queue_range(cell_size_type i) {
Expand Down
2 changes: 0 additions & 2 deletions arbor/communication/communicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ class ARB_ARBOR_API communicator {

explicit communicator(const recipe& rec,
const domain_decomposition& dom_dec,
const label_resolution_map& source_resolver,
const label_resolution_map& target_resolver,
execution_context& ctx);

/// The range of event queues that belong to cells in group i.
Expand Down
3 changes: 1 addition & 2 deletions arbor/simulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ simulation_state::simulation_state(

source_resolution_map_ = label_resolution_map(std::move(global_sources));
target_resolution_map_ = label_resolution_map(std::move(local_targets));
communicator_ = communicator(rec, ddc_, source_resolution_map_, target_resolution_map_, *ctx_);
communicator_ = communicator(rec, ddc_, *ctx_);
update(rec);
epoch_.reset();
}
Expand Down Expand Up @@ -268,7 +268,6 @@ void simulation_state::update(const connectivity& rec) {
event_lanes_[1].resize(num_local_cells);
}


void simulation_state::reset() {
epoch_ = epoch();

Expand Down
40 changes: 28 additions & 12 deletions example/plasticity/plasticity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct recipe: public arb::recipe {
arb::region all = "(all)"_reg; // Whole cell
arb::cell_size_type n_ = 0; // Cell count

mutable std::unordered_map<arb::cell_gid_type, std::vector<arb::cell_connection>> connected; // lookup table for connections
std::unordered_map<arb::cell_gid_type, std::vector<arb::cell_connection>> connected; // lookup table for connections
// Required but uninteresting methods
recipe(arb::cell_size_type n): n_{n} {}
arb::cell_size_type num_cells() const override { return n_; }
Expand All @@ -48,7 +48,12 @@ struct recipe: public arb::recipe {
return {arb::cable_probe_membrane_voltage{center}};
}
// Look up the (potential) connection to this cell
std::vector<arb::cell_connection> connections_on(arb::cell_gid_type gid) const override { return connected[gid]; }
std::vector<arb::cell_connection> connections_on(arb::cell_gid_type gid) const override {
if (auto it = connected.find(gid); it != connected.end()) {
return it->second;
}
return {};
}
// Connect cell `to` to the spike source
void add_connection(arb::cell_gid_type to) { assert(to > 0); connected[to] = {arb::cell_connection({0, src}, {syn}, weight, delay)}; }
// Return the cell at gid
Expand All @@ -57,29 +62,39 @@ struct recipe: public arb::recipe {
if (gid == 0) return arb::spike_source_cell{src, arb::regular_schedule(f_spike)};
// all others are receiving cable cells; single CV w/ HH
arb::segment_tree tree; tree.append(arb::mnpos, {-r_soma, 0, 0, r_soma}, {r_soma, 0, 0, r_soma}, 1);
auto decor = arb::decor{};
decor.paint(all, arb::density("hh", {{"gl", 5}}));
decor.place(center, arb::synapse("expsyn"), syn);
decor.place(center, arb::threshold_detector{-10.0}, det);
decor.set_default(arb::cv_policy_every_segment());
auto decor = arb::decor{}
.paint(all, arb::density("hh", {{"gl", 5}}))
.place(center, arb::synapse("expsyn"), syn)
.place(center, arb::threshold_detector{-10.0}, det)
.set_default(arb::cv_policy_every_segment());
return arb::cable_cell({tree}, {}, decor);
}
};

// For demonstration: Avoid interleaving std::cout in multi-threaded scenarios.
// NEVER do this in HPC!!!
std::mutex mtx;

void sampler(arb::probe_metadata pm, std::size_t n, const arb::sample_record* samples) {
auto* loc = arb::util::any_cast<const arb::mlocation*>(pm.meta);
std::cout << std::fixed << std::setprecision(4);

for (std::size_t i = 0; i<n; ++i) {
std::lock_guard<std::mutex> lock{mtx};
auto* value = arb::util::any_cast<const double*>(samples[i].data);
std::cout << "| " << samples[i].time << " | " << loc->pos << " | " << *value << " |\n";
std::cout << std::fixed << std::setprecision(4)
<< "| " << samples[i].time << " | " << loc->pos << " | " << *value << " |\n";
}
}

void spike_cb(const std::vector<arb::spike>& spikes) {
for(const auto& spike: spikes) std::cout << " * " << spike.source << "@" << spike.time << '\n';
for(const auto& spike: spikes) {
std::lock_guard<std::mutex> lock{mtx};
std::cout << " * " << spike.source << "@" << spike.time << '\n';
}
}

void print_header(double from, double to) {
std::lock_guard<std::mutex> lock{mtx};
std::cout << "\n"
<< "Running simulation from " << from << "ms to " << to << "ms\n"
<< "Spikes are marked: *\n"
Expand All @@ -93,9 +108,10 @@ const double dt = 0.05;
int main(int argc, char** argv) {
auto rec = recipe(3);
rec.add_connection(1);
auto sim = arb::simulation(rec);
auto ctx = arb::make_context(arb::proc_allocation{8, -1});
auto sim = arb::simulation(rec, ctx);
sim.add_sampler(arb::all_probes, arb::regular_schedule(dt), sampler, arb::sampling_policy::exact);
sim.set_local_spike_callback(spike_cb);
sim.set_global_spike_callback(spike_cb);
print_header(0, 1);
sim.run(1.0, dt);
rec.add_connection(2);
Expand Down
22 changes: 12 additions & 10 deletions python/example/plasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ def add_connection_to_spike_source(self, to):
self.connected.add(to)


# Context for multi-threading
ctx = A.context(threads=2)
# Make an unconnected network with 2 cable cells and one spike source,
rec = recipe(3)

# but before setting up anything, connect cable cell gid=1 to spike source gid=0
# and make the simulation of the simple network
#
Expand All @@ -88,24 +89,25 @@ def add_connection_to_spike_source(self, to):
# Note that the connection is just _recorded_ in the recipe, the actual connectivity
# is set up in the simulation construction.
rec.add_connection_to_spike_source(1)
sim = A.simulation(rec)
sim = A.simulation(rec, ctx)
sim.record(A.spike_recording.all)

# then run the simulation for a bit
sim.run(0.25, 0.025)

# update the simulation to
#
# spike_source <gid=0> ----> cable_cell <gid=1>
# \
# ----> cable_cell <gid=2>
rec.add_connection_to_spike_source(2)
sim.update(rec)

# and run the simulation for another bit.
sim.run(0.5, 0.025)

# When finished, print spike times and locations.
print("spikes:")
for sp in sim.spikes():
print(" ", sp)
# when finished, print spike times and locations.
source_spikes = 0
print("Spikes:")
for (gid, lid), t in sim.spikes():
if gid == 0:
source_spikes += 1
else:
print(f" * {t:>8.4f}ms: gid={gid} detector={lid}")
print(f"Source spiked {source_spikes:>5d} times.")
16 changes: 8 additions & 8 deletions python/recipe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,35 +53,35 @@ class py_recipe {
class py_recipe_trampoline: public py_recipe {
public:
arb::cell_size_type num_cells() const override {
PYBIND11_OVERLOAD_PURE(arb::cell_size_type, py_recipe, num_cells);
PYBIND11_OVERRIDE_PURE(arb::cell_size_type, py_recipe, num_cells);
}

pybind11::object cell_description(arb::cell_gid_type gid) const override {
PYBIND11_OVERLOAD_PURE(pybind11::object, py_recipe, cell_description, gid);
PYBIND11_OVERRIDE_PURE(pybind11::object, py_recipe, cell_description, gid);
}

arb::cell_kind cell_kind(arb::cell_gid_type gid) const override {
PYBIND11_OVERLOAD_PURE(arb::cell_kind, py_recipe, cell_kind, gid);
PYBIND11_OVERRIDE_PURE(arb::cell_kind, py_recipe, cell_kind, gid);
}

std::vector<pybind11::object> event_generators(arb::cell_gid_type gid) const override {
PYBIND11_OVERLOAD(std::vector<pybind11::object>, py_recipe, event_generators, gid);
PYBIND11_OVERRIDE(std::vector<pybind11::object>, py_recipe, event_generators, gid);
}

std::vector<arb::cell_connection> connections_on(arb::cell_gid_type gid) const override {
PYBIND11_OVERLOAD(std::vector<arb::cell_connection>, py_recipe, connections_on, gid);
PYBIND11_OVERRIDE(std::vector<arb::cell_connection>, py_recipe, connections_on, gid);
}

std::vector<arb::gap_junction_connection> gap_junctions_on(arb::cell_gid_type gid) const override {
PYBIND11_OVERLOAD(std::vector<arb::gap_junction_connection>, py_recipe, gap_junctions_on, gid);
PYBIND11_OVERRIDE(std::vector<arb::gap_junction_connection>, py_recipe, gap_junctions_on, gid);
}

std::vector<arb::probe_info> probes(arb::cell_gid_type gid) const override {
PYBIND11_OVERLOAD(std::vector<arb::probe_info>, py_recipe, probes, gid);
PYBIND11_OVERRIDE(std::vector<arb::probe_info>, py_recipe, probes, gid);
}

pybind11::object global_properties(arb::cell_kind kind) const override {
PYBIND11_OVERLOAD(pybind11::object, py_recipe, global_properties, kind);
PYBIND11_OVERRIDE(pybind11::object, py_recipe, global_properties, kind);
}
};

Expand Down
15 changes: 11 additions & 4 deletions python/simulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ class simulation_shim {
}
}

void update(std::shared_ptr<py_recipe>& rec) {
try {
sim_->update(py_recipe_shim(rec));
}
catch (...) {
py_reset_and_throw();
throw;
}
}

void reset() {
sim_->reset();
spike_record_.clear();
Expand Down Expand Up @@ -94,10 +104,6 @@ class simulation_shim {
sim_->set_binning_policy(policy, bin_interval);
}

void update(std::shared_ptr<py_recipe>& rec) {
sim_->update(py_recipe_shim(rec));
}

void record(spike_recording policy) {
auto spike_recorder = [this](const std::vector<arb::spike>& spikes) {
auto old_size = spike_record_.size();
Expand Down Expand Up @@ -222,6 +228,7 @@ void register_simulation(pybind11::module& m, pyarb_global_ptr global_ptr) {
pybind11::arg_v("domains", pybind11::none(), "Domain decomposition"),
pybind11::arg_v("seed", 0u, "Random number generator seed"))
.def("update", &simulation_shim::update,
pybind11::call_guard<pybind11::gil_scoped_release>(),
"Rebuild the connection table from recipe::connections_on and the event"
"generators based on recipe::event_generators.",
"recipe"_a)
Expand Down
12 changes: 7 additions & 5 deletions test/unit-distributed/test_communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,8 +531,8 @@ TEST(communicator, ring)
auto global_sources = g_context->distributed->gather_cell_labels_and_gids(local_sources);

// construct the communicator
auto C = communicator(R, D, label_resolution_map(global_sources), label_resolution_map(local_targets), *g_context);

auto C = communicator(R, D, *g_context);
C.update_connections(R, D, label_resolution_map(global_sources), label_resolution_map(local_targets));
// every cell fires
EXPECT_TRUE(test_ring(D, C, [](cell_gid_type g){return true;}));
// last cell in each domain fires
Expand Down Expand Up @@ -638,11 +638,12 @@ TEST(communicator, all2all)
auto global_sources = g_context->distributed->gather_cell_labels_and_gids({local_sources, mc_gids});

// construct the communicator
auto C = communicator(R, D, label_resolution_map(global_sources), label_resolution_map({local_targets, mc_gids}), *g_context);
auto C = communicator(R, D, *g_context);
C.update_connections(R, D, label_resolution_map(global_sources), label_resolution_map({local_targets, mc_gids}));
auto connections = C.connections();

for (auto i: util::make_span(0, n_global)) {
for (unsigned j = 0; j < n_local; ++j) {
for (auto j: util::make_span(0, n_local)) {
auto c = connections[i*n_local+j];
EXPECT_EQ(i, c.source.gid);
EXPECT_EQ(0u, c.source.index);
Expand Down Expand Up @@ -684,7 +685,8 @@ TEST(communicator, mini_network)
auto global_sources = g_context->distributed->gather_cell_labels_and_gids({local_sources, gids});

// construct the communicator
auto C = communicator(R, D, label_resolution_map(global_sources), label_resolution_map({local_targets, gids}), *g_context);
auto C = communicator(R, D, *g_context);
C.update_connections(R, D, label_resolution_map(global_sources), label_resolution_map({local_targets, gids}));

// sort connections by source then target
auto connections = C.connections();
Expand Down

0 comments on commit fc85765

Please sign in to comment.