Skip to content

Commit

Permalink
Merge pull request #4988 from ye-luo/move-resource
Browse files Browse the repository at this point in the history
Move resource out of delayed update engines into DiracDeterminantBatched
  • Loading branch information
ye-luo committed May 15, 2024
2 parents 2a9196c + c02a170 commit ea046d3
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 206 deletions.
41 changes: 24 additions & 17 deletions src/QMCWaveFunctions/Fermion/DiracDeterminantBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ struct DiracDeterminantBatched<DET_ENGINE>::DiracDeterminantBatchedMultiWalkerRe
OffloadMatrix<ComplexType> mw_dspin;
/// reference to per DDB psiMinvs in a crowd
RefVector<DualMatrix<Value>> psiMinv_refs;
///
typename DET_ENGINE::MultiWalkerResource engine_rsc;
};

/** constructor
Expand Down Expand Up @@ -116,7 +118,7 @@ void DiracDeterminantBatched<DET_ENGINE>::mw_invertPsiM(const RefVectorWithLeade
mw_res.log_values[iw] = {0.0, 0.0};
}

wfc_leader.accel_inverter_.getResource().mw_invertTranspose(wfc_leader.det_engine_.getLAhandles(), logdetT_list,
wfc_leader.accel_inverter_.getResource().mw_invertTranspose(mw_res.engine_rsc.getLAhandles(), logdetT_list,
a_inv_list, mw_res.log_values);

for (int iw = 0; iw < nw; ++iw)
Expand Down Expand Up @@ -210,7 +212,8 @@ void DiracDeterminantBatched<DET_ENGINE>::mw_evalGrad(const RefVectorWithLeader<
engine_list.push_back(det.det_engine_);
}

DET_ENGINE::mw_evalGrad(engine_list, mw_res.psiMinv_refs, dpsiM_row_list, WorkingIndex, grad_now);
DET_ENGINE::mw_evalGrad(engine_list, wfc_leader.mw_res_handle_.getResource().engine_rsc, mw_res.psiMinv_refs,
dpsiM_row_list, WorkingIndex, grad_now);

#ifndef NDEBUG
for (int iw = 0; iw < nw; iw++)
Expand Down Expand Up @@ -285,8 +288,8 @@ void DiracDeterminantBatched<DET_ENGINE>::mw_evalGradWithSpin(
engine_list.push_back(det.det_engine_);
}

DET_ENGINE::mw_evalGradWithSpin(engine_list, mw_res.psiMinv_refs, dpsiM_row_list, mw_dspin, WorkingIndex, grad_now,
spingrad_now);
DET_ENGINE::mw_evalGradWithSpin(engine_list, wfc_leader.mw_res_handle_.getResource().engine_rsc, mw_res.psiMinv_refs,
dpsiM_row_list, mw_dspin, WorkingIndex, grad_now, spingrad_now);

#ifndef NDEBUG
for (int iw = 0; iw < nw; iw++)
Expand Down Expand Up @@ -344,7 +347,8 @@ void DiracDeterminantBatched<DET_ENGINE>::mw_ratioGrad(const RefVectorWithLeader
}

auto psiMinv_row_dev_ptr_list =
DET_ENGINE::mw_getInvRow(engine_list, mw_res.psiMinv_refs, WorkingIndex, !Phi->isOMPoffload());
DET_ENGINE::mw_getInvRow(engine_list, wfc_leader.mw_res_handle_.getResource().engine_rsc, mw_res.psiMinv_refs,
WorkingIndex, !Phi->isOMPoffload());

phi_vgl_v.resize(DIM_VGL, wfc_list.size(), NumOrbitals);
ratios_local.resize(wfc_list.size());
Expand Down Expand Up @@ -395,7 +399,8 @@ void DiracDeterminantBatched<DET_ENGINE>::mw_ratioGradWithSpin(
}

auto psiMinv_row_dev_ptr_list =
DET_ENGINE::mw_getInvRow(engine_list, mw_res.psiMinv_refs, WorkingIndex, !Phi->isOMPoffload());
DET_ENGINE::mw_getInvRow(engine_list, wfc_leader.mw_res_handle_.getResource().engine_rsc, mw_res.psiMinv_refs,
WorkingIndex, !Phi->isOMPoffload());

phi_vgl_v.resize(DIM_VGL, wfc_list.size(), NumOrbitals);
ratios_local.resize(wfc_list.size());
Expand Down Expand Up @@ -522,11 +527,12 @@ void DiracDeterminantBatched<DET_ENGINE>::mw_accept_rejectMove(
det.curRatio = 1.0;
}

DET_ENGINE::mw_accept_rejectRow(engine_list, mw_res.psiMinv_refs, WorkingIndex, psiM_g_dev_ptr_list,
psiM_l_dev_ptr_list, isAccepted, phi_vgl_v, ratios_local);
DET_ENGINE::mw_accept_rejectRow(engine_list, wfc_leader.mw_res_handle_.getResource().engine_rsc, mw_res.psiMinv_refs,
WorkingIndex, psiM_g_dev_ptr_list, psiM_l_dev_ptr_list, isAccepted, phi_vgl_v,
ratios_local);

if (!safe_to_delay)
DET_ENGINE::mw_updateInvMat(engine_list, mw_res.psiMinv_refs);
DET_ENGINE::mw_updateInvMat(engine_list, wfc_leader.mw_res_handle_.getResource().engine_rsc, mw_res.psiMinv_refs);
}

/** move was rejected. copy the real container to the temporary to move on
Expand Down Expand Up @@ -568,14 +574,15 @@ void DiracDeterminantBatched<DET_ENGINE>::mw_completeUpdates(

{
ScopedTimer update(UpdateTimer);
DET_ENGINE::mw_updateInvMat(engine_list, mw_res.psiMinv_refs);
DET_ENGINE::mw_updateInvMat(engine_list, wfc_leader.mw_res_handle_.getResource().engine_rsc, mw_res.psiMinv_refs);
}

{ // transfer dpsiM, d2psiM, psiMinv to host
ScopedTimer d2h(D2HTimer);

// this call also completes all the device copying of dpsiM, d2psiM before the target update
DET_ENGINE::mw_transferAinv_D2H(engine_list, mw_res.psiMinv_refs);
DET_ENGINE::mw_transferAinv_D2H(engine_list, wfc_leader.mw_res_handle_.getResource().engine_rsc,
mw_res.psiMinv_refs);

if (UpdateMode == ORB_PBYP_PARTIAL)
{
Expand All @@ -588,7 +595,8 @@ void DiracDeterminantBatched<DET_ENGINE>::mw_completeUpdates(
}

// transfer device to host, total size 4, g(3) + l(1), skipping v
DET_ENGINE::mw_transferVGL_D2H(wfc_leader.det_engine_, psiM_vgl_list, 1, 4);
DET_ENGINE::mw_transferVGL_D2H(wfc_leader.det_engine_, wfc_leader.mw_res_handle_.getResource().engine_rsc,
psiM_vgl_list, 1, 4);
}
}
}
Expand Down Expand Up @@ -784,7 +792,8 @@ void DiracDeterminantBatched<DET_ENGINE>::mw_calcRatio(const RefVectorWithLeader
}

auto psiMinv_row_dev_ptr_list =
DET_ENGINE::mw_getInvRow(engine_list, mw_res.psiMinv_refs, WorkingIndex, !Phi->isOMPoffload());
DET_ENGINE::mw_getInvRow(engine_list, wfc_leader.mw_res_handle_.getResource().engine_rsc, mw_res.psiMinv_refs,
WorkingIndex, !Phi->isOMPoffload());

phi_vgl_v.resize(DIM_VGL, wfc_list.size(), NumOrbitals);
ratios_local.resize(wfc_list.size());
Expand Down Expand Up @@ -1159,7 +1168,8 @@ void DiracDeterminantBatched<DET_ENGINE>::mw_recompute(const RefVectorWithLeader
}

// transfer host to device, total size 4, g(3) + l(1), skipping v
DET_ENGINE::mw_transferVGL_H2D(wfc_leader.det_engine_, psiM_vgl_list, 1, 4);
DET_ENGINE::mw_transferVGL_H2D(wfc_leader.det_engine_, wfc_leader.mw_res_handle_.getResource().engine_rsc,
psiM_vgl_list, 1, 4);
}
}

Expand Down Expand Up @@ -1199,7 +1209,6 @@ void DiracDeterminantBatched<DET_ENGINE>::createResource(ResourceCollection& col
{
collection.addResource(std::make_unique<DiracDeterminantBatchedMultiWalkerResource>());
Phi->createResource(collection);
det_engine_.createResource(collection);
collection.addResource(std::make_unique<typename DET_ENGINE::DetInverter>());
}

Expand All @@ -1221,7 +1230,6 @@ void DiracDeterminantBatched<DET_ENGINE>::acquireResource(
mw_res.psiMinv_refs.push_back(det.psiMinv_);
}
wfc_leader.Phi->acquireResource(collection, phi_list);
wfc_leader.det_engine_.acquireResource(collection);
wfc_leader.accel_inverter_ = collection.lendResource<typename DET_ENGINE::DetInverter>();
}

Expand All @@ -1240,7 +1248,6 @@ void DiracDeterminantBatched<DET_ENGINE>::releaseResource(
phi_list.push_back(*det.Phi);
}
wfc_leader.Phi->releaseResource(collection, phi_list);
wfc_leader.det_engine_.releaseResource(collection);
collection.takebackResource(wfc_leader.accel_inverter_);
}

Expand Down

0 comments on commit ea046d3

Please sign in to comment.