-
Couldn't load subscription status.
- Fork 56
Description
Dear all,
first of all, thanks a lot for making this very nice effort open source!
I'm comparing the result of an optimization with ASE and torch-sim and I'm wondering whether I should expect the differences that I observe (small numerical differences are of course expected). I took the frechet_cell_fire example from the autobatching_tutorial and used the corresponding FrechetCellFilter and Fire implementations in ASE. You can find the initial structure and the final results obtained with torch-sim and ASE below.
The main aspect I didn't expect when naively starting this comparison is that the angles and lattice constants obtained with torch-sim significantly differ from the ones of the initial structure and the ASE result.
Is there something obvious that I'm missing, or are some aspects implemented in a different way?
I'm running on the main branch, i.e., the frechet_cell_fire should use the ASE flavor.
Thanks a lot in advance!
Initial structure
[Structure Summary
Lattice
abc : 3.211996 3.211996 3.211996
angles : 59.99999999999999 59.99999999999999 59.99999999999999
volume : 23.43203403396848
A : 2.7816701328540017 0.0 1.6059980000000005
B : 0.9272233776180007 2.622583751953532 1.6059980000000005
C : 0.0 0.0 3.211996
pbc : True True True
PeriodicSite: Os (2.782, 1.967, 1.606) [0.75, 0.75, -0.25]
PeriodicSite: N (0.0, 0.0, 0.0) [0.0, 0.0, 0.0]]
torch-sim
[Structure Summary
Lattice
abc : 3.2776661189912484 3.277666118991247 3.2519298596553377
angles : 60.44244096341136 60.44244096341136 60.68463155343435
volume : 24.995080279571642
A : 2.8484921585612026 -0.013698735928738848 1.621419117573383
B : 0.9365820940952077 2.6901504072766884 1.621419117573383
C : -0.005213232697568147 -0.003686312192354059 3.251923591572417
pbc : True True True
PeriodicSite: Os (2.803, 1.982, 1.611) [0.7403, 0.7403, -0.2429]
PeriodicSite: N (-0.02169, -0.01534, -0.004666) [-0.005725, -0.005725, 0.004274]]
ASE
[Structure Summary
Lattice
abc : 3.282014574980207 3.2820145749802023 3.2820145749802028
angles : 59.99999999999996 60.00000000000003 60.00000000000003
volume : 24.998072211255
A : 2.842307997523647 -1.7528979539778155e-15 1.641007287490103
B : 0.9474359991745464 2.679753679026299 1.6410072874901025
C : -8.083858586323889e-16 1.431220580133293e-15 3.2820145749802028
pbc : True True True
PeriodicSite: Os (2.842, 2.01, 1.641) [0.75, 0.75, -0.25]
PeriodicSite: N (-1.188e-16, 5.225e-16, -9.997e-16) [-1.068e-16, 1.95e-16, -3.487e-16]]Script to reproduce
import torch_sim as ts
import torch
from mace.calculators.foundations_models import mace_mp
from torch_sim.models.mace import MaceModel
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from pymatgen.io.ase import AseAtomsAdaptor
from torch_sim.io import state_to_structures
from pymatgen.core import Lattice, Structure
a = 3.211996
lattice = Lattice.from_parameters(a, a, a, 60, 60, 60)
species = ["Os", "N"]
frac_coords = [
[0.75, 0.75, -0.25],
[0.00, 0.00, 0.00]
]
structure = Structure(lattice, species, frac_coords, coords_are_cartesian=False)
# Adapted from https://github.com/Radical-AI/torch-sim/blob/main/examples/tutorials/autobatching_tutorial.py#L234
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mace = mace_mp(model="large", return_raw_model=True, default_dtype="float64")
mace_model = MaceModel(model=mace, device=device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state = ts.initialize_state(structure, device=device, dtype=torch.float64)
fire_init, fire_update = ts.frechet_cell_fire(mace_model)
fire_state = fire_init(state)
batcher = ts.InFlightAutoBatcher(
model=mace_model,
memory_scales_with="n_atoms",
max_memory_scaler=1000,
max_iterations=300,
)
batcher.load_states(fire_state)
total_states = fire_state.n_batches
convergence_fn = ts.generate_force_convergence_fn(25e-4)
all_converged_states, convergence_tensor = [], None
while (result := batcher.next_batch(fire_state, convergence_tensor))[0] is not None:
fire_state, converged_states = result
all_converged_states.extend(converged_states)
for _ in range(10):
fire_state = fire_update(fire_state)
convergence_tensor = convergence_fn(fire_state, None)
else:
all_converged_states.extend(result[1])
final_states = batcher.restore_original_order(all_converged_states)
# ASE reference optimization
mace_calc = mace_mp(model="large", default_dtype='float64')
ase_relaxed = []
atoms = AseAtomsAdaptor.get_atoms(structure)
atoms_copy = atoms.copy()
atoms_copy.calc = mace_calc
optimizer = FIRE(FrechetCellFilter(atoms_copy), logfile=None)
optimizer.run(fmax=25e-4, steps=300)
if (abs(atoms_copy.get_forces()).max() < 25e-4):
ase_relaxed.append(atoms_copy)
ts_final_structures = []
for fs in final_states:
ts_final_structures.extend(state_to_structures(fs))
print('\nInitial structure')
print([structure])
print('\ntorch-sim')
print(ts_final_structures)
print('\nASE')
print([AseAtomsAdaptor.get_structure(a) for a in ase_relaxed])