In [1]:
from pyxtal import pyxtal
from pyxtal.lattice import Lattice
from pymatgen.core import Structure
import pandas as pd
import numpy as np
import json
from pyxtal.tolerance import Tol_matrix
from datetime import datetime
import os
from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer
from ase.io import read
import matgl
from matgl.ext.ase import M3GNetCalculator #, Relaxer
from ase.spacegroup.symmetrize import FixSymmetry
from ase.optimize import FIRE
from ase.constraints import ExpCellFilter
from ase import Atoms
from ase.calculators.calculator import Calculator
from ase.optimize.optimize import Optimizer
from ase.optimize import FIRE, LBFGS, BFGSLineSearch
from ase.io import read, write
from ase.visualize import view
from pymatgen.analysis import structure_matcher
from pymatgen.io.ase import AseAtomsAdaptor
import torch
# from ase.visualize import view
# import ase.optimize as opt
print(torch.__version__)

2.0.1+cpu


# Load data

In [2]:
# get all dft_relaxed index
ind_lst = os.listdir('temp_files/test_compounds/')
ind_lst

['0', '1', '2']

In [3]:
def ase_relaxer(
    atoms_in: Atoms,
    calculator: Calculator,
    optimizer: Optimizer = FIRE,
    cell_filter = None,
    fix_symmetry: bool = True,
    fix_fractional: bool = False,
    hydrostatic_strain: bool = False,
    fmax: float = 0.05,
    steps_limit: int = 500,
    logfile: str = "-",
    wdir: str = "./",
) -> Atoms:
    atoms = atoms_in.copy()
    full_formula = atoms.get_chemical_formula(mode="metal")
    reduced_formula = atoms.get_chemical_formula(mode="metal", empirical=True)
    print(f'relaxing {reduced_formula}_{full_formula}')
    atoms.calc = calculator
    E0 = atoms.get_potential_energy()
    if fix_fractional:
        atoms.set_constraint([FixAtoms(indices=[atom.index for atom in atoms])])
    if fix_symmetry:
        atoms.set_constraint([FixSymmetry(atoms)])
    if cell_filter is not None:
        target = cell_filter(atoms, hydrostatic_strain=hydrostatic_strain)
    else:
        target = atoms
    now = datetime.now()
    strnow = now.strftime("%Y%m%d%H%M%S")
    opt = optimizer(target,
#                     maxstep = max_step,
                    trajectory=f'{wdir}/{reduced_formula}_{full_formula}_{strnow}.traj',
                    logfile=logfile,
                   )
    opt.run(fmax=fmax, steps=steps_limit)
    if cell_filter is None:
        write(filename=f'{wdir}/{reduced_formula}_{full_formula}_fix_cell_relaxed_{strnow}.cif',
          images=atoms,
         format="cif",
         )
    else:
        write(filename=f'{wdir}/{reduced_formula}_{full_formula}_relax_postitions_and_cell_{strnow}.cif',
          images=atoms,
         format="cif",
         )
    cell_diff = (atoms.cell.cellpar() / atoms_in.cell.cellpar() - 1.0) * 100
    E1 = atoms.get_potential_energy()
    # print("Optimized Cell         :", atoms.cell.cellpar())
    # print("Optimized Cell diff (%):", cell_diff)
    # print("Scaled positions       :\n", atoms.get_scaled_positions())
    # print(f"Potential energy before opt: {E0:.4f} eV")
    # print(f"Potential energy after  opt: {E1:.4f} eV")
    
    return atoms

In [4]:
def ase_m3gnet_relax(test_ind,strc_pre_scaled):
    # ase_m3gnet relaxation
    pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
    calculator = M3GNetCalculator(potential=pot,stress_weight=1 / 160.21766208)
    atoms_pyxtal_scaled = AseAtomsAdaptor.get_atoms(strc_pre_scaled)
    wdir = "temp_files/test_compounds/{}".format(str(test_ind))
    # fix cell relaxation
    strc_relax_fixcell = ase_relaxer(
        atoms_in=atoms_pyxtal_scaled,
        calculator=calculator,
        cell_filter=None,
        logfile=wdir+"/strc_volume-scaled_fix-cell_relaxed_atomic_postitions.log",
        wdir=wdir
    )

    # relax both cell and atoms
    strc_relax_cell = ase_relaxer(
        atoms_in=strc_relax_fixcell,
        calculator=calculator,
        cell_filter=ExpCellFilter,
        fix_fractional=False,
        logfile=wdir+"/strc_volume-scaled_relax_postitions_and_cell.log",
        wdir=wdir
    )
    
    return strc_relax_cell

In [5]:
import warnings
warnings.filterwarnings('ignore')

In [6]:
all_strc_m3gnet = {}

for i in ind_lst[:]:
    test_ind = int(i)
    print('test_ind',': ', test_ind)
    
    # read from past file
    wdir = "temp_files/test_compounds/{}".format(str(test_ind))
    file_name = os.listdir(wdir)[0]
    strc_pre_set = Structure.from_file(wdir+'/'+file_name)

    # m3gnet relax
    strc_relax_cell = ase_m3gnet_relax(test_ind,strc_pre_set)
    strc_m3gnet = AseAtomsAdaptor.get_structure(strc_relax_cell)
    strc_m3gnet = SpacegroupAnalyzer(structure=strc_m3gnet #, symprec=0.1
                  ).get_conventional_standard_structure()
    all_strc_m3gnet[test_ind] = strc_m3gnet

    print('-----------------------------------------------')


test_ind :  0
relaxing CaTiO3_CaTiO3
relaxing CaTiO3_CaTiO3
-----------------------------------------------
test_ind :  1
relaxing CaTiO3_Ca6Ti6O18
relaxing CaTiO3_Ca6Ti6O18
-----------------------------------------------
test_ind :  2
relaxing CaTiO3_Ca2Ti2O6
relaxing CaTiO3_Ca2Ti2O6
-----------------------------------------------


In [7]:
all_strc_m3gnet

{0: Structure Summary
 Lattice
     abc : 5.078614072321848 5.124552412011771 9.020916847074343
  angles : 90.0 90.0 90.0
  volume : 234.77498994304136
       A : 5.078614072321848 0.0 3.109754233886828e-16
       B : 8.240908803211169e-16 5.124552412011771 3.137883354216532e-16
       C : 0.0 0.0 9.020916847074343
     pbc : True True True
 PeriodicSite: Ca (8.026e-16, 4.991, 3.056e-16) [0.0, 0.9739, 0.0]
 PeriodicSite: Ca (2.539, 2.429, 3.042e-16) [0.5, 0.4739, 0.0]
 PeriodicSite: Ti (3.084e-18, 0.01918, 4.51) [0.0, 0.003743, 0.5]
 PeriodicSite: Ti (2.539, 2.581, 4.51) [0.5, 0.5037, 0.5]
 PeriodicSite: O (4.413e-16, 2.744, 1.68e-16) [0.0, 0.5355, 0.0]
 PeriodicSite: O (3.809, 1.301, 4.51) [0.75, 0.2538, 0.5]
 PeriodicSite: O (1.27, 1.301, 4.51) [0.25, 0.2538, 0.5]
 PeriodicSite: O (2.539, 0.182, 1.666e-16) [0.5, 0.03552, 0.0]
 PeriodicSite: O (1.27, 3.863, 4.51) [0.25, 0.7538, 0.5]
 PeriodicSite: O (3.809, 3.863, 4.51) [0.75, 0.7538, 0.5],
 1: Structure Summary
 Lattice
     abc : 14