Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions client/Dimer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,50 @@ double Dimer::calcRotationalForceReturnCurvature(AtomMatrix &rotationalForce) {
posDimer = posCenter + direction * params.main_options.finiteDifference;
}

// Obtain forces for dimer and center (parallel if thread-safe)
// Obtain forces for dimer and center
matterDimer->setPositions(posDimer);
AtomMatrix forceA, forceCenter;
bool canParallel = pot->isThreadSafe() || pot->needsPerImageInstance();
if (params.main_options.parallel && canParallel) {
std::thread t([&] { forceA = matterDimer->getForces(); });
if (pot->supportsBatchEvaluation()) {
// Only batch systems that actually need recomputation.
bool centerDirty = matterCenter->needsForceUpdate();
bool dimerDirty = matterDimer->needsForceUpdate();

if (centerDirty && dimerDirty) {
// Both need eval -- batch together
auto nrs0 = matterCenter->getAtomicNrs();
auto nrs1 = matterDimer->getAtomicNrs();
auto box0 = matterCenter->getCell();
auto box1 = matterDimer->getCell();
const double *posVec[] = {matterCenter->getPositions().data(),
posDimer.data()};
const int *nrsVec[] = {nrs0.data(), nrs1.data()};
double *frcVec[] = {matterCenter->forcesData(),
matterDimer->forcesData()};
double energies[2], vars[2];
const double *boxVec[] = {box0.data(), box1.data()};
pot->forceBatch(2, nAtoms, posVec, nrsVec, frcVec, energies, vars,
boxVec);
matterCenter->setComputedPotential(energies[0], vars[0]);
matterDimer->setComputedPotential(energies[1], vars[1]);
} else if (dimerDirty) {
// Only dimer moved -- eval just dimer, center is cached
auto nrs = matterDimer->getAtomicNrs();
auto box = matterDimer->getCell();
const double *posVec[] = {posDimer.data()};
const int *nrsVec[] = {nrs.data()};
double *frcVec[] = {matterDimer->forcesData()};
double energies[1], vars[1];
const double *boxVec[] = {box.data()};
pot->forceBatch(1, nAtoms, posVec, nrsVec, frcVec, energies, vars,
boxVec);
matterDimer->setComputedPotential(energies[0], vars[0]);
} else if (centerDirty) {
// Only center moved (rare)
matterCenter->getForces(); // through computePotential
}
// else: both cached, nothing to do
forceCenter = matterCenter->getForces();
t.join();
forceA = matterDimer->getForces();
} else {
forceA = matterDimer->getForces();
forceCenter = matterCenter->getForces();
Expand Down
49 changes: 44 additions & 5 deletions client/ImprovedDimer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,51 @@ void ImprovedDimer::compute(std::shared_ptr<Matter> matter,
x1_r = x0_r + tau * delta;
}

// Calculate gradients on x0 and x1 (parallel if possible)
// Calculate gradients on x0 and x1.
// Prefer batched evaluation when the potential supports it (single
// model.forward() call for both replicas, e.g. MetatomicPotential on GPU).
// Else fall back to thread-parallel when the potential is thread-safe or
// wants per-image instances. Otherwise sequential.
VectorXd g0, g1;
bool canParallel = pot->isThreadSafe() || pot->needsPerImageInstance();
if (params.main_options.parallel && canParallel) {
// std::thread instead of std::jthread (Apple Clang libc++). Use a guard
// so an exception from the foreground call still joins t0 before rethrow.
bool canParallel =
pot->isSharedInstanceThreadSafe() || pot->needsPerImageInstance();
if (pot->supportsBatchEvaluation()) {
long n = x0->numberOfAtoms();
bool x0dirty = x0->needsForceUpdate();
bool x1dirty = x1->needsForceUpdate();

if (x0dirty && x1dirty) {
auto nrs0 = x0->getAtomicNrs();
auto nrs1 = x1->getAtomicNrs();
auto box0 = x0->getCell();
auto box1 = x1->getCell();
const double *posVec[] = {x0->getPositions().data(),
x1->getPositions().data()};
const int *nrsVec[] = {nrs0.data(), nrs1.data()};
double *frcVec[] = {x0->forcesData(), x1->forcesData()};
double energies[2], vars[2];
const double *boxVec[] = {box0.data(), box1.data()};
pot->forceBatch(2, n, posVec, nrsVec, frcVec, energies, vars, boxVec);
x0->setComputedPotential(energies[0], vars[0]);
x1->setComputedPotential(energies[1], vars[1]);
} else if (x1dirty) {
auto nrs = x1->getAtomicNrs();
auto box = x1->getCell();
const double *posVec[] = {x1->getPositions().data()};
const int *nrsVec[] = {nrs.data()};
double *frcVec[] = {x1->forcesData()};
double energies[1], vars[1];
const double *boxVec[] = {box.data()};
pot->forceBatch(1, n, posVec, nrsVec, frcVec, energies, vars, boxVec);
x1->setComputedPotential(energies[0], vars[0]);
} else if (x0dirty) {
x0->getForcesRaw(); // through computePotential
}
g0 = -x0->getForcesV();
g1 = -x1->getForcesV();
} else if (params.main_options.parallel && canParallel) {
// std::thread instead of std::jthread (Apple Clang libc++). Guard so an
// exception from the foreground call still joins t0 before rethrow.
std::thread t0([&] { g0 = -x0->getForcesV(); });
try {
g1 = -x1->getForcesV();
Expand Down
16 changes: 16 additions & 0 deletions client/Matter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,22 @@ void Matter::setPotential(std::shared_ptr<Potential> pot) {
recomputeMaskedForces = true;
}

void Matter::setComputedPotential(double energy, double variance) {
potentialEnergy = energy;
energyVariance = variance;
recomputePotential = false;
recomputeMaskedForces = true;
forceCalls++;

// Apply the same net force removal as computePotential()
if (isFixed.sum() == 0 && removeNetForce) {
Vector3d tempForce = forces.colwise().sum() / nAtoms;
for (long int i = 0; i < nAtoms; i++) {
forces.row(i) -= tempForce.transpose();
}
}
}

size_t Matter::getPotentialCalls() const {
return this->potential->forceCallCounter;
}
Expand Down
11 changes: 11 additions & 0 deletions client/Matter.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,17 @@ class Matter {
int isFixed); // set the atom to fixed (true) or movable (false)
double getEnergyVariance() const;
double getPotentialEnergy() const;

/// Whether forces need recomputation (positions changed since last eval).
[[nodiscard]] bool needsForceUpdate() const { return recomputePotential; }

/// Mutable access to force storage for batched potential evaluation.
/// Caller must also call setComputedPotential() after writing forces.
double *forcesData() { return forces.data(); }

/// Set energy/variance from external batched evaluation and mark forces
/// as up-to-date (recomputePotential = false).
void setComputedPotential(double energy, double variance);
double getKineticEnergy() const;
double getMechanicalEnergy() const;

Expand Down
1 change: 1 addition & 0 deletions client/MinModeSaddleSearch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ MinModeSaddleSearch::MinModeSaddleSearch(std::shared_ptr<Matter> matterPassed,
matter{matterPassed} {
reactantEnergy = reactantEnergyPassed;
mode = modePassed;
initialTangent_ = modePassed;
status = STATUS_GOOD;
iteration = 0;

Expand Down
1 change: 1 addition & 0 deletions client/MinModeSaddleSearch.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class MinModeSaddleSearch : public SaddleSearchMethod {

private:
AtomMatrix mode;
AtomMatrix initialTangent_;
std::shared_ptr<Matter> matter;
std::shared_ptr<EigenmodeStrategy>
minModeMethod; // shared with the objective func
Expand Down
34 changes: 24 additions & 10 deletions client/NEBOcinebController.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ OCINEBController::fromParams(const Parameters &params) {
r.max_steps,
r.ci_stability_count,
r.angle_tol,
r.penalty.strength,
r.penalty.base,
params.neb_options.force_tolerance,
};
}
Expand Down Expand Up @@ -64,6 +62,7 @@ void OCINEBController::updateStability(long climbingImage) {
} else {
ciStabilityCounter_ = 0;
previousClimbingImage_ = climbingImage;
has_cached_mode_ = false;
}
}

Expand Down Expand Up @@ -104,12 +103,24 @@ OCINEBController::MMFResult OCINEBController::run(eonc::NudgedElasticBand &neb,
mmfResult, convForce, newForce, newForce / baseline_force_,
current_threshold_);
} else {
// On positive curvature (status=-2), the CI is at a minimum, not a
// saddle. Restore to pre-MMF position to prevent catastrophic force
// explosion. For other failures (alignment loss, force increase),
// let the NEB recover naturally -- the CI position may still be
// closer to the saddle than before.
if (mmfResult == -2) {
neb.path[neb.climbingImage]->setPositions(savedPositions);
neb.movedAfterForceCall = true;
has_cached_mode_ = false;
newForce = convForce;
}
updateThresholdBackoff(alignment);
QUILL_LOG_DEBUG(
log,
"MMF backoff (status={}). Force: {:.4f} -> {:.4f}, "
"Alignment: {:.3f}. New threshold: {:.4f} ({:.2f}x baseline)",
mmfResult, convForce, newForce, alignment, current_threshold_,
"Alignment: {:.3f}. {}New threshold: {:.4f} ({:.2f}x baseline)",
mmfResult, convForce, newForce, alignment,
mmfResult == -2 ? "Restored CI. " : "", current_threshold_,
current_threshold_ / baseline_force_);
}

Expand Down Expand Up @@ -138,7 +149,12 @@ int OCINEBController::runDimer(eonc::NudgedElasticBand &neb,
return -1;
}

AtomMatrix initialMode = *neb.tangent[neb.climbingImage];
AtomMatrix initialMode;
if (has_cached_mode_) {
initialMode = cached_mode_;
} else {
initialMode = *neb.tangent[neb.climbingImage];
}
double tangentNorm = initialMode.norm();
if (tangentNorm < 1e-8) {
QUILL_LOG_WARNING(log, "Tangent too small for MMF initialization");
Expand Down Expand Up @@ -186,6 +202,8 @@ int OCINEBController::runDimer(eonc::NudgedElasticBand &neb,
alignment, cfg_.angle_tol);
return -1;
}
cached_mode_ = finalModeMatrix;
has_cached_mode_ = true;
return 0;
} else if (minModeStatus == MinModeSaddleSearch::STATUS_BAD_MAX_ITERATIONS) {
return 1;
Expand All @@ -205,11 +223,7 @@ void OCINEBController::updateThresholdSuccess(double convForce,

void OCINEBController::updateThresholdBackoff(double alignment) {
double alpha = std::clamp(alignment, 0.0, 1.0);
double penalty_factor =
cfg_.penalty_base +
(1.0 - cfg_.penalty_base) * std::pow(alpha, cfg_.penalty_strength);
penalty_factor = std::clamp(penalty_factor, cfg_.penalty_base, 1.0);

double penalty_factor = 0.5 + 0.5 * alpha;
current_threshold_ = baseline_force_ * cfg_.trigger_factor * penalty_factor;
// Lower bound on the MMF trigger threshold. Scaled by force_tolerance so
// the MMF gate never collapses to zero when the NEB is already near
Expand Down
7 changes: 5 additions & 2 deletions client/NEBOcinebController.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
*/
#pragma once

#include "Eigen.h"
#include "Parameters.h"

namespace eonc {
Expand All @@ -31,8 +32,6 @@ class OCINEBController {
long max_steps;
long ci_stability_count;
double angle_tol;
double penalty_strength;
double penalty_base;
double force_tolerance;
};

Expand Down Expand Up @@ -65,6 +64,10 @@ class OCINEBController {
int ciStabilityCounter_{0};
int mmf_iterations_used_{0};

// Warm-start: cache converged eigenvector for next dimer call
bool has_cached_mode_{false};
AtomMatrix cached_mode_;

int runDimer(eonc::NudgedElasticBand &neb, double &alignment);
void updateThresholdSuccess(double convForce, double newForce);
void updateThresholdBackoff(double alignment);
Expand Down
Loading
Loading