Skip to content

Commit

Permalink
correct RAJA implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
nychiang committed Mar 13, 2023
1 parent a1d8a15 commit eb2c094
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
31 changes: 16 additions & 15 deletions src/LinAlg/hiopVectorRajaImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2390,10 +2390,10 @@ void hiopVectorRaja<MEM, POL>::process_constraints_local(const hiopVector& gl_ve
size_type n_in = dl_vec.get_local_size();
size_type n_cons = n_eq + n_in;

hiopVectorInt* idx_eq_cumsum = LinearAlgebraFactory::create_vector_int(mem_space_, n_cons);
hiopVectorInt* idx_in_cumsum = LinearAlgebraFactory::create_vector_int(mem_space_, n_cons);
index_type* find_eq = idx_eq_cumsum->local_data();
index_type* find_in = idx_in_cumsum->local_data();
hiopVectorInt* find_eq = LinearAlgebraFactory::create_vector_int(mem_space_, n_cons);
hiopVectorInt* find_in = LinearAlgebraFactory::create_vector_int(mem_space_, n_cons);
index_type* idx_eq_cumsum = find_eq->local_data();
index_type* idx_in_cumsum = find_in->local_data();

RAJA::ReduceSum< hiop_raja_reduce, int > sum_n_bnds_low(0);
RAJA::ReduceSum< hiop_raja_reduce, int > sum_n_bnds_upp(0);
Expand All @@ -2405,11 +2405,11 @@ void hiopVectorRaja<MEM, POL>::process_constraints_local(const hiopVector& gl_ve
RAJA_LAMBDA(RAJA::Index_type i)
{
if(gl[i] == gu[i]) {
find_eq[i] = 1;
find_in[i] = 0;
idx_eq_cumsum[i] = 1;
idx_in_cumsum[i] = 0;
} else {
find_eq[i] = 0;
find_in[i] = 1;
idx_eq_cumsum[i] = 0;
idx_in_cumsum[i] = 1;
}
}
);
Expand All @@ -2421,8 +2421,6 @@ void hiopVectorRaja<MEM, POL>::process_constraints_local(const hiopVector& gl_ve
// (0,1,1) -- (1,1,2) after scan
// map [1] [0,2]

index_type* nnz_cumsum = idx_cumsum_->local_data();
index_type v_n_local = v.n_local_;
RAJA::forall<hiop_raja_exec>(
RAJA::RangeSegment(0, n_cons),
RAJA_LAMBDA(RAJA::Index_type i)
Expand Down Expand Up @@ -2464,9 +2462,10 @@ void hiopVectorRaja<MEM, POL>::process_constraints_local(const hiopVector& gl_ve
} else {
assert(idx_in_cumsum[i] == idx_in_cumsum[i-1] + 1);
int in_idx = idx_in_cumsum[i] - 1;
incon_map[in_idx] = cons_type[i];
incon_type[in_idx] = cons_type[i];
dl[in_idx] = gl[i];
du[in_idx] = gu[i];
incon_map[in_idx] = i;

if(gl[i]>-1e20) {
idl[in_idx] = 1.;
Expand All @@ -2488,11 +2487,13 @@ void hiopVectorRaja<MEM, POL>::process_constraints_local(const hiopVector& gl_ve
}
);

n_bnds_low = sum_n_bnds_low.get();
n_bnds_upp = sum_n_bnds_upp.get();
n_bnds_lu = sum_n_bnds_lu.get();
n_ineq_low = sum_n_bnds_low.get();
n_ineq_upp = sum_n_bnds_upp.get();
n_ineq_lu = sum_n_bnds_lu.get();

delete find_eq;
delete find_in;

return true;
}


Expand Down
8 changes: 7 additions & 1 deletion src/Optimization/hiopNlpFormulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,13 @@ bool hiopNlpFormulation::process_constraints()
cons_ineq_type_,
cons_type);



hiopVectorIntSeq cons_eq_mapping_host(n_cons_eq_);
hiopVectorIntSeq cons_ineq_mapping_host(n_cons_ineq_);

cons_eq_mapping_->copy_to_vectorseq(cons_eq_mapping_host);
cons_ineq_mapping_->copy_to_vectorseq(cons_ineq_mapping_host);

/* delete the temporary buffers */
delete gl;
delete gu;
Expand Down

0 comments on commit eb2c094

Please sign in to comment.