In [None]:
import copy
import numpy as np
import matplotlib.pyplot as plt

from mp_api.client import MPRester
from pymatgen.core import Composition, Element
from pymatgen.analysis import pourbaix_diagram
from pymatgen.analysis.pourbaix_diagram import PourbaixDiagram, PourbaixPlotter, PourbaixEntry

# ============================
# ---- USER INPUT ------------
# ============================
elements = input("Enter elements (comma-separated): ").split(",")
T = float(input("Enter temperature in K: "))
conc = float(input("Enter concentration in mol/L (default = 1e-6): ") or "1e-6")
MP_API_KEY = "jPRD4Dz76GJOi0oagYb7Vn9OOp6HC33v"  # or read from env

# ============================
# ---- CONSTANTS -------------
# ============================
R = 8.314462618  # J/mol/K
F = 96485.33212  # C/mol
KJMOL_PER_EV = 96.485
T1 = 298.15  # reference temp for entries

# Patch Nernst slope for T
PREFAC_T = 2.303 * R * T / F
pourbaix_diagram.PREFAC = PREFAC_T
print(f"\nSet PREFAC = {PREFAC_T:.5f} V for T = {T} K")

# ============================
# ---- NUMERIC HELPERS -------
# ============================
def trapz_cp_and_cp_over_T(T_to, cp_func, Tref=298.15, n=2001):
    """Return (ΔH, ΔS) from Tref to T_to using Cp(T) [J/mol/K].
       ΔH in kJ/mol, ΔS in kJ/mol/K."""
    if abs(T_to - Tref) < 1e-12:
        return 0.0, 0.0
    grid = np.linspace(Tref, T_to, n)
    cp = np.array([cp_func(t) for t in grid])  # J/mol/K
    dH = np.trapz(cp, grid) / 1000.0          # kJ/mol
    dS = np.trapz(cp / grid, grid) / 1000.0   # kJ/mol/K
    return dH, dS

def avg_cp_kJmolK(cp_func, T2, T1=T1, n=1501):
    """Average Cp between T1 and T2 in kJ/mol/K."""
    if abs(T2 - T1) < 1e-12:
        return 0.0
    dH_kJ, _ = trapz_cp_and_cp_over_T(T2, cp_func, Tref=T1, n=n)
    return dH_kJ / (T2 - T1)

def cp_poly(A,B=0,C=0,D=0,E=0):
    """Generic Cp(T) [J/mol/K]: A + B*T + C*T^2 + D*T^3 + E/T^2"""
    def _cp(T):
        return A + B*T + C*T*T + D*T*T*T + E/(T*T)
    return _cp

# ============================
# ---- STANDARD STATES MAP ---
# ============================
# Map element -> (standard species string, atoms per species)
STD_STATE = {
    "H": ("H2", 2), "N": ("N2", 2), "O": ("O2", 2),
    "F": ("F2", 2), "Cl": ("Cl2", 2), "Br": ("Br2", 2), "I": ("I2", 2),
}
def standard_species_for_element(el: Element):
    return STD_STATE.get(el.symbol, (el.symbol, 1))

# ============================
# ---- THERMO REGISTRIES -----
# ============================
# Elemental standard states (S°298 in kJ/mol/K; Cp(T) in J/mol/K)
# *** REPLACE placeholders with real data ***
element_db = {
    "Al": {"S298": 0.02832, "cp": cp_poly(24.2)},            # Al(s) placeholder
    "Si": {"S298": 0.01881, "cp": cp_poly(22.8)},            # Si(s) placeholder
    "O2": {"S298": 0.20515, "cp": cp_poly(29.4, 1.0e-2)},    # O2(g) placeholder
    "H2": {"S298": 0.13068, "cp": cp_poly(28.8)},            # H2(g) placeholder
    "N2": {"S298": 0.19161, "cp": cp_poly(29.1, 0.5e-2)},    # N2(g) placeholder
    "Cl2": {"S298": 0.22308, "cp": cp_poly(33.9)},           # Cl2(g) placeholder
}

# Solids to temperature-correct (by reduced formula)
# Provide ΔHf°(298) [kJ/mol-fu], S°(298) [kJ/mol/K], Cp(T), and
# elemental standard-state stoichiometry used for formation if you want to override default builder.
compound_db = {
    # EXAMPLES (placeholders)
    "Al2O3": {
        "dHf298": -1675.7, "S298": 0.0509, "cp": cp_poly(79.0, 3.0e-2),
        # omit "elements_standard" to auto-infer from formula
    },
    "SiO2": {
        "dHf298": -910.9,  "S298": 0.0415, "cp": cp_poly(45.0, 2.0e-2),
    },
    "Al": {
        "dHf298": 0.0, "S298": 0.02832, "cp": cp_poly(24.2),
    },
}

# Aqueous/liquid reference species used in ion reactions
species_db = {
    "H2O(l)": {"S298": 0.06995, "cp": cp_poly(75.3)},  # placeholders
    "H+(aq)": {"S298": 0.0,      "cp": cp_poly(0.0)},  # IUPAC convention
    "e-":     {"S298": 0.0,      "cp": cp_poly(0.0)},
}

# Ions to temperature-correct (by reduced formula string like "AlOH", "Al3", ...)
# Provide S°(298) [kJ/mol/K] and Cp(T) [J/mol/K]
ion_db = {
    # EXAMPLE (placeholder)
    "AlOH": {"S298": 0.010, "cp": cp_poly(90.0)},
}

# ============================
# ---- THERMO CORE FUNCS -----
# ============================
def elemental_terms_from_formula(comp: Composition):
    """Return list of {'nu','cp','S298'} for standard-state species inferred from composition."""
    terms_by_species = {}
    for el, amt in comp.get_el_amt_dict().items():
        sp, atoms_per_sp = standard_species_for_element(Element(el))
        nu = amt / atoms_per_sp
        terms_by_species[sp] = terms_by_species.get(sp, 0.0) + nu
    terms = []
    for sp, nu in terms_by_species.items():
        if sp not in element_db:
            raise ValueError(f"Missing elemental thermo for {sp}")
        ed = element_db[sp]
        terms.append({"nu": nu, "cp": ed["cp"], "S298": ed["S298"]})
    return terms

def delta_f_G_T(TK, dHf298_kJmol, cp_compound, S298_comp_kJmolK, elem_terms):
    """ΔfG°(T) (kJ/mol-fu) via Cp/S integrals and elemental terms."""
    dH_c, dS_c = trapz_cp_and_cp_over_T(TK, cp_compound, 298.15)
    S_comp_T = S298_comp_kJmolK + dS_c
    H_inc_e, S_e_T = 0.0, 0.0
    for e in elem_terms:
        dH_i, dS_i = trapz_cp_and_cp_over_T(TK, e["cp"], 298.15)
        H_inc_e += e["nu"] * dH_i
        S_e_T   += e["nu"] * (e["S298"] + dS_i)
    return dHf298_kJmol + dH_c - H_inc_e - TK*(S_comp_T - S_e_T)

def update_solid_energy(entry_energy_eV_fu, TK, rf):
    """Temperature-update a solid entry using ΔGf(T)-ΔGf(298)."""
    thermo = compound_db.get(rf)
    if thermo is None:
        return entry_energy_eV_fu  # no data → leave as-is
    elem_terms = elemental_terms_from_formula(Composition(rf))
    G298 = delta_f_G_T(298.15, thermo["dHf298"], thermo["cp"], thermo["S298"], elem_terms)
    GT   = delta_f_G_T(TK,      thermo["dHf298"], thermo["cp"], thermo["S298"], elem_terms)
    dG = GT - G298  # kJ/mol-fu
    return entry_energy_eV_fu + dG / KJMOL_PER_EV

def ion_key_from_entry(e):
    # e.g., "Al1 O1 H1" -> "AlOH"
    return e.composition.reduced_formula.replace(" ", "")

def base_metal_from_entry(e):
    """Pick the metal in an ion like AlOH (assumes single metal)."""
    for el, amt in e.composition.reduced_composition.get_el_amt_dict().items():
        if el not in ("H", "O") and amt > 0:
            return el, amt
    raise ValueError(f"Could not identify base metal in ion {e.composition}")

def update_ion_energy(entry, T2):
    """Apply ΔG(T2) update to an ion entry using provided formula."""
    key = ion_key_from_entry(entry)
    ion_thermo = ion_db.get(key)
    if ion_thermo is None:
        return entry.energy  # no data → leave as-is

    metal_sym, metal_count = base_metal_from_entry(entry)

    # ΔS°(298) for Ion + (-npH)H+ + (-nPhi)e- -> metal(s) + nH2O H2O(l)
    S_prod = metal_count * element_db[metal_sym]["S298"] + entry.nH2O * species_db["H2O(l)"]["S298"]
    S_reac = ion_thermo["S298"]
    if entry.npH < 0:
        S_reac += (-entry.npH) * species_db["H+(aq)"]["S298"]
    if entry.nPhi < 0:
        S_reac += (-entry.nPhi) * species_db["e-"]["S298"]
    dS298 = S_prod - S_reac  # kJ/mol/K

    # average ΔCp° over [T1, T2]
    Cp_prod = (
        metal_count * avg_cp_kJmolK(element_db[metal_sym]["cp"], T2)
        + entry.nH2O * avg_cp_kJmolK(species_db["H2O(l)"]["cp"], T2)
    )
    Cp_reac = avg_cp_kJmolK(ion_thermo["cp"], T2)
    if entry.npH < 0:
        Cp_reac += (-entry.npH) * avg_cp_kJmolK(species_db["H+(aq)"]["cp"], T2)
    if entry.nPhi < 0:
        Cp_reac += (-entry.nPhi) * avg_cp_kJmolK(species_db["e-"]["cp"], T2)
    dCp_avg = Cp_prod - Cp_reac  # kJ/mol/K

    dT = T2 - T1
    dG_kJmol = dCp_avg * dT - dS298 * dT - T2 * dCp_avg * np.log(T2 / T1)
    return entry.energy + dG_kJmol / KJMOL_PER_EV  # eV/fu

# ============================
# ---- FETCH ENTRIES ----------
# ============================
with MPRester(MP_API_KEY) as mpr:
    entries = mpr.get_pourbaix_entries(elements)

# ============================
# ---- ION CONCENTRATIONS ----
# ============================
solution_elements = set()
for entry in entries:
    if isinstance(entry, PourbaixEntry) and entry.phase_type == "Ion":
        solution_elements.update(el.symbol for el in entry.composition.elements)
conc_dict = {el: conc for el in solution_elements}
print(f"Using [C] = {conc:.1e} mol/L for ions: {', '.join(conc_dict)}")

# ============================
# ---- APPLY: SOLIDS ---------
# ============================
entries_solid_fixed = []
for e in entries:
    if getattr(e, "phase_type", None) == "Solid":
        rf = e.composition.reduced_formula
        e2 = copy.deepcopy(e)
        e2._energy = update_solid_energy(e.energy, T, rf)
        entries_solid_fixed.append(e2)
    else:
        entries_solid_fixed.append(e)

# ============================
# ---- APPLY: IONS -----------
# ============================
entries_all_fixed = []
for e in entries_solid_fixed:
    if getattr(e, "phase_type", None) == "Ion":
        e2 = copy.deepcopy(e)
        e2._energy = update_ion_energy(e, T)
        entries_all_fixed.append(e2)
    else:
        entries_all_fixed.append(e)

# ============================
# ---- BUILD & PLOT ----------
# ============================
pbx = PourbaixDiagram(entries_all_fixed, conc_dict=conc_dict)
plotter = PourbaixPlotter(pbx)
ax = plotter.get_pourbaix_plot()
ax.set_title(f"Pourbaix Diagram at {T} K, [C] = {conc} mol/L\n(solids & ions T-corrected)")
plt.tight_layout()
plt.show()
