In [1]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
from monty.serialization import dumpfn, loadfn
from collections import defaultdict

In [3]:
from ase.filters import FrechetCellFilter
from ase.optimize import LBFGS
from pymatgen.io.ase import AseAtomsAdaptor
from fairchem.core import pretrained_mlip, FAIRChemCalculator
# from mace.calculators import mace_mp # need e3nn 0.4.4
# from sevenn.calculator import SevenNetCalculator # need e3nn 0.5.0 later change numpy to 1.26.4

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import torch._dynamo
torch._dynamo.config.suppress_errors = True

In [5]:
MODEL = "uma" # "orb" or "uma" or "mace" or "seven" or "mattersim"

In [6]:
# orb_ff = pretrained.orb_v3_conservative_inf_omat(device="cuda",precision="float32-high")   # or "float32-highest" / "float64
# orb_calc = ORBCalculator(orb_ff, device="cuda")

uma_ff = pretrained_mlip.get_predict_unit("uma-m-1p1", device="cuda")
uma_calc = FAIRChemCalculator(uma_ff, task_name="omat")

# mace_calc = mace_mp(model="MACE-matpes-r2scan-omat-ft.model", dispersion=False, device='cuda')

# seven_calc = SevenNetCalculator('7net-mf-ompa', modal='omat24')

# mattersim_calc = MatterSimCalculator(device="cuda")

MODELS = {
    # "orb":orb_calc,
    "uma":uma_calc,
    # "mace":mace_calc,
    # "seven":seven_calc,
    # "mattersim":mattersim_calc,
}

In [7]:
# Function used to relax initial primitive cell and for competing phase calculation, full relaxation using FrechetCellFilter
def relax_structure(structure):
    atoms = structure.to_ase_atoms()
    atoms.calc = MODELS[MODEL]

    opt = LBFGS(FrechetCellFilter(atoms)) # Delete this FrechetCellFilter when doing the defect structures
    opt.run(0.01, 1000)

    relaxed_structure = AseAtomsAdaptor.get_structure(atoms)
    energy = atoms.get_potential_energy()

    return relaxed_structure, energy
    

In [8]:
# Function that relaxes defect and pristine supercells, ensuring a fixed lattice by no FrechetCellFilter
def relax_defect(structure): # takes in pymatgen structure
    atoms = structure.to_ase_atoms()
    atoms.calc = MODELS[MODEL]

    opt = LBFGS(atoms) # Delete this FrechetCellFilter when doing the defect structures
    opt.run(0.01, 1000)

    relaxed_structure = AseAtomsAdaptor.get_structure(atoms)
    energy = atoms.get_potential_energy()

    return relaxed_structure, energy

## Relaxation of Primitive Al Cells

In [8]:
matching_pairs = loadfn('../../data/matching_AlCr_ox.json')

In [None]:
del matching_pairs['mp-554152']

In [3]:
relaxed_Al_prim = {}
for name, val in matching_pairs.items():
    structure, energy = relax_structure(val['structure'])
    relaxed_Al_prim[name] = {
        'formula': val['formula'],
        'energy': energy,
        'structure': structure
    }

NameError: name 'matching_pairs' is not defined

In [17]:
dumpfn(relaxed_Al_prim, f'../../data/{MODEL}/{MODEL}_relaxed_Al_prim.json')

## Defect Creation and Relaxation

Load in all relaxed primitive Al Oxide structures
<br>Loop through all Al oxides 
<br>Create function using doped and shakenbreak for each structure
<br>Create function that loops over unpertubed and all disortions and uses MLIP to calc energies - append structure and energies to data dict
<br>get energy difference from unperturbed for each Al oxide and append to dictionary or list for bar chart - can do this manually in loop 
<br>dumpfn the data dict and list of energy differences, or maybe add energy diff to data dict as 'delta_e' key?
<br>Plot bar chart for all metal oxides and their energy relative to unperturbed

In [9]:
from doped.generation import DefectsGenerator
from shakenbreak.input import Distortions
from pymatgen.entries.computed_entries import ComputedEntry
from doped.core import DefectEntry
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
from pymatgen.core import Element

In [10]:
supercell_size = 10.1

In [11]:
MP_U_VALUES = {
    "Co": 3.32,
    "Cr": 3.7,
    "Fe": 5.3,
    "Mn": 3.98,
    "Mo": 4.38,
    "Ni": 6.2,
    "V": 3.25,
    "W": 6.2,
}
from pymatgen.core import Species

def add_parameters(entry):
    elements = entry.composition.elements
    elements = [el.element if isinstance(el, Species) else el for el in elements]
    parameters = {}
    if any([Element(el) in elements for el in MP_U_VALUES]):
        parameters["run_type"] = "GGA+U"
        parameters["hubbards"] = {el.name: MP_U_VALUES.get(el.symbol, 0.0) for el in elements}
        parameters["is_hubbard"] = True
    else:
        parameters["run_type"] = "GGA"
        parameters["hubbards"] = None
        parameters["is_hubbard"] = False
    entry.parameters = parameters
    return entry

In [12]:
def create_defect(structure):
    defect_gen = DefectsGenerator(structure, extrinsic="Cr", interstitial_gen_kwargs=False, supercell_gen_kwargs={'min_image_distance':supercell_size} ) # set large super cell 25 Angstroms 'supercell_gen_kwargs={'min_image_distance':15}'
    valid_keys = []
    for key in defect_gen.keys():
        if  key[:5] == 'Cr_Al' and key[-1] == '0':
            valid_keys.append(key)
    distortions = Distortions(
        defect_entries = {key: defect_gen[key] for key in valid_keys}, # NO Cr_Al_0 for Al2(SO4)3
        # oxidation_states={} # oxidation states predicted are good for now
    )
    distorted_defects_dict, distortions.distortion_metadata = distortions.apply_distortions()
    unperturbed_structures = {}
    defect_distortions = {}
    for key in valid_keys:
        unperturbed_structures[key] = distorted_defects_dict[key[:-2]]['charges'][0]['structures']['Unperturbed']
        defect_distortions[key] = distorted_defects_dict[key[:-2]]['charges'][0]['structures']['distortions']

    return unperturbed_structures, defect_distortions, defect_gen.bulk_supercell

In [13]:
def make_entry(structure, energy, material_id):
    new_entry = ComputedEntry(
        structure.composition,
        energy, 
        data={'material_id':material_id}
    )
    new_entry.structure = structure
    new_entry = add_parameters(new_entry)
    compat = MaterialsProject2020Compatibility(check_potcar=False, check_potcar_hash=False)
    compat.process_entry(new_entry, on_error="warn")
    return new_entry

In [14]:
def get_energy(unperturbed_structures, distortions, supercell, material_id):
    data = defaultdict(dict)
    supercell_struct, supercell_e = relax_defect(supercell)
    new_entries_sup = make_entry(supercell_struct, supercell_e, material_id)
    data['supercell'] = {
        'structure': new_entries_sup.structure,
        'energy': new_entries_sup.energy
    }
    
    for site, unperturbed_structure in unperturbed_structures.items():
        data[site] = defaultdict(dict)
        data[site]['unperturbed'] = defaultdict(dict)
        unperturbed_structure1, unp_energy = relax_defect(unperturbed_structure)
        new_entries_unp = make_entry(unperturbed_structure1, unp_energy, material_id)
        data[site]['unperturbed'] = {
        'structure':new_entries_unp.structure,
        'energy': new_entries_unp.energy,
        'distortion_amount':'N/A'
        }
        data[site]['distortions'] = defaultdict(dict)
        for key, structure in distortions[site].items():
            structure1, energy = relax_defect(structure)
            new_entries_def = make_entry(structure1, energy, material_id)
            if key != 'Rattled':
                distortion_amount = float(key.split("_")[-1].replace("%", "").replace("-", "-"))/100
            else:
                distortion_amount = 0
            data[site]['distortions'][key] = {
                'structure': new_entries_def.structure,
                'energy': new_entries_def.energy,
                'distortion_amount': distortion_amount,
                'delta_e': new_entries_def.energy - data[site]['unperturbed']['energy']
            }
        
    return dict(data)

In [15]:
Al_prim = loadfn(f'../../data/{MODEL}/{MODEL}_relaxed_Al_prim.json')

In [21]:
for key, val in Al_prim.items():
    unperturbed_structures, distortions, supercell = create_defect(Al_prim[key]['structure'])
    mini_data = get_energy(unperturbed_structures, distortions, supercell, key) # add in host supercell relaxation move back from the front then see whats gonna happen 
    Al_prim[key]['defects'] = mini_data

Generating DefectEntry objects: 100.0%|████████████████████████████████████████████████████████████| [00:36,   2.71it/s]


Vacancies    Guessed Charges    Conv. Cell Coords    Wyckoff
-----------  -----------------  -------------------  ---------
v_Mn         [+1,0,-1,-2]       [0.000,0.000,0.000]  8a
v_Al         [+1,0,-1,-2,-3]    [0.625,0.625,0.625]  16d
v_O          [+2,+1,0,-1]       [0.392,0.392,0.392]  32e

Substitutions    Guessed Charges     Conv. Cell Coords    Wyckoff
---------------  ------------------  -------------------  ---------
Mn_Al            [+1,0,-1]           [0.625,0.625,0.625]  16d
Mn_O             [+5,+4,+3,+2,+1,0]  [0.392,0.392,0.392]  32e
Al_Mn            [+1,0]              [0.000,0.000,0.000]  8a
Al_O             [+5,+4,+3,+2,+1,0]  [0.392,0.392,0.392]  32e
O_Mn             [0,-1,-2,-3,-4]     [0.000,0.000,0.000]  8a
O_Al             [0,-1,-2,-3,-4,-5]  [0.625,0.625,0.625]  16d
Cr_Mn            [+4,+3,+2,+1,0]     [0.000,0.000,0.000]  8a
Cr_Al            [+3,+2,+1,0,-1]     [0.625,0.625,0.625]  16d
Cr_O             [+5,+4,+3,+2,+1,0]  [0.392,0.392,0.392]  32e

The number in t

Generating distorted defect structures...0.0%|                                                         | [00:00,  ?it/s]

[1m
Defect: Cr_Al[0m
[1mNumber of missing electrons in neutral state: 0[0m

Defect Cr_Al in charge state: 0. Number of distorted neighbours: 0


Generating distorted defect structures...100.0%|███████████████████████████████████████████████████| [00:00,   4.22it/s]


       Step     Time          Energy          fmax
LBFGS:    0 16:15:47     -637.992589        0.002540




       Step     Time          Energy          fmax
LBFGS:    0 16:15:50     -638.893886        1.511544
LBFGS:    1 16:15:54     -639.055164        0.993409
LBFGS:    2 16:15:58     -639.209133        0.257964
LBFGS:    3 16:16:02     -639.228588        0.212619
LBFGS:    4 16:16:06     -639.245067        0.192317
LBFGS:    5 16:16:10     -639.253032        0.162471
LBFGS:    6 16:16:14     -639.258854        0.131837
LBFGS:    7 16:16:18     -639.262493        0.138895
LBFGS:    8 16:16:22     -639.265606        0.114248
LBFGS:    9 16:16:26     -639.267887        0.066589
LBFGS:   10 16:16:30     -639.269352        0.035138
LBFGS:   11 16:16:34     -639.270244        0.028556
LBFGS:   12 16:16:38     -639.270809        0.033967
LBFGS:   13 16:16:42     -639.271175        0.033382
LBFGS:   14 16:16:46     -639.271427        0.017055
LBFGS:   15 16:16:50     -639.271564        0.012358
LBFGS:   16 16:16:54     -639.271633        0.014939
LBFGS:   17 16:16:58     -639.271663        0.01



       Step     Time          Energy          fmax
LBFGS:    0 16:17:11     -577.422138       17.509613
LBFGS:    1 16:17:15     -605.808315        6.822168
LBFGS:    2 16:17:19     -625.884228        3.812242
LBFGS:    3 16:17:23     -632.166546        3.504947
LBFGS:    4 16:17:27     -634.703449        2.561366
LBFGS:    5 16:17:31     -636.686771        1.920092
LBFGS:    6 16:17:34     -637.772358        1.119458
LBFGS:    7 16:17:38     -638.374248        0.872125
LBFGS:    8 16:17:42     -638.726078        0.650965
LBFGS:    9 16:17:46     -638.964199        0.458036
LBFGS:   10 16:17:50     -639.077167        0.360105
LBFGS:   11 16:17:54     -639.142536        0.356502
LBFGS:   12 16:17:58     -639.185581        0.258393
LBFGS:   13 16:18:02     -639.213680        0.166879
LBFGS:   14 16:18:06     -639.231960        0.147355
LBFGS:   15 16:18:10     -639.243343        0.131382
LBFGS:   16 16:18:14     -639.250538        0.096999
LBFGS:   17 16:18:18     -639.255779        0.08



In [22]:
for key,val in Al_prim.items(): # goes through mp id's
    if 'minimum energy site' not in val.keys():
        min_energy = [float('inf'), '']
        for key1, val1 in val['defects'].items(): # goes through supercell and sites
            if key1 != 'supercell' and key1 != 'formula':
                for key2, val2 in val1.items(): 
                    if key2 == 'unperturbed':
                        if val2['energy'] < min_energy[0]:
                            min_energy[0] = val2['energy']
                            min_energy[1] = [key1, key2]
                    else:
                        for key3, val3 in val2.items():
                            if val3['energy'] < min_energy[0]:
                                min_energy[0] = val3['energy']
                                min_energy[1] = [key1, key2, key3]
        Al_prim[key]['minimum energy site'] = {
            'energy': min_energy[0],
            'information': min_energy[1]
        }

In [28]:
dumpfn(Al_prim, f'../../data/{MODEL}/{MODEL}_corrected_doped_defect_energies_{supercell_size}A.json')

In [None]:
compounds = []
delta_e = []
for key, val in data_load.items():
    min_energy = float('inf')
    for key1, val1 in val.items():
        if key1 != 'supercell' and key1 != 'formula' and key1 != 'minimum energy site' and key1 != 'defects':
            for key2, val2 in val1.items():
                if key2 != 'unperturbed':
                    for key3, val3 in val2.items():
                        if val3['energy'] < min_energy:
                            min_energy = val3['energy']
                            min_delta_e = val3['delta_e'] # filters for lowest energy rattled site, may be higher or lower than unperturbed
    delta_e.append(min_delta_e)
    compounds.append(data_load[key]['formula'])

In [28]:
import matplotlib.pyplot as plt

In [None]:
fig, ax = plt.subplots()
bars = ax.bar(compounds, delta_e)  # Assuming you have these lists already

# Add labels on top of each bar
ax.set_title(f"Energy Difference of Rattled Structure with Respect to Unperturbed Structure with {MODEL} per atom")
ax.set_ylabel("$\Delta$E(eV)")
plt.xticks(rotation=90)
plt.tight_layout()
plt.savefig(f"../../figures/{MODEL}/{MODEL}_corrected_energy_difference_plot_{supercell_size}A.png", dpi=300, bbox_inches='tight')

plt.show()