In [1]:
from pyxtal import pyxtal
from pyxtal.lattice import Lattice
from pymatgen.core import Structure
import pandas as pd
import numpy as np
import json
from pyxtal.tolerance import Tol_matrix
from datetime import datetime
import os
from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer
from ase.io import read
import matgl
from matgl.ext.ase import M3GNetCalculator #, Relaxer
import mace
from mace.calculators import mace_mp
from chgnet.model import CHGNet
from chgnet.model import CHGNetCalculator

from ase.spacegroup.symmetrize import FixSymmetry
from ase.optimize import FIRE
from ase.constraints import ExpCellFilter
from ase import Atoms
from ase.calculators.calculator import Calculator
from ase.optimize.optimize import Optimizer
from ase.optimize import FIRE, LBFGS, BFGSLineSearch
from ase.io import read, write
from ase.visualize import view
from pymatgen.analysis import structure_matcher
from pymatgen.io.ase import AseAtomsAdaptor
import torch
# from ase.visualize import view
# import ase.optimize as opt
print(torch.__version__)

2.0.1+cpu


# Load IAPs

In [2]:
# # Use CHGNET
# calculator = CHGNetCalculator()

# Use MACE
calculator = mace_mp(model="medium", dispersion=False, default_dtype="float32", device='cpu')

# # Use M3GNET
# pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
# calculator = M3GNetCalculator(potential=pot,stress_weight=1 / 160.21766208)

Using Materials Project MACE for MACECalculator with C:\Users\Raymo/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.


# Load data

In [3]:
# get all dft_relaxed index
ind_lst = os.listdir('temp_files/test_compounds/')
ind_lst

['1174', '1375']

In [4]:
def ase_relaxer(
    atoms_in: Atoms,
    calculator: Calculator,
    optimizer: Optimizer = FIRE,
    cell_filter = None,
    fix_symmetry: bool = True,
    fix_fractional: bool = False,
    hydrostatic_strain: bool = False,
    fmax: float = 0.05,
    steps_limit: int = 500,
    logfile: str = "-",
    wdir: str = "./",
) -> Atoms:
    atoms = atoms_in.copy()
    full_formula = atoms.get_chemical_formula(mode="metal")
    reduced_formula = atoms.get_chemical_formula(mode="metal", empirical=True)
    print(f'relaxing {reduced_formula}_{full_formula}')
    atoms.calc = calculator
    E0 = atoms.get_potential_energy()
    if fix_fractional:
        atoms.set_constraint([FixAtoms(indices=[atom.index for atom in atoms])])
    if fix_symmetry:
        atoms.set_constraint([FixSymmetry(atoms)])
    if cell_filter is not None:
        target = cell_filter(atoms, hydrostatic_strain=hydrostatic_strain)
    else:
        target = atoms
    now = datetime.now()
    strnow = now.strftime("%Y%m%d%H%M%S")
    opt = optimizer(target,
#                     maxstep = max_step,
                    trajectory=f'{wdir}/{reduced_formula}_{full_formula}_{strnow}.traj',
                    logfile=logfile,
                   )
    opt.run(fmax=fmax, steps=steps_limit)
    if cell_filter is None:
        write(filename=f'{wdir}/{reduced_formula}_{full_formula}_fix_cell_relaxed_{strnow}.cif',
          images=atoms,
         format="cif",
         )
    else:
        write(filename=f'{wdir}/{reduced_formula}_{full_formula}_relax_postitions_and_cell_{strnow}.cif',
          images=atoms,
         format="cif",
         )
    cell_diff = (atoms.cell.cellpar() / atoms_in.cell.cellpar() - 1.0) * 100
    E1 = atoms.get_potential_energy()
    # print("Optimized Cell         :", atoms.cell.cellpar())
    # print("Optimized Cell diff (%):", cell_diff)
    # print("Scaled positions       :\n", atoms.get_scaled_positions())
    # print(f"Potential energy before opt: {E0:.4f} eV")
    # print(f"Potential energy after  opt: {E1:.4f} eV")
    
    return atoms

In [5]:
def ase_iaps_relax(test_ind,strc_pre_scaled,calculator):
    # ase_relaxation
    atoms_pyxtal_scaled = AseAtomsAdaptor.get_atoms(strc_pre_scaled)
    wdir = "temp_files/test_compounds/{}".format(str(test_ind))
    # fix cell relaxation
    strc_relax_fixcell = ase_relaxer(
        atoms_in=atoms_pyxtal_scaled,
        calculator=calculator,
        cell_filter=None,
        logfile=wdir+"/strc_volume-scaled_fix-cell_relaxed_atomic_postitions.log",
        wdir=wdir
    )

    # relax both cell and atoms
    strc_relax_cell = ase_relaxer(
        atoms_in=strc_relax_fixcell,
        calculator=calculator,
        cell_filter=ExpCellFilter,
        fix_fractional=False,
        logfile=wdir+"/strc_volume-scaled_relax_postitions_and_cell.log",
        wdir=wdir
    )
    
    return strc_relax_cell

In [6]:
import warnings
warnings.filterwarnings('ignore')

In [7]:
all_strc_relaxed = {}

for i in ind_lst[:]:
    test_ind = int(i)
    print('test_ind',': ', test_ind)
    
    # read from past file
    wdir = "temp_files/test_compounds/{}".format(str(test_ind))
    file_name = os.listdir(wdir)[0]
    strc_pre_set = Structure.from_file(wdir+'/'+file_name)

    # ase_IAPs relax
    strc_relax_cell = ase_iaps_relax(test_ind,strc_pre_set,calculator)
    strc_relaxed = AseAtomsAdaptor.get_structure(strc_relax_cell)
    strc_relaxed = SpacegroupAnalyzer(structure=strc_relaxed #, symprec=0.1
                  ).get_conventional_standard_structure()
    all_strc_relaxed[test_ind] = strc_relaxed

    print('-----------------------------------------------')


test_ind :  1174
relaxing CaTiO3_Ca4Ti4O12
relaxing CaTiO3_Ca4Ti4O12
-----------------------------------------------
test_ind :  1375
relaxing CaTiO3_Ca4Ti4O12
relaxing CaTiO3_Ca4Ti4O12
-----------------------------------------------


In [8]:
all_strc_relaxed

{1174: Structure Summary
 Lattice
     abc : 5.399744914149461 5.507340693044186 7.690220209620782
  angles : 90.0 90.0 90.0
  volume : 228.69357500915558
       A : 5.399744914149461 0.0 3.3063901626626687e-16
       B : 8.85647930797018e-16 5.507340693044186 3.372273575775264e-16
       C : 0.0 0.0 7.690220209620782
     pbc : True True True
 PeriodicSite: Ca (3.427, 4.372, 4.775e-16) [0.6346, 0.7938, 0.0]
 PeriodicSite: Ca (0.8283, 1.617, 1.497e-16) [0.1534, 0.2936, 0.0]
 PeriodicSite: Ca (0.7267, 1.135, 3.845) [0.1346, 0.2062, 0.5]
 PeriodicSite: Ca (3.528, 3.89, 3.845) [0.6534, 0.7064, 0.5]
 PeriodicSite: Ti (3.486, 1.377, 1.925) [0.6457, 0.25, 0.2503]
 PeriodicSite: Ti (0.7866, 4.131, 5.77) [0.1457, 0.75, 0.7503]
 PeriodicSite: Ti (3.486, 1.377, 5.765) [0.6457, 0.25, 0.7497]
 PeriodicSite: Ti (0.7866, 4.131, 1.92) [0.1457, 0.75, 0.2497]
 PeriodicSite: O (4.6, 5.285, 2.23) [0.8518, 0.9596, 0.29]
 PeriodicSite: O (4.6, 5.285, 5.46) [0.8518, 0.9596, 0.71]
 PeriodicSite: O (3.058, 1.