Skip to content

Commit

Permalink
Fixed performance issue in LDLT with APP updates. Added prio queues t…
Browse files Browse the repository at this point in the history
…o new scheduler
  • Loading branch information
flipflapflop committed May 8, 2020
1 parent 6cdefd7 commit 2ca5fa6
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 71 deletions.
31 changes: 29 additions & 2 deletions include/sylver/StarPU/hlws.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,42 @@
#include <string>
#include <vector>
#include <list>
#include <set>
#include <iostream>

namespace sylver {
namespace starpu {

class HeteroLwsScheduler {
public:

struct TaskGreater {
bool operator()(
struct starpu_task const* lhs, struct starpu_task const* rhs) const {

bool is_less = false;

// std::cout << "[HeteroLwsScheduler::TaskLess] "
// << ", lhs task name = " << lhs->cl->name
// << ", lhs task prio = " << lhs->priority
// << ", rhs task name = " << rhs->cl->name
// << ", rhs task prio = " << rhs->priority
// << std::endl;

if (lhs->priority == rhs->priority) {
is_less = true;
}
else {
is_less = (lhs->priority > rhs->priority);
}

return is_less;
}
};

struct WorkerData {
std::list<struct starpu_task *> task_queue;
std::set<struct starpu_task *, TaskGreater> task_prio_queue;
// std::list<struct starpu_task *> task_queue;
bool running;
bool busy;
};
Expand All @@ -24,7 +51,7 @@ public:
unsigned last_pop_worker;
unsigned last_push_worker;
};

static void initialize(unsigned sched_ctx_id);

static void finalize(unsigned sched_ctx_id);
Expand Down
102 changes: 61 additions & 41 deletions src/StarPU/hlws.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,13 @@ struct starpu_task* HeteroLwsScheduler::pick_task(

starpu_task* task = nullptr;

auto& source_list = sched_data->worker_data[source].task_queue;
auto task_iterator = source_list.begin();
// auto& source_list = sched_data->worker_data[source].task_queue;
auto& source_list = sched_data->worker_data[source].task_prio_queue;

// auto task_iterator = source_list.begin();
auto task_iterator = source_list.begin();

// while (task_iterator != source_list.end()) {
while (task_iterator != source_list.end()) {

task = *task_iterator;
Expand All @@ -107,12 +111,12 @@ struct starpu_task* HeteroLwsScheduler::pick_task(

// Debug
// if (starpu_worker_get_type(target) == STARPU_CUDA_WORKER) {
// std::cout << "[HeteroLwsScheduler::pick_task] "
// << source << " -> " << target
// << ", task = " << task
// << ", task name = " << task->cl->name
// << ", task impl = " << starpu_task_get_implementation(task)
// << std::endl;
// std::cout << "[HeteroLwsScheduler::pick_task] "
// << source << " -> " << target
// << ", task = " << task
// << ", task name = " << task->cl->name
// << ", task impl = " << starpu_task_get_implementation(task)
// << std::endl;

// if (starpu_worker_get_type(target) == STARPU_CUDA_WORKER) {
// return NULL;
Expand Down Expand Up @@ -157,7 +161,8 @@ struct starpu_task* HeteroLwsScheduler::pop_task(unsigned sched_ctx_id) {

#ifdef STARPU_NON_BLOCKING_DRIVERS
// if (STARPU_RUNNING_ON_VALGRIND || !worker_data.task_queue.empty())
if (!worker_data.task_queue.empty())
// if (!worker_data.task_queue.empty())
if (!worker_data.task_prio_queue.empty())
#endif
{
task = HeteroLwsScheduler::pick_task(sched_data, workerid, workerid);
Expand Down Expand Up @@ -218,7 +223,8 @@ struct starpu_task* HeteroLwsScheduler::pop_task(unsigned sched_ctx_id) {

auto& victim_data = sched_data->worker_data[victim];

if (victim_data.running && !victim_data.task_queue.empty()) {
if (victim_data.running && !victim_data.task_prio_queue.empty()) {
// if (victim_data.running && !victim_data.task_queue.empty()) {
// Victim is running and has ready tasks available in its task
// queue

Expand Down Expand Up @@ -286,24 +292,34 @@ unsigned HeteroLwsScheduler::select_worker(

// std::cout << "[HeteroLwsScheduler::select_worker]" << std::endl;

// Round robin

unsigned worker;
unsigned nworkers;
int *workerids;
nworkers = starpu_sched_ctx_get_workers_list_raw(sched_ctx_id, &workerids);
// Get current worker id
int workerid = starpu_worker_get_id();

worker = sched_data->last_push_worker;
do {
worker = (worker + 1) % nworkers;
}
while (!sched_data->worker_data[workerids[worker]].running ||
!starpu_worker_can_execute_task_first_impl(workerids[worker], task, NULL));
// Check if current worker can execute task
if (workerid == -1 || // Not a worker
!starpu_sched_ctx_contains_worker(workerid, sched_ctx_id) || // Not part of current context
!starpu_worker_can_execute_task_first_impl(workerid, task, NULL)) {

// Round robin

unsigned worker;
unsigned nworkers;
int *workerids;
nworkers = starpu_sched_ctx_get_workers_list_raw(sched_ctx_id, &workerids);

sched_data->last_push_worker = worker;
worker = sched_data->last_push_worker;
do {
worker = (worker + 1) % nworkers;
}
while (!sched_data->worker_data[workerids[worker]].running ||
!starpu_worker_can_execute_task_first_impl(workerids[worker], task, NULL));

return workerids[worker];
sched_data->last_push_worker = worker;

workerid = workerids[worker];
}

return workerid;
}

int HeteroLwsScheduler::push_task(struct starpu_task *task) {
Expand All @@ -316,27 +332,26 @@ int HeteroLwsScheduler::push_task(struct starpu_task *task) {
auto *sched_data = reinterpret_cast<SchedulerData*>(starpu_sched_ctx_get_policy_data(sched_ctx_id));
int workerid;

// #ifdef USE_LOCALITY
// workerid = select_worker_locality(ws, task, sched_ctx_id);
// #else
// workerid = -1;
// #endif
// if (workerid == -1)
workerid = starpu_worker_get_id();
// workerid = starpu_worker_get_id();

/* If the current thread is not a worker but the main thread (-1)
* or the current worker is not in the target context, we find the
* better one to put task on its queue */
if (workerid == -1 || !starpu_sched_ctx_contains_worker(workerid, sched_ctx_id) ||
!starpu_worker_can_execute_task_first_impl(workerid, task, NULL)) {
workerid = select_worker(sched_data, task, sched_ctx_id);
}
// /* If the current thread is not a worker but the main thread (-1)
// * or the current worker is not in the target context, we find the
// * better one to put task on its queue */
// if (workerid == -1 || !starpu_sched_ctx_contains_worker(workerid, sched_ctx_id) ||
// !starpu_worker_can_execute_task_first_impl(workerid, task, NULL)) {
// workerid = HeteroLwsScheduler::select_worker(sched_data, task, sched_ctx_id);
// }

workerid = HeteroLwsScheduler::select_worker(sched_data, task, sched_ctx_id);

assert(workerid != -1);

starpu_worker_lock(workerid);
// STARPU_AYU_ADDTOTASKQUEUE(starpu_task_get_job_id(task), workerid);

// Task break for debugging purpose
starpu_sched_task_break(task);

// record_data_locality(task, workerid);
// STARPU_ASSERT_MSG(worker_data.running, "workerid=%d, ws=%p\n", workerid, sched_data);
// _starpu_prio_deque_push_back_task(&ws->per_worker[workerid].queue, task);
Expand All @@ -349,8 +364,12 @@ int HeteroLwsScheduler::push_task(struct starpu_task *task) {
// worker_data.task_queue.end(), task) ==
// worker_data.task_queue.end());

worker_data.task_queue.push_back(task);
// worker_data.task_queue.push_back(task);
auto res = worker_data.task_prio_queue.insert(task);

assert(res.first != worker_data.task_prio_queue.end());
assert(res.second);

// std::cout << "[HeteroLwsScheduler::push_task] task_queue::size() = " << worker_data.task_queue.size() << std::endl;

// locality_pushed_task(ws, task, workerid, sched_ctx_id);
Expand All @@ -368,8 +387,8 @@ int HeteroLwsScheduler::push_task(struct starpu_task *task) {
while(workers->has_next(workers, &it))
starpu_wake_worker_relax_light(workers->get_next(workers, &it));
#endif

return 0;

}

void HeteroLwsScheduler::add_workers(unsigned sched_ctx_id, int *workerids,unsigned nworkers) {
Expand Down Expand Up @@ -433,7 +452,8 @@ int HeteroLwsScheduler::select_victim(
// ntasks = ws->per_worker[workerids[worker]].queue.ntasks;

auto& worker_data = sched_data->worker_data[workerids[worker]];
ntasks = worker_data.task_queue.size();
// ntasks = worker_data.task_queue.size();
ntasks = worker_data.task_prio_queue.size();

if (
(ntasks > 0) &&
Expand Down
1 change: 1 addition & 0 deletions src/StarPU/kernels_indef.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ namespace starpu {
// any failed rows and release resources storing backup
ublk.restore_if_required(*backup, blk);
// Perform actual update
// std::cout << "[updateT_block_app_cpu_func] cdata_[isrc].nelim = " << (*cdata)[isrc.get_row()].nelim << std::endl;
ublk.update(isrc, jsrc, (*workspaces)[id]);
}

Expand Down
9 changes: 9 additions & 0 deletions src/factor_indef.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,15 @@ namespace spldlt {

#if defined(SPLDLT_USE_STARPU)

int prio = APPLYN_APP_PRIO;

if (iblk == blk+1) {
// If udpate operation is on the critical path, increase
// its priority
prio = MAX_APP_PRIO;
}


sylver::spldlt::starpu::insert_applyN_block_app(
dblk.get_hdl(), rblk.get_hdl(),
cdata[blk].get_hdl(),
Expand Down
80 changes: 52 additions & 28 deletions src/kernels/ldlt_app.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -601,24 +601,32 @@ public:
double beta=1.0, T* upd=nullptr, int ldupd=0) {

if(isrc.i_ == i_ && isrc.j_ == jsrc.j_) {

// Update to right of elim column (UpdateN)
int elim_col = isrc.j_;
if(cdata_[elim_col].nelim == 0) return; // nothing to do
int rfrom = (i_ <= elim_col) ? cdata_[i_].nelim : 0;
int cfrom = (j_ <= elim_col) ? cdata_[j_].nelim : 0;
int ldld = align_lda<T>(block_size_);
T* ld = work.get_ptr<T>(block_size_*ldld);
// NB: we use ld[rfrom] below so alignment matches that of aval[rfrom]
calcLD<spral::ssids::cpu::OP_N>(
nrow()-rfrom, cdata_[elim_col].nelim, &isrc.aval_[rfrom],
lda_, cdata_[elim_col].d, &ld[rfrom], ldld
);
host_gemm(
OP_N, OP_T, nrow()-rfrom, ncol()-cfrom, cdata_[elim_col].nelim,
-1.0, &ld[rfrom], ldld, &jsrc.aval_[cfrom], lda_,
1.0, &aval_[cfrom*lda_+rfrom], lda_
);

// Perform update of trailing submatrix block
if (((ncol()-cfrom)>0) && ((nrow()-rfrom)>0)) {
// Make sure there is work to do before starting computing
// L*D

// NB: we use ld[rfrom] below so alignment matches that of aval[rfrom]
calcLD<spral::ssids::cpu::OP_N>(
nrow()-rfrom, cdata_[elim_col].nelim, &isrc.aval_[rfrom],
lda_, cdata_[elim_col].d, &ld[rfrom], ldld
);
host_gemm(
OP_N, OP_T, nrow()-rfrom, ncol()-cfrom, cdata_[elim_col].nelim,
-1.0, &ld[rfrom], ldld, &jsrc.aval_[cfrom], lda_,
1.0, &aval_[cfrom*lda_+rfrom], lda_
);
}

if(upd && j_==calc_nblk(n_,block_size_)-1) {
// Handle fractional part of upd that "belongs" to this block
int u_ncol = std::min(block_size_-ncol(), m_-n_); // ncol for upd
Expand Down Expand Up @@ -652,25 +660,41 @@ public:
int cfrom = (j_ <= elim_col) ? cdata_[j_].nelim : 0;
int ldld = align_lda<T>(block_size_);
T* ld = work.get_ptr<T>(block_size_*ldld);
// NB: we use ld[rfrom] below so alignment matches that of aval[rfrom]
if(isrc.j_==elim_col) {
calcLD<spral::ssids::cpu::OP_N>(
nrow()-rfrom, cdata_[elim_col].nelim,
&isrc.aval_[rfrom], lda_,
cdata_[elim_col].d, &ld[rfrom], ldld
);
} else {
calcLD<spral::ssids::cpu::OP_T>(
nrow()-rfrom, cdata_[elim_col].nelim, &
isrc.aval_[rfrom*lda_], lda_,
cdata_[elim_col].d, &ld[rfrom], ldld

// if (((ncol()-cfrom) == 0) && ((nrow()-rfrom) != 0)) {
// std::cout << "[Block::update]"
// << " elim_col = " << elim_col << ", i = " << i_ << ", j = " << j_
// << ", cdata_[elim_col].nelim = " << cdata_[elim_col].nelim
// << ", nrow()-rfrom = " << nrow()-rfrom << ", ncol()-cfrom = " << ncol()-cfrom
// << ", upd = " << upd
// << std::endl;
// }

// Perform update of left-diagonal block
if (((ncol()-cfrom)>0) && ((nrow()-rfrom)>0)) {
// Make sure there is work to do before starting computing
// L*D

// NB: we use ld[rfrom] below so alignment matches that of aval[rfrom]
if(isrc.j_==elim_col) {
calcLD<spral::ssids::cpu::OP_N>(
nrow()-rfrom, cdata_[elim_col].nelim,
&isrc.aval_[rfrom], lda_,
cdata_[elim_col].d, &ld[rfrom], ldld
);
} else {
calcLD<spral::ssids::cpu::OP_T>(
nrow()-rfrom, cdata_[elim_col].nelim, &
isrc.aval_[rfrom*lda_], lda_,
cdata_[elim_col].d, &ld[rfrom], ldld
);
}
host_gemm(
OP_N, OP_N, nrow()-rfrom, ncol()-cfrom, cdata_[elim_col].nelim,
-1.0, &ld[rfrom], ldld, &jsrc.aval_[cfrom*lda_], lda_,
1.0, &aval_[cfrom*lda_+rfrom], lda_
);
}
host_gemm(
OP_N, OP_N, nrow()-rfrom, ncol()-cfrom, cdata_[elim_col].nelim,
-1.0, &ld[rfrom], ldld, &jsrc.aval_[cfrom*lda_], lda_,
1.0, &aval_[cfrom*lda_+rfrom], lda_
);
}
}

Expand Down

0 comments on commit 2ca5fa6

Please sign in to comment.