Skip to content

Commit

Permalink
Merge pull request QMCPACK#4826 from ye-luo/query-MSD
Browse files Browse the repository at this point in the history
Change findMSD return type to RefVector.
  • Loading branch information
ye-luo committed Nov 20, 2023
2 parents d16179b + 7e45afc commit e843a30
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 9 deletions.
7 changes: 4 additions & 3 deletions src/QMCWaveFunctions/TrialWaveFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,13 @@ const SPOSet& TrialWaveFunction::getSPOSet(const std::string& name) const
return *spoit->second;
}

std::optional<std::reference_wrapper<MultiSlaterDetTableMethod>> TrialWaveFunction::findMSD() const
RefVector<MultiSlaterDetTableMethod> TrialWaveFunction::findMSD() const
{
RefVector<MultiSlaterDetTableMethod> refs;
for (auto& component : Z)
if (auto* comp_ptr = dynamic_cast<MultiSlaterDetTableMethod*>(component.get()); comp_ptr)
return *comp_ptr;
return std::nullopt;
refs.push_back(*comp_ptr);
return refs;
}

/** return log(|psi|)
Expand Down
6 changes: 2 additions & 4 deletions src/QMCWaveFunctions/TrialWaveFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -504,10 +504,8 @@ class TrialWaveFunction
/// spomap_ reference accessor
const SPOMap& getSPOMap() const { return *spomap_; }

/** find the first MSD WFC if exists
* @return the first found MSD WFC
*/
std::optional<std::reference_wrapper<MultiSlaterDetTableMethod>> findMSD() const;
/// find MSD WFCs if exist
RefVector<MultiSlaterDetTableMethod> findMSD() const;

private:
static void debugOnlyCheckBuffer(WFBufferType& buffer);
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/tests/test_TrialWaveFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ TEST_CASE("TrialWaveFunction_diamondC_1x1x1", "[wavefunction]")
psi.addComponent(jb.buildComponent(jas1));

// should not find MSD
CHECK(!psi.findMSD());
CHECK(psi.findMSD().empty());

// initialize distance tables.
elec_.update();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ void test_LiH_msd(const std::string& spo_xml_string,
elec_.update();

auto& twf(*twf_ptr);
CHECK(twf.findMSD());
CHECK(twf.findMSD().size() == 1);
twf.setMassTerm(elec_);
twf.evaluateLog(elec_);

Expand Down

0 comments on commit e843a30

Please sign in to comment.