In [None]:
import copy
import numpy as np
import matplotlib.pyplot as plt

from mp_api.client import MPRester
from pymatgen.core import Composition, Element
from pymatgen.analysis import pourbaix_diagram
from pymatgen.analysis.pourbaix_diagram import PourbaixDiagram, PourbaixPlotter, PourbaixEntry

# ============================
# ---- USER INPUT ------------
# ============================
elements = input("Enter elements (comma-separated): ").split(",")
T = float(input("Enter temperature in K: "))
conc = float(input("Enter concentration in mol/L (default = 1e-6): ") or "1e-6")
MP_API_KEY = "jPRD4Dz76GJOi0oagYb7Vn9OOp6HC33v"  # or read from env

# ============================
# ---- CONSTANTS -------------
# ============================
R = 8.314462618  # J/mol/K
F = 96485.33212  # C/mol
KJMOL_PER_EV = 96.485
T1 = 298.15  # reference temp for entries (K)

# Patch Nernst slope for T
PREFAC_T = 2.303 * R * T / F
pourbaix_diagram.PREFAC = PREFAC_T
print(f"\nSet PREFAC = {PREFAC_T:.5f} V for T = {T} K")

# ============================
# ---- DEBUG SWITCH ----------
# ============================
DEBUG_IONS = True  # set False to silence once verified
_debug_seen_ions = set()
def _debug_ion_once(key, msg):
    if not DEBUG_IONS:
        return
    if key in _debug_seen_ions:
        return
    print(msg)
    _debug_seen_ions.add(key)

# ============================
# ---- NUMERIC HELPERS -------
# ============================
def trapz_cp_and_cp_over_T(T_to, cp_func, Tref=298.15, n=2001):
    """Return (ΔH, ΔS) from Tref to T_to using Cp(T) [J/mol/K].
       ΔH in kJ/mol, ΔS in kJ/mol/K. Works for T_to above or below Tref."""
    if abs(T_to - Tref) < 1e-12:
        return 0.0, 0.0
    a, b = (Tref, T_to) if T_to > Tref else (T_to, Tref)
    grid = np.linspace(a, b, n)
    cp = np.array([cp_func(float(t)) for t in grid], dtype=float)  # J/mol/K
    dH = np.trapz(cp, grid) / 1000.0          # kJ/mol
    dS = np.trapz(cp / grid, grid) / 1000.0   # kJ/mol/K
    sign = 1.0 if T_to > Tref else -1.0
    return sign * dH, sign * dS

def avg_cp_kJmolK(cp_func, T2, T1=T1, n=1501):
    """Average Cp between T1 and T2 in kJ/mol/K."""
    if abs(T2 - T1) < 1e-12:
        return 0.0
    dH_kJ, _ = trapz_cp_and_cp_over_T(T2, cp_func, Tref=T1, n=n)
    return dH_kJ / (T2 - T1)

def cp_poly(A,B=0,C=0,D=0,E=0):
    """Generic Cp(T) [J/mol/K]: A + B*T + C*T^2 + D*T^3 + E/T^2."""
    def _cp(T):
        T_eff = max(float(T), 1e-6)  # avoid division by zero
        return A + B*T_eff + C*T_eff*T_eff + D*T_eff**3 + E/(T_eff*T_eff)
    return _cp

# ============================
# ---- STANDARD STATES MAP ---
# ============================
STD_STATE = {
    "H": ("H2", 2), "N": ("N2", 2), "O": ("O2", 2),
    "F": ("F2", 2), "Cl": ("Cl2", 2), "Br": ("Br2", 2), "I": ("I2", 2),
}
def standard_species_for_element(el: Element):
    return STD_STATE.get(el.symbol, (el.symbol, 1))

# ============================
# ---- THERMO REGISTRIES -----
# ============================
# Elemental standard states (S°298 in kJ/mol/K; Cp(T) in J/mol/K)
# *** REPLACE placeholders with real data for your system ***
element_db = {
    "Al": {"S298": 0.02832, "cp": cp_poly(24.2)},            # examples
    "Si": {"S298": 0.01881, "cp": cp_poly(22.8)},
    "O2": {"S298": 0.20515, "cp": cp_poly(29.4, 1.0e-2)},
    "H2": {"S298": 0.13068, "cp": cp_poly(28.8)},
    "N2": {"S298": 0.19161, "cp": cp_poly(29.1, 0.5e-2)},
    "Cl2": {"S298": 0.22308, "cp": cp_poly(33.9)},
    # W–O–H placeholders
    "W":  {"S298": 0.032, "cp": cp_poly(24.0)},
}

# Solids to temperature-correct (by reduced formula)
# Provide ΔHf°(298) [kJ/mol-fu], S°(298) [kJ/mol/K], Cp(T)
compound_db = {
    "Al2O3": {"dHf298": -1675.7, "S298": 0.0509, "cp": cp_poly(79.0, 3.0e-2)},
    "SiO2":  {"dHf298": -910.9,  "S298": 0.0415, "cp": cp_poly(45.0, 2.0e-2)},
    "Al":    {"dHf298": 0.0,     "S298": 0.02832,"cp": cp_poly(24.2)},
    "Si":   {"dHf298": 0.0,     "S298": 0.01881,"cp": cp_poly(22.8)},
    # tungsten system (placeholders – put your trusted numbers if you have them)
    "WO3":   {"dHf298": -842.9,  "S298": 0.076,  "cp": cp_poly(73.3)},
    "WO2":   {"dHf298": -589.7,  "S298": 0.051,  "cp": cp_poly(55.74)},
    "W":     {"dHf298": 0.0,     "S298": 0.032,  "cp": cp_poly(24.0)},
}

# Aqueous/liquid reference species used in ion reactions
species_db = {
    "H2O(l)": {"S298": 0.06995, "cp": cp_poly(75.3)},  # reasonable defaults
    "H+(aq)": {"S298": 0.0,      "cp": cp_poly(0.0)},  # IUPAC convention
    "e-":     {"S298": 0.0,      "cp": cp_poly(0.0)},
}

# Ions with full thermo (if you have measured S298 and Cp(T)); otherwise CC fallback kicks in
ion_db = {
    "Si(HO)4": {"S298": 0.189, "cp": cp_poly(66.04)},  # if you have real data, add to override CC 
    "AlO2": {"S298": -0.029, "cp": cp_poly(-33.41)},  # AlO2^- (if you have real data)
}

# Charge map for ions by reduced formula (absolute value); used by Criss–Cobble fallback
ion_charge = {
    "WO4": 2,    # WO4^2-
    "HWO4": 1,   # HWO4^-
    "W2O7": 2,   # W2O7^2-
    "H7SiO6":1, # H7SiO6^-
    "Si(HO)4": 0, # Si(HO)4 (neutral)
    "Al(HO)4": 1, # Al(HO)4^-
    "AlOH": 1,   # AlOH^+
    "Al": 3,    # Al3^3+
    "SiO4": 4,   # SiO4^4-
    "AlO2": 1,  # AlO2^-
    "SiH7O6": 1,  # SiH7O6^- (if you have real data)
    # add silicon oxyanions etc. as needed
}

# ============================
# ---- CRISS–COBBLE (ions) ---
# ============================
# Units in original paper are cal mol^-1 K^-1; convert with 1 cal = 4.184 J
CAL_TO_KJ = 4.184e-3

# α(t), β(t) vs temperature (°C), linear interpolation inside [60,200] °C
_CC_TABLE = {
    "oxyanion": {
        "T_C":  [60,   100,  150,  200],
        "alpha":[-127, -138, -133, -145],
        "beta": [1.96,  2.24, 2.27, 2.53],
    },
    "acid_oxyanion": {
        "T_C":  [60,   100,  150,  200],
        "alpha":[-122, -135, -143, -152],
        "beta": [3.44,  3.97, 3.95, 4.24],
    },
}

def _cc_alpha_beta(category: str, T2_K: float):
    """α, β (cal/mol/K; β dimensionless) at T2 by linear interp in °C.
       (Clamp outside table range; you can switch to extrapolation if desired.)"""
    T2_C = T2_K - 273.15
    tbl = _CC_TABLE[category]
    alpha = float(np.interp(T2_C, tbl["T_C"], tbl["alpha"]))
    beta  = float(np.interp(T2_C, tbl["T_C"], tbl["beta"]))
    return alpha, beta

def _cc_S298_cal(Z: int, n_oxygen: int):
    """Criss-Cobble oxyanion estimate: S298(cal/mol/K) = 182 − 195*(Z − 0.28*n_O)."""
    return 182.0 - 195.0 * (Z - 0.28 * n_oxygen)

# ============================
# ---- THERMO CORE FUNCS -----
# ============================
def elemental_terms_from_formula(comp: Composition):
    """Return list of {'nu','cp','S298'} for standard-state species inferred from composition."""
    terms_by_species = {}
    for el, amt in comp.get_el_amt_dict().items():
        sp, atoms_per_sp = standard_species_for_element(Element(el))
        nu = amt / atoms_per_sp
        terms_by_species[sp] = terms_by_species.get(sp, 0.0) + nu
    terms = []
    for sp, nu in terms_by_species.items():
        if sp not in element_db:
            raise ValueError(f"Missing elemental thermo for '{sp}' (add it to element_db)")
        ed = element_db[sp]
        terms.append({"nu": nu, "cp": ed["cp"], "S298": ed["S298"]})
    return terms

def delta_f_G_T(TK, dHf298_kJmol, cp_compound, S298_comp_kJmolK, elem_terms):
    """ΔfG°(T) (kJ/mol-fu) via Cp/S integrals and elemental terms."""
    dH_c, dS_c = trapz_cp_and_cp_over_T(TK, cp_compound, 298.15)
    S_comp_T = S298_comp_kJmolK + dS_c
    H_inc_e, S_e_T = 0.0, 0.0
    for e in elem_terms:
        dH_i, dS_i = trapz_cp_and_cp_over_T(TK, e["cp"], 298.15)
        H_inc_e += e["nu"] * dH_i
        S_e_T   += e["nu"] * (e["S298"] + dS_i)
    return dHf298_kJmol + dH_c - H_inc_e - TK*(S_comp_T - S_e_T)

def update_solid_energy(entry_energy_eV_fu, TK, rf):
    """Temperature-update a solid entry using ΔGf(T)-ΔGf(298) and add as a shift to MP energy."""
    thermo = compound_db.get(rf)
    if thermo is None:
        return entry_energy_eV_fu  # no data → leave as-is
    elem_terms = elemental_terms_from_formula(Composition(rf))
    G298 = delta_f_G_T(298.15, thermo["dHf298"], thermo["cp"], thermo["S298"], elem_terms)
    GT   = delta_f_G_T(TK,      thermo["dHf298"], thermo["cp"], thermo["S298"], elem_terms)
    dG = GT - G298  # kJ/mol-fu
    return entry_energy_eV_fu + dG / KJMOL_PER_EV

def ion_key_from_entry(e):
    # e.g., "Al1 O1 H1" -> "AlOH", "W1 O4" -> "WO4"
    return e.composition.reduced_formula.replace(" ", "")

def metals_from_entry(e):
    """Return dict {metal_symbol: stoich} for all non-H/O elements in the ion."""
    metals = {}
    for el, amt in e.composition.reduced_composition.get_el_amt_dict().items():
        if el not in ("H", "O") and amt > 0:
            metals[el] = metals.get(el, 0.0) + amt
    if not metals:
        raise ValueError(f"No metal found in ion {e.composition}")
    return metals

def update_ion_energy(entry, T2):
    """Temperature-shift for ions.
       Primary: use ion_db (S298 and Cp(T) callable).
       Fallback (oxyanions): Criss–Cobble to estimate S298 and Cp̄(298→T2)."""

    key = ion_key_from_entry(entry)   # e.g., "WO4", "HWO4"
    ion_thermo = ion_db.get(key)

    # Partition reaction sides from entry coefficients:
    nH2O_reac = max(-entry.nH2O, 0.0)
    nH2O_prod = max(entry.nH2O, 0.0)
    nHp_reac  = max(-entry.npH,  0.0)
    nHp_prod  = max(entry.npH,   0.0)
    ne_reac   = max(-entry.nPhi, 0.0)
    ne_prod   = max(entry.nPhi,  0.0)
    metals = metals_from_entry(entry)

    if ion_thermo is not None and callable(ion_thermo.get("cp", None)):
        # --- Database path (preferred) ---
        S_reac = ion_thermo["S298"] \
               + nHp_reac  * species_db["H+(aq)"]["S298"] \
               + ne_reac   * species_db["e-"]["S298"] \
               + nH2O_reac * species_db["H2O(l)"]["S298"]
        S_prod = sum(metals[m]*element_db[m]["S298"] for m in metals) \
               + nHp_prod  * species_db["H+(aq)"]["S298"] \
               + ne_prod   * species_db["e-"]["S298"] \
               + nH2O_prod * species_db["H2O(l)"]["S298"]
        dS298 = S_prod - S_reac

        Cp_reac = avg_cp_kJmolK(ion_thermo["cp"], T2) \
                + nHp_reac  * avg_cp_kJmolK(species_db["H+(aq)"]["cp"], T2) \
                + ne_reac   * avg_cp_kJmolK(species_db["e-"]["cp"], T2) \
                + nH2O_reac * avg_cp_kJmolK(species_db["H2O(l)"]["cp"], T2)
        Cp_prod = sum(metals[m]*avg_cp_kJmolK(element_db[m]["cp"], T2) for m in metals) \
                + nHp_prod  * avg_cp_kJmolK(species_db["H+(aq)"]["cp"], T2) \
                + ne_prod   * avg_cp_kJmolK(species_db["e-"]["cp"], T2) \
                + nH2O_prod * avg_cp_kJmolK(species_db["H2O(l)"]["cp"], T2)
        dCp_avg = Cp_prod - Cp_reac

        _debug_ion_once(
            key,
            f"[ION DEBUG] {key}: using ion_db\n"
            f"  T2 = {T2:.2f} K, ΔT = {T2 - T1:.2f} K\n"
            f"  S298(ion) = {ion_thermo['S298']:.6f} kJ/mol/K\n"
            f"  Cp̄(ion,298→T2) = {avg_cp_kJmolK(ion_thermo['cp'], T2):.6f} kJ/mol/K\n"
            f"  ΔS298(reaction) = {dS298:.6f} kJ/mol/K\n"
            f"  ΔCp̄(reaction)  = {dCp_avg:.6f} kJ/mol/K"
        )

    else:
        # --- Criss–Cobble fallback for oxyanions / acid oxyanions ---
        comp = entry.composition.reduced_composition
        n_O = int(round(comp.get_el_amt_dict().get("O", 0.0)))
        category = "acid_oxyanion" if "H" in comp.get_el_amt_dict() else "oxyanion"

        Z = ion_charge.get(key)
        if Z is None:
            _debug_ion_once(key, f"[ION DEBUG] {key}: no ion_db & no charge → skip correction")
            return entry.energy

        # S°(298) estimate for ion (cal → kJ)
        S298_cal = _cc_S298_cal(Z, n_O)
        S298_kJ  = S298_cal * CAL_TO_KJ

        S_reac = S298_kJ \
               + nHp_reac  * species_db["H+(aq)"]["S298"] \
               + ne_reac   * species_db["e-"]["S298"] \
               + nH2O_reac * species_db["H2O(l)"]["S298"]
        S_prod = sum(metals[m]*element_db[m]["S298"] for m in metals) \
               + nHp_prod  * species_db["H+(aq)"]["S298"] \
               + ne_prod   * species_db["e-"]["S298"] \
               + nH2O_prod * species_db["H2O(l)"]["S298"]
        dS298 = S_prod - S_reac

        a_cal, b = _cc_alpha_beta(category, T2)
        Cpbar_ion_kJ = (a_cal + b * S298_cal) * CAL_TO_KJ  # kJ/mol/K

        Cp_reac = Cpbar_ion_kJ \
                + nHp_reac  * avg_cp_kJmolK(species_db["H+(aq)"]["cp"], T2) \
                + ne_reac   * avg_cp_kJmolK(species_db["e-"]["cp"], T2) \
                + nH2O_reac * avg_cp_kJmolK(species_db["H2O(l)"]["cp"], T2)
        Cp_prod = sum(metals[m]*avg_cp_kJmolK(element_db[m]["cp"], T2) for m in metals) \
                + nHp_prod  * avg_cp_kJmolK(species_db["H+(aq)"]["cp"], T2) \
                + ne_prod   * avg_cp_kJmolK(species_db["e-"]["cp"], T2) \
                + nH2O_prod * avg_cp_kJmolK(species_db["H2O(l)"]["cp"], T2)
        dCp_avg = Cp_prod - Cp_reac

        _debug_ion_once(
            key,
            f"[ION DEBUG] {key}: Criss–Cobble fallback ({category})\n"
            f"  T2 = {T2:.2f} K (T_C = {T2-273.15:.2f} °C), ΔT = {T2 - T1:.2f} K\n"
            f"  Z = {Z}, n_O = {n_O}\n"
            f"  S298(ion) = {S298_cal:.2f} cal/mol/K = {S298_kJ:.6f} kJ/mol/K\n"
            f"  α(T_C) = {a_cal:.2f} cal/mol/K,  β(T_C) = {b:.2f}\n"
            f"  Cp̄(ion,298→T2) = {Cpbar_ion_kJ:.6f} kJ/mol/K\n"
            f"  ΔS298(reaction) = {dS298:.6f} kJ/mol/K\n"
            f"  ΔCp̄(reaction)  = {dCp_avg:.6f} kJ/mol/K"
        )

    # Final ΔG increment (kJ/mol → eV)
    dT = T2 - T1
    if T2 <= 0.0:
        raise ValueError("Temperature must be > 0 K for the ln(T2/T1) term.")
    dG_kJmol = dCp_avg * dT - dS298 * dT - T2 * dCp_avg * np.log(T2 / T1)
    return entry.energy + dG_kJmol / KJMOL_PER_EV

# ============================
# ---- FETCH ENTRIES ----------
# ============================
with MPRester(MP_API_KEY) as mpr:
    entries = mpr.get_pourbaix_entries(elements)

# ============================
# ---- ION CONCENTRATIONS ----
# ============================
solution_elements = set()
for entry in entries:
    if isinstance(entry, PourbaixEntry) and entry.phase_type == "Ion":
        solution_elements.update(el.symbol for el in entry.composition.elements)
conc_dict = {el: conc for el in solution_elements}
print(f"Using [C] = {conc:.1e} mol/L for ions: {', '.join(conc_dict) if conc_dict else '(none)'}")

# ============================
# ---- APPLY: SOLIDS ---------
# ============================
entries_solid_fixed = []
solid_changes = []
# print("\n[CHECK] Solid entries (rf, id, original energy eV/fu):")
# for e in entries:
#     if getattr(e, "phase_type", None) == "Solid":
#         print(f"  {e.composition.reduced_formula:6s}  {getattr(e, 'entry_id', '?'):>12s}  {e.energy: .6f}")

for e in entries:
    if getattr(e, "phase_type", None) == "Solid":
        rf = e.composition.reduced_formula
        e2 = copy.deepcopy(e)
        newE = update_solid_energy(e.energy, T, rf)
        if abs(newE - e.energy) > 1e-9:
            solid_changes.append((rf, e.energy, newE))
        # set BOTH fields so PBX uses/prints updated value
        e2._energy = newE
        e2._g0 = newE
        entries_solid_fixed.append(e2)
    else:
        entries_solid_fixed.append(e)

#print(f"\n[SOLID DEBUG] corrected {len(solid_changes)} solids:")
#for rf, Eo, En in solid_changes:
#    print(f"  {rf}: Δ = {En - Eo:+.6f} eV/fu (from {Eo:.6f} → {En:.6f})")

# ============================
# ---- APPLY: IONS -----------
# ============================
entries_all_fixed = []
for e in entries_solid_fixed:
    if getattr(e, "phase_type", None) == "Ion":
        e2 = copy.deepcopy(e)
        newE = update_ion_energy(e, T)
        e2._energy = newE
        #if hasattr(e2, "_g0"):
        e2._g0 = newE
        entries_all_fixed.append(e2)
    else:
        entries_all_fixed.append(e)

# print("\n[FINAL ENERGIES FED TO DIAGRAM]")
# for e in entries_all_fixed:
#     print(f"{e.phase_type:5s} {e.composition.reduced_formula:6s}  energy={e.energy: .6f}  g0={getattr(e,'_g0', None)}  id={getattr(e,'entry_id','?')}")

# ============================
# ---- BUILD & PLOT ----------
# ============================
pbx = PourbaixDiagram(entries_all_fixed, conc_dict=conc_dict)
# print("\n[STABLE ENTRIES SEEN BY PBX]")
# for se in pbx.stable_entries:
#     print(se)

plotter = PourbaixPlotter(pbx)
ax = plotter.get_pourbaix_plot()
ax.set_title(f"Pourbaix Diagram at {T} K, [C] = {conc} mol/L\n(solids & ions T-corrected; CC fallback for oxyanions)")
plt.tight_layout()
plt.show()
