In [33]:
# Google Colab Notebook: Ionic‐Params → DNN Prediction
!wget -q https://github.com/LIANGTING-WU/ML_Phase_Predictor/releases/download/v1.0/scaler.pkl -O scaler.pkl
!wget -q https://github.com/LIANGTING-WU/ML_Phase_Predictor/releases/download/v1.0/trained_model.h5 -O trained_model.h5
# Cell 1: Install & Import (suppress installation and TF logs)
import os, logging
# Suppress TensorFlow warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
tf.get_logger().setLevel(logging.ERROR)

import math
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model
import pickle
from sklearn.preprocessing import StandardScaler

# Radii dictionary
radii = {
    "Li": {1: 76},
    "Na": {1: 102},
    "Nt": {1: 102},
    "K": {1: 138},
    "Ca": {2: 100},
    "O":  {-2: 140},
    "Mg": {2: 72},
    "Ti": {2: 86, 3: 67, 4: 60.5},
    "V":  {2: 79, 3: 64, 4: 58, 5: 54},
    "Cr": {2: 73, 3: 61.5, 4: 55, 5: 49, 6: 44},
    "Mn": {2: 67, 3: 58, 4: 53},
    "Fe": {2: 61, 3: 55, 4: 58.5},
    "Co": {2: 65, 3: 54.5, 4: 53},
    "Ni": {2: 69, 3: 56, 4: 48},
    "Cu": {1: 77, 2: 73, 3: 54},
    "Zn": {2: 74},
    "Sr": {2: 118},
    "Ru": {3: 68, 4: 62, 5: 56.5},
    "Nb": {3: 72, 4: 68, 5: 64},
    "Te": {4: 97, 6: 56},
    "Mo": {3: 69, 4: 65, 5: 61, 6: 59},
    "Rh": {3: 66.5, 4: 60, 5: 55},
    "Lu": {3: 86.1},
    "Zr": {4: 72},
    "Pd": {1: 59, 2: 86, 3: 76, 4: 61.5},
    "Ir": {3: 68, 4: 62.5, 5: 57},
    "Pt": {2: 80, 4: 62.5, 5: 57},
    "Sn": {2: 109, 4: 69},
    "Bi": {3: 103, 5: 76},
    "Sb": {3: 76, 5: 60},
    "Al": {3: 53.5},
    "Er": {3: 103},
    "Tl": {3: 102.5},
    "Sc": {3: 88.5},
    "In": {3: 94},
    "Y":  {3: 104},
    "Ho": {3: 104.1},
    "Yb": {2: 116, 3: 100.8},
    "Pu": {3: 100, 4: 86},
    "Ag": {1: 115}
}

# Simplified Pauling electronegativity dictionary
electronegativity = {
    "Li": 0.98,
    "Na": 0.93,
    "Nt": 0.93,
    "K": 0.82,
    "Ca": 1.0,
    "O": 3.44,
    "Mg": 1.31,
    "Ti": 1.54,
    "V": 1.63,
    "Cr": 1.66,
    "Mn": 1.55,
    "Fe": 1.83,
    "Co": 1.88,
    "Ni": 1.91,
    "Cu": 1.90,
    "Zn": 1.65,
    "Sr": 0.95,
    "Ru": 2.20,
    "Nb": 1.60,
    "Te": 2.10,
    "Mo": 2.16,
    "Rh": 2.28,
    "Lu": 1.27,
    "Zr": 1.33,
    "Pd": 2.20,
    "Ir": 2.20,
    "Pt": 2.28,
    "Sn": 1.96,
    "Bi": 2.02,
    "Sb": 2.05,
    "Al": 1.61,
    "Er": 1.24,
    "Tl": 1.62,
    "Sc": 1.36,
    "In": 1.78,
    "Y": 1.22,
    "Ho": 1.23,
    "Yb": 1.10,
    "Pu": 1.28,
    "Ag": 1.93
}

ionization_energies = {
    "Li": {1: 75.640097},
    "Na": {1: 47.28636},
    "Nt": {1: 47.28636},
    "K": {1: 31.62500},
    "Ca": {2: 50.91316},
    "O":  {-2: 140},
    "Mg": {2: 80.1436},
    "Ti": {2: 27.49171, 3: 43.26717, 4: 99.299},
    "V":  {2: 29.3111, 3: 46.709, 4: 65.28165, 5: 128.125},
    "Cr": {2: 30.959, 3: 49.16, 4: 69.46, 5: 90.6349, 6: 160.29},
    "Mn": {2: 33.668, 3: 51.21, 4: 72.41},
    "Fe": {2: 30.651, 3: 54.91, 4: 75},
    "Co": {2: 33.5, 3: 51.27, 4: 79.5},
    "Ni": {2: 35.187, 3: 54.92, 4: 76.06},
    "Cu": {1: 20.29239, 2: 36.841, 3: 57.38},
    "Zn": {2: 39.7233},
    "Sr": {2: 42.88353},
    "Ru": {3: 45, 4: 59, 5: 76},
    "Nb": {3: 37.611, 4: 50.5728, 5: 102.069},
    "Te": {4: 59.3, 6: 124.2},
    "Mo": {3: 40.33, 4: 54.417, 5: 68.82704, 6: 125.638},
    "Rh": {3: 42, 4: 63, 5: 80},
    "Lu": {3: 45.249},
    "Zr": {4: 80.348},
    "Pd": {1: 19.43, 2: 32.93, 3: 46, 4: 61},
    "Ir": {3: 40, 4: 57, 5: 72},
    "Pt": {2: 29, 4: 56, 5: 75},
    "Sn": {2: 30.506, 4: 77.03},
    "Bi": {3: 45.37, 5: 88.4},
    "Sb": {3: 43.804, 5: 99.51},
    "Al": {3: 119.9924},
    "Er": {3: 42.42},
    "Tl": {3: 51.14},
    "Sc": {3: 73.4894},
    "In": {3: 55.45},
    "Y":  {3: 60.6072},
    "Ho": {3: 42.52},
    "Yb": {2: 25.053, 3: 43.61},
    "Pu": {3: 35, 4: 49},
    "Ag": {1: 21.4844}
}

def check_elements_in_radii(data, radii_dict):
    """
    Check if every element (except 'O') that appears in 'data' also exists in 'radii_dict'.
    Print out all elements that are missing in the 'radii_dict'.
    """

    missing_elements = set()

    # Loop through each sample in data
    for sample in data:
        # Each sample looks like: [sample_id, phase, na_content, composition_dict]
        sample_id, phase, na_content, composition = sample

        # Check each element in composition
        for element, value in composition.items():
            # We skip oxygen, as it is usually in radii but we often treat it separately
            if element == 'O':
                continue

            # If element not in the radii dictionary, record it
            if element not in radii_dict:
                missing_elements.add(element)

    if missing_elements:
        print("The following elements were found in 'data' but are missing in 'radii_dict':")
        for elem in missing_elements:
            print(f" - {elem}")
    else:
        print("All elements found in 'data' are present in 'radii_dict'.")


def get_possible_valences(element, radii_dict):
    """
    Return a sorted list of all possible valences for the given element from the radii dictionary.
    If element is not found, return an empty list.
    """
    if element not in radii_dict:
        return []
    return sorted(radii_dict[element].keys())

def choose_default_valence(element, valences):
    """
    Choose a default valence:
    1) If +3 exists, use +3.
    2) If max valence < 3, use max valence.
    3) If min valence > 3, use min valence.
    4) If no +3 but there's 2 and 4 (or any that includes >3), pick the smallest valence that is > 3.
    """
    if 3 in valences:
        return 3
    if len(valences) == 0:
        raise ValueError(f"No valence data for {element}")  # Raise an error instead of returning 0

    vmin = min(valences)
    vmax = max(valences)

    if vmax < 3:
        return vmax
    if vmin > 3:
        return vmin

    # If no +3 but there exists valences both below 3 and above 3,
    # choose the smallest valence that is greater than 3
    vals_above_3 = [v for v in valences if v > 3]
    if vals_above_3:
        return min(vals_above_3)

    # Fallback (should not reach here if data is consistent)
    return 0


def compute_initial_metal_charge(metals, radii_dict):
    """
    Compute the total positive charge from metals based on their default valences.
    Return:
    - total_metal_charge
    - detail: dict of { metal_element: {"amount": x, "valence": v_default} }
    """
    detail = {}
    total_charge = 0.0
    for m, amt in metals.items():
        valences = get_possible_valences(m, radii_dict)
        chosen_val = choose_default_valence(m, valences)
        charge_contribution = chosen_val * amt
        total_charge += charge_contribution
        detail[m] = {
            "amount": amt,
            "valence": chosen_val
        }
    return total_charge, detail

def adjust_oxidation_states(metal_detail, charge_deficit, radii_dict, electronegativity_dict):
    """
    This version prioritizes fully oxidizing or reducing one metal before moving to another.
    If partial oxidation/reduction is needed to exactly match the leftover charge, we do it,
    then stop using this metal and proceed (or finish if net charge is balanced).

    1) If charge_deficit < 0 => oxidation:
       - Sort metals by ascending electronegativity (lowest first).
       - For each metal, repeatedly move from the current valence to the next higher valence in possible_vals.
         Perform a "full fraction" shift if that doesn't overshoot needed_charge,
         or partial shift if overshoot.
         Then try the next higher valence until we cannot go further or needed_charge=0.
       - If leftover needed_charge is still <0 after fully pushing a metal to its highest valence,
         move on to the next metal, etc.

    2) If charge_deficit > 0 => reduction:
       - Sort metals by descending electronegativity (highest first).
       - Similar logic: for each metal, from current valence to next lower valence in possible_vals,
         do full fraction shift or partial if it overshoots, until we can't reduce more or needed_charge=0.

    3) If abs(charge_deficit) < 1e-12 => return no change (but convert to distribution format).
    """

    import math

    # If no significant net charge, return the original detail in distribution format
    if abs(charge_deficit) < 1e-12:
        no_change_distribution = {}
        for m, info in metal_detail.items():
            val = info["valence"]
            amt = info["amount"]
            no_change_distribution[m] = {val: amt}
        return no_change_distribution

    # Convert to distribution structure {element: {valence: fraction}}
    metal_valence_distribution = {}
    for m, info in metal_detail.items():
        metal_valence_distribution[m] = {info["valence"]: info["amount"]}

    # ----------------------
    # OXIDATION (charge_deficit < 0)
    # ----------------------
    if charge_deficit < 0:
        needed_charge = abs(charge_deficit)
        # Sort metals by ascending electronegativity => lower EN first, more prone to losing e-
        sorted_metals = sorted(metal_valence_distribution.keys(),
                               key=lambda x: electronegativity_dict.get(x, 9999))

        for m in sorted_metals:
            if needed_charge <= 1e-12:
                break

            # Get possible valences for this metal
            if m in radii_dict:
                possible_vals = sorted(radii_dict[m].keys())
            else:
                possible_vals = []

            current_dist = metal_valence_distribution[m]

            # Keep pushing this metal as far as possible
            # until no more valences or needed_charge is zero
            while True:
                if needed_charge <= 1e-12:
                    break
                if len(current_dist) == 0:
                    break

                # Find the highest valence currently used for this metal
                current_valence = max(current_dist.keys())
                fraction = current_dist[current_valence]

                # Find the next higher valence
                higher_candidates = [v for v in possible_vals if v > current_valence]
                if not higher_candidates:
                    # no next higher valence
                    break
                next_val = min(higher_candidates)  # the immediate next higher

                valence_step = next_val - current_valence
                max_gain = fraction * valence_step  # if we move the entire fraction from current_valence to next_val

                # Check if fully shifting the entire fraction overshoots the needed_charge
                if max_gain > needed_charge:
                    # partial shift
                    fraction_to_shift = needed_charge / valence_step
                    old_fraction_left = fraction - fraction_to_shift
                    if old_fraction_left < 1e-12:
                        del current_dist[current_valence]
                    else:
                        current_dist[current_valence] = old_fraction_left

                    current_dist[next_val] = current_dist.get(next_val, 0.0) + fraction_to_shift
                    needed_charge = 0.0
                    break  # done with this metal, charge is balanced
                else:
                    # fully shift the entire fraction
                    del current_dist[current_valence]
                    current_dist[next_val] = current_dist.get(next_val, 0.0) + fraction
                    needed_charge -= max_gain
                    # continue while-loop to see if we can push further

        return metal_valence_distribution

    # ----------------------
    # REDUCTION (charge_deficit > 0)
    # ----------------------
    else:
        needed_reduction = abs(charge_deficit)
        # Sort metals by descending electronegativity => higher EN first, more prone to gaining e-
        sorted_metals = sorted(metal_valence_distribution.keys(),
                               key=lambda x: electronegativity_dict.get(x, 0),
                               reverse=True)

        for m in sorted_metals:
            if needed_reduction <= 1e-12:
                break

            if m in radii_dict:
                possible_vals = sorted(radii_dict[m].keys(), reverse=True)
            else:
                possible_vals = []

            current_dist = metal_valence_distribution[m]

            while True:
                if needed_reduction <= 1e-12:
                    break
                if len(current_dist) == 0:
                    break

                # Find the lowest valence currently used for this metal
                current_valence = min(current_dist.keys())
                fraction = current_dist[current_valence]

                # Next lower valence in possible_vals
                lower_candidates = [v for v in possible_vals if v < current_valence]
                if not lower_candidates:
                    break
                next_val = max(lower_candidates)  # immediate next lower valence

                valence_step = current_valence - next_val
                max_reduction = fraction * valence_step

                if max_reduction > needed_reduction:
                    # partial shift
                    fraction_to_shift = needed_reduction / valence_step
                    old_fraction_left = fraction - fraction_to_shift
                    if old_fraction_left < 1e-12:
                        del current_dist[current_valence]
                    else:
                        current_dist[current_valence] = old_fraction_left

                    current_dist[next_val] = current_dist.get(next_val, 0.0) + fraction_to_shift
                    needed_reduction = 0.0
                    break
                else:
                    # fully shift
                    del current_dist[current_valence]
                    current_dist[next_val] = current_dist.get(next_val, 0.0) + fraction
                    needed_reduction -= max_reduction
                    # continue while-loop to see if we can reduce further

        return metal_valence_distribution



def process_sample(sample, radii_dict, en_dict):
    """
    Main function to handle one sample and return:
    1) A dictionary that lists each species (e.g., 'Mn3+', 'Mn4+') and amount
    2) A dictionary that gives the radius and electronegativity for each species
    """
    mat_id, phase, na_amount, comp = sample
    # Separate metals and oxygen
    o_amt = comp.get('O', 0)
    metals = {k: v for k, v in comp.items() if k != 'O'}

    # Compute sum of metal amounts
    sum_metal = sum(metals.values())

    # Step 1: Determine default valences for metals
    initial_metal_charge, metal_detail = compute_initial_metal_charge(metals, radii_dict)


    # Step 2 (modified): Directly use the initial_metal_charge without forcing a correction
    # We do not force (1 - sum_metal) * 3 anymore; we rely on net_charge to guide oxidation/reduction.
    metal_charge_correction = 0.0  # Not used; keep as placeholder if needed

    # Step 3: Compute total cation charge based on initial metal valences + Na
    total_metal_charge = initial_metal_charge  # Use metals' default valence sum directly
    total_na_charge = na_amount * 1.0          # +1 for Na
    total_cation = total_metal_charge + total_na_charge

    # Step 4: Compute total anion charge (oxygen = -2 each)
    total_anion = o_amt * (-2)

    # Step 5: Net charge (if net_charge > 0 => surplus positive => reduction, if < 0 => oxidation)
    net_charge = total_cation + total_anion

    # Step 6: Adjust metal oxidation states if net_charge != 0
    updated_distribution = adjust_oxidation_states(metal_detail, net_charge, radii_dict, en_dict)

    # Build final species dictionary
    # For Na and O, they remain single valences
    results = {}

    # Na+
    if na_amount > 1e-12:
        species_name = "Na+"
        results[species_name] = na_amount

    # O2-
    if o_amt > 1e-12:
        species_name = "O2-"
        results[species_name] = o_amt

    # Now metals
    # updated_distribution might look like: {"Pt": {4: 0.6667}}
    for m, dist in updated_distribution.items():
        for val, amt in dist.items():
            if amt < 1e-12:
                continue
            species_name = f"{m}{val}+"
            results[species_name] = amt

    # Build property dictionary (radius, electronegativity)
    properties = {}
    for species, amt in results.items():
        if species == "Na+":
            properties[species] = {
                "radius": radii_dict["Na"][1],
                "electronegativity": en_dict["Na"]
            }
        elif species == "O2-":
            properties[species] = {
                "radius": radii_dict["O"][-2],
                "electronegativity": en_dict["O"]
            }
        else:
            # metal, e.g. "Mn3+"
            import re
            match = re.match(r"^([A-Za-z]+)(\d+)\+$", species)
            if not match:
                properties[species] = {
                    "radius": None,
                    "electronegativity": None
                }
            else:
                elem = match.group(1)  # e.g. "Mn"
                val_str = match.group(2)  # e.g. "3"
                val = int(val_str)

                # Check radius
                r = None
                if elem in radii_dict and val in radii_dict[elem]:
                    r = radii_dict[elem][val]

                # Check electronegativity
                e = en_dict.get(elem, None)

                properties[species] = {
                    "radius": r,
                    "electronegativity": e
                }

    return results, properties

def compute_ionic_params(final_dist, final_props, ionization_energies):
    """
    Compute the 5 parameters for Na, transition metals, and O:
    1) total amount
    2) total charge
    3) radius (weighted by fraction)
    4) ionization energy (sum over fraction)
    5) ionic potential (sum of q_i * x_i / r_i)

    final_dist: dict, e.g. {"Na+": 0.65, "Mn3+": 0.37, "Mn4+": 0.35, "O2-": 2.0}
    final_props: dict, e.g. {"Na+": {"radius": 102, "electronegativity": 0.93}, ...}
    ionization_energies: dict, e.g. {"Mn": {2: 999, 3: 1200, 4: 1500}, ...}

    Returns a dict with three keys: "Na", "TransitionMetal", "O",
    each storing a dict of the five parameters.
    """

    import re

    # Initialize result structures
    result = {
        "Na":   {"amount": 0.0, "charge": 0.0, "radius_num": 0.0, "radius_den": 0.0, "IE": 0.0, "ionic_pot": 0.0},
        "TM":   {"amount": 0.0, "charge": 0.0, "radius_num": 0.0, "radius_den": 0.0, "IE": 0.0, "ionic_pot": 0.0},
        "O":    {"amount": 0.0, "charge": 0.0, "radius_num": 0.0, "radius_den": 0.0, "IE": 0.0, "ionic_pot": 0.0}
    }

    # NEW: dictionary to track each transition metal element fraction
    tm_composition = {}

    for species, frac in final_dist.items():
        if frac < 1e-12:
            continue
        # Identify group: Na, O, or TM
        group_key = None
        if species == "Na+":
            group_key = "Na"
        elif species == "O2-":
            group_key = "O"
        else:
            group_key = "TM"

        # Extract radius, if available
        r = final_props[species].get("radius", None)
        if r is None or r < 1e-12:
            # fallback to skip or set r=1 to avoid division by zero
            r = 1e-12

        # Parse charge from species name, e.g. "Mn3+" => element="Mn", valence=3 => charge=+3
        # or "O2-" => charge=-2
        charge_sign = +1


        match_plus = re.match(r"^([A-Za-z]+)(\d*)\+$", species)
        match_minus = re.match(r"^([A-Za-z]+)(\d*)-$", species)



        if match_plus:
            val_str = match_plus.group(2)
            element = match_plus.group(1)
            if val_str == "":
                valence = +1
            else:
                valence = +int(val_str)
            # charge_val = valence
            charge_val = valence  # e.g. Na+ => valence=+1 => charge_val=+1

        elif match_minus:
            val_str = match_minus.group(2)
            element = match_minus.group(1)
            if val_str == "":
                valence = -1
            else:
                valence = -int(val_str)
            # charge_val = valence
            charge_val = valence  # e.g. O2- => valence=-2 => charge_val=-2


        # Ionization energy look-up
        # If not found, fallback to 0.0
        # e.g. ionization_energies["Mn"][3]
        IEi = 0.0
        if element in ionization_energies:
            if valence in ionization_energies[element]:
                IEi = ionization_energies[element][valence]

        # Update partial sums
        # 1) amount
        result[group_key]["amount"] += frac
        # 2) charge
        result[group_key]["charge"] += charge_val * frac
        # 3) radius => we do a weighted average => radius_num += r_i * x_i, radius_den += x_i
        result[group_key]["radius_num"] += r * frac
        result[group_key]["radius_den"] += frac
        # 4) IE => sum( IE_i * x_i )
        result[group_key]["IE"] += IEi * frac
        # 5) ionic potential => sum( q_i * x_i / r_i )
        #   caution about sign of q_i if you want negative for anions, or magnitude for cations
        #   here we just do the literal q_i * x_i / r_i
        result[group_key]["ionic_pot"] += (charge_val * frac / r)

        # NEW: if group_key is TM, track composition of each element
        if group_key == "TM":
            if element not in tm_composition:
                tm_composition[element] = 0.0
            tm_composition[element] += frac

    # NEW: after the loop, compute sum( x_i * ln(x_i) ) for TM
    sum_xi_lnxi = 0.0
    for elem, x in tm_composition.items():
        if x > 1e-12:
            sum_xi_lnxi += x * math.log(x)

    # Now finalize each group's radius = radius_num / radius_den
    final_data = {}
    for gkey in ["Na", "TM", "O"]:
        amt = result[gkey]["amount"]
        chg = result[gkey]["charge"]
        IE_ = result[gkey]["IE"]
        rad_num = result[gkey]["radius_num"]
        rad_den = result[gkey]["radius_den"]
        ionic_pot = result[gkey]["ionic_pot"] * 1000

        if rad_den < 1e-12:
            avg_radius = 0.0
        else:
            avg_radius = rad_num / rad_den

        final_data[gkey] = {
            "amount": amt,
            "charge": chg,
            "radius": avg_radius,
            "IE": IE_,
            "ionic_potential": ionic_pot
        }

    # NEW: attach the entropy to final_data["TM"]
    final_data["TM"]["entropy"] = -sum_xi_lnxi
    return final_data



In [34]:
# Cell 3: User inputs new sample(s)
# Example format:
data = [
    [147473, 'O3', 1.0, {'Ni': 0.3, 'Fe': 0.2, 'Mn': 0.5, 'O': 2.0}],
    [36992, 'P2', 0.69, {'Mn': 0.77, 'Fe': 0.08, 'Mg': 0.15, 'O': 2.0}]
]

# Cell 4: Process each sample with process_sample + compute_ionic_params, then build a DataFrame
records = []
for sample in data:
    # Calculate species distribution and properties
    dist, props = process_sample(sample, radii, electronegativity)
    ionic = compute_ionic_params(dist, props, ionization_energies)

    # Unpack sample metadata
    sample_id, phase, _, _ = sample

    # Compute features for DNN input
    Na_amount = ionic['Na']['amount']
    TM_radius = ionic['TM']['radius']
    TM_IE     = ionic['TM']['IE']
    TM_IP     = ionic['TM']['ionic_potential']
    TM_S      = ionic['TM']['entropy']
    CP        = Na_amount * 9.803922 * TM_IP / 28.571429

    # Collect into a record dict
    rec = {
        'sample_id': sample_id,
        'Na_amount': Na_amount,
        'TM_radius': TM_radius,
        'TM_IE':     TM_IE,
        'TM_IP':     TM_IP,
        'TM_S':      TM_S,
        'CP':        CP
    }
    records.append(rec)

# Convert records list to DataFrame
# This DataFrame contains the input features for prediction

df_new = pd.DataFrame(records)
print("New‐input features:")
print(df_new)


New‐input features:
   sample_id  Na_amount  TM_radius     TM_IE      TM_IP      TM_S         CP
0     147473       1.00      56.80  53.06300  52.842588  1.029653  18.132261
1      36992       0.69      57.56  65.59804  59.281767  0.687877  14.035831


In [35]:
# Cell 5: Load previously saved StandardScaler
# Assumes you saved the scaler during training as 'scaler.pkl'
with open('scaler.pkl', 'rb') as f:
    scaler = pickle.load(f)

# Standardize new sample features
X_new = df_new[['Na_amount','TM_radius','TM_IE','TM_IP','TM_S','CP']].values
X_new_scaled = scaler.transform(X_new)

# Cell 6: Load trained DNN model and predict
model = load_model('trained_model.h5', compile=False)
proba = model.predict(X_new_scaled)
pred  = (proba.ravel() >= 0.5).astype(int)

# Append prediction results to DataFrame
df_new['y_proba'] = proba.ravel()
df_new['y_pred']  = pred
# Add 'Phase' column mapping predictions to labels
df_new['Phase']   = df_new['y_pred'].map({1: 'O3', 0: 'P2'})

print("Prediction results:")
print(df_new[['sample_id','y_proba','y_pred','Phase']])

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 188ms/step
Prediction results:
   sample_id   y_proba  y_pred Phase
0     147473  0.999437       1    O3
1      36992  0.011626       0    P2
