Permalink
Browse files

ENH: Submit parameter "load"

  • Loading branch information...
spirali committed Apr 20, 2018
1 parent a16f591 commit ada5cfb0eefb4ef6569cba1eba19c3826b143a71
@@ -164,6 +164,7 @@ message ClientRequest {

// PLAN
optional Plan plan = 2;
optional bool load_checkpoints = 4;

// FETCH + RELEASE
optional int32 id = 3;
@@ -241,7 +241,7 @@ def _process_events(self, on_finished=None, on_data=None):
print(t)
assert 0

def submit_one(self, task):
def submit_one(self, task, load=False):
"""Submits a task to the server and returns a future
Args:
@@ -256,9 +256,9 @@ def submit_one(self, task):
>>> result = client.submit(task3)
>>> print(result.gather())
"""
return self.submit((task,))[0]
return self.submit((task,), load=load)[0]

def submit(self, tasks):
def submit(self, tasks, load=False):
"""Submits tasks to the server and returns list of futures
Args:
@@ -294,7 +294,7 @@ def submit(self, tasks):

msg = ClientRequest()
msg.type = ClientRequest.PLAN

msg.load_checkpoints = load
include_metadata = self.trace_path is not None
msg.plan.id_base = id_base
plan.set_message(
@@ -56,8 +56,8 @@ void ClientConnection::on_message(const char *buffer, size_t size)
case ClientRequest_Type_PLAN: {
logger->debug("Plan received");
const Plan &plan = request.plan();
loom::base::Id id_base = task_manager.add_plan(plan);
logger->info("Plan submitted tasks={}", plan.tasks_size());
loom::base::Id id_base = task_manager.add_plan(plan, request.load_checkpoints());
logger->info("Plan submitted tasks={}, load_checkpoints={}", plan.tasks_size(), request.load_checkpoints());

if (server.get_trace()) {
server.create_file_in_trace_dir(std::to_string(id_base) + ".plan", buffer, size);
@@ -23,34 +23,25 @@ ComputationState::ComputationState(Server &server) : server(server)
void ComputationState::add_node(std::unique_ptr<TaskNode> &&node) {
auto id = node->get_id();

/*
for (TaskNode* input_node : node->get_inputs()) {
input_node->add_next(node.get());
}
if (node->is_ready()) {
pending_nodes.insert(node.get());
}*/

auto result = nodes.insert(std::make_pair(id, std::move(node)));
assert(result.second); // Check that ID is fresh
}

void ComputationState::plan_node(TaskNode &node, std::vector<TaskNode *> &to_load) {
void ComputationState::plan_node(TaskNode &node, bool load_checkpoints, std::vector<TaskNode *> &to_load) {
if (node.is_planned()) {
return;
}
node.set_planned();

if (!node.get_task_def().checkpoint_path.empty() && loom::base::file_exists(node.get_task_def().checkpoint_path.c_str())) {
if (load_checkpoints && !node.get_task_def().checkpoint_path.empty() && loom::base::file_exists(node.get_task_def().checkpoint_path.c_str())) {
node.set_checkpoint();
to_load.push_back(&node);
return;
}

int remaining_inputs = 0;
for (TaskNode *input_node : node.get_inputs()) {
plan_node(*input_node, to_load);
plan_node(*input_node, load_checkpoints, to_load);
if (!input_node->is_computed()) {
remaining_inputs += 1;
input_node->add_next(&node);
@@ -269,7 +260,7 @@ void ComputationState::make_expansion(std::vector<std::string> &configs,
}
}*/

loom::base::Id ComputationState::add_plan(const loom::pb::comm::Plan &plan, std::vector<TaskNode*> &to_load)
loom::base::Id ComputationState::add_plan(const loom::pb::comm::Plan &plan, bool load_checkpoints, std::vector<TaskNode*> &to_load)
{
auto task_size = plan.tasks_size();
assert(plan.has_id_base());
@@ -317,7 +308,7 @@ loom::base::Id ComputationState::add_plan(const loom::pb::comm::Plan &plan, std:

auto new_node = std::make_unique<TaskNode>(id, std::move(def));
if (is_result) {
plan_node(*new_node.get(), to_load);
plan_node(*new_node.get(), load_checkpoints, to_load);
}
add_node(std::move(new_node));
}
@@ -43,7 +43,7 @@ class ComputationState {

int get_n_data_objects() const;

loom::base::Id add_plan(const loom::pb::comm::Plan &plan, std::vector<TaskNode *> &to_load);
loom::base::Id add_plan(const loom::pb::comm::Plan &plan, bool load_checkpoints, std::vector<TaskNode *> &to_load);
void test_ready_nodes(std::vector<loom::base::Id> ids);

loom::base::Id pop_result_client_id(loom::base::Id id);
@@ -61,7 +61,7 @@ class ComputationState {
std::unique_ptr<TaskNode> pop_node(loom::base::Id id);
void clear_all();
void add_pending_node(TaskNode &node);
void plan_node(TaskNode &node, std::vector<TaskNode*> &to_load);
void plan_node(TaskNode &node, bool load_checkpoints, std::vector<TaskNode*> &to_load);

private:
std::unordered_map<loom::base::Id, std::unique_ptr<TaskNode>> nodes;
@@ -20,10 +20,10 @@ TaskManager::TaskManager(Server &server)
{
}

loom::base::Id TaskManager::add_plan(const loom::pb::comm::Plan &plan)
loom::base::Id TaskManager::add_plan(const loom::pb::comm::Plan &plan, bool load_checkpoints)
{
std::vector<TaskNode*> to_load;
loom::base::Id id_base = cstate.add_plan(plan, to_load);
loom::base::Id id_base = cstate.add_plan(plan, load_checkpoints, to_load);
for (TaskNode *node : to_load) {
WorkerConnection *wc = random_worker();
node->set_as_loading(wc);
@@ -23,7 +23,7 @@ class TaskManager
cstate.add_node(std::move(node));
}*/

loom::base::Id add_plan(const loom::pb::comm::Plan &plan);
loom::base::Id add_plan(const loom::pb::comm::Plan &plan, bool load_checkpoints);

void on_task_finished(loom::base::Id id, size_t size, size_t length, WorkerConnection *wc, bool checkpointing);
void on_data_transferred(loom::base::Id id, WorkerConnection *wc);
@@ -140,12 +140,12 @@ def client(self):
self.check_stats()
return self._client

def submit_and_gather(self, tasks, check=True):
def submit_and_gather(self, tasks, check=True, load=False):
if isinstance(tasks, Task):
future = self.client.submit_one(tasks)
future = self.client.submit_one(tasks, load=load)
return self.client.gather_one(future)
else:
futures = self.client.submit(tasks)
futures = self.client.submit(tasks, load=load)
return self.client.gather(futures)
if check:
self.check_final_state()
@@ -51,4 +51,4 @@ def test_checkpoint_load(loom_env):
x4 = tasks.merge((x3, x1, x2, t1, t2, t3))
x4.checkpoint_path = path5

assert loom_env.submit_and_gather(x4) == b'[4][1][2]$t3$[3][1][2]$t3$'
assert loom_env.submit_and_gather(x4, load=True) == b'[4][1][2]$t3$[3][1][2]$t3$'
@@ -319,7 +319,7 @@ static std::vector<TaskNode*> nodes(ComputationState &s, std::vector<loom::base:

static void add_plan(ComputationState &s, const loom::pb::comm::Plan &plan) {
std::vector<TaskNode*> to_load;
s.add_plan(plan, to_load);
s.add_plan(plan, false, to_load);
assert(to_load.empty());
}

0 comments on commit ada5cfb

Please sign in to comment.