In [None]:
#!/usr/bin/env python3
import copy
import numpy as np
import matplotlib.pyplot as plt

from mp_api.client import MPRester
from pymatgen.core import Composition, Element
from pymatgen.analysis import pourbaix_diagram
from pymatgen.analysis.pourbaix_diagram import PourbaixDiagram, PourbaixPlotter, PourbaixEntry

# ============================  USER INPUT  =======================================================
elements = input("Enter elements (comma-separated): ").split(",")
T = float(input("Enter temperature in K: "))
conc = float(input("Enter concentration in mol/L (default = 1e-6): ") or "1e-6")
MP_API_KEY = "jPRD4Dz76GJOi0oagYb7Vn9OOp6HC33v"  # or read from env

# ============================  CONSTANTS  ========================================================
R = 8.314462618  # J/mol/K
F = 96485.33212  # C/mol
KJMOL_PER_EV = 96.485
T_REF = 298.15

def is_ref(TK: float) -> bool:
    return abs(TK - T_REF) < 5e-4

# By default we will do full T-dependence
pourbaix_diagram.PREFAC = 0.0591 if is_ref(T) else (2.303 * R * T / F)

# ============================  NUMERIC HELPERS  ==================================================
def cp_poly(A,B=0,C=0,D=0,E=0):
    def _cp(T):
        t = max(float(T), 1e-6)
        return A + B*t + C*t*t + D*t**3 + E/(t*t)
    return _cp

def H_increment_kJmol(T_to, cp_func, Tref=T_REF, n=2001):
    if abs(T_to - Tref) < 1e-12:
        return 0.0
    a, b = (Tref, T_to) if T_to > Tref else (T_to, Tref)
    grid = np.linspace(a, b, n)
    cp = np.array([cp_func(float(t)) for t in grid])  # J/mol/K
    dH = np.trapz(cp, grid) / 1000.0                  # kJ/mol
    return dH if T_to > Tref else -dH

def S_at_T_kJmolK(T_to, S298_kJmolK, cp_func, Tref=T_REF, n=2001):
    if abs(T_to - Tref) < 1e-12:
        return S298_kJmolK
    a, b = (Tref, T_to) if T_to > Tref else (T_to, Tref)
    grid = np.linspace(a, b, n)
    cp_over_T = np.array([cp_func(float(t))/float(t) for t in grid])  # J/mol/K^2
    dS = np.trapz(cp_over_T, grid) / 1000.0
    return (S298_kJmolK + dS) if T_to > Tref else (S298_kJmolK - dS)

# ============================  STANDARD STATES / THERMO DBs  ====================================
STD_STATE = {"H": ("H2", 2), "N": ("N2", 2), "O": ("O2", 2),
             "F": ("F2", 2), "Cl": ("Cl2", 2), "Br": ("Br2", 2), "I": ("I2", 2)}
def standard_species_for_element(el: Element):
    return STD_STATE.get(el.symbol, (el.symbol, 1))

element_db = {
    "Al": {"S298": 0.02832, "cp": cp_poly(24.2)},
    "Si": {"S298": 0.01881, "cp": cp_poly(22.8)},
    "O2": {"S298": 0.20515, "cp": cp_poly(29.4, 1.0e-2)},
    "H2": {"S298": 0.13068, "cp": cp_poly(28.8)},
    "N2": {"S298": 0.19161, "cp": cp_poly(29.1, 0.5e-2)},
    "Cl2": {"S298": 0.22308, "cp": cp_poly(33.9)},
    "W":  {"S298": 0.032,   "cp": cp_poly(24.0)},
    "Hf": {"S298": 0.0436,   "cp": cp_poly(25.7)},  # 
    "Ti": {"S298": 0.031,   "cp": cp_poly(25.0)},  # 
    "Ni": {"S298": 0.0299,   "cp": cp_poly(25.99)},  # 
    "Cu": {"S298": 0.033,   "cp": cp_poly(24.0)},  #
    "Fe": {"S298": 0.033,   "cp": cp_poly(24.0)},  
}

compound_db = {
    "Al2O3": {"dHf298": -1675.7, "S298": 0.0509, "cp": cp_poly(79.0, 3.0e-2)},
    "SiO2":  {"dHf298":  -910.9, "S298": 0.0415, "cp": cp_poly(45.0, 2.0e-2)},
    "Al":    {"dHf298":     0.0, "S298": 0.02832,"cp": cp_poly(24.2)},
    "Si":    {"dHf298":     0.0, "S298": 0.01881,"cp": cp_poly(22.8)},
    "H2O":   {"dHf298":  -285.83, "S298": 0.06995, "cp": cp_poly(75.3)},
    "WO3":   {"dHf298":  -842.9,  "S298": 0.076,   "cp": cp_poly(73.3)},
    "WO2":   {"dHf298":  -589.7,  "S298": 0.051,   "cp": cp_poly(55.74)},
    "W":     {"dHf298":     0.0,  "S298": 0.032,   "cp": cp_poly(24.0)},
}

species_db = {
    "H2O(l)": {"S298": 0.06995, "cp": cp_poly(75.3)},
    "H+(aq)": {"S298": 0.0,      "cp": cp_poly(0.0)},
    "e-":     {"S298": 0.0,      "cp": cp_poly(0.0)},
}

ion_db = {
    "AlO2": {"S298": -0.029, "cp": cp_poly(-33.41)},  
    "Al": {"S298": -0.338, "cp": cp_poly(-122.65)}, 
    "Al(HO)4": {"S298": 0.117, "cp": cp_poly(-160)},
    # add explicit ion thermo if available; else Criss–Cobble is used by your ion updater
}

ion_charge = {"WO4":2, "HWO4":1, "W2O7":2, "Al(HO)4":1, "AlOH":1, "Al":3, "SiO4":4, "Si(HO)4":0,
              "SiH7O6": 1,
              
              }

# ============================  CRISS–COBBLE support  ============================================
CAL_TO_KJ = 4.184e-3
_CC_TABLE = {
    "oxyanion":      {"T_C":[60,100,150,200], "alpha":[-127,-138,-133,-145], "beta":[1.96,2.24,2.27,2.53]},
    "acid_oxyanion":{"T_C":[60,100,150,200], "alpha":[-122,-135,-143,-152], "beta":[3.44,3.97,3.95,4.24]},
}
def _cc_alpha_beta(category: str, T2_K: float):
    T2_C = T2_K - 273.15
    tbl = _CC_TABLE[category]
    return float(np.interp(T2_C, tbl["T_C"], tbl["alpha"])), float(np.interp(T2_C, tbl["T_C"], tbl["beta"]))
def _cc_S298_cal(Z: int, n_oxygen: int):
    return 182.0 - 195.0 * (Z - 0.28 * n_oxygen)

# ============================  ΔfH°(T) and ΔfG°(T)  =============================================
def elemental_terms_from_formula(comp: Composition):
    terms_by_species = {}
    for el, amt in comp.get_el_amt_dict().items():
        sp, atoms_per_sp = standard_species_for_element(Element(el))
        nu = amt / atoms_per_sp
        terms_by_species[sp] = terms_by_species.get(sp, 0.0) + nu
    terms = []
    for sp, nu in terms_by_species.items():
        if sp not in element_db:
            raise ValueError(f"Missing elemental thermo for '{sp}'")
        ed = element_db[sp]
        terms.append({"nu": nu, "cp": ed["cp"], "S298": ed["S298"]})
    return terms

def delta_f_H_T_kJmol(TK, dHf298_kJmol, cp_compound, elem_terms):
    dH_comp = H_increment_kJmol(TK, cp_compound)
    dH_elems = sum(e["nu"] * H_increment_kJmol(TK, e["cp"]) for e in elem_terms)
    return dHf298_kJmol + dH_comp - dH_elems

def delta_f_G_T(TK, dHf298_kJmol, cp_compound, S298_comp_kJmolK, elem_terms):
    dHf_T = delta_f_H_T_kJmol(TK, dHf298_kJmol, cp_compound, elem_terms)
    S_comp_T = S_at_T_kJmolK(TK, S298_comp_kJmolK, cp_compound)
    S_elems_T = sum(e["nu"] * S_at_T_kJmolK(TK, e["S298"], e["cp"]) for e in elem_terms)
    return dHf_T - TK * (S_comp_T - S_elems_T)

def solid_dG_kJ_per_fu(TK, rf):
    thermo = compound_db.get(rf)
    if thermo is None:
        return 0.0
    elem_terms = elemental_terms_from_formula(Composition(rf))
    G298 = delta_f_G_T(T_REF, thermo["dHf298"], thermo["cp"], thermo["S298"], elem_terms)
    GT   = delta_f_G_T(TK,    thermo["dHf298"], thermo["cp"], thermo["S298"], elem_terms)
    return GT - G298

# ============================  μ_H2O(T)  =========================================================
MU_H2O_298_eV = -2.4583
def mu_H2O_T_eV(TK: float) -> float:
    if is_ref(TK):
        return MU_H2O_298_eV
    elem_terms = elemental_terms_from_formula(Composition("H2O"))
    th = compound_db["H2O"]
    dG_kJ = delta_f_G_T(TK, th["dHf298"], th["cp"], th["S298"], elem_terms) \
          - delta_f_G_T(T_REF, th["dHf298"], th["cp"], th["S298"], elem_terms)
    return MU_H2O_298_eV + dG_kJ / KJMOL_PER_EV

pourbaix_diagram.MU_H2O = MU_H2O_298_eV if is_ref(T) else mu_H2O_T_eV(T)
print(f"\nSet PREFAC = {pourbaix_diagram.PREFAC:.5f} V and MU_H2O = {pourbaix_diagram.MU_H2O:.6f} eV/H2O at T = {T} K")

# ============================  (YOUR) ION UPDATE — UNCHANGED  ===================================
DEBUG_IONS = False
_debug_seen_ions = set()
def _debug_ion_once(key, msg):
    if not DEBUG_IONS or key in _debug_seen_ions:
        return
    print(msg); _debug_seen_ions.add(key)

def avg_cp_kJmolK(cp_func, T2, T1=T_REF, n=1501):
    if abs(T2 - T1) < 1e-12:
        return 0.0
    return H_increment_kJmol(T2, cp_func, Tref=T1, n=n) / (T2 - T1)

def ion_key_from_entry(e):
    return e.composition.reduced_formula.replace(" ", "")

def metals_from_entry(e):
    metals = {}
    for el, amt in e.composition.reduced_composition.get_el_amt_dict().items():
        if el not in ("H", "O") and amt > 0:
            metals[el] = metals.get(el, 0.0) + amt
    if not metals:
        raise ValueError(f"No metal found in ion {e.composition}")
    return metals

def update_ion_energy(entry, T2):
    key = ion_key_from_entry(entry)
    ion_thermo = ion_db.get(key)

    nH2O_reac = max(-entry.nH2O, 0.0); nH2O_prod = max(entry.nH2O, 0.0)
    nHp_reac  = max(-entry.npH,  0.0); nHp_prod  = max(entry.npH,   0.0)
    ne_reac   = max(-entry.nPhi, 0.0); ne_prod   = max(entry.nPhi,  0.0)
    metals = metals_from_entry(entry)

    if ion_thermo is not None and callable(ion_thermo.get("cp", None)):
        S_reac = ion_thermo["S298"] + nHp_reac*species_db["H+(aq)"]["S298"] + ne_reac*species_db["e-"]["S298"] + nH2O_reac*species_db["H2O(l)"]["S298"]
        S_prod = sum(metals[m]*element_db[m]["S298"] for m in metals) + nHp_prod*species_db["H+(aq)"]["S298"] + ne_prod*species_db["e-"]["S298"] + nH2O_prod*species_db["H2O(l)"]["S298"]
        dS298 = S_prod - S_reac
        Cp_reac = avg_cp_kJmolK(ion_thermo["cp"], T2) + nHp_reac*0.0 + ne_reac*0.0 + nH2O_reac*avg_cp_kJmolK(species_db["H2O(l)"]["cp"], T2)
        Cp_prod = sum(metals[m]*avg_cp_kJmolK(element_db[m]["cp"], T2) for m in metals) + nHp_prod*0.0 + ne_prod*0.0 + nH2O_prod*avg_cp_kJmolK(species_db["H2O(l)"]["cp"], T2)
        dCp_avg = Cp_prod - Cp_reac
    else:
        comp = entry.composition.reduced_composition
        n_O = int(round(comp.get_el_amt_dict().get("O", 0.0)))
        category = "acid_oxyanion" if "H" in comp.get_el_amt_dict() else "oxyanion"
        Z = ion_charge.get(key)
        if Z is None:
            _debug_ion_once(key, f"[ION DEBUG] {key}: no ion_db & no charge → skip correction")
            return entry.energy
        S298_cal = _cc_S298_cal(Z, n_O); S298_kJ = S298_cal * CAL_TO_KJ
        S_reac = S298_kJ + nHp_reac*species_db["H+(aq)"]["S298"] + ne_reac*species_db["e-"]["S298"] + nH2O_reac*species_db["H2O(l)"]["S298"]
        S_prod = sum(metals[m]*element_db[m]["S298"] for m in metals) + nHp_prod*species_db["H+(aq)"]["S298"] + ne_prod*species_db["e-"]["S298"] + nH2O_prod*species_db["H2O(l)"]["S298"]
        dS298 = S_prod - S_reac
        a_cal, b = _cc_alpha_beta(category, T2)
        Cpbar_ion_kJ = (a_cal + b * S298_cal) * CAL_TO_KJ
        Cp_reac = Cpbar_ion_kJ + nHp_reac*0.0 + ne_reac*0.0 + nH2O_reac*avg_cp_kJmolK(species_db["H2O(l)"]["cp"], T2)
        Cp_prod = sum(metals[m]*avg_cp_kJmolK(element_db[m]["cp"], T2) for m in metals) + nHp_prod*0.0 + ne_prod*0.0 + nH2O_prod*avg_cp_kJmolK(species_db["H2O(l)"]["cp"], T2)
        dCp_avg = Cp_prod - Cp_reac

    dT = T2 - T_REF
    if T2 <= 0.0: raise ValueError("Temperature must be > 0 K.")
    dG_kJmol = dCp_avg*dT - dS298*dT - T2*dCp_avg*np.log(T2/T_REF)
    return entry.energy + dG_kJmol / KJMOL_PER_EV  # new absolute energy (eV)

# ============================  FETCH ENTRIES  ====================================================
with MPRester(MP_API_KEY) as mpr:
    entries = mpr.get_pourbaix_entries(elements)

# ============================  ION CONCENTRATIONS  ===============================================
solution_elements = set()
for entry in entries:
    if isinstance(entry, PourbaixEntry) and entry.phase_type == "Ion":
        solution_elements.update(el.symbol for el in entry.composition.elements)
conc_dict = {el: conc for el in solution_elements}
print(f"Using [C] = {conc:.1e} mol/L for ions: {', '.join(conc_dict) if conc_dict else '(none)'}")

# ============================  NEW: Option B — lock the stable SET to the 298 K diagram ==========
def ref_key(e: PourbaixEntry):
    """Robust key to match entries across T: use type + reduced formula + stoich tuple + entry_id fallback."""
    return (
        getattr(e, "phase_type", None),
        e.composition.reduced_formula,
        float(getattr(e, "npH", 0.0)),
        float(getattr(e, "nPhi", 0.0)),
        float(getattr(e, "nH2O", 0.0)),
        getattr(e, "entry_id", None),
    )

# Temporarily force reference constants to build the 298.15 K diagram
_prefac_saved = pourbaix_diagram.PREFAC
_mu_h2o_saved = pourbaix_diagram.MU_H2O
pourbaix_diagram.PREFAC = 0.0591
pourbaix_diagram.MU_H2O = MU_H2O_298_eV
pbx_ref = PourbaixDiagram(entries, conc_dict=conc_dict)
keep_keys = {ref_key(se) for se in pbx_ref.stable_entries}
# Restore T-dependent constants for the real (higher-T) build
pourbaix_diagram.PREFAC = _prefac_saved
pourbaix_diagram.MU_H2O = _mu_h2o_saved

# Filter the original entries to ONLY those that were stable at 298 K
entries_refset = [e for e in entries if ref_key(e) in keep_keys]

# ============================  APPLY FULL T-SHIFTS, but only on the ref-stable set ===============
if is_ref(T):
    entries_all_fixed = [copy.deepcopy(e) for e in entries_refset]
else:
    def total_E_pH0_V0(entry_obj):
        return (entry_obj.uncorrected_energy
                + pourbaix_diagram.PREFAC * np.log10(entry_obj.concentration)
                - pourbaix_diagram.MU_H2O * entry_obj.nH2O)

    # Solids first
    entries_shifted_total = []
    for e in entries_refset:
        if e.phase_type == "Solid":
            rf = e.composition.reduced_formula
            dE_eV = solid_dG_kJ_per_fu(T, rf) / KJMOL_PER_EV
            E_total_target = e.energy + dE_eV
            e2 = copy.deepcopy(e)
            conc_term = pourbaix_diagram.PREFAC * np.log10(e2.concentration)
            e2.uncorrected_energy = E_total_target - conc_term + pourbaix_diagram.MU_H2O * e2.nH2O
            entries_shifted_total.append(e2)
        else:
            entries_shifted_total.append(copy.deepcopy(e))

    # Then ions via your updater
    entries_all_fixed = []
    for e in entries_shifted_total:
        if e.phase_type == "Ion":
            new_abs_E = update_ion_energy(e, T)
            dE_eV = new_abs_E - e.energy
            E_total_target = e.energy + dE_eV
            e2 = copy.deepcopy(e)
            conc_term = pourbaix_diagram.PREFAC * np.log10(e2.concentration)
            e2.uncorrected_energy = E_total_target - conc_term + pourbaix_diagram.MU_H2O * e2.nH2O
            entries_all_fixed.append(e2)
        else:
            entries_all_fixed.append(e)

# ============================  BUILD & PLOT  =====================================================
pbx = PourbaixDiagram(entries_all_fixed, conc_dict=conc_dict)
plotter = PourbaixPlotter(pbx)
ax = plotter.get_pourbaix_plot()
title_S = "(same stable set as 298 K; T-shifted μ_H2O, solids, ions)"
ax.set_title(f"Pourbaix Diagram at {T:.2f} K, [C] = {conc:g} mol/L {title_S}")
plt.tight_layout()
plt.show()
