Skip to content

Commit

Permalink
Merge pull request #4849 from ye-luo/hybridrep-builder
Browse files Browse the repository at this point in the history
Refactor SplineSetReader and HybridRepSetReader
  • Loading branch information
prckent committed Nov 30, 2023
2 parents dfacf43 + d9f3461 commit 4189f84
Show file tree
Hide file tree
Showing 13 changed files with 350 additions and 289 deletions.
31 changes: 7 additions & 24 deletions src/QMCWaveFunctions/BsplineFactory/BsplineReaderBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,12 @@
namespace qmcplusplus
{
BsplineReaderBase::BsplineReaderBase(EinsplineSetBuilder* e)
: mybuilder(e), MeshSize(0), checkNorm(true), saveSplineCoefs(false), rotate(true)
: mybuilder(e), checkNorm(true), saveSplineCoefs(false), rotate(true)
{
myComm = mybuilder->getCommunicator();
}

void BsplineReaderBase::get_psi_g(int ti, int spin, int ib, Vector<std::complex<double>>& cG)
{
int ncg = 0;
if (myComm->rank() == 0)
{
std::string path = psi_g_path(ti, spin, ib);
mybuilder->H5File.read(cG, path);
ncg = cG.size();
}
myComm->bcast(ncg);
if (ncg != mybuilder->MaxNumGvecs)
{
APP_ABORT("Failed : ncg != MaxNumGvecs");
}
myComm->bcast(cG);
}

BsplineReaderBase::~BsplineReaderBase() {}
BsplineReaderBase::~BsplineReaderBase() = default;

inline std::string make_bandinfo_filename(const std::string& root,
int spin,
Expand Down Expand Up @@ -216,11 +199,11 @@ void BsplineReaderBase::initialize_spo2band(int spin,
<< std::endl;
for (int i = 0; i < bigspace.size(); ++i)
{
int ti = bigspace[i].TwistIndex;
int bi = bigspace[i].BandIndex;
double e = bigspace[i].Energy;
int nd = (bigspace[i].MakeTwoCopies) ? 2 : 1;
PosType k = mybuilder->PrimCell.k_cart(mybuilder->primcell_kpoints[ti]);
int ti = bigspace[i].TwistIndex;
int bi = bigspace[i].BandIndex;
double e = bigspace[i].Energy;
int nd = (bigspace[i].MakeTwoCopies) ? 2 : 1;
PosType k = mybuilder->PrimCell.k_cart(mybuilder->primcell_kpoints[ti]);
int s_size = std::snprintf(s.data(), s.size(), "%8d %8d %8d %8d %12.6f %7.4f %7.4f %7.4f %7.4f %7.4f %7.4f %6d\n",
i, ns, ti, bi, e, k[0], k[1], k[2], mybuilder->primcell_kpoints[ti][0],
mybuilder->primcell_kpoints[ti][1], mybuilder->primcell_kpoints[ti][2], nd);
Expand Down
69 changes: 34 additions & 35 deletions src/QMCWaveFunctions/BsplineFactory/BsplineReaderBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ struct BsplineReaderBase
///communicator
Communicate* myComm;
///mesh size
TinyVector<int, 3> MeshSize;
///check the norm of orbitals
bool checkNorm;
///save spline coefficients to storage
Expand All @@ -56,25 +55,27 @@ struct BsplineReaderBase

virtual ~BsplineReaderBase();

std::string getSplineDumpFileName(const BandInfoGroup& bandgroup) const
{
auto& MeshSize = mybuilder->MeshSize;
std::ostringstream oo;
oo << bandgroup.myName << ".g" << MeshSize[0] << "x" << MeshSize[1] << "x" << MeshSize[2] << ".h5";
return oo.str();
}

/** read gvectors and set the mesh, and prepare for einspline
*/
template<typename GT, typename BCT>
inline bool set_grid(const TinyVector<int, 3>& halfg, GT* xyz_grid, BCT* xyz_bc)
inline bool set_grid(const TinyVector<int, 3>& halfg, GT* xyz_grid, BCT* xyz_bc) const
{
//This sets MeshSize from the input file
bool havePsig = mybuilder->ReadGvectors_ESHDF();

//If this MeshSize is not initialized, use the meshsize set by the input based on FFT grid and meshfactor
if (MeshSize[0] == 0)
MeshSize = mybuilder->MeshSize;

app_log() << " Using meshsize=" << MeshSize << "\n vs input meshsize=" << mybuilder->MeshSize << std::endl;

for (int j = 0; j < 3; ++j)
{
xyz_grid[j].start = 0.0;
xyz_grid[j].end = 1.0;
xyz_grid[j].num = MeshSize[j];
xyz_grid[j].num = mybuilder->MeshSize[j];

if (halfg[j])
{
Expand All @@ -96,76 +97,74 @@ struct BsplineReaderBase
/** initialize twist-related data for N orbitals
*/
template<typename SPE>
inline void check_twists(SPE* bspline, const BandInfoGroup& bandgroup)
inline void check_twists(SPE& bspline, const BandInfoGroup& bandgroup) const
{
//init(orbitalSet,bspline);
bspline->PrimLattice = mybuilder->PrimCell;
bspline->GGt = dot(transpose(bspline->PrimLattice.G), bspline->PrimLattice.G);
bspline.PrimLattice = mybuilder->PrimCell;
bspline.GGt = dot(transpose(bspline.PrimLattice.G), bspline.PrimLattice.G);

int N = bandgroup.getNumDistinctOrbitals();
int numOrbs = bandgroup.getNumSPOs();

bspline->setOrbitalSetSize(numOrbs);
bspline->resizeStorage(N, N);
bspline.setOrbitalSetSize(numOrbs);
bspline.resizeStorage(N, N);

bspline->first_spo = bandgroup.getFirstSPO();
bspline->last_spo = bandgroup.getLastSPO();
bspline.first_spo = bandgroup.getFirstSPO();
bspline.last_spo = bandgroup.getLastSPO();

int num = 0;
const std::vector<BandInfo>& cur_bands = bandgroup.myBands;
for (int iorb = 0; iorb < N; iorb++)
{
int ti = cur_bands[iorb].TwistIndex;
bspline->kPoints[iorb] = mybuilder->PrimCell.k_cart(-mybuilder->primcell_kpoints[ti]);
bspline->MakeTwoCopies[iorb] = (num < (numOrbs - 1)) && cur_bands[iorb].MakeTwoCopies;
num += bspline->MakeTwoCopies[iorb] ? 2 : 1;
int ti = cur_bands[iorb].TwistIndex;
bspline.kPoints[iorb] = mybuilder->PrimCell.k_cart(-mybuilder->primcell_kpoints[ti]);
bspline.MakeTwoCopies[iorb] = (num < (numOrbs - 1)) && cur_bands[iorb].MakeTwoCopies;
num += bspline.MakeTwoCopies[iorb] ? 2 : 1;
}

app_log() << "NumDistinctOrbitals " << N << " numOrbs = " << numOrbs << std::endl;

bspline->HalfG = 0;
bspline.HalfG = 0;
TinyVector<int, 3> bconds = mybuilder->TargetPtcl.getLattice().BoxBConds;
if (!bspline->isComplex())
if (!bspline.isComplex())
{
//no k-point folding, single special k point (G, L ...)
TinyVector<double, 3> twist0 = mybuilder->primcell_kpoints[bandgroup.TwistIndex];
for (int i = 0; i < 3; i++)
if (bconds[i] && ((std::abs(std::abs(twist0[i]) - 0.5) < 1.0e-8)))
bspline->HalfG[i] = 1;
bspline.HalfG[i] = 1;
else
bspline->HalfG[i] = 0;
bspline.HalfG[i] = 0;
app_log() << " TwistIndex = " << cur_bands[0].TwistIndex << " TwistAngle " << twist0 << std::endl;
app_log() << " HalfG = " << bspline->HalfG << std::endl;
app_log() << " HalfG = " << bspline.HalfG << std::endl;
}
app_log().flush();
}

/** return the path name in hdf5
* @param ti twist index
* @param spin spin index
* @param ib band index
*/
inline std::string psi_g_path(int ti, int spin, int ib)
inline std::string psi_g_path(int ti, int spin, int ib) const
{
std::ostringstream path;
path << "/electrons/kpoint_" << ti << "/spin_" << spin << "/state_" << ib << "/psi_g";
return path.str();
}

/** return the path name in hdf5
* @param ti twist index
* @param spin spin index
* @param ib band index
*/
inline std::string psi_r_path(int ti, int spin, int ib)
inline std::string psi_r_path(int ti, int spin, int ib) const
{
std::ostringstream path;
path << "/electrons/kpoint_" << ti << "/spin_" << spin << "/state_" << ib << "/psi_r";
return path.str();
}

/** read/bcast psi_g
* @param ti twist index
* @param spin spin index
* @param ib band index
* @param cG psi_g as stored in hdf5
*/
void get_psi_g(int ti, int spin, int ib, Vector<std::complex<double>>& cG);

/** create the actual spline sets
*/
virtual std::unique_ptr<SPOSet> create_spline_set(const std::string& my_name,
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/BsplineFactory/BsplineSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class BsplineSet : public SPOSet
}

template<class BSPLINESPO>
friend struct SplineSetReader;
friend class SplineSetReader;
friend struct BsplineReaderBase;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ std::unique_ptr<SPOSet> EinsplineSetBuilder::createSPOSetFromXML(xmlNodePtr cur)
auto OrbitalSet = MixedSplineReader->create_spline_set(spinSet, spo_cur);
if (!OrbitalSet)
myComm->barrier_and_abort("Failed to create SPOSet*");
app_log() << "Time spent in creating B-spline SPOs " << mytimer.elapsed() << "sec" << std::endl;
app_log() << "Time spent in creating B-spline SPOs " << mytimer.elapsed() << " sec" << std::endl;
OrbitalSet->finalizeConstruction();
SPOSetMap[aset] = OrbitalSet.get();
return OrbitalSet;
Expand Down
15 changes: 2 additions & 13 deletions src/QMCWaveFunctions/BsplineFactory/HybridRepCplx.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ template<typename SPLINEBASE>
class HybridRepCplx : public SPLINEBASE, private HybridRepCenterOrbitals<typename SPLINEBASE::DataType>
{
public:
using SplineBase = SPLINEBASE;
using HYBRIDBASE = HybridRepCenterOrbitals<typename SPLINEBASE::DataType>;
using ST = typename SPLINEBASE::DataType;
using PointType = typename SPLINEBASE::PointType;
Expand Down Expand Up @@ -70,12 +71,6 @@ class HybridRepCplx : public SPLINEBASE, private HybridRepCenterOrbitals<typenam

std::unique_ptr<SPOSet> makeClone() const override { return std::make_unique<HybridRepCplx>(*this); }

inline void resizeStorage(size_t n, size_t nvals)
{
SPLINEBASE::resizeStorage(n, nvals);
HYBRIDBASE::resizeStorage(myV.size());
}

void bcast_tables(Communicate* comm)
{
SPLINEBASE::bcast_tables(comm);
Expand All @@ -92,12 +87,6 @@ class HybridRepCplx : public SPLINEBASE, private HybridRepCenterOrbitals<typenam

bool write_splines(hdf_archive& h5f) { return HYBRIDBASE::write_splines(h5f) && SPLINEBASE::write_splines(h5f); }

inline void flush_zero()
{
//SPLINEBASE::flush_zero();
HYBRIDBASE::flush_zero();
}

void evaluateValue(const ParticleSet& P, const int iat, ValueVector& psi) override
{
HYBRIDBASE::evaluate_v(P, iat, myV, info);
Expand Down Expand Up @@ -242,7 +231,7 @@ class HybridRepCplx : public SPLINEBASE, private HybridRepCenterOrbitals<typenam
template<class BSPLINESPO>
friend class HybridRepSetReader;
template<class BSPLINESPO>
friend struct SplineSetReader;
friend class SplineSetReader;
friend struct BsplineReaderBase;
};

Expand Down
15 changes: 2 additions & 13 deletions src/QMCWaveFunctions/BsplineFactory/HybridRepReal.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ template<typename SPLINEBASE>
class HybridRepReal : public SPLINEBASE, private HybridRepCenterOrbitals<typename SPLINEBASE::DataType>
{
public:
using SplineBase = SPLINEBASE;
using HYBRIDBASE = HybridRepCenterOrbitals<typename SPLINEBASE::DataType>;
using ST = typename SPLINEBASE::DataType;
using PointType = typename SPLINEBASE::PointType;
Expand Down Expand Up @@ -72,12 +73,6 @@ class HybridRepReal : public SPLINEBASE, private HybridRepCenterOrbitals<typenam

std::unique_ptr<SPOSet> makeClone() const override { return std::make_unique<HybridRepReal>(*this); }

inline void resizeStorage(size_t n, size_t nvals)
{
SPLINEBASE::resizeStorage(n, nvals);
HYBRIDBASE::resizeStorage(myV.size());
}

void bcast_tables(Communicate* comm)
{
SPLINEBASE::bcast_tables(comm);
Expand All @@ -90,12 +85,6 @@ class HybridRepReal : public SPLINEBASE, private HybridRepCenterOrbitals<typenam
HYBRIDBASE::gather_atomic_tables(comm, SPLINEBASE::offset);
}

inline void flush_zero()
{
//SPLINEBASE::flush_zero();
HYBRIDBASE::flush_zero();
}

bool read_splines(hdf_archive& h5f) { return HYBRIDBASE::read_splines(h5f) && SPLINEBASE::read_splines(h5f); }

bool write_splines(hdf_archive& h5f) { return HYBRIDBASE::write_splines(h5f) && SPLINEBASE::write_splines(h5f); }
Expand Down Expand Up @@ -249,7 +238,7 @@ class HybridRepReal : public SPLINEBASE, private HybridRepCenterOrbitals<typenam
template<class BSPLINESPO>
friend class HybridRepSetReader;
template<class BSPLINESPO>
friend struct SplineSetReader;
friend class SplineSetReader;
friend struct BsplineReaderBase;
};

Expand Down

0 comments on commit 4189f84

Please sign in to comment.