In [None]:
!pip install cantera
!pip install torchdiffeq
!pip install pandas

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torchdiffeq import odeint
import cantera as ct
from sklearn.neighbors import KernelDensity
from scipy import stats
import matplotlib.pyplot as plt
from tqdm import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}\n")

# Define temperature groups
TEMPERATURE_GROUPS = {
    'high': [1900, 2100, 2300],
    'medium': [1500, 1700, 1900],
    'low': [900, 1100, 1300, 1500]
}

# =============================================
# 1. Training Data Generation with Cantera
# =============================================

def generate_training_data():
    # Define ammonia-air mechanism
    gas = ct.Solution('/content/sample_data/okafor.yaml')

    # Define initial conditions
    equivalence_ratios = np.arange(0.6, 1.5, 0.2)  # phi = 0.6 to 1.4
    temperatures = list(set(t for group in TEMPERATURE_GROUPS.values() for t in group))
    temperatures.sort()
    N_samples = 10000  # Number of points per condition

    # Species to track (major species)
    species = ['NH3', 'O2', 'H2', 'OH', 'H2O', 'N2', 'NO']

    # Storage for all data
    all_data = []

    for T_init in temperatures:
        for phi in equivalence_ratios:
            # Set initial state
            gas.set_equivalence_ratio(phi, 'NH3', 'O2:1.0, N2:3.76')
            gas.TP = T_init, ct.one_atm

            # Create reactor
            reactor = ct.IdealGasConstPressureReactor(gas)
            sim = ct.ReactorNet([reactor])
            rtol, atol = 1e-12, 1e-16
            sim.rtol, sim.atol = rtol, atol
            print(f"\nGenerating Training Data for (T, φ) = ({T_init}, {phi.round(1)})")

            # Compute ignition delay time (tau_ign)
            if T_init == 900:
                # tau_ign = calculate_ignition_delay_time_dynamic(phi, T, P, mechanism)
                tau_ign_info = "~10³ seconds"
                t_elap = 5
                dt = t_elap / N_samples
            elif T_init == 1100:
                tau_ign = 14.50   # calculate_ignition_delay_time_dynamic(phi, T, P, mechanism)
                t_elap = 5
                dt = t_elap / N_samples
            else:
                tau_ign = calculate_ignition_delay_time_dynamic(phi, T_init, ct.one_atm, '/content/sample_data/okafor.yaml')
                dt = 4 * tau_ign / N_samples
                t_elap = 4 * tau_ign

            # Storage for this condition
            wdot = reactor.thermo.net_production_rates * reactor.thermo.molecular_weights / reactor.thermo.density
            condition_data = {
                'phi': phi,
                'T_init': T_init,
                'times': [0.0],
                'T': [reactor.T],
                'Y': [np.array([reactor.thermo[sp].Y[0] for sp in species])],
                'dYdt': [np.array([wdot[reactor.thermo.species_index(sp)] for sp in species])]
            }

            # Run simulation
            while sim.time < t_elap:
                sim.step()
                if len(condition_data['times']) == 0 or sim.time - condition_data['times'][-1] >= dt:
                    condition_data['times'].append(sim.time)
                    condition_data['T'].append(reactor.T)

                    # Get mass fractions and reaction rates
                    Y = np.array([reactor.thermo[sp].Y[0] for sp in species])
                    condition_data['Y'].append(Y)

                    # Get reaction rates (dY/dt)
                    wdot = reactor.thermo.net_production_rates * reactor.thermo.molecular_weights / reactor.thermo.density
                    dYdt = np.array([wdot[reactor.thermo.species_index(sp)] for sp in species])
                    condition_data['dYdt'].append(dYdt)

            # Convert to numpy arrays
            condition_data['Y'] = np.array(condition_data['Y'])
            condition_data['dYdt'] = np.array(condition_data['dYdt'])
            condition_data['T'] = np.array(condition_data['T'])

            all_data.append(condition_data)

    return all_data, species


def calculate_ignition_delay_time_dynamic(equivalence_ratio, temperature, pressure, mechanism):
    """
    Calculate the ignition delay time by fitting a sigmoid function
    to the temperature profile and identifying the inflection point.

    Parameters:
        equivalence_ratio (float): Equivalence ratio of the mixture.
        temperature (float): Initial temperature in K.
        pressure (float): Initial pressure in Pa.
        mechanism (str): Path to the chemical mechanism file.

    Returns:
        float: Ignition delay time in seconds.
    """
    from scipy.optimize import curve_fit

    # Sigmoid function for fitting
    def sigmoid(t, L, k, t0, C):
        return L / (1 + np.exp(-k * (t - t0))) + C

    # Initialize gas and reactor
    gas = ct.Solution(mechanism)
    gas.TP = temperature, pressure
    gas.set_equivalence_ratio(equivalence_ratio, fuel="NH3", oxidizer="O2:1, N2:3.76")

    reactor = ct.IdealGasConstPressureReactor(gas)
    sim = ct.ReactorNet([reactor])

    # Time and data collection
    t = 0.0
    dt = 1e-6
    if temperature == 900:
        max_time = 25000.0
    elif temperature == 1100:
        max_time = 20.0
    else:
        max_time = 1.0

    times = []
    temperatures = []

    # Simulation loop
    if temperature in [900, 1100]:
        while t < max_time:
            sim.advance(t + dt)
            t += dt
            times.append(t)
            temperatures.append(reactor.T)

        # Convert lists to arrays
        times = np.array(times)
        temperatures = np.array(temperatures)

        # Initial guess for sigmoid parameters
        L_guess = max(temperatures) - min(temperatures)
        k_guess = 1  # Guess steepness
        t0_guess = times[np.argmax(np.gradient(temperatures, times))]  # Approximate midpoint
        C_guess = min(temperatures)

        # Fit the sigmoid function to the data
        try:
            popt, _ = curve_fit(sigmoid, times, temperatures, p0=[L_guess, k_guess, t0_guess, C_guess])
            L, k, t0, C = popt  # Extract parameters
            tau_ign = t0  # The inflection point corresponds to t0
        except RuntimeError:
            print(f"Sigmoid fitting failed for T={temperature}, φ={equivalence_ratio.round(1)}.")
            tau_ign = np.nan

        # Plot the results
        # plt.figure(figsize=(8, 5))
        # plt.plot(times, temperatures, label="Temperature Profile")
        # if not np.isnan(tau_ign):
        #     sigmoid_fit = sigmoid(times, *popt)
        #     plt.plot(times, sigmoid_fit, label="Sigmoid Fit", linestyle="--")
        #     plt.axvline(tau_ign, color="red", linestyle="--", label=f"Ignition Delay Time = {tau_ign:.4f} s")
        # plt.title(f"Temperature Profile (T={temperature} K, φ={equivalence_ratio.round(1)})")
        # plt.xlabel("Time (s)")
        # plt.ylabel("Temperature (K)")
        # plt.legend()
        # plt.grid(True)
        # plt.show()

        return tau_ign

    else:
        # Variables to track derivatives and store data
        prev_temp = reactor.T
        prev_dT_dt = 0.0
        inflection_detected = False

        while t < max_time:
            sim.advance(t + dt)
            t += dt
            current_temp = reactor.T

            # Record data for plotting
            times.append(t)
            temperatures.append(current_temp)

            # Calculate first and second derivatives
            dT_dt = (current_temp - prev_temp) / dt
            d2T_dt2 = (dT_dt - prev_dT_dt) / dt

            # Detect the inflection point
            if not inflection_detected and prev_dT_dt > 0 and d2T_dt2 < 0:
                inflection_detected = True
                tau_ign = t
                break

            # Update previous values
            prev_temp = current_temp
            prev_dT_dt = dT_dt

        # If no inflection point was detected, use the final time as fallback
        if not inflection_detected:
            print(f"Warning: No inflection point found for T={temperature}, φ={equivalence_ratio.round(1)}. Using {t:.2f}s as fallback.")
            tau_ign = t

        # Plot the temperature profile
        # plt.figure(figsize=(8, 5))
        # plt.plot(times, temperatures, label="Temperature Profile")
        # if inflection_detected:
        #     plt.axvline(tau_ign, color="red", linestyle="--", label=f"Ignition Delay Time = {tau_ign:.4f} s")
        # plt.title(f"Temperature Profile (T={temperature} K, φ={equivalence_ratio.round(1)})")
        # plt.xlabel("Time (s)")
        # plt.ylabel("Temperature (K)")
        # plt.legend()
        # plt.grid(True)
        # plt.show()

        return tau_ign

# =============================================
# 2. Execution
# =============================================

# Step 1: Generate training data
print("Generating training data...")
all_data, species = generate_training_data()

# Add species list to each condition data
for data in all_data:
    data['species'] = species
print('\nPrinting all_data[2]:\n', all_data[2], '\n')


In [None]:
# =============================================
# 3. Multi-scale Sampling with Density Weighting
# =============================================

def calculate_sampling_weights(data, target_species='NH3'):
    """Calculate sampling weights using kernel density estimation"""
    # Extract target species mass fractions
    Y_target = np.concatenate([d['Y'][:, d['species'].index(target_species)] for d in data])

    # Kernel density estimation
    kde = KernelDensity(bandwidth=0.01, kernel='gaussian')
    kde.fit(Y_target.reshape(-1, 1))

    # Calculate densities
    log_dens = kde.score_samples(Y_target.reshape(-1, 1))
    P = np.exp(log_dens)
    P_min = np.min(P)

    # Calculate weights
    Q = P_min / P

    # Assign weights back to each data point
    start_idx = 0
    weighted_data = []
    for d in data:
        n_points = len(d['Y'])
        end_idx = start_idx + n_points
        d['weights'] = Q[start_idx:end_idx]
        weighted_data.append(d)
        start_idx = end_idx

    return weighted_data

# =============================================
# 4. Customized Box-Cox Transformation
# =============================================

class BoxCoxTransformer:
    def __init__(self, lambda_val=0.1):
        self.lambda_val = lambda_val
        self.consumption_species = ['NH3', 'O2']

    def transform(self, Y, species):
        """Apply customized Box-Cox transformation"""
        Y_transformed = np.zeros_like(Y)
        for i, sp in enumerate(species):
            if sp in self.consumption_species:
                # Consumption species transformation
                Y_transformed[:, i] = (Y[:, i]**self.lambda_val - 1) / self.lambda_val
            else:
                # Production species transformation
                Y_transformed[:, i] = ((1 - Y[:, i])**self.lambda_val - 1) / self.lambda_val
        return Y_transformed

    def inverse_transform(self, Y_transformed, species):
        """Inverse of the Box-Cox transformation"""
        Y = np.zeros_like(Y_transformed)
        for i, sp in enumerate(species):
            if sp in self.consumption_species:
                # Inverse for consumption species
                Y[:, i] = (self.lambda_val * Y_transformed[:, i] + 1)**(1/self.lambda_val)
            else:
                # Inverse for production species
                Y[:, i] = 1 - (self.lambda_val * Y_transformed[:, i] + 1)**(1/self.lambda_val)
        return Y

# =============================================
# 5. Neural ODE Model Definition
# =============================================

class GroupedODEFunc(nn.Module):
    def __init__(self, input_dim, group_name, target_species):
        super().__init__()
        assert input_dim == 7, f"Model requires 7 input features, got {input_dim}"
        self.group_name = group_name
        self.target_species = target_species

        # Unified architecture with 4 hidden layers
        hidden_dim = 192 if target_species in ['NH3', 'O2', 'H2O'] else 384
        self.net = nn.Sequential(
            # Input layer (not counted as hidden)
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.SiLU(),
            nn.Dropout(0.2),

            # Hidden Layer 1
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.SiLU(),

            # Hidden Layer 2
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.SiLU(),

            # Hidden Layer 3
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.SiLU(),

            # Hidden Layer 4
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.SiLU(),

            # Hidden Layer 5
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.SiLU(),

            # Output layer (not counted as hidden)
            nn.Linear(hidden_dim, 1)
        )

        # Enhanced residual connection
        self.skip = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 1)
        ) if input_dim != 1 else None

    def forward(self, t, x):
        return self.net(x) + (self.skip(x) if self.skip else 0)

# =============================================
# 6. Training Procedure
# =============================================

def train_node_model(data, target_species, species_list, group_name, n_epochs=5000, batch_size=256):
    # Set up transformer
    transformer = BoxCoxTransformer(lambda_val=0.1)

    # Prepare data
    X_list = []
    y_list = []
    weights_list = []

    # Define ACTUAL species used in model (exclude N2)
    model_species = [sp for sp in species_list if sp != 'N2']
    assert len(model_species) == 6, "Should have exactly 6 species (excl. N2)"

    for condition in data:
        # Transform mass fractions
        Y_transformed = transformer.transform(condition['Y'][:, [species_list.index(sp) for sp in model_species]], model_species)

        # Normalize temperature using calculated max_temp
        T_norm = condition['T'] / max_temp

        # Create input features (Y and T)
        X = np.concatenate([Y_transformed, T_norm.reshape(-1, 1)], axis=1)

        # Get target (dY/dt for target species)
        target_idx = species_list.index(target_species)
        y = condition['dYdt'][:, target_idx]

        # Get weights
        weights = condition['weights']

        X_list.append(X)
        y_list.append(y)
        weights_list.append(weights)

    # Concatenate all data
    X_all = np.concatenate(X_list)
    y_all = np.concatenate(y_list)

    # # =============================================
    # # NEW: Data Export Section
    # # =============================================
    # def save_debug_data(X, y, species_list, target_species, group_name):
    #     import pandas as pd
    #     import os

    #     # Create debug directory if needed
    #     os.makedirs('debug_data', exist_ok=True)

    #     # Prepare column names
    #     input_species = [sp for sp in species_list if sp != 'N2']
    #     feature_columns = input_species + ['T_normalized']

    #     # Create DataFrames
    #     df_X = pd.DataFrame(X, columns=feature_columns)
    #     df_y = pd.DataFrame(y, columns=[f'd{target_species}_dt'])

    #     # Save to CSV
    #     timestamp = pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')
    #     filename = f"debug_data/{target_species}_{group_name}_{timestamp}"

    #     df_X.to_csv(f"{filename}_X.csv", index=False)
    #     df_y.to_csv(f"{filename}_y.csv", index=False)

    #     # Save human-readable text version
    #     with open(f"{filename}_summary.txt", 'w') as f:
    #         f.write(f"=== Debug Data for {target_species} ({group_name} group) ===\n")
    #         f.write(f"Generated at: {timestamp}\n\n")

    #         f.write("=== Input Features (X_all) ===\n")
    #         f.write(f"Shape: {X.shape}\n")
    #         f.write("First 5 rows:\n")
    #         f.write(df_X.head().to_string() + "\n\n")

    #         f.write("=== Targets (y_all) ===\n")
    #         f.write(f"Shape: {y.shape}\n")
    #         f.write("Statistics:\n")
    #         f.write(df_y.describe().to_string() + "\n")

    #         f.write("\n=== Feature Statistics ===\n")
    #         f.write(df_X.describe().to_string())

    #     print(f"Debug data saved to {filename}_[X.csv|y.csv|summary.txt]")

    # # Call the debug function
    # save_debug_data(X_all, y_all, species_list, target_species, group_name)

    # Standardize dY/dt labels
    dYdt_mean, dYdt_std = np.mean(y_all), np.std(y_all)
    y_all = (y_all - dYdt_mean) / (dYdt_std + 1e-12)  # Prevent division by zero

    weights_all = np.concatenate(weights_list)

    # Convert to PyTorch tensors
    X_tensor = torch.FloatTensor(X_all).to(device)
    y_tensor = torch.FloatTensor(y_all).to(device)
    weights_tensor = torch.FloatTensor(weights_all).to(device)

    # Create model
    input_dim = X_all.shape[1]
    assert input_dim == 7, f"Expected 7 input features, got {input_dim}"

    model = GroupedODEFunc(
        input_dim=input_dim,
        group_name=group_name,
        target_species=target_species).to(device)

    # Store for inverse transform during inference
    model.dYdt_mean = dYdt_mean
    model.dYdt_std = dYdt_std

    # Optimizer Initialization
    optimizer = optim.AdamW(model.parameters(), lr=5e-3)

    # Learning Rate Scheduling
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=100,
        verbose=True
    )

    # Weighted Loss Function
    def weighted_mse(pred, target):
        # More aggressive weighting for ignition phases
        weights = torch.log1p(torch.abs(target)) + 1.0
        return torch.mean(weights * (pred - target)**2)

    loss_fn = weighted_mse

    # Training loop
    losses = []
    for epoch in tqdm(range(n_epochs), desc=f"Training {target_species} NODE"):
        # Sample batch with weights
        batch_indices = np.random.choice(len(X_all), size=batch_size, p=weights_all/weights_all.sum())
        X_batch = X_tensor[batch_indices]
        y_batch = y_tensor[batch_indices]

        # Forward pass
        y_pred = model(None, X_batch).squeeze()

        # Compute loss
        loss = loss_fn(y_pred, y_batch)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()

        # Inside training loop (after loss.backward()):
        scheduler.step(loss)

        optimizer.step()

        losses.append(loss.item())

    # Plot training loss
    plt.figure()
    plt.plot(losses)
    plt.yscale('log')
    plt.title(f"Training Loss for {target_species}")
    plt.xlabel("Epoch")
    plt.ylabel("MSE Loss")
    plt.show()

    return model

def train_temperature_group_models(all_data, species_list):
    # First filter out N2 from species list
    model_species = [sp for sp in species_list if sp != 'N2']
    print(f"Using species: {model_species} (N2 excluded)")

    # Group data by temperature ranges - MODIFIED to include boundary temps in both groups
    grouped_data = {'high': [], 'medium': [], 'low': []}

    for condition in all_data:
        T_init = condition['T_init']
        if T_init >= 1900:  # High group
            grouped_data['high'].append(condition)
        if 1500 <= T_init <= 1900:  # Medium group (includes 1900)
            grouped_data['medium'].append(condition)
        if T_init <= 1500:  # Low group (includes 1500)
            grouped_data['low'].append(condition)

    # Rest of the function remains the same...
    node_models = {}

    for group_name in ['high', 'medium', 'low']:
        print(f"\nTraining {group_name.upper()} temperature group models:\n")
        group_data = grouped_data[group_name]

        weighted_group_data = calculate_sampling_weights(group_data, target_species='NH3')

        for target_species in model_species:
            print(f"\nTraining {target_species} ({group_name.upper()} group)")
            model = train_node_model(weighted_group_data, target_species, species_list, group_name = group_name)
            node_models[f"{target_species}_{group_name}"] = model

    return node_models


In [None]:
# Step 2: Train with temperature grouping
print("Training temperature-grouped NODE models...\n")

# Calculate maximum temperature across all data points
global max_temp
max_temp = max(
    np.max(condition['T'])      # Get max T for this condition
    for condition in all_data   # Loop through all conditions
)
print(f"Maximum temperature across all data points: {max_temp}")

node_models = train_temperature_group_models(all_data, species)


In [None]:
!zip -r /content/DATA.zip /content/debug_data/

In [None]:
# Step 3: Save all 18 models
torch.save(node_models, 'grouped_ammonia_models.pth')
print("Saved 18 models (6 species × 3 temp groups) to grouped_ammonia_models.pth")


In [None]:
# Step 4: Validation of NODE Model

# =============================================
# Constants
# =============================================
R_UNIV = 8.314  # Universal gas constant [J/(mol·K)]

# =============================================
# Thermodynamic Data (NASA polynomials for species)
# =============================================
# NASA 7-coefficient polynomial format: [a_low_1, a_low_2, ..., a_low_7, T_low, T_high, a_high_1, ..., a_high_7] (Burcat Database)
# NASA_COEFFS = {
#     'SPECIES_NAME': {
#         'coeffs': [
#             # Low-T range coefficients (200-1000K) [a1, a2, a3, a4, a5, a6, a7]
#             # High-T range coefficients (1000-6000K) [b1, b2, b3, b4, b5, b6, b7]
#         ],
#         'T_low': 200.0,    # Minimum temperature for low-T coefficients (K)
#         'T_high': 1000.0,  # Transition temperature between low/high ranges (K)
#         'Mw': 17.031       # Molecular weight (g/mol)
#     }
# }

NASA_COEFFS = {
    'H2': {
        'coeffs': [
            2.344331, 0.007980, -1.947815e-05, 2.015721e-08, -7.376118e-12, -917.935, 0.683010,
            3.337279, 0.002470, -8.180640e-07, 1.277629e-10, -6.378056e-15, -950.158, -3.205023
        ],
        'T_low': 200.0, 'T_high': 1000.0, 'Mw': 2.016  # Molecular weight [g/mol]
    },
    'O2': {
        'coeffs': [
            3.782456, -0.002996, 9.847302e-06, -9.681295e-09, 3.243728e-12, -1063.943, 3.657676,
            3.660961, 0.001656, -4.599965e-07, 6.669201e-11, -3.067911e-15, -1215.977, 3.415362
        ],
        'T_low': 200.0, 'T_high': 1000.0, 'Mw': 32.00
    },
    'H2O': {
        'coeffs': [
            4.198640, -0.002036, 4.301402e-06, -2.368140e-09, 5.087687e-13, -30293.726, -0.849009,
            3.033992, 0.002176, -1.640725e-07, -9.704198e-11, 1.682009e-14, -30004.488, 4.966770
        ],
        'T_low': 200.0, 'T_high': 1000.0, 'Mw': 18.015
    },
    'NH3': {
        'coeffs': [
            3.578323, -0.000610, 2.500000e-06, -2.000000e-09, 5.000000e-13, -10623.7, 2.203030,
            3.568874, 0.001366, -1.983749e-06, 1.426804e-09, -3.836744e-13, -1020.896, 5.872526
        ],
        'T_low': 200.0, 'T_high': 1000.0, 'Mw': 17.031
    },
    'OH': {
        'coeffs': [
            3.637268, 0.000506, -8.814150e-07, 9.804318e-10, -4.384130e-13, 3419.309, 2.932866,
            3.697578, 0.001012, -1.321330e-06, 1.410378e-09, -4.889708e-13, 3369.887, -1.272096
        ],
        'T_low': 200.0, 'T_high': 1000.0, 'Mw': 17.007
    },
    'NO': {
        'coeffs': [
            4.046189, -0.002470, 6.348420e-06, -5.961540e-09, 2.156450e-12, 983.261, 5.980056,
            3.531005, -0.000123, -1.182722e-06, 2.655356e-09, -1.322536e-12, 976.287, 6.500787
        ],
        'T_low': 200.0, 'T_high': 1000.0, 'Mw': 30.006
    },
    'N2': {
        'coeffs': [
            3.298677, 0.001408, -3.963222e-06, 5.641515e-09, -2.444854e-12, -1020.900, 3.950372,
            2.926640, 0.001487, -2.842380e-06, 3.365680e-09, -1.688258e-12, -905.851, 5.980528
        ],
        'T_low': 200.0, 'T_high': 1000.0, 'Mw': 28.013
    }
}

def compute_h_i(T, species):
    """Compute specific enthalpy h_i(T) [J/kg] for a species using NASA polynomials."""
    coeffs = NASA_COEFFS[species]['coeffs']
    T_mid = NASA_COEFFS[species]['T_low']
    if T <= T_mid:
        a = coeffs[:7]
    else:
        a = coeffs[7:14]

    # Enthalpy polynomial: h_i(T)/RT = a1 + a2*T/2 + a3*T^2/3 + a4*T^3/4 + a5*T^4/5 + a6/T
    h_RT = (
        a[0] +
        a[1] * T / 2 +
        a[2] * T**2 / 3 +
        a[3] * T**3 / 4 +
        a[4] * T**4 / 5 +
        a[5] / T
    )
    h_i = h_RT * R_UNIV * T  # [J/mol]
    h_i /= NASA_COEFFS[species]['Mw'] / 1000.0  # Convert to [J/kg]
    return h_i

def compute_cp_i(T, species):
    """Compute specific heat cp_i(T) [J/(kg·K)] for a species."""
    coeffs = NASA_COEFFS[species]['coeffs']
    T_mid = NASA_COEFFS[species]['T_low']
    if T <= T_mid:
        a = coeffs[:7]
    else:
        a = coeffs[7:14]

    # cp_i/R = a1 + a2*T + a3*T^2 + a4*T^3 + a5*T^4
    cp_R = (
        a[0] +
        a[1] * T +
        a[2] * T**2 +
        a[3] * T**3 +
        a[4] * T**4
    )
    cp_i = cp_R * R_UNIV  # [J/(mol·K)]
    cp_i /= NASA_COEFFS[species]['Mw'] / 1000.0  # Convert to [J/(kg·K)]
    return cp_i

def update_temperature(Y, T, dYdt, species_list):
    """
    Update temperature in a constant-pressure, adiabatic reactor.

    Args:
        Y (dict): Mass fractions {species: Y_i}.
        T (float): Current temperature [K].
        dYdt (dict): Rates of change of mass fractions {species: dY_i/dt}.
        dt (float): Time step [s].
        species_list (list): List of species names.

    Returns:
        T_new (float): Updated temperature [K].
    """
    # Compute sum(dY_i/dt * h_i(T))
    sum_dYdt_h = 0.0
    for species in species_list:
        h_i = compute_h_i(T, species)
        sum_dYdt_h += dYdt[species] * h_i

    # Compute mixture cp = sum(Y_i * cp_i(T))
    cp_mix = 0.0
    for species in species_list:
        cp_i = compute_cp_i(T, species)
        cp_mix += Y[species] * cp_i

    # Compute dT/dt = -sum(dY_i/dt * h_i(T)) / cp_mix
    dTdt = -sum_dYdt_h / cp_mix

    return dTdt


def validate_node_vs_cantera(node_models, T_init=1900, phi=1.2, t_end=5):
    """
    Compare NODE and Cantera simulations for a single initial condition.

    Args:
        node_models: Dictionary of trained models (e.g., {'NH3_high': model, ...})
        T_init: Initial temperature (K)
        phi: Equivalence ratio
        t_end: Simulation end time (s)
    """
    # --------------------------------------------------------------------------
    # 1. Set up Cantera simulation (Ground Truth)
    # --------------------------------------------------------------------------
    gas = ct.Solution('/content/sample_data/okafor.yaml')
    gas.set_equivalence_ratio(phi, 'NH3', 'O2:1.0, N2:3.76')
    gas.TP = T_init, ct.one_atm

    # Create reactor
    reactor = ct.IdealGasReactor(gas)
    sim = ct.ReactorNet([reactor])
    sim.atol, sim.rtol = 1e-14, 1e-16

    # Time points (log-spaced for ignition resolution)
    t_eval = np.linspace(0, t_end, 500)

    # Storage
    cantera_results = {
        'time': [],
        'T': [],
        'Y': {sp: [] for sp in ['NH3', 'O2', 'H2O', 'N2']}  # Key species to plot
    }

    # --------------------------------------------------------------------------
    # 2. Set up NODE simulation
    # --------------------------------------------------------------------------
    # gas_node = ct.Solution('/content/sample_data/okafor.yaml')  # Separate instance for NODE

    # Initial mass fractions (match Cantera's initialization)
    Y0 = {sp: gas[sp].Y[0] for sp in ['NH3', 'O2', 'H2', 'OH', 'H2O', 'NO']}
    Y0['N2'] = 1 - sum(Y0.values())  # Enforce mass conservation

    # Normalize initial state for NODE
    transformer = BoxCoxTransformer(lambda_val=0.1)
    species_order = ['NH3', 'O2', 'H2', 'OH', 'H2O', 'NO']
    Y0_transformed = transformer.transform(np.array([Y0[sp] for sp in species_order]).reshape(1, -1), species_order).flatten()
    T0_norm = T_init / max_temp  # Use max_temp from training

    # Initial state tensor
    x0 = torch.tensor(np.concatenate([Y0_transformed, [T0_norm]]), dtype=torch.float32).to(device)

    # Debug: Print shapes to verify
    print(f"Input shape to NODE: {x0.shape}")  # Should be (7,)

    # --------------------------------------------------------------------------
    # 3. Define NODE ODE system
    # --------------------------------------------------------------------------
    def node_system(t, x):
        """ODE system for NODE integration (fixed dimensionality)"""
        # Ensure input is 2D [batch=1, features=7]
        x_in = x.unsqueeze(0) if x.dim() == 1 else x  # Shape: [1, 7]

        # Denormalize input
        Y_transformed = x_in[0, :-1].cpu().numpy()  # First (only) batch, exclude T
        T_norm = x_in[0, -1].item()

        Y = transformer.inverse_transform(Y_transformed.reshape(1, -1), species_order).flatten()
        T = T_norm * max_temp

        print(f"Called at t={t:.3e}s, T={T:.1f}K")

        # Hybrid mode: Use Cantera below threshold
        if T < 1400:
            gas.TPY = T, ct.one_atm, dict(zip(species_order, Y))
            wdot = gas.net_production_rates * gas.molecular_weights / gas.density
            dYdt_phys = {sp: wdot[gas.species_index(sp)] for sp in species_order}
        else:
            # Select model group dynamically
            group = 'high' if T >= 1900 else 'medium' if T >= 1500 else 'low'

            # Predict dY/dt for each species
            dYdt_phys = {}
            for i, sp in enumerate(species_order):
                model = node_models[f"{sp}_{group}"]
                pred = model(None, x_in)
                dYdt_phys[sp] = pred.item() * model.dYdt_std + model.dYdt_mean

        # Update temperature (using thermodynamics)
        dTdt = update_temperature(Y = dict(zip(species_order, Y)), T = T, dYdt = dYdt_phys, species_list = species_order)

        # Return derivatives (now with dT/dt)
        dYdt_transformed = (np.array([dYdt_phys[sp] for sp in species_order]) - model.dYdt_mean) / model.dYdt_std

        dTdt_transformed = dTdt / max_temp    # Normalized derivative

        # Output must match input shape [7]
        return torch.tensor(np.concatenate([dYdt_transformed, [dTdt_transformed]]), dtype=torch.float32).to(device)

    # --------------------------------------------------------------------------
    # 4. Run simulations
    # --------------------------------------------------------------------------
    # Integrate NODE
    t_eval_tensor = torch.tensor(t_eval, dtype=torch.float32).to(device)
    node_sol = odeint(node_system, x0, t_eval_tensor, method="scipy_solver", rtol=1e-6, atol=1e-8, options={"method": "BDF", "min_step": 1e-10})
    # print(f"Solver stats: {node_sol.stats}")
    print(node_sol)
    print(f"node_sol shape: {node_sol.shape}")

    # Run Cantera
    for t in t_eval:
        sim.advance(t)
        cantera_results['time'].append(t)
        cantera_results['T'].append(reactor.T)
        for sp in ['NH3', 'O2', 'H2O', 'N2']:
            cantera_results['Y'][sp].append(gas[sp].Y[0])

    # Process NODE results with N2 reconstruction
    node_results = {'time': t_eval, 'T': [], 'Y': {sp: [] for sp in ['NH3', 'O2', 'H2O', 'N2']}}
    for i, t in enumerate(t_eval):
        Y_transformed = node_sol[i, :-1].cpu().numpy()
        Y = transformer.inverse_transform(Y_transformed.reshape(1, -1), species_order).flatten()
        node_results['T'].append(node_sol[i, -1].item() * max_temp)
        node_results['Y']['NH3'].append(Y[0])
        node_results['Y']['O2'].append(Y[1])
        node_results['Y']['H2O'].append(Y[4])       # H2O is 5th in species order
        node_results['Y']['N2'].append(1 - sum(Y))  # Mass conservation

    # --------------------------------------------------------------------------
    # 5. Plot comparison
    # --------------------------------------------------------------------------
    plt.figure(figsize=(10, 8))

    # Temperature plot
    plt.subplot(2, 1, 1)
    plt.plot(cantera_results['time'], cantera_results['T'], 'r--', label='Cantera')
    plt.plot(node_results['time'], node_results['T'], 'b-', label='NODE')
    plt.ylabel('Temperature (K)')
    plt.xscale('log')
    plt.legend()
    plt.title(f'NH3 Combustion (T={T_init}K, ϕ={phi})')

    # Mass fractions plot
    plt.subplot(2, 1, 2)
    for sp in ['NH3', 'O2', 'H2O', 'N2']:
        plt.plot(cantera_results['time'], cantera_results['Y'][sp], '--',
                label=f'Cantera {sp}')
        plt.plot(node_results['time'], node_results['Y'][sp], '-',
                label=f'NODE {sp}')
    plt.ylabel('Mass Fraction')
    plt.xlabel('Time (s)')
    plt.xscale('log')
    plt.legend()

    plt.tight_layout()
    plt.show()


# Example Usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
node_models = torch.load('grouped_ammonia_models.pth', weights_only=False)  # Loads the trained models
validate_node_vs_cantera(node_models, T_init=1300, phi=1.0)

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

In [None]:
def extract_nasa_coefficients(mechanism_path, species_list):
    """
    Extract NASA coefficients for specified species from a Cantera mechanism file.

    Args:
        mechanism_path: Path to .yaml/.cti mechanism file
        species_list: List of species names (e.g., ['NH3','NO','OH','N2'])

    Returns:
        Dictionary of NASA coefficients in your preferred format
    """
    gas = ct.Solution(mechanism_path)
    nasa_data = {}

    for species in species_list:
        try:
            thermo = gas.species(species).thermo
            if not isinstance(thermo, ct.NasaPoly2):
                print(f"Warning: {species} uses {type(thermo).__name__} thermo model (not NASA)")
                continue

            nasa_data[species] = {
                'coeffs': list(thermo.low_coeffs) + list(thermo.high_coeffs),
                'T_low': thermo.min_temp,
                'T_high': thermo.max_temp,
                'Mw': gas.molecular_weights[gas.species_index(species)]
            }

        except Exception as e:
            print(f"Error processing {species}: {str(e)}")

    return nasa_data

# Usage example:
mechanism_path = '/content/sample_data/okafor.yaml'  # Your mechanism file
species_of_interest = ['NH3', 'OH', 'NO', 'N2', 'H2', 'O2', 'H2O', 'N', 'NH', 'NH2', 'N2O', 'NO2']

NASA_COEFFS = extract_nasa_coefficients(mechanism_path, species_of_interest)    # Instead of manually writing the values