Skip to content

Question regarding comparison of ASE and torch-sim optimization #198

@t-reents

Description

@t-reents

Dear all,

first of all, thanks a lot for making this very nice effort open source!

I'm comparing the result of an optimization with ASE and torch-sim and I'm wondering whether I should expect the differences that I observe (small numerical differences are of course expected). I took the frechet_cell_fire example from the autobatching_tutorial and used the corresponding FrechetCellFilter and Fire implementations in ASE. You can find the initial structure and the final results obtained with torch-sim and ASE below.

The main aspect I didn't expect when naively starting this comparison is that the angles and lattice constants obtained with torch-sim significantly differ from the ones of the initial structure and the ASE result.

Is there something obvious that I'm missing, or are some aspects implemented in a different way?

I'm running on the main branch, i.e., the frechet_cell_fire should use the ASE flavor.

Thanks a lot in advance!

Initial structure

[Structure Summary
Lattice
    abc : 3.211996 3.211996 3.211996
 angles : 59.99999999999999 59.99999999999999 59.99999999999999
 volume : 23.43203403396848
      A : 2.7816701328540017 0.0 1.6059980000000005
      B : 0.9272233776180007 2.622583751953532 1.6059980000000005
      C : 0.0 0.0 3.211996
    pbc : True True True
PeriodicSite: Os (2.782, 1.967, 1.606) [0.75, 0.75, -0.25]
PeriodicSite: N (0.0, 0.0, 0.0) [0.0, 0.0, 0.0]]

torch-sim

[Structure Summary
Lattice
    abc : 3.2776661189912484 3.277666118991247 3.2519298596553377
 angles : 60.44244096341136 60.44244096341136 60.68463155343435
 volume : 24.995080279571642
      A : 2.8484921585612026 -0.013698735928738848 1.621419117573383
      B : 0.9365820940952077 2.6901504072766884 1.621419117573383
      C : -0.005213232697568147 -0.003686312192354059 3.251923591572417
    pbc : True True True
PeriodicSite: Os (2.803, 1.982, 1.611) [0.7403, 0.7403, -0.2429]
PeriodicSite: N (-0.02169, -0.01534, -0.004666) [-0.005725, -0.005725, 0.004274]]

ASE

[Structure Summary
Lattice
    abc : 3.282014574980207 3.2820145749802023 3.2820145749802028
 angles : 59.99999999999996 60.00000000000003 60.00000000000003
 volume : 24.998072211255
      A : 2.842307997523647 -1.7528979539778155e-15 1.641007287490103
      B : 0.9474359991745464 2.679753679026299 1.6410072874901025
      C : -8.083858586323889e-16 1.431220580133293e-15 3.2820145749802028
    pbc : True True True
PeriodicSite: Os (2.842, 2.01, 1.641) [0.75, 0.75, -0.25]
PeriodicSite: N (-1.188e-16, 5.225e-16, -9.997e-16) [-1.068e-16, 1.95e-16, -3.487e-16]]
Script to reproduce
import torch_sim as ts
import torch
from mace.calculators.foundations_models import mace_mp
from torch_sim.models.mace import MaceModel
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from pymatgen.io.ase import AseAtomsAdaptor
from torch_sim.io import state_to_structures
from pymatgen.core import Lattice, Structure


a = 3.211996
lattice = Lattice.from_parameters(a, a, a, 60, 60, 60)
species = ["Os", "N"]
frac_coords = [
    [0.75, 0.75, -0.25],
    [0.00, 0.00,  0.00]
]
structure = Structure(lattice, species, frac_coords, coords_are_cartesian=False)

# Adapted from https://github.com/Radical-AI/torch-sim/blob/main/examples/tutorials/autobatching_tutorial.py#L234
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mace = mace_mp(model="large", return_raw_model=True, default_dtype="float64")
mace_model = MaceModel(model=mace, device=device)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state = ts.initialize_state(structure, device=device, dtype=torch.float64)

fire_init, fire_update = ts.frechet_cell_fire(mace_model)
fire_state = fire_init(state)

batcher = ts.InFlightAutoBatcher(
    model=mace_model,
    memory_scales_with="n_atoms",
    max_memory_scaler=1000,
    max_iterations=300,
)

batcher.load_states(fire_state)
total_states = fire_state.n_batches
convergence_fn = ts.generate_force_convergence_fn(25e-4)

all_converged_states, convergence_tensor = [], None
while (result := batcher.next_batch(fire_state, convergence_tensor))[0] is not None:
    fire_state, converged_states = result
    all_converged_states.extend(converged_states)

    for _ in range(10):
        fire_state = fire_update(fire_state)

    convergence_tensor = convergence_fn(fire_state, None)

else:
    all_converged_states.extend(result[1])

final_states = batcher.restore_original_order(all_converged_states)


# ASE reference optimization

mace_calc = mace_mp(model="large", default_dtype='float64')
ase_relaxed = []

atoms = AseAtomsAdaptor.get_atoms(structure)
atoms_copy = atoms.copy()
atoms_copy.calc = mace_calc

optimizer = FIRE(FrechetCellFilter(atoms_copy), logfile=None)
optimizer.run(fmax=25e-4,  steps=300)

if (abs(atoms_copy.get_forces()).max() < 25e-4):
    ase_relaxed.append(atoms_copy)

ts_final_structures = []
for fs in final_states:
    ts_final_structures.extend(state_to_structures(fs))
    
print('\nInitial structure')
print([structure])
print('\ntorch-sim')
print(ts_final_structures)
print('\nASE')
print([AseAtomsAdaptor.get_structure(a) for a in ase_relaxed])

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinggeo-optGeometry optimization

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions