Notebook adapted from: https://github.com/WMD-group/SMACT/blob/master/docs/tutorials/filtering_icsd_oxidation_states.ipynb

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/WMD-group/SMACT/blob/master/docs/tutorials/filtering_icsd_oxidation_states.ipynb)

In [100]:
import json
import math
import re
from collections import defaultdict
from itertools import combinations_with_replacement, product
from pathlib import Path

import numpy as np
import pandas as pd
import smact
from pymatgen.core import Composition, Element, Species
from smact.utils.oxidation import ICSD24OxStatesFilter

from lemat_genbench.utils.oxidation_state import (
    build_oxi_dict,
    build_oxi_dict_probs,
    build_oxi_state_map,
    build_sorted_oxi_dict,
)

In [216]:
import smact
from smact.screening import smact_filter, smact_validity

In [214]:
space = smact.element_dictionary(["Mn"])
smact_elems = [e[1] for e in space.items()]

In [215]:
smact_elems[0].oxidation_states 

[-1, 1, 2, 3, 4, 5, 6, 7]

In [197]:


space = smact.element_dictionary(elem_symbols)
smact_elems = [e[1] for e in space.items()]

NameError: name 'elem_symbols' is not defined

In [2]:
# Initialise the oxidation state filter
ox_filter = ICSD24OxStatesFilter()

The above code output presents a Dataframe of the occurrences of elements in particular oxidation states mined from the ICSD in September 2024. Here the results count refers to the number of structures in the ICSD that contain the element in a given oxidation state. The mininig process was quite simple, as such there are many 0 values in the table. We can use the built-in method `get_species_occurrences_df` to return a dataframe of the elements, the ionic species reported for a given element, the number of occurrences of a species in the ICSD and the species' proportion with respect to all ionic species of a given element.

In [3]:
# Return the dataframe with non-zero results
test = ox_filter.get_species_occurrences_df(sort_by_occurrences=False)

In [4]:
test.element.unique()

array(['H', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Na', 'Mg', 'Al', 'Si',
       'P', 'S', 'Cl', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co',
       'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr',
       'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In',
       'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd',
       'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu',
       'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb',
       'Bi', 'Po', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm',
       'Bk', 'Cf', 'Es'], dtype=object)

In [6]:
test["species_proportion_fraction"] = test["species_proportion (%)"]/100

In [7]:
test_subcols = test[["species", "species_proportion_fraction"]]

In [8]:
icsd_dict = test_subcols.set_index("species")["species_proportion_fraction"].to_dict()

In [11]:
for i in range(len(test)):
    print(test.iloc[i])

element                               H
species                              H-
results_count                      3442
species_proportion (%)         9.136274
species_proportion_fraction    0.091363
Name: 0, dtype: object
element                                H
species                               H+
results_count                      34232
species_proportion (%)         90.863726
species_proportion_fraction     0.908637
Name: 1, dtype: object
element                           Li
species                          Li+
results_count                  10513
species_proportion (%)         100.0
species_proportion_fraction      1.0
Name: 2, dtype: object
element                           Be
species                         Be2+
results_count                   1131
species_proportion (%)         100.0
species_proportion_fraction      1.0
Name: 3, dtype: object
element                               B
species                             B5-
results_count                         3
species_propo

In [13]:
with open("icsd_oxi_dict_probs.json", "w") as f:
    json.dump(icsd_dict, f, indent=4)

In [19]:
pattern = re.compile(r"([A-Za-z]+)(\d*)([+-])")

In [36]:
oxi_state_mapping = defaultdict(list)

In [39]:
for species in icsd_dict.keys():
    match = pattern.fullmatch(species)
    element, number, sign = match.groups()
    try:
        number = int(number)
    except ValueError:
        number = 1 
    if sign == "+":
        charge = number
    else:
        charge = -number
    oxi_state_mapping[element].append(charge)

In [44]:
with open("icsd_oxi_state_mapping.json", "w") as f:
    json.dump(oxi_state_mapping, f, indent=4)

In [173]:
with open("oxi_state_mapping.json", "rb") as f:
    lemat_oxi_state_mapping = json.load(f)

In [174]:
with open("oxi_dict_probs.json", "rb") as f:
    lemat_oxi_dict_probs = json.load(f)

In [176]:
icsd_dict

{'H-': 0.09136274353665658,
 'H+': 0.9086372564633433,
 'Li+': 1.0,
 'Be2+': 1.0,
 'B5-': 0.0004818503051718599,
 'B3-': 0.03726309026662383,
 'B2-': 0.0059428204304529395,
 'B-': 0.006103437198843559,
 'B+': 0.0056215868936717,
 'B2+': 0.004497269514937359,
 'B3+': 0.9400899453902987,
 'C4-': 0.0572118650464186,
 'C3-': 0.011849950939693564,
 'C2-': 0.09306362744358065,
 'C-': 0.02075628349309382,
 'C+': 0.016755981583515735,
 'C2+': 0.22590384179938108,
 'C3+': 0.10113970865725715,
 'C4+': 0.4733187410370594,
 'N5-': 0.0005673436968115284,
 'N4-': 0.00039714058776806993,
 'N3-': 0.7933166912515602,
 'N2-': 0.024452513332576874,
 'N-': 0.03670713718370589,
 'N+': 0.02121865426075116,
 'N2+': 0.006808124361738341,
 'N3+': 0.042040167933734246,
 'N4+': 0.0012481561329853626,
 'N5+': 0.07324407125836832,
 'O2-': 0.9955718300263987,
 'O-': 0.00440262283913821,
 'O2+': 2.5547134463084392e-05,
 'F-': 1.0,
 'Na+': 1.0,
 'Mg2+': 1.0,
 'Al3-': 0.00019969380283565196,
 'Al3+': 0.999800306197164

In [175]:
for key in oxi_state_mapping:
    if key in lemat_oxi_state_mapping.keys():
        pass
    else:
        lemat_oxi_state_mapping[key] = oxi_state_mapping[key]
        for charge in oxi_state_mapping[key]:
            new_key = str(key)
            if charge > 0:
                if charge > 1:
                    new_key += str(charge)
                new_key += "+"
            if charge < 0:
                if charge < -1: 
                    new_key += str(np.abs(charge))
                new_key += "-"
            print(new_key)
            lemat_oxi_dict_probs[new_key] = icsd_dict[new_key]

Kr2+
Tc+
Tc2+
Tc3+
Tc4+
Tc5+
Tc6+
Tc7+
Xe2+
Xe4+
Xe6+
Xe8+
Pm3+
Os+
Os2+
Os3+
Os4+
Os5+
Os6+
Os7+
Os8+
Pt2-
Pt+
Pt2+
Pt3+
Pt4+
Pt5+
Pt6+
Au-
Au+
Au2+
Au3+
Au5+
Po4+
Ra2+
Ac3+
Pa3+
Pa4+
Pa5+
Np2+
Np3+
Np4+
Np5+
Np6+
Np7+
Pu2+
Pu3+
Pu4+
Pu5+
Pu6+
Pu7+
Am2+
Am3+
Am4+
Am5+
Am6+
Cm3+
Cm4+
Bk3+
Bk4+
Cf3+
Es3+


In [179]:
with open("lemat_icsd_oxi_state_mapping.json", "w") as f:
    json.dump(lemat_oxi_state_mapping, f, indent=4)

In [180]:
with open("lemat_icsd_oxi_dict_probs.json", "w") as f:
    json.dump(lemat_oxi_dict_probs, f, indent=4)

In [181]:
with open("lemat_icsd_oxi_state_mapping.json", "rb") as f:
    lemat_oxi_state_mapping = json.load(f)

In [182]:
with open("lemat_icsd_oxi_dict_probs.json", "rb") as f:
    lemat_oxi_dict_probs = json.load(f)

In [189]:
def compositional_oxi_state_guesses(
    comp,
    all_oxi_states: bool,
    max_sites: int | None,
    oxi_states_override: dict[str, list] | None,
    target_charge: float,
) -> tuple[tuple, tuple, tuple]:
    """Utility operation for guessing oxidation states. 
    Adapted from the _get_oxi_state_guesses function from Pymatgen.core.Composition

    See `oxi_state_guesses` for full details. This operation does the
    calculation of the most likely oxidation states

    Args:
        comp: A Pymatgen composition object.
        oxi_states_override (dict): dict of str->list to override an element's common oxidation states, e.g.
            {"V": [2,3,4,5]}.
        target_charge (float): the desired total charge on the structure. Default is 0 signifying charge balance.
        all_oxi_states (bool): if True, all oxidation states of an element, even rare ones, are used in the search
            for guesses. However, the full oxidation state list is *very* inclusive and can produce nonsensical
            results. If False, the icsd_oxidation_states list is used when present, or the common_oxidation_states
            is used when icsd_oxidation_states is not present. These oxidation states lists comprise more
            commonly occurring oxidation states and results in more reliable guesses, albeit at the cost of
            missing some uncommon situations. The default is False.
        max_sites (int): if possible, will reduce Compositions to at most
            this many sites to speed up oxidation state guesses. If the
            composition cannot be reduced to this many sites a ValueError
            will be raised. Set to -1 to just reduce fully. If set to a
            number less than -1, the formula will be fully reduced but a
            ValueError will be thrown if the number of atoms in the reduced
            formula is greater than abs(max_sites).

    Returns:
        list[dict]: Each dict maps the element symbol to a list of
            oxidation states for each site of that element. For example, Fe3O4 could
            return a list of [2,2,2,3,3,3] for the oxidation states of the 6 Fe sites.
            If the composition is not charge balanced, an empty list is returned.
    """
    # Reduce Composition if necessary
    if max_sites and max_sites < 0:
        comp = comp.reduced_composition

        if max_sites < -1 and comp.num_atoms > abs(max_sites):
            raise ValueError(
                f"Composition {comp} cannot accommodate max_sites setting!"
            )

    elif max_sites and comp.num_atoms > max_sites:
        reduced_comp, reduced_factor = comp.get_reduced_composition_and_factor()
        if reduced_factor > 1:
            reduced_comp *= max(1, int(max_sites / reduced_comp.num_atoms))
            comp = reduced_comp  # as close to max_sites as possible
        if comp.num_atoms > max_sites:
            raise ValueError(
                f"Composition {comp} cannot accommodate max_sites setting!"
            )

    # Load prior probabilities of oxidation states, used to rank solutions

    with open("lemat_icsd_oxi_dict_probs.json", "r") as f:
        loaded_dict = json.load(f)
    type(comp).oxi_prob = loaded_dict
    oxi_states_override = oxi_states_override or {}
    # Assert Composition only has integer amounts
    if not all(amt == int(amt) for amt in comp.values()):
        raise ValueError(
            "Charge balance analysis requires integer values in Composition!"
        )

    # For each element, determine all possible sum of oxidations
    # (taking into account nsites for that particular element)
    el_amt = comp.get_el_amt_dict()
    n_sites = int(sum(el_amt.values()))
    elements = list(el_amt)
    el_sums: list = []  # matrix: dim1= el_idx, dim2=possible sums
    el_sum_scores: defaultdict = defaultdict(set)  # dict of el_idx, sum -> score
    el_best_oxid_combo: dict = {}  # dict of el_idx, sum -> oxid combo with best score
    for idx, el in enumerate(elements):
        el_sum_scores[idx] = {}
        el_best_oxid_combo[idx] = {}
        el_sums.append([])
        if oxi_states_override.get(el):
            oxids: list | tuple = oxi_states_override[el]
        elif all_oxi_states:
            oxids = Element(el).oxidation_states
        else:
            oxids = (
                Element(el).icsd_oxidation_states or Element(el).common_oxidation_states
            )

        # Get all possible combinations of oxidation states
        # and sum each combination
        for oxid_combo in combinations_with_replacement(oxids, int(el_amt[el])):
            # check to make sure none of the oxidation states deviate by more than 1 
            if max(oxid_combo) - min(oxid_combo) <= 1: 
                # print(oxid_combo)
                # List this sum as a possible option
                oxid_sum = sum(oxid_combo)
                if oxid_sum not in el_sums[idx]:
                    el_sums[idx].append(oxid_sum)
                # Determine how probable is this combo?
                scores = []
                for o in oxid_combo:
                    scores.append(type(comp).oxi_prob[str(Species(el, o))])
                # print(scores)
                score = math.prod(scores)
                # print(score)
                # If it is the most probable combo for a certain sum,
                # store the combination
                if oxid_sum not in el_sum_scores[idx] or score > el_sum_scores[idx].get(
                    oxid_sum, 0
                ):
                    if max(oxid_combo) - min(oxid_combo) > 1:
                        pass
                    else:
                        el_sum_scores[idx][oxid_sum] = score
                        el_best_oxid_combo[idx][oxid_sum] = oxid_combo
            else:
                pass
    
    # for i in el_sum_scores:
        # el_sum_scores[i] = {k: v / sum(el_sum_scores[i].values()) for k, v in el_sum_scores[i].items()}
        # print(sum(el_sum_scores[i].values()))
    
    # Determine which combination of oxidation states for each element
    # is the most probable

    el_sums = [[x for x in sublist if x != 0] for sublist in el_sums]
    
    all_sols = []  # will contain all solutions
    all_oxid_combo = []  # will contain the best combination of oxidation states for each site
    all_scores = []  # will contain a score for each solution
    for x in product(*el_sums):
        # Each x is a trial of one possible oxidation sum for each element
        if sum(x) == target_charge:  # charge balance condition
            el_sum_sol = dict(zip(elements, x, strict=True))  # element->oxid_sum
            # Normalize oxid_sum by amount to get avg oxid state
            sol = {el: v / el_amt[el] for el, v in el_sum_sol.items()}
            # Add the solution to the list of solutions

                        
            all_sols.append(sol)

            # Determine the score for this solution
            scores = []
            for idx, v in enumerate(x):
                scores.append(el_sum_scores[idx][v])
            # the score is the minimum of the scores of each of the oxidation states in the composition - the goal is to find a charge
            # balanced oxidation state which limits the occurance of very uncommon oxidation states
            # print(n_sites)
            print(scores)
            all_scores.append(math.prod(scores)**(1/n_sites))
            # all_scores.append(np.log(math.prod(scores))/n_sites)
            # Collect the combination of oxidation states for each site
            all_oxid_combo.append(
                {
                    e: el_best_oxid_combo[idx][v]
                    for idx, (e, v) in enumerate(zip(elements, x, strict=True))
                }
            )
    # Sort the solutions from highest to lowest score
    if all_scores:
        all_sols, all_oxid_combo = zip(
            *(
                (y, x)
                for (z, y, x) in sorted(
                    zip(all_scores, all_sols, all_oxid_combo, strict=True),
                    key=lambda pair: pair[0],
                    reverse=True,
                )
            ),
            strict=True,
        )
    return (
        tuple(all_sols),
        tuple(all_oxid_combo),
        tuple(sorted(all_scores, reverse=True)),
    )


In [190]:
with open("lemat_icsd_oxi_dict_probs.json", "r") as f:
    oxi_state_mapping = json.load(f)

In [194]:
oxi_state_mapping["Mo"]

[2, 3, 4, 5, 6]

In [191]:
comp = Composition("MoOsP2")

In [192]:
oxi_states_override = {}
for e in comp.elements:
    if str(e) in oxi_state_mapping:
        oxi_states_override[str(e)] = oxi_state_mapping[str(e)]

output = compositional_oxi_state_guesses(
    comp,
    all_oxi_states=False,
    max_sites=-1,
    target_charge=0,
    oxi_states_override=oxi_states_override,
)

[0.1436329560379558, 0.01845018450184502, 0.012271096619922659]
[0.1436329560379558, 0.13653136531365315, 0.012614544691108853]
[0.1436329560379558, 0.08487084870848709, 0.04071830247243536]
[0.1436329560379558, 0.1660516605166052, 0.13143400707956862]
[0.23576287944748608, 0.01845018450184502, 0.012614544691108853]
[0.23576287944748608, 0.13653136531365315, 0.04071830247243536]
[0.23576287944748608, 0.08487084870848709, 0.13143400707956862]
[0.22736596667798062, 0.01845018450184502, 0.04071830247243536]
[0.22736596667798062, 0.13653136531365315, 0.13143400707956862]
[0.12469931968773519, 0.01845018450184502, 0.13143400707956862]


In [193]:
output

(({'Mo': 4.0, 'Os': 2.0, 'P': -3.0},
  {'Mo': 2.0, 'Os': 4.0, 'P': -3.0},
  {'Mo': 3.0, 'Os': 3.0, 'P': -3.0},
  {'Mo': 3.0, 'Os': 2.0, 'P': -2.5},
  {'Mo': 2.0, 'Os': 3.0, 'P': -2.5},
  {'Mo': 5.0, 'Os': 1.0, 'P': -3.0},
  {'Mo': 2.0, 'Os': 2.0, 'P': -2.0},
  {'Mo': 4.0, 'Os': 1.0, 'P': -2.5},
  {'Mo': 3.0, 'Os': 1.0, 'P': -2.0},
  {'Mo': 2.0, 'Os': 1.0, 'P': -1.5}),
 ({'Mo': (4,), 'Os': (2,), 'P': (-3, -3)},
  {'Mo': (2,), 'Os': (4,), 'P': (-3, -3)},
  {'Mo': (3,), 'Os': (3,), 'P': (-3, -3)},
  {'Mo': (3,), 'Os': (2,), 'P': (-2, -3)},
  {'Mo': (2,), 'Os': (3,), 'P': (-2, -3)},
  {'Mo': (5,), 'Os': (1,), 'P': (-3, -3)},
  {'Mo': (2,), 'Os': (2,), 'P': (-2, -2)},
  {'Mo': (4,), 'Os': (1,), 'P': (-2, -3)},
  {'Mo': (3,), 'Os': (1,), 'P': (-2, -2)},
  {'Mo': (2,), 'Os': (1,), 'P': (-1, -2)}),
 (0.2527355938996254,
  0.2366199005636795,
  0.22645683341162362,
  0.19027178002738218,
  0.14926245381428271,
  0.131869128660858,
  0.1254121253513949,
  0.11432177500060454,
  0.086067052486061