In [None]:
# ----------------------------------------------------------
## Import Libraries
# ----------------------------------------------------------

import os
import sys
from pathlib import Path
import configparser

import numpy as np
import pandas as pd
from scipy.stats import norm
from scipy.stats import gaussian_kde
from scipy.integrate import solve_ivp

import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns

# -----------------------------
## File Paths
# -----------------------------

# Current Directory
current_dir = Path.cwd()
# Top Directory (for .py it is same as current_dir, for .ipynb it is one level up)
    # Sometimes use .parent ( for ipynb and .py in oracle cloud), someimes use nothing (.py on windows)
top_dir = current_dir.parent if current_dir.name == '6-ResultAnalysis' else current_dir


# Libs Directory
libs_dir = str(top_dir / "0-Libs")
# Config
config_dir = top_dir / "0-Config"

# Data Directory
data_dir = top_dir / "0-Data"   
    # HighRes (0)
highres_dir = str(data_dir / "0-HighRes")
    # Routine (1)
routine_dir = str(data_dir / "1-Routine")
    # Active (2)
active_dir = str(data_dir / "2-Active")
    # LongActive (3)
long_active_dir = str(data_dir / "3-LongActive")
    # LongRoutine (4)
long_routine_dir = str(data_dir / "4-LongRoutine")

# Sensitivity Results Directory
sensitivity_results_dir = top_dir / "2-Sensitivity" / "Results"

# Results Directory
trace_file_path = str(top_dir / "5-Sampling" / "Results" / "trace.nc")
# -----------------------------
## Import Libraries - Custom
# -----------------------------
sys.path.append(libs_dir)

from plant_config import get_reactor_initial_values
from asm3_model import ode_system_wrapper
from asm3_model_sunode import ode_system_sunode
from Trace_Plot_functions import plot_data
from Trace_Plot_functions import sim_all_states
from Trace_Plot_functions import plot_all_inference_lines

print("Libraries Imported")

# Config File
config = configparser.ConfigParser()
config.read(config_dir / "config.ini")
   # Seed for random number generator
seed = int(config['OVERALL']['seed'])        # Random seed
np.random.seed(seed)                         # Set random seed
data_to_use_for_run = str(config['OVERALL']['data_to_use_for_run'])
    # Reactor volumes
r1_V = float(config['REACTOR']['r1_V'])        # Volume of reactor 1
    # Parameter Ranges
range_param_k_H = tuple(map(float, config['PARAMRANGES']['range_param_k_H'].split(',')))              # Range of k_H
range_param_K_X = tuple(map(float, config['PARAMRANGES']['range_param_K_X'].split(',')))              # Range of K_X
range_param_k_STO = tuple(map(float, config['PARAMRANGES']['range_param_small_k_STO'].split(',')))          # Range of k_STO
range_param_eta_NOX = tuple(map(float, config['PARAMRANGES']['range_param_eta_NOX'].split(',')))      # Range of eta_NOX
range_param_K_O2 = tuple(map(float, config['PARAMRANGES']['range_param_K_O2'].split(',')))            # Range of K_O2
range_param_K_NOX = tuple(map(float, config['PARAMRANGES']['range_param_K_NOX'].split(',')))          # Range of K_NOX
range_param_K_S = tuple(map(float, config['PARAMRANGES']['range_param_K_S'].split(',')))              # Range of K_S
range_param_K_STO = tuple(map(float, config['PARAMRANGES']['range_param_big_K_STO'].split(',')))          # Range of K_STO
range_param_mu_H = tuple(map(float, config['PARAMRANGES']['range_param_mu_H'].split(',')))            # Range of mu_H
range_param_K_NH4 = tuple(map(float, config['PARAMRANGES']['range_param_K_NH4'].split(',')))          # Range of K_NH4
range_param_K_ALK = tuple(map(float, config['PARAMRANGES']['range_param_K_ALK'].split(',')))          # Range of K_ALK
range_param_b_H_O2 = tuple(map(float, config['PARAMRANGES']['range_param_b_H_O2'].split(',')))        # Range of b_H_O2
range_param_b_H_NOX = tuple(map(float, config['PARAMRANGES']['range_param_b_H_NOX'].split(',')))      # Range of b_H_NOX
range_param_b_STO_O2 = tuple(map(float, config['PARAMRANGES']['range_param_b_STO_O2'].split(',')))    # Range of b_STO_O2
range_param_b_STO_NOX = tuple(map(float, config['PARAMRANGES']['range_param_b_STO_NOX'].split(',')))  # Range of b_STO_NOX
range_param_mu_A = tuple(map(float, config['PARAMRANGES']['range_param_mu_A'].split(',')))            # Range of mu_A
range_param_K_A_NH4 = tuple(map(float, config['PARAMRANGES']['range_param_K_A_NH4'].split(',')))      # Range of K_A_NH4
range_param_K_A_O2 = tuple(map(float, config['PARAMRANGES']['range_param_K_A_O2'].split(',')))        # Range of K_A_O2
range_param_K_A_ALK = tuple(map(float, config['PARAMRANGES']['range_param_K_A_ALK'].split(',')))      # Range of K_A_ALK
range_param_b_A_O2 = tuple(map(float, config['PARAMRANGES']['range_param_b_A_O2'].split(',')))        # Range of b_A_O2
range_param_b_A_NOX = tuple(map(float, config['PARAMRANGES']['range_param_b_A_NOX'].split(',')))      # Range of b_A_NOX
range_param_f_S_I = tuple(map(float, config['PARAMRANGES']['range_param_f_S_I'].split(',')))          # Range of f_S_I
range_param_Y_STO_O2 = tuple(map(float, config['PARAMRANGES']['range_param_Y_STO_O2'].split(',')))    # Range of Y_STO_O2
range_param_Y_STO_NOX = tuple(map(float, config['PARAMRANGES']['range_param_Y_STO_NOX'].split(',')))  # Range of Y_STO_NOX
range_param_Y_H_O2 = tuple(map(float, config['PARAMRANGES']['range_param_Y_H_O2'].split(',')))        # Range of Y_H_O2
range_param_Y_H_NOX = tuple(map(float, config['PARAMRANGES']['range_param_Y_H_NOX'].split(',')))      # Range of Y_H_NOX
range_param_Y_A = tuple(map(float, config['PARAMRANGES']['range_param_Y_A'].split(',')))              # Range of Y_A
range_param_f_X_I = tuple(map(float, config['PARAMRANGES']['range_param_f_X_I'].split(',')))          # Range of f_X_I
range_param_i_N_S_I = tuple(map(float, config['PARAMRANGES']['range_param_i_N_S_I'].split(',')))      # Range of i_N_S_I
range_param_i_N_S_S = tuple(map(float, config['PARAMRANGES']['range_param_i_N_S_S'].split(',')))      # Range of i_N_S_S
range_param_i_N_X_I = tuple(map(float, config['PARAMRANGES']['range_param_i_N_X_I'].split(',')))      # Range of i_N_X_I
range_param_i_N_X_S = tuple(map(float, config['PARAMRANGES']['range_param_i_N_X_S'].split(',')))      # Range of i_N_X_S
range_param_i_N_BM = tuple(map(float, config['PARAMRANGES']['range_param_i_N_BM'].split(',')))        # Range of i_N_BM
range_param_i_SS_X_I = tuple(map(float, config['PARAMRANGES']['range_param_i_SS_X_I'].split(',')))    # Range of i_SS_X_I
range_param_i_SS_X_S = tuple(map(float, config['PARAMRANGES']['range_param_i_SS_X_S'].split(',')))    # Range of i_SS_X_S
range_param_i_SS_BM = tuple(map(float, config['PARAMRANGES']['range_param_i_SS_BM'].split(',')))      # Range of i_SS_BM

# Dictionary with minimum and maximum values for each parameter
theta_ranges = {
    'k_H': range_param_k_H,
    'K_X': range_param_K_X,
    'k_STO': range_param_k_STO,
    'eta_NOX': range_param_eta_NOX,
    'K_O2': range_param_K_O2,
    'K_NOX': range_param_K_NOX,
    'K_S': range_param_K_S,
    'K_STO': range_param_K_STO,
    'mu_H': range_param_mu_H,
    'K_NH4': range_param_K_NH4,
    'K_ALK': range_param_K_ALK,
    'b_H_O2': range_param_b_H_O2,
    'b_H_NOX': range_param_b_H_NOX,
    'b_STO_O2': range_param_b_STO_O2,
    'b_STO_NOX': range_param_b_STO_NOX,
    'mu_A': range_param_mu_A,
    'K_A_NH4': range_param_K_A_NH4,
    'K_A_O2': range_param_K_A_O2,
    'K_A_ALK': range_param_K_A_ALK,
    'b_A_O2': range_param_b_A_O2,
    'b_A_NOX': range_param_b_A_NOX,
    'f_S_I': range_param_f_S_I,
    'Y_STO_O2': range_param_Y_STO_O2,
    'Y_STO_NOX': range_param_Y_STO_NOX,
    'Y_H_O2': range_param_Y_H_O2,
    'Y_H_NOX': range_param_Y_H_NOX,
    'Y_A': range_param_Y_A,
    'f_X_I': range_param_f_X_I,
    'i_N_S_I': range_param_i_N_S_I,
    'i_N_S_S': range_param_i_N_S_S,
    'i_N_X_I': range_param_i_N_X_I,
    'i_N_X_S': range_param_i_N_X_S,
    'i_N_BM': range_param_i_N_BM,
    'i_SS_X_I': range_param_i_SS_X_I,
    'i_SS_X_S': range_param_i_SS_X_S,
    'i_SS_BM': range_param_i_SS_BM,
}

    # True theta params
true_param_k_H = float(config['TRUEPARAMS']['true_param_k_H'])              # True k_H
true_param_K_X = float(config['TRUEPARAMS']['true_param_K_X'])              # True K_X
true_param_k_STO = float(config['TRUEPARAMS']['true_param_small_k_STO'])          # True k_STO
true_param_eta_NOX = float(config['TRUEPARAMS']['true_param_eta_NOX'])      # True eta_NOX
true_param_K_O2 = float(config['TRUEPARAMS']['true_param_K_O2'])            # True K_O2
true_param_K_NOX = float(config['TRUEPARAMS']['true_param_K_NOX'])          # True K_NOX
true_param_K_S = float(config['TRUEPARAMS']['true_param_K_S'])              # True K_S
true_param_K_STO = float(config['TRUEPARAMS']['true_param_big_K_STO'])          # True K_STO
true_param_mu_H = float(config['TRUEPARAMS']['true_param_mu_H'])            # True mu_H
true_param_K_NH4 = float(config['TRUEPARAMS']['true_param_K_NH4'])          # True K_NH4
true_param_K_ALK = float(config['TRUEPARAMS']['true_param_K_ALK'])          # True K_ALK
true_param_b_H_O2 = float(config['TRUEPARAMS']['true_param_b_H_O2'])        # True b_H_O2
true_param_b_H_NOX = float(config['TRUEPARAMS']['true_param_b_H_NOX'])      # True b_H_NOX
true_param_b_STO_O2 = float(config['TRUEPARAMS']['true_param_b_STO_O2'])    # True b_STO_O2
true_param_b_STO_NOX = float(config['TRUEPARAMS']['true_param_b_STO_NOX'])  # True b_STO_NOX
true_param_mu_A = float(config['TRUEPARAMS']['true_param_mu_A'])            # True mu_A
true_param_K_A_NH4 = float(config['TRUEPARAMS']['true_param_K_A_NH4'])      # True K_A_NH4
true_param_K_A_O2 = float(config['TRUEPARAMS']['true_param_K_A_O2'])        # True K_A_O2
true_param_K_A_ALK = float(config['TRUEPARAMS']['true_param_K_A_ALK'])      # True K_A_ALK
true_param_b_A_O2 = float(config['TRUEPARAMS']['true_param_b_A_O2'])        # True b_A_O2
true_param_b_A_NOX = float(config['TRUEPARAMS']['true_param_b_A_NOX'])      # True b_A_NOX
true_param_f_S_I = float(config['TRUEPARAMS']['true_param_f_S_I'])          # True f_S_I
true_param_Y_STO_O2 = float(config['TRUEPARAMS']['true_param_Y_STO_O2'])    # True Y_STO_O2
true_param_Y_STO_NOX = float(config['TRUEPARAMS']['true_param_Y_STO_NOX'])  # True Y_STO_NOX
true_param_Y_H_O2 = float(config['TRUEPARAMS']['true_param_Y_H_O2'])        # True Y_H_O2
true_param_Y_H_NOX = float(config['TRUEPARAMS']['true_param_Y_H_NOX'])      # True Y_H_NOX
true_param_Y_A = float(config['TRUEPARAMS']['true_param_Y_A'])              # True Y_A
true_param_f_X_I = float(config['TRUEPARAMS']['true_param_f_X_I'])          # True f_X_I
true_param_i_N_S_I = float(config['TRUEPARAMS']['true_param_i_N_S_I'])      # True i_N_S_I
true_param_i_N_S_S = float(config['TRUEPARAMS']['true_param_i_N_S_S'])      # True i_N_S_S
true_param_i_N_X_I = float(config['TRUEPARAMS']['true_param_i_N_X_I'])      # True i_N_X_I
true_param_i_N_X_S = float(config['TRUEPARAMS']['true_param_i_N_X_S'])      # True i_N_X_S
true_param_i_N_BM = float(config['TRUEPARAMS']['true_param_i_N_BM'])        # True i_N_BM
true_param_i_SS_X_I = float(config['TRUEPARAMS']['true_param_i_SS_X_I'])    # True i_SS_X_I
true_param_i_SS_X_S = float(config['TRUEPARAMS']['true_param_i_SS_X_S'])    # True i_SS_X_S
true_param_i_SS_BM = float(config['TRUEPARAMS']['true_param_i_SS_BM'])      # True i_SS_BM

true_theta = {
    'k_H': true_param_k_H,
    'K_X': true_param_K_X,
    'k_STO': true_param_k_STO,
    'eta_NOX': true_param_eta_NOX,
    'K_O2': true_param_K_O2,
    'K_NOX': true_param_K_NOX,
    'K_S': true_param_K_S,
    'K_STO': true_param_K_STO,
    'mu_H': true_param_mu_H,
    'K_NH4': true_param_K_NH4,
    'K_ALK': true_param_K_ALK,
    'b_H_O2': true_param_b_H_O2,
    'b_H_NOX': true_param_b_H_NOX,
    'b_STO_O2': true_param_b_STO_O2,
    'b_STO_NOX': true_param_b_STO_NOX,
    'mu_A': true_param_mu_A,
    'K_A_NH4': true_param_K_A_NH4,
    'K_A_O2': true_param_K_A_O2,
    'K_A_ALK': true_param_K_A_ALK,
    'b_A_O2': true_param_b_A_O2,
    'b_A_NOX': true_param_b_A_NOX,
    'f_S_I': true_param_f_S_I,
    'Y_STO_O2': true_param_Y_STO_O2,
    'Y_STO_NOX': true_param_Y_STO_NOX,
    'Y_H_O2': true_param_Y_H_O2,
    'Y_H_NOX': true_param_Y_H_NOX,
    'Y_A': true_param_Y_A,
    'f_X_I': true_param_f_X_I,
    'i_N_S_I': true_param_i_N_S_I,
    'i_N_S_S': true_param_i_N_S_S,
    'i_N_X_I': true_param_i_N_X_I,
    'i_N_X_S': true_param_i_N_X_S,
    'i_N_BM': true_param_i_N_BM,
    'i_SS_X_I': true_param_i_SS_X_I,
    'i_SS_X_S': true_param_i_SS_X_S,
    'i_SS_BM': true_param_i_SS_BM
}
true_theta_array = np.array(list(true_theta.values()))

    # From identifiability
NAASI_threshold = float(config['IDENTIFIABILITY']['NAASI_threshold'])  # NAASI threshold for identifiability
    # Sampling
solver_method = str(config['SAMPLING']['solver_method'])
config_tuning_samples = int(config['SAMPLING']['tuning_samples'])
config_draw_samples = int(config['SAMPLING']['draw_samples'])
config_sample_chains = int(config['SAMPLING']['run_chains'])
config_sample_cores = int(config['SAMPLING']['run_cores'])

# ----------------------------------------------------------
## Load Data from csv
# ----------------------------------------------------------

# Highres (0)
data_highres_influent_states = pd.read_csv(highres_dir + "/HighRes_Influent_States.csv")
data_highres_effluent_states = pd.read_csv(highres_dir + "/HighRes_Effluent_States.csv")
data_highres_influent_compounds = pd.read_csv(highres_dir + "/HighRes_Influent_Compounds.csv")
data_highres_effluent_compounds = pd.read_csv(highres_dir + "/HighRes_Effluent_Compounds.csv")
# Routine (1)
data_routine_influent_states = pd.read_csv(routine_dir + "/Routine_Influent_States.csv")
data_routine_effluent_states = pd.read_csv(routine_dir + "/Routine_Effluent_States.csv")
data_routine_influent_compounds = pd.read_csv(routine_dir + "/Routine_Influent_Compounds.csv")
data_routine_effluent_compounds = pd.read_csv(routine_dir + "/Routine_Effluent_Compounds.csv")
# Active (2)
data_active_influent_states = pd.read_csv(active_dir + "/Active_Influent_States.csv")
data_active_effluent_states = pd.read_csv(active_dir + "/Active_Effluent_States.csv")
data_active_influent_compounds = pd.read_csv(active_dir + "/Active_Influent_Compounds.csv")
data_active_effluent_compounds = pd.read_csv(active_dir + "/Active_Effluent_Compounds.csv")
# LongActive (3)
data_longactive_influent_states = pd.read_csv(long_active_dir + "/LongActive_Influent_States.csv")
data_longactive_effluent_states = pd.read_csv(long_active_dir + "/LongActive_Effluent_States.csv")
data_longactive_influent_compounds = pd.read_csv(long_active_dir + "/LongActive_Influent_Compounds.csv")
data_longactive_effluent_compounds = pd.read_csv(long_active_dir + "/LongActive_Effluent_Compounds.csv")
# LongRoutine (4)
data_longroutine_influent_states = pd.read_csv(long_routine_dir + "/LongRoutine_Influent_States.csv")
data_longroutine_effluent_states = pd.read_csv(long_routine_dir + "/LongRoutine_Effluent_States.csv")
data_longroutine_influent_compounds = pd.read_csv(long_routine_dir + "/LongRoutine_Influent_Compounds.csv")
data_longroutine_effluent_compounds = pd.read_csv(long_routine_dir + "/LongRoutine_Effluent_Compounds.csv")

# ----------------------------------------------------------
## Data to use for identifiability
# ----------------------------------------------------------

data_mapping_influent_states = {
    'HighRes': data_highres_influent_states,
    'Routine': data_routine_influent_states,
    'Active': data_active_influent_states,
    'LongActive': data_longactive_influent_states,
    'LongRoutine': data_longroutine_influent_states,
}
data_mapping_effluent_states = {
    'HighRes': data_highres_effluent_states,
    'Routine': data_routine_effluent_states,
    'Active': data_active_effluent_states,
    'LongActive': data_longactive_effluent_states,
    'LongRoutine': data_longroutine_effluent_states,
}
data_mapping_influent_compounds = {
    'HighRes': data_highres_influent_compounds,
    'Routine': data_routine_influent_compounds,
    'Active': data_active_influent_compounds,
    'LongActive': data_longactive_influent_compounds,
    'LongRoutine': data_longroutine_influent_compounds,
}
data_mapping_effluent_compounds = {
    'HighRes': data_highres_effluent_compounds,
    'Routine': data_routine_effluent_compounds,
    'Active': data_active_effluent_compounds,
    'LongActive': data_longactive_effluent_compounds,
    'LongRoutine': data_longroutine_effluent_compounds,
}

try:
    Data_Influent_states = data_mapping_influent_states[data_to_use_for_run]
    Data_Effluent_states = data_mapping_effluent_states[data_to_use_for_run]
    Data_Influent_compounds = data_mapping_influent_compounds[data_to_use_for_run]
    Data_Effluent_compounds = data_mapping_effluent_compounds[data_to_use_for_run]
except KeyError:
    raise ValueError("Invalid data for sanpling. Choose from HighRes, Routine, LongRoutine, Active, or LongActive.")


# ----------------------------------------------------------
print("Config and Data loaded")

# Arviz rcParams
plt.rcParams['figure.constrained_layout.use'] = True


In [None]:
# ----------------------------------------------------------
## Plot trace from Main_Runtime_Python_Sunode.py
# ----------------------------------------------------------

if os.path.exists(trace_file_path):
    trace = az.from_netcdf(trace_file_path)
else:
    raise Exception(f"File {trace_file_path} does not exist. Please run MCMC_Sampling_scipy or MCMC_Sampling_sunode first.")

# ----------------------------------------------------------
## Trace Summary
# ----------------------------------------------------------

if __name__ == '__main__':
    print("Summary of the trace")
    print("------------------------------------")
    print(az.summary(trace))
    print("------------------------------------")


# All parameters
"""
    k_H,        # 0
    K_X,        # 1
    k_STO,      # 2
    eta_NOX,    # 3
    K_O2,       # 4
    K_NOX,      # 5
    K_S,        # 6
    K_STO,      # 7
    mu_H,       # 8
    K_NH4,      # 9
    K_ALK,      # 10
    b_H_O2,     # 11
    b_H_NOX,    # 12
    b_STO_O2,   # 13
    b_STO_NOX,  # 14
    mu_A,       # 15
    K_A_NH4,    # 16
    K_A_O2,     # 17
    K_A_ALK,    # 18
    b_A_O2,     # 19
    b_A_NOX,    # 20
    f_S_I,      # 21
    Y_STO_O2,   # 22
    Y_STO_NOX,  # 23
    Y_H_O2,     # 24
    Y_H_NOX,    # 25
    Y_A,        # 26
    f_X_I,      # 27
    i_N_S_I,    # 28
    i_N_S_S,    # 29
    i_N_X_I,    # 30
    i_N_X_S,    # 31
    i_N_BM,     # 32
    i_SS_X_I,   # 33
    i_SS_X_S,   # 34
    i_SS_BM     # 35
    sigma_COD,  # 36
    sigma_NH4,  # 37
    sigma_NOx,  # 38
    sigma_TKN,  # 39
    sigma_Alkalinity,  # 40
    sigma_TSS   # 41
"""


all_param_idx = {
    'k_H': 0,
    'K_X': 1,
    'k_STO': 2,
    'eta_NOX': 3,
    'K_O2': 4,
    'K_NOX': 5,
    'K_S': 6,
    'K_STO': 7,
    'mu_H': 8,
    'K_NH4': 9,
    'K_ALK': 10,
    'b_H_O2': 11,
    'b_H_NOX': 12,
    'b_STO_O2': 13,
    'b_STO_NOX': 14,
    'mu_A': 15,
    'K_A_NH4': 16,
    'K_A_O2': 17,
    'K_A_ALK': 18,
    'b_A_O2': 19,
    'b_A_NOX': 20,
    'f_S_I': 21,
    'Y_STO_O2': 22,
    'Y_STO_NOX': 23,
    'Y_H_O2': 24,
    'Y_H_NOX': 25,
    'Y_A': 26,
    'f_X_I': 27,
    'i_N_S_I': 28,
    'i_N_S_S': 29,
    'i_N_X_I': 30,
    'i_N_X_S': 31,
    'i_N_BM': 32,
    'i_SS_X_I': 33,
    'i_SS_X_S': 34,
    'i_SS_BM': 35
}

theta_format_names = {
    'k_H': r'$k_H$',
    'K_X': r'$K_X$',
    'k_STO': r'$k_{STO}$',
    'eta_NOX': r'$\eta_{NOX}$',
    'K_O2': r'$K_{O2}$',
    'K_NOX': r'$K_{NOX}$',
    'K_S': r'$K_S$',
    'K_STO': r'$K_{STO}$',
    'mu_H': r'$\mu_H$',
    'K_NH4': r'$K_{NH4}$',
    'K_ALK': r'$K_{ALK}$',
    'b_H_O2': r'$b_{H,O2}$',
    'b_H_NOX': r'$b_{H,NOX}$',
    'b_STO_O2': r'$b_{STO,O2}$',
    'b_STO_NOX': r'$b_{STO,NOX}$',
    'mu_A': r'$\mu_A$',
    'K_A_NH4': r'$K_{A,NH4}$',
    'K_A_O2': r'$K_{A,O2}$',
    'K_A_ALK': r'$K_{A,ALK}$',
    'b_A_O2': r'$b_{A,O2}$',
    'b_A_NOX': r'$b_{A,NOX}$',
    'f_S_I': r'$f_{S,I}$',
    'Y_STO_O2': r'$Y_{STO,O2}$',
    'Y_STO_NOX': r'$Y_{STO,NOX}$',
    'Y_H_O2': r'$Y_{H,O2}$',
    'Y_H_NOX': r'$Y_{H,NOX}$',
    'Y_A': r'$Y_A$',
    'f_X_I': r'$f_{X,I}$',
    'i_N_S_I': r'$i_{N,S,I}$',
    'i_N_S_S': r'$i_{N,S,S}$',
    'i_N_X_I': r'$i_{N,X,I}$',
    'i_N_X_S': r'$i_{N,X,S}$',
    'i_N_BM': r'$i_{N,BM}$',
    'i_SS_X_I': r'$i_{SS,X,I}$',
    'i_SS_X_S': r'$i_{SS,X,S}$',
    'i_SS_BM': r'$i_{SS,BM}$'
}

In [None]:
## Plot Trace

"""    
    Fixed Parameters,Identifiable Parameters,Non-Identifiable Parameters,Special Case Parameters
    i_SS_BM,    b_H_O2,     eta_NOX,    i_N_BM
    i_SS_X_I,   b_STO_O2,   K_NOX,      K_A_ALK
    i_SS_X_S,   k_H,        K_O2,       K_A_O2
    f_S_I,      K_S,        k_STO,
    i_N_S_I,    K_X,        mu_A,
    K_NH4,      Y_A,        Y_H_NOX,
    i_N_X_I,    Y_H_O2,     Y_STO_O2,
    i_N_S_S,    Y_STO_NOX,  K_A_NH4,
    K_ALK,      mu_H,,
    K_STO,,,
    b_STO_NOX,,,
    b_A_NOX,,,
    b_H_NOX,,,
    f_X_I,,,
    b_A_O2,,,
    i_N_X_S,,,

    # Specify removed as included priors
    k_STO = pymc.TruncatedNormal("k_STO", mu=priors['k_STO'][0], sigma=priors['k_STO'][1], initval=priors['k_STO'][0], lower=0)
    mu_A = pymc.TruncatedNormal("mu_A", mu=priors['mu_A'][0], sigma=priors['mu_A'][1], initval=priors['mu_A'][0], lower=0)
    Y_H_NOX = pymc.TruncatedNormal("Y_H_NOX", mu=priors['Y_H_NOX'][0], sigma=priors['Y_H_NOX'][1], initval=priors['Y_H_NOX'][0], lower=0)

    K_NOX = pymc.Deterministic("K_NOX", 2.40490 - 0.33060*Y_H_NOX)
    eta_NOX = pymc.Deterministic("eta_NOX", 2.50301 + 0.03954*K_NOX - 1.51210*mu_A)
    K_A_NH4 = pymc.Deterministic("K_A_NH4", 359.75339 - 30.21659*k_STO + 3.73381*mu_A + 5.63579*Y_H_NOX)
    Y_STO_O2 = pymc.Deterministic("Y_STO_O2", -12.43310 - 13.91973*eta_NOX + 5.03741*k_STO - 18.46277*mu_A - 23.01354*Y_H_NOX - 0.06531*K_A_NH4)
    K_O2 = pymc.Deterministic("K_O2", -2209036.39118 - 342511.10302*eta_NOX + 238097.67369*K_NOX + 84906.11946*k_STO + 904827.76299*mu_A - 507888.47528*Y_H_NOX + 87835.96620*Y_STO_O2 + 232823.30551*K_A_NH4) 
    
"""

params_to_extract = [
    'b_H_O2',
    'eta_NOX',
    'b_STO_O2',
    'K_NOX',
    'k_H',
    'K_O2',
    'K_S',
    'k_STO',
    'K_X',
    'mu_A',
    'Y_A',
    'Y_H_NOX',
    'Y_H_O2',
    'Y_STO_O2',
    'Y_STO_NOX',
    'K_A_NH4',
    'mu_H'
]


# params_to_extract = [
#     'k_H',
#     'K_X',
#     'k_STO',
#     'eta_NOX',
#     'K_O2',
#     'K_NOX',
#     'K_S',
#     'K_STO',
#     'mu_H',
#     'K_NH4',
#     'K_ALK',
#     'b_H_O2',
#     'b_H_NOX',
#     'b_STO_O2',
#     'b_STO_NOX',
#     'mu_A',
#     'K_A_NH4',
#     'K_A_O2',
#     'K_A_ALK',
#     'b_A_O2',
#     'b_A_NOX',
#     'f_S_I',
#     'Y_STO_O2',
#     'Y_STO_NOX',
#     'Y_H_O2',
#     'Y_H_NOX',
#     'Y_A',
#     'f_X_I',
#     'i_N_S_I',
#     'i_N_S_S',
#     'i_N_X_I',
#     'i_N_X_S',
#     'i_N_BM',
# ]

# Trace Plots of Parameters

In [None]:
# ----------------------------------------------------------
## Plot Trace - Trace
# ----------------------------------------------------------

for i, var_name in enumerate(params_to_extract):
    # Create plot
    lines = [(var_name, {}, [true_theta[var_name]])] if var_name in true_theta else None # True values
    ax = az.plot_trace(data=trace, var_names=var_name, lines=lines, kind='trace', compact=True, combined=True)
    plt.suptitle(f"Trace Plot - {var_name}")
    plt.show()

In [None]:
# # ----------------------------------------------------------
# ## Plot Trace - Rank Bars
# # ----------------------------------------------------------

# for i, var_name in enumerate(params_to_extract):
#     # Create plot
#     lines = [(var_name, {}, [true_theta[var_name]])] if var_name in true_theta else None    # TODO: True value line
#     ax = az.plot_trace(data=trace, var_names=var_name, lines=lines, kind='rank_bars', compact=True, combined=True)   # TODO: Add lines to lines= for true value line 
#     plt.suptitle(f"Trace Plot - {var_name}")
#     plt.show()


In [None]:
# # ----------------------------------------------------------
# ## Plot Trace - Rank Vlines
# # ----------------------------------------------------------

# for i, var_name in enumerate(params_to_extract):
#     # Create plot
#     lines = [(var_name, {}, [true_theta[var_name]])] if var_name in true_theta else None    # TODO: True value line
#     ax = az.plot_trace(data=trace, var_names=var_name, lines=lines, kind='rank_vlines', compact=True, combined=True)   # TODO: Add lines to lines= for true value line 
#     plt.suptitle(f"Trace Plot - {var_name}")
#     plt.show()

In [None]:
# # ----------------------------------------------------------
# ## Plot Posterior - KDE
# # ----------------------------------------------------------

# for i, var_name in enumerate(params_to_extract):
#     # Create plot
#         # True value
#     # ref_val = theta_true[var_name] if var_name in theta_true else None
#     ax = az.plot_posterior(data=trace, var_names=var_name, kind='kde', hdi_prob='hide', point_estimate=None)
#     if var_name in true_theta:
#         ax.vlines(true_theta[var_name], *ax.get_ylim(), color='red', linestyle='--')
#     plt.legend(['Posterior', 'HDI', 'True',])
#     plt.suptitle(f"Trace Plot - {var_name}")
#     plt.show()

In [None]:
# # ----------------------------------------------------------
# ## Plot Posterior - Histogram
# # ----------------------------------------------------------

# for i, var_name in enumerate(params_to_extract):
#     # Create plot
#         # True value
#     # ref_val = theta_true[var_name] if var_name in theta_true else None
#     # Edit above line to exclude if the key does not exists
#     if var_name in true_theta:
#         # ax = az.plot_posterior(data=trace, var_names=var_name, kind='hist', hdi_prob=0.95, point_estimate=None, ref_val=true_theta[var_name], ref_val_color='red')
#         ax = az.plot_posterior(data=trace, var_names=var_name, kind='hist', hdi_prob=0.95, point_estimate=None)
#         ax.vlines(true_theta[var_name], *ax.get_ylim(), color='red', linestyle='--')
#         plt.legend(['95% HDI', 'True', 'Posterior'])
#     else:
#         ax = az.plot_posterior(data=trace, var_names=var_name, kind='hist', hdi_prob=0.95, point_estimate=None)
#         plt.legend(['95% HDI', 'Posterior'])
#         # ax.vlines(true_theta[var_name], *ax.get_ylim(), color='red', linestyle='--', label='True')
#     plt.suptitle(f"Trace Plot - {var_name}")
#     plt.show()


In [None]:
## Get parameter samples from trace (and any other parameters that might be fixed)

## Trace samples
# Number of inference lines to plot
num_inference_lines = 500
# Get (num_inference_lines) random samples from the trace
trace_samples_df = az.extract(trace, num_samples=num_inference_lines, var_names=params_to_extract).to_dataframe()

# Remove 'chain', and 'draw' columns
# trace_samples_df = trace_samples_df.drop(columns=['sigma_COD', 'sigma_NH4', 'sigma_NOx', 'sigma_TKN', 'sigma_Alkalinity', 'sigma_TSS', 'chain', 'draw'], errors='ignore')
trace_samples_df = trace_samples_df.drop(columns=['chain', 'draw'], errors='ignore')

# -----------------------------------------------------------
# Direct from '5-Sampling/MCMC_Sampling.py'
# -----------------------------------------------------------
alpha = 0.05
crit_value = norm.ppf(1 - (alpha / 2)) # For 95% CI
priors = {
    # Priors for parameters -- mean and std
    'k_H':          [np.mean(range_param_k_H),          (range_param_k_H[1] - range_param_k_H[0]) / (2 * crit_value)],    # (g COD_X_S) / (g COD_X_H * d), Hydrolysis rate constant
    'K_X':          [np.mean(range_param_K_X),          (range_param_K_X[1] - range_param_K_X[0]) / (2 * crit_value)],    # (g COD_X_S) / (g COD_X_H), Hydrolysis saturation constant
    'k_STO':        [np.mean(range_param_k_STO),        (range_param_k_STO[1] - range_param_k_STO[0]) / (2 * crit_value)],    # (g COD_S_S) / (g COD_X_H * d), Storage rate constant
    'eta_NOX':      [np.mean(range_param_eta_NOX),      (range_param_eta_NOX[1] - range_param_eta_NOX[0]) / (2 * crit_value)],    # - , Anoxic reduction factor
    'K_O2':         [np.mean(range_param_K_O2),         (range_param_K_O2[1] - range_param_K_O2[0]) / (2 * crit_value)],    # (g O2 / m3), Saturation constant for S_NO2
    'K_NOX':        [np.mean(range_param_K_NOX),        (range_param_K_NOX[1] - range_param_K_NOX[0]) / (2 * crit_value)],    # (g (NO3-)-N / m3), Saturation constant for S_NOX
    'K_S':          [np.mean(range_param_K_S),          (range_param_K_S[1] - range_param_K_S[0]) / (2 * crit_value)],    # (g COD_S_S / m3), Saturation constant for Substrate S_S
    'K_STO':        [np.mean(range_param_K_STO),        (range_param_K_STO[1] - range_param_K_STO[0]) / (2 * crit_value)],    # (g COD_X_STO / g COD_X_H), Saturation constant for X_STO
    'mu_H':         [np.mean(range_param_mu_H),         (range_param_mu_H[1] - range_param_mu_H[0]) / (2 * crit_value)],    # (d^-1), Heterotrophic maximum specific growth rate of X_H
    'K_NH4':        [np.mean(range_param_K_NH4),        (range_param_K_NH4[1] - range_param_K_NH4[0]) / (2 * crit_value)],    # (g N / m3), Saturation constant for ammonium, S_NH4
    'K_ALK':        [np.mean(range_param_K_ALK),        (range_param_K_ALK[1] - range_param_K_ALK[0]) / (2 * crit_value)],    # (mole HCO3- / m3), Saturation constant for alkalinity for X_H
    'b_H_O2':       [np.mean(range_param_b_H_O2),       (range_param_b_H_O2[1] - range_param_b_H_O2[0]) / (2 * crit_value)],    # (d^-1), Aerobic endogenous respiration rate of X_H
    'b_H_NOX':      [np.mean(range_param_b_H_NOX),      (range_param_b_H_NOX[1] - range_param_b_H_NOX[0]) / (2 * crit_value)],    # (d^-1), Anoxic endogenous respiration rate of X_H
    'b_STO_O2':     [np.mean(range_param_b_STO_O2),     (range_param_b_STO_O2[1] - range_param_b_STO_O2[0]) / (2 * crit_value)],    # (d^-1), Aerobic endogenous respiration rate for X_STO
    'b_STO_NOX':    [np.mean(range_param_b_STO_NOX),    (range_param_b_STO_NOX[1] - range_param_b_STO_NOX[0]) / (2 * crit_value)],    # (d^-1), Anoxic endogenous respiration rate for X_STO
    'mu_A':         [np.mean(range_param_mu_A),         (range_param_mu_A[1] - range_param_mu_A[0]) / (2 * crit_value)],    # (d^-1), Autotrophic maximum specific growth rate of X_A
    'K_A_NH4':      [np.mean(range_param_K_A_NH4),      (range_param_K_A_NH4[1] - range_param_K_A_NH4[0]) / (2 * crit_value)],    # (g N / m3), Ammonium substrate saturation constant for X_A
    'K_A_O2':       [np.mean(range_param_K_A_O2),       (range_param_K_A_O2[1] - range_param_K_A_O2[0]) / (2 * crit_value)],    # (g O2 / m3), Oxygen saturation for nitrifiers
    'K_A_ALK':      [np.mean(range_param_K_A_ALK),      (range_param_K_A_ALK[1] - range_param_K_A_ALK[0]) / (2 * crit_value)],    # (mole HCO3- / m3), Bicarbonate saturation for nitrifiers
    'b_A_O2':       [np.mean(range_param_b_A_O2),       (range_param_b_A_O2[1] - range_param_b_A_O2[0]) / (2 * crit_value)],    # (d^-1), Aerobic endogenous respiration rate of X_A
    'b_A_NOX':      [np.mean(range_param_b_A_NOX),      (range_param_b_A_NOX[1] - range_param_b_A_NOX[0]) / (2 * crit_value)],    # (d^-1), Anoxic endogenous respiration rate of X_A
    'f_S_I':        [np.mean(range_param_f_S_I),        (range_param_f_S_I[1] - range_param_f_S_I[0]) / (2 * crit_value)],    # (g COD_S_I) / (g COD_X_s), Production of S_I in hydrolisis
    'Y_STO_O2':     [np.mean(range_param_Y_STO_O2),     (range_param_Y_STO_O2[1] - range_param_Y_STO_O2[0]) / (2 * crit_value)],    # (g COD_X_STO) / (g COD_S_S), Aerobic yield of stored product per S_S
    'Y_STO_NOX':    [np.mean(range_param_Y_STO_NOX),    (range_param_Y_STO_NOX[1] - range_param_Y_STO_NOX[0]) / (2 * crit_value)],    # (g COD_X_STO) / (g COD_S_S), Anoxic yield of stored product per S_S
    'Y_H_O2':       [np.mean(range_param_Y_H_O2),       (range_param_Y_H_O2[1] - range_param_Y_H_O2[0]) / (2 * crit_value)],    # (g COD_X_H) / (g COD_S_STO), Aerobic yield of heterotrophic biomass
    'Y_H_NOX':      [np.mean(range_param_Y_H_NOX),      (range_param_Y_H_NOX[1] - range_param_Y_H_NOX[0]) / (2 * crit_value)],    # (g COD_X_H) / (g COD_S_STO), Anoxic yield of heterotrophic biomass
    'Y_A':          [np.mean(range_param_Y_A),          (range_param_Y_A[1] - range_param_Y_A[0]) / (2 * crit_value)],    # (g COD_X_A) / (g N_S_NOX), Yield of autotrophic biomass per NO3-N
    'f_X_I':        [np.mean(range_param_f_X_I),        (range_param_f_X_I[1] - range_param_f_X_I[0]) / (2 * crit_value)],    # (g COD_X_I) / (g COD_X_BM), Production of X_I in endogenous repiration
    'i_N_S_I':      [np.mean(range_param_i_N_S_I),      (range_param_i_N_S_I[1] - range_param_i_N_S_I[0]) / (2 * crit_value)],    # (g N) / (g COD_S_I), N content of S_I
    'i_N_S_S':      [np.mean(range_param_i_N_S_S),      (range_param_i_N_S_S[1] - range_param_i_N_S_S[0]) / (2 * crit_value)],    # (g N) / (g COD_S_S), N content of S_S
    'i_N_X_I':      [np.mean(range_param_i_N_X_I),      (range_param_i_N_X_I[1] - range_param_i_N_X_I[0]) / (2 * crit_value)],    # (g N) / (g COD_X_I), N content of X_I
    'i_N_X_S':      [np.mean(range_param_i_N_X_S),      (range_param_i_N_X_S[1] - range_param_i_N_X_S[0]) / (2 * crit_value)],    # (g N) / (g COD_X_S), N content of X_S
    'i_N_BM':       [np.mean(range_param_i_N_BM),       (range_param_i_N_BM[1] - range_param_i_N_BM[0]) / (2 * crit_value)],    # (g N) / (g COD_X_BM), N content of biomass, X_H, X_A
    'i_SS_X_I':     [np.mean(range_param_i_SS_X_I),     (range_param_i_SS_X_I[1] - range_param_i_SS_X_I[0]) / (2 * crit_value)],    # (g SS) / (g COD_X_I), SS to COD ratio for X_I
    'i_SS_X_S':     [np.mean(range_param_i_SS_X_S),     (range_param_i_SS_X_S[1] - range_param_i_SS_X_S[0]) / (2 * crit_value)],    # (g SS) / (g COD_X_S), SS to COD ratio for X_S
    'i_SS_BM':      [np.mean(range_param_i_SS_BM),      (range_param_i_SS_BM[1] - range_param_i_SS_BM[0]) / (2 * crit_value)],    # (g SS) / (g COD_X_BM), SS to COD ratio for biomass, X_H, X_A
}


    # The following parameters are fixed to prior means:
        # i_SS_BM, i_SS_X_I, i_SS_X_S, f_S_I, i_N_S_I, K_NH4, 
        # i_N_X_I, i_N_S_S, K_ALK, K_STO, b_STO_NOX, b_A_NOX, 
        # b_H_NOX, f_X_I, b_A_O2, i_N_X_S
    # The following parameters are identifiable from the PLA graphs:
        # k_H, K_X, K_S, b_H_O2, b_STO_O2, Y_STO_NOX, Y_H_O2, Y_A
    # The others are non-identifiable: 
        # k_STO, eta_NOX, K_O2, K_NOX, mu_H, mu_A, K_A_NH4, Y_H_NOX,
    # MAYBE Special cases (for those going to 0):
        # K_A_O2, K_A_ALK, Y_STO_O2, i_N_BM


# -----------------------------------------------------------
## Sensitivity Analysis Results
# -----------------------------------------------------------

NAASI_results = pd.read_csv(sensitivity_results_dir / 'NAASI_combined.csv')
NAASI_results = NAASI_results.set_index('Parameter')
# Get parameters with NAASI values below threshold
params_to_always_fix = {
    param: priors[param][0] for param in NAASI_results.index if NAASI_results.loc[param, 'NAASI'] < NAASI_threshold
}
# Print the parameters to always fix based on NAASI results 
print(f'Parameters to always fix based on NAASI results (Param, value):')
for param, fixed_val in params_to_always_fix.items():
    print(f'    {param}:    {fixed_val}')

## Fixed parameters  - overwrite the normal distribution with fixed values
always_fixed_param_idxs = [all_param_idx[name] for name in params_to_always_fix.keys()] # Indices of parameters to always fix
# always_fixed_param_vals = [priors[name][0] for name in params_to_always_fix.keys()] # Values of parameters to always fix
    # TODO: MANUAL
always_fixed_param_vals = [true_theta[name] for name in params_to_always_fix.keys()] # Values of parameters to always fix

    # TODO: MANUAL
# Set always fixed params to empty
# always_fixed_param_idxs = []
# always_fixed_param_vals = []

# -------------------------------------------------------------
## Special Cases
# -------------------------------------------------------------
    # TODO: MANUAL
special_case_params = {
    'K_A_O2': 1e-6,
    'K_A_ALK': 1e-6,
    'i_N_BM': 1e-6,
}

In [None]:
# -------------------------------------------------------------
## Create theta_samples_df
# -------------------------------------------------------------

# Create new dataframe like trace_df but that has columns for each parameter in theta, in the order of priors
    # Each column is different theta param
    # Each row is value of theta param for each sample
        # Some come from trace_samples_df (those that were infered)
        # Some are fixed values for each sample (fixed_params)
# Create empty dataframe
theta_samples_df = pd.DataFrame(columns=priors.keys())
# Loop through each sample in trace_samples_df and add them to the new dataframe, as well as fixed values and correlated params
for i in range(len(trace_samples_df)):
    # Get sample
    sample = trace_samples_df.iloc[i]
    # Create new row
    new_row = {}
    # Loop through each parameter in priors and add them to the new row
    for param in all_param_idx.keys():
        if param in sample.index:
            new_row[param] = sample[param]
        elif param in params_to_always_fix:
            new_row[param] = params_to_always_fix[param]
        elif param in special_case_params:
            new_row[param] = special_case_params[param]
    # Add new row to dataframe
    theta_samples_df = pd.concat([theta_samples_df, pd.DataFrame(new_row, index=[0])], ignore_index=True)
# End of loop

In [None]:
# Create trace_samples_df for other datasets

# Get data for highres dataset
Data_Influent_states_HighRes = data_mapping_influent_states['HighRes']
Data_Influent_compounds_HighRes = data_mapping_influent_compounds['HighRes']
Data_Effluent_states_HighRes = data_mapping_effluent_states['HighRes']
Data_Effluent_compounds_HighRes = data_mapping_effluent_compounds['HighRes']

# Get data for Active dataset
Data_Influent_states_Active = data_mapping_influent_states['Active']
Data_Influent_compounds_Active = data_mapping_influent_compounds['Active']
Data_Effluent_states_Active = data_mapping_effluent_states['Active']
Data_Effluent_compounds_Active = data_mapping_effluent_compounds['Active']


# Get trace for highres
trace_file_path_highres = str(top_dir / "5-Sampling" / "Results" / "HighRes" / "trace.nc")
trace_file_path_active = str(top_dir / "5-Sampling" / "Results" / "Active" / "trace.nc")

if os.path.exists(trace_file_path_highres):
    trace_highres = az.from_netcdf(trace_file_path_highres)
if os.path.exists(trace_file_path_active):
    trace_active = az.from_netcdf(trace_file_path_active)

"""
HighRes:
Fixed Parameters,Identifiable Parameters,Non-Identifiable Parameters,Special Case Parameters
i_SS_BM,    k_H,K_X,
i_SS_X_I,   k_STO,K_NOX,
i_SS_X_S,   eta_NOX,Y_H_NOX,
f_S_I,     K_S,,
i_N_S_I,   mu_H,,
K_NH4,     b_H_O2,,
i_N_X_I,   b_STO_O2,,
i_N_S_S,   mu_A,,
K_ALK,      K_A_NH4,,
K_STO,      K_A_O2,,
b_A_NOX,    K_A_ALK,,
i_N_X_S,    Y_STO_O2,,
b_H_NOX,    Y_STO_NOX,,
b_STO_NOX,  Y_H_O2,,
f_X_I,     Y_A,,
b_A_O2,   ,,
i_N_BM,   ,,
K_O2,    ,,

"""
# params_to_extract_highres = [
#     'k_H',
#     'K_X',
#     'k_STO',
#     'K_NOX',
#     'eta_NOX',
#     'Y_H_NOX',
#     'K_S',
#     'mu_H',
#     'b_H_O2',
#     'b_STO_O2',
#     'mu_A',
#     'K_A_NH4',
#     'K_A_O2',
#     'K_A_ALK',
#     'Y_STO_O2',
#     'Y_STO_NOX',
#     'Y_H_O2',
#     'Y_A',
# ]

params_to_extract_highres = priors.keys()
# remove 'i_SS_BM', 'i_SS_X_I', 'i_SS_X_S',
params_to_extract_highres = [param for param in params_to_extract_highres if param not in ['i_SS_BM', 'i_SS_X_I', 'i_SS_X_S']]

trace_samples_highres = az.extract(trace_highres, num_samples=num_inference_lines, var_names=params_to_extract_highres).to_dataframe()  
trace_samples_highres = trace_samples_highres.drop(columns=['chain', 'draw'], errors='ignore')

""""
Active:
Fixed Parameters,Identifiable Parameters,Non-Identifiable Parameters,Special Case Parameters
i_SS_BM,    b_H_O2,eta_NOX,     i_N_BM
i_SS_X_I,   b_STO_O2,K_NOX,     K_A_ALK
i_SS_X_S,   k_H,K_O2,           K_A_O2
f_S_I,     K_S,k_STO,
i_N_S_I,   K_X,mu_A,
K_NH4,     Y_A,Y_H_NOX,
i_N_X_I,   Y_H_O2,Y_STO_O2,
i_N_S_S,   Y_STO_NOX,K_A_NH4,
K_ALK,    mu_H,,
K_STO,   ,,
b_STO_NOX,  ,,
b_A_NOX,,,
b_H_NOX,,,
f_X_I,,,
b_A_O2,,,
i_N_X_S,,,

"""

params_to_extract_active = [
    'b_H_O2',
    'eta_NOX',
    'b_STO_O2',
    'K_NOX',
    'k_H',
    'K_O2',
    'K_S',
    'k_STO',
    'K_X',
    'mu_A',
    'Y_A',
    'Y_H_NOX',
    'Y_H_O2',
    'Y_STO_O2',
    'Y_STO_NOX',
    'K_A_NH4',
    'mu_H'
]

trace_samples_active = az.extract(trace_active, num_samples=num_inference_lines, var_names=params_to_extract_active).to_dataframe()
trace_samples_active = trace_samples_active.drop(columns=['chain', 'draw'], errors='ignore')


theta_samples_highres_df = pd.DataFrame(columns=priors.keys())
for i in range(len(trace_samples_highres)):
    sample = trace_samples_highres.iloc[i]
    new_row = {}
    for param in all_param_idx.keys():
        if param in sample.index:
            new_row[param] = sample[param]
        elif param in params_to_always_fix:
            new_row[param] = params_to_always_fix[param]
        elif param in special_case_params:
            new_row[param] = special_case_params[param]
    theta_samples_highres_df = pd.concat([theta_samples_highres_df, pd.DataFrame(new_row, index=[0])], ignore_index=True)
    
theta_samples_active_df = pd.DataFrame(columns=priors.keys())
for i in range(len(trace_samples_active)):
    sample = trace_samples_active.iloc[i]
    new_row = {}
    for param in all_param_idx.keys():
        if param in sample.index:
            new_row[param] = sample[param]
        elif param in params_to_always_fix:
            new_row[param] = params_to_always_fix[param]
        elif param in special_case_params:
            new_row[param] = special_case_params[param]
    theta_samples_active_df = pd.concat([theta_samples_active_df, pd.DataFrame(new_row, index=[0])], ignore_index=True)

In [None]:
# # --------------------------------------------------------------
# ## Custom Parameter Plotting
# # --------------------------------------------------------------

# # Plot each parameter in theta_samples_df
#     # For parameters that have a single value, plot a vertical line at that value
#     # For parameters that have multiple values, plot a histogram of the values
# fontsize=24
# for param in theta_samples_df.columns:
#     # Create plot
#     plt.figure(figsize=(24, 10))
#     true_value = true_theta.get(param)
#     param_formatted_name = theta_format_names[param] if param in theta_format_names else param
#     if param in always_fixed_param_vals:
#         # If the parameter is fixed, plot a vertical line at the fixed value
#         plt.axvline(x=always_fixed_param_vals[param], color='black', linestyle='-', label='Fixed Value')
#     else:
#         samples = theta_samples_df[param]
#         # If the parameter is not fixed, plot histogram of samples, along with mean, mode, and median lines
#         plt.hist(samples, bins=30, density=True, alpha=0.5, color='black', label='Sampled Values') # samples
#             # Descriptive statistics
#         mean_value = samples.mean()
#         median_value = samples.median()
#             # 95% confidence interval
#         ci_lower = samples.quantile(0.025)
#         ci_upper = samples.quantile(0.975)
#             # Plot CI and descriptive statistics
#         plt.axvline(x=mean_value, color='blue', linestyle='--', label='Mean Sampled Value')
#         plt.axvline(x=median_value, color='orange', linestyle='--', label='Median Sampled Value')
#         plt.axvline(x=ci_lower, color='purple', linestyle='-.', label='95% CI Lower Bound')
#         plt.axvline(x=ci_upper, color='purple', linestyle='-.', label='95% CI Upper Bound')
#     plt.axvline(x=true_value, color='red', linestyle='-', label='True Value') if true_value is not None else None
#     plt.title(f"Parameter: {param_formatted_name}", fontsize=fontsize)
#     plt.xlabel(f'{param_formatted_name} distribution', fontsize=fontsize)
#     plt.ylabel('Density', fontsize=fontsize)
#     plt.title(f"Parameter: {param_formatted_name}", fontsize=fontsize)
#     plt.xticks(fontsize=fontsize)
#     plt.yticks(fontsize=fontsize)
#     plt.grid(False)
#     plt.legend(fontsize=fontsize)
#     plt.show()
    

In [None]:
# --------------------------------------------------------------
## Custom Parameter Plotting with KDE
# --------------------------------------------------------------

fontsize = 24
for param in theta_samples_df.columns:
    plt.figure(figsize=(12, 6))
    true_value = true_theta.get(param)
    param_formatted_name = theta_format_names[param] if param in theta_format_names else param
    if param in always_fixed_param_vals:
        plt.axvline(x=always_fixed_param_vals[param], color='black', linestyle='-', label='Fixed Value', linewidth=3)
    else:
        samples = theta_samples_df[param]
        sns.kdeplot(samples, fill=True, color='blue', label='Sampled Density', linewidth=2)
        # Descriptive statistics
        mean_value = samples.mean()
        median_value = samples.median()
        hdi_lower = samples.quantile(0.025)
        hdi_upper = samples.quantile(0.975)
        # print hdi values and mean
        print(f"{param_formatted_name} - Mean: {mean_value}, Median: {median_value}, 95% HDI: [{hdi_lower}, {hdi_upper}], True Value: {true_value if true_value is not None else 'N/A'}")
        # Plot lines for mean, median, and 95% HDI
        plt.axvline(x=mean_value, color='darkblue', linestyle='--', label='Mean Sampled Value', linewidth=3, dashes=(5, 2))
        # plt.axvline(x=median_value, color='red', linestyle='--', label='Median Sampled Value', linewidth=3, dashes=(5, 2))
        plt.axvline(x=hdi_lower, color='purple', linestyle='-.', label='95% HDI Bound', linewidth=3)
        plt.axvline(x=hdi_upper, color='purple', linestyle='-.', linewidth=3)
    if true_value is not None:
        plt.axvline(x=true_value, color='black', linestyle='-', label='True Value', linewidth=3)
    # plt.title(f"Parameter: {param_formatted_name}", fontsize=fontsize)
    plt.title(f"(a)", fontsize=fontsize)
    plt.xlabel(f'{param_formatted_name} Distribution', fontsize=fontsize)
    # plt.ylabel('Density', fontsize=fontsize)
    plt.ylabel('', fontsize=fontsize)
    plt.xticks(fontsize=fontsize)
    # plt.yticks(fontsize=fontsize)
    plt.yticks([])
    plt.grid(False)
    plt.legend(fontsize=fontsize)
    plt.tight_layout()
    plt.show()


In [None]:
fontsize = 36
for param in theta_samples_active_df:
    # Drop NaNs to avoid KDE failures
    active_samples = theta_samples_active_df[param].dropna()
    highres_samples = theta_samples_highres_df[param].dropna()
    if active_samples.empty or highres_samples.empty:
        continue

    true_value = true_theta.get(param)
    param_formatted_name = theta_format_names.get(param, param)

    # Compute KDE for Active to fix y-limits
    kde_active = gaussian_kde(active_samples)
    x_vals = np.linspace(
        min(active_samples.min(), highres_samples.min()),
        max(active_samples.max(), highres_samples.max()),
        1000
    )
    y_vals_active = kde_active(x_vals)
    y_max = y_vals_active.max()

    # Create 2-row subplot
    fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(24, 10), gridspec_kw={'width_ratios': [1, 1.2]})

    # ----------------------------------------
    # LEFT PLOT: Highres KDE only
    # ----------------------------------------
    line1 = sns.kdeplot(
        highres_samples, fill=True, color='red',
        label='HF Sampled Density', linewidth=2, ax=ax1, clip=(0, None)
    )

    # Highres stats
    highres_mean = highres_samples.mean()
    highres_hdi_lower = highres_samples.quantile(0.025)
    highres_hdi_upper = highres_samples.quantile(0.975)

    print(f"{param_formatted_name} - Highres Mean: {highres_mean}, 95% HDI: [{highres_hdi_lower}, {highres_hdi_upper}], True Value: {true_value if true_value is not None else 'N/A'}")
    if true_value is not None:
        ax1.axvline(true_value, color='green', linestyle='-', label='True Value', linewidth=4)
    ax1.axvline(highres_mean, color='darkred', linestyle='--', label='HF Mean Sampled Value', linewidth=3, dashes=(5, 2))
    ax1.axvline(highres_hdi_lower, color='darkorange', linestyle='-.', label='HF 95% HDI Bound', linewidth=3)
    ax1.axvline(highres_hdi_upper, color='darkorange', linestyle='-.', linewidth=3)
    ax1.set_title(f'(a)', fontsize=fontsize)
    ax1.set_xlabel(f'', fontsize=fontsize)
    ax1.set_ylabel('Density', fontsize=fontsize)
    ax1.tick_params(axis='both', labelsize=fontsize)
    ax1.grid(False)

    # ----------------------------------------
    # RIGHT PLOT: Active + Highres overlay
    # ----------------------------------------
    sns.kdeplot(
        active_samples, fill=True, color='blue',
        label='AC Sampled Density', linewidth=2, ax=ax2, clip=(0, None)
    )
    sns.kdeplot(
        highres_samples, fill=True, color='red',
        label='HF Sampled Density', linewidth=2, ax=ax2, clip=(0, None), alpha=0.4
    )

    # Active stats
    active_mean = active_samples.mean()
    active_hdi_lower = active_samples.quantile(0.025)
    active_hdi_upper = active_samples.quantile(0.975)

    print(f"{param_formatted_name} - Active Mean: {active_mean}, 95% HDI: [{active_hdi_lower}, {active_hdi_upper}], True Value: {true_value if true_value is not None else 'N/A'}")

    ax2.axvline(active_mean, color='darkblue', linestyle='--', label='AC Mean Sampled Value', linewidth=3, dashes=(5, 2))
    ax2.axvline(active_hdi_lower, color='purple', linestyle='-.', label='AC 95% HDI Bound', linewidth=3)
    ax2.axvline(active_hdi_upper, color='purple', linestyle='-.', linewidth=3)
    if true_value is not None:
        ax2.axvline(true_value, color='green', linestyle='-', label='True Value', linewidth=3)
    ax2.set_xlabel(f'', fontsize=fontsize)
    # ax2.set_xlabel(f'{param_formatted_name} Distribution', fontsize=fontsize)
    ax2.set_ylabel('Density', fontsize=fontsize)
    ax2.set_title(f'(b)', fontsize=fontsize)
    ax2.set_ylim(0, y_max * 1.1)
    ax2.tick_params(axis='both', labelsize=fontsize)
    ax2.grid(False)

    # ----------------------------------------
    # Shared Legend
    # ----------------------------------------
    # Get unique handles and labels from both axes
    handles1, labels1 = ax1.get_legend_handles_labels()
    handles2, labels2 = ax2.get_legend_handles_labels()

    combined = dict(zip(labels1 + labels2, handles1 + handles2))
    fig.legend(
        combined.values(), combined.keys(),
        loc='lower center', ncol=2, fontsize=fontsize, frameon=False, bbox_to_anchor=(0.5, -0.12)
    )
    plt.suptitle(f'{param_formatted_name} Distribution', fontsize=fontsize + 4)
    plt.tight_layout(rect=[0, 0.18, 1, 1])  # Leave space for legend at bottom
    plt.show()

In [None]:
# ----------------------------------------------------------
## Inference plotting -- ODE model
# ----------------------------------------------------------

    # States
        # 0 S_O2
        # 1: S_I
        # 2: S_S
        # 3: S_NH4
        # 4: S_N2
        # 5: S_NOX
        # 6: S_ALK
        # 7: X_I
        # 8: X_S
        # 9: X_H
        # 10: X_STO
        # 11: X_A
        # 12: X_SS

    # Compounds
        # 0: Time
        # 1: Flowrate
        # 2: COD = S_I + S_S + X_I + X_S + X_STO + X_H + X_A
        # 3: NH4+NH3 = S_NH4
        # 4: NO3+NO2 = S_NOX
        # 5: TKN = S_NH4 + S_N2
        # 6: Alkalinity = S_ALK
        # 7: TSS = X_SS

    # Calculate compounds from states
        # Flowrate = influent flowrate
        # COD = S_I + S_S + X_I + X_S + X_STO + X_H + X_A
        # NH4+NH3 = S_NH4
        # NO3+NO2 = S_NOX
        # TKN = S_NH4 + S_N2
        # Alkalinity = S_ALK
        # TSS = X_SS

reactor_volumes = [
    r1_V    # m3, Reactor 1 Volume
]
# Ode system
ode_system = lambda t,y,theta: ode_system_wrapper(t=t, y=y, theta=theta, influentData=Data_Influent_states.to_numpy(), reactorVolumes=reactor_volumes)
#t_eval = np.linspace(min(Data_Influent_states['Time']), max(Data_Influent_states['Time']), 1000) 
t_eval = np.linspace(0, 14, 1000) # TODO: MANUAL
t_span = (min(t_eval), max(t_eval))
y0 = get_reactor_initial_values(top_dir)

# Grab theta samples from the distribution given by trace file and simulate ODE model with each theta sample then plot those simulations (called inference lines)

# 3D Dataframe to store simulation results for each theta sample -- 
    # Columns (y-axis): Time, Flowrate, COD, NH4+NH3, NO3+NO2, TKN, Alkalinity, TSS
    # Rows (x-axis): Each row will contain the simulated results over all time points
    # 3D row (z-axis): Result for each theta sample
inference_results = []  # List to hold results per sample
# Loop over inference lines
for i in range(num_inference_lines):
    theta_sample = theta_samples_df.iloc[i].to_numpy()
    ode_fun = lambda t,y : ode_system(t=t, y=y, theta=theta_sample)
    # Simulate ODE model with theta sample
    sol = solve_ivp(
        fun= ode_fun,
        t_span=t_span,
        y0=y0, # Initial conditions
        t_eval=t_eval, # Time points to evaluate function at
        method='BDF',
    )

    sol_flow = np.interp(t_eval, Data_Influent_states['Time'], Data_Effluent_states['Flowrate'])  # Interpolated Effluent flowrate
    sol_COD = sol.y[1] + sol.y[2] + sol.y[7] + sol.y[8] + sol.y[10] + sol.y[9] + sol.y[11]
    sol_NH4 = sol.y[3]
    sol_NOx = sol.y[5]
    sol_TKN = sol.y[3] + sol.y[4]
    sol_Alkalinity = sol.y[6]
    sol_TSS = sol.y[12]

    sample_df = pd.DataFrame({
        'Time': sol.t,
        'Flowrate': sol_flow,
        'COD': sol_COD,
        'NH4+NH3': sol_NH4,
        'NO3+NO2': sol_NOx,
        'TKN': sol_TKN,
        'Alkalinity': sol_Alkalinity,
        'TSS': sol_TSS
    })
    inference_results.append(sample_df)

    # Determine Confidence Interval for each time point

# End of loop over inference lines

# Calculate confidence intervals for each compound at each time point
ci_df = pd.DataFrame({'Time': t_eval})
compound_keys = ['COD', 'NH4+NH3', 'NO3+NO2', 'TKN', 'Alkalinity', 'TSS']
for key in compound_keys:
    samples = np.vstack([df[key].to_numpy() for df in inference_results])
    mean   = samples.mean(axis=0) # along time axis (axis=0)
    std    = samples.std(axis=0, ddof=0)
        # Non percentiles, better for symmetric distributions
    # lower  = mean - crit_value * std
    # upper  = mean + crit_value * std
        # Percentiles, better for asymmetric distributions
    # # 95%
    # lower = np.percentile(samples, 2.5, axis=0)
    # upper = np.percentile(samples, 97.5, axis=0)
    # 99%
    lower = np.percentile(samples, 0.5, axis=0)
    upper = np.percentile(samples, 99.5, axis=0)


    ci_df[f'{key}_lower'] = lower
    ci_df[f'{key}_mean']  = mean
    ci_df[f'{key}_upper'] = upper


## Underlying true values for compounds
sol_true = solve_ivp(
    fun=ode_system,
    t_span=t_span,
    y0=y0, # Initial conditions
    t_eval=t_eval, # Time points to evaluate function at
    method='BDF',
    args=(true_theta_array,) # Arguments to pass to the function
)
true_flow = np.interp(t_eval, Data_Influent_states['Time'], Data_Effluent_states['Flowrate'])
true_COD = sol_true.y[1] + sol_true.y[2] + sol_true.y[7] + sol_true.y[8] + sol_true.y[10] + sol_true.y[9] + sol_true.y[11]
true_NH4 = sol_true.y[3]
true_NOx = sol_true.y[5]
true_TKN = sol_true.y[3] + sol_true.y[4]
true_Alkalinity = sol_true.y[6]
true_TSS = sol_true.y[12]
# Create a DataFrame for the true values
true_values_df = pd.DataFrame({
    'Time': t_eval,
    'Flowrate': true_flow,
    'COD': true_COD,
    'NH4+NH3': true_NH4,
    'NO3+NO2': true_NOx,
    'TKN': true_TKN,
    'Alkalinity': true_Alkalinity,
    'TSS': true_TSS
})


In [None]:
# ----------------------------------------------------------
## Plot inference lines versus data
# ----------------------------------------------------------
fontsize=24
markersize = 10
compound_format_names = {
    'COD': r'$COD$',
    'NH4+NH3': r'$NH_{4}+NH_{3}$',
    'NO3+NO2': r'$NO_{3}+NO_{2}$',
    'TKN': r'$TKN$',
    'Alkalinity': r'$Alkalinity$',
    'TSS': r'$TSS$'
}

# Loop over inference results and plot theta sample on each figure 
    # Plotting COD, NH4+NH3, NO3+NO2, TKN, Alkalinity, TSS
for key in compound_keys:
    compound_formatted_name = compound_format_names[key] if key in compound_format_names else key
    fig, ax = plt.subplots(figsize=(24, 10))
    # Plot raw data
    ax.plot(
        Data_Effluent_compounds['Time'],
        Data_Effluent_compounds[key],
        linestyle='None',
        marker='.',
        markersize=markersize,
        color='black',
        label=f'{compound_formatted_name} Data',
    )
    # Plot all inference result samples for this compound
    for idx, sample_df in enumerate(inference_results):
        if idx == 0:
            ax.plot(
                sample_df['Time'],
                sample_df[key],
                linestyle='-',
                alpha=0.05,
                color='blue',
                label='Inference Lines'
            )
        else:
            ax.plot(
                sample_df['Time'],
                sample_df[key],
                linestyle='-',
                alpha=0.05,
                color='blue'
            )
    # Plot True values
    ax.plot(
        true_values_df['Time'],
        true_values_df[key],
        linestyle='--',
        color='red',
        label='True Value'
    )
    ax.set_title(f'Inference Lines – {compound_formatted_name}', fontsize=fontsize)
    ax.set_xlabel('Time (days)', fontsize=fontsize)
    ax.set_ylabel(f'{compound_formatted_name} (mg/L)', fontsize=fontsize)
    plt.xticks(fontsize=fontsize)
    plt.yticks(fontsize=fontsize)
    # ax.set_ylim(bottom=0, top=Data_Effluent_compounds[key].max() * 1.1)
    ax.legend(fontsize=fontsize)
    plt.tight_layout()
    plt.show()


In [None]:
# ------------------------------------------------------------
## Confidence Interval Plots for compounds
# ------------------------------------------------------------

# y_lims = {
#     'COD': (225, 255),
#     'NH4+NH3': (0, 2),
#     'NO3+NO2': (0, 10),
#     'TKN': (0, 10),
#     'Alkalinity': (0, 200),
#     'TSS': (0, 10)
# }
y_lims = {
    'COD': (225, 255),
    'NH4+NH3': (None, None),
    'NO3+NO2': (None, None),
    'TKN': (None, None),
    'Alkalinity': (None, None),
    'TSS': (None, None)
}

for key in compound_keys:
    compound_formatted_name = compound_format_names[key] if key in compound_format_names else key
    fig, ax = plt.subplots(figsize=(12, 8))
    # Plot raw data
    ax.plot(
        Data_Effluent_compounds['Time'],
        Data_Effluent_compounds[key],
        linestyle='None',
        marker='.',
        markersize=markersize,
        color='black',
        label=f'{compound_formatted_name} Data'
    )
    # Plot mean trajectory
    ax.plot(
        ci_df['Time'],
        ci_df[f'{key}_mean'],
        linestyle='-',
        linewidth=2,
        color='blue',
        label='Mean'
    )
    # Shade 95% credible interval
    ax.fill_between(
        ci_df['Time'],
        ci_df[f'{key}_lower'],
        ci_df[f'{key}_upper'],
        alpha=0.2,
        color='blue',
        label='99% Credible Interval'
    )
    # Plot true values
    ax.plot(
        true_values_df['Time'],
        true_values_df[key],
        linestyle='--',
        linewidth=2,
        color='green',
        label='True Value'
    )
    # ax.set_title(f'{compound_formatted_name} – 99% Credbile Interval', fontsize=fontsize)
    ax.set_title(f'PCA Method - Active Case', fontsize=fontsize)
    ax.set_xlabel('Time (days)', fontsize=fontsize)
    ax.set_ylabel(f'{compound_formatted_name} (mg/L)', fontsize=fontsize)
    plt.xticks(fontsize=fontsize)
    plt.yticks(fontsize=fontsize)
    # ax.set_ylim(bottom=0, top=Data_Effluent_compounds[key].max() * 1.1)
    ax.set_xlim(0, 14)
    ax.set_ylim(y_lims[key])
    ax.legend(fontsize=fontsize)
    plt.tight_layout()
    plt.show()
    

In [None]:
# -------------------------------------------------------------
## Inference plotting -- ODE model --- Active and HighRes
# -------------------------------------------------------------

def compute_inference_results(theta_samples_df, Data_Influent_states, Data_Effluent_states, t_eval, y0, reactor_volumes):
    t_span = (min(t_eval), max(t_eval))

    inference_results = []  # Stores all simulations
    for i in range(len(theta_samples_df)):
        theta_sample = theta_samples_df.iloc[i].to_numpy()
        ode_system = lambda t, y: ode_system_wrapper(
            t=t, y=y, theta=theta_sample,
            influentData=Data_Influent_states.to_numpy(),
            reactorVolumes=reactor_volumes
        )
        sol = solve_ivp(
            fun=ode_system,
            t_span=t_span,
            y0=y0,
            t_eval=t_eval,
            method='BDF',
        )

        sol_flow = np.interp(t_eval, Data_Influent_states['Time'], Data_Effluent_states['Flowrate'])
        sol_COD = sol.y[1] + sol.y[2] + sol.y[7] + sol.y[8] + sol.y[10] + sol.y[9] + sol.y[11]
        sol_NH4 = sol.y[3]
        sol_NOx = sol.y[5]
        sol_TKN = sol.y[3] + sol.y[4]
        sol_Alkalinity = sol.y[6]
        sol_TSS = sol.y[12]

        df = pd.DataFrame({
            'Time': t_eval,
            'Flowrate': sol_flow,
            'COD': sol_COD,
            'NH4+NH3': sol_NH4,
            'NO3+NO2': sol_NOx,
            'TKN': sol_TKN,
            'Alkalinity': sol_Alkalinity,
            'TSS': sol_TSS
        })

        inference_results.append(df)

    # Confidence interval calculation
    ci_df = pd.DataFrame({'Time': t_eval})
    compound_keys = ['COD', 'NH4+NH3', 'NO3+NO2', 'TKN', 'Alkalinity', 'TSS']
    for key in compound_keys:
        samples = np.vstack([df[key].to_numpy() for df in inference_results])
        mean   = samples.mean(axis=0)
        # 95%
        # lower  = np.percentile(samples, 2.5, axis=0)
        # upper  = np.percentile(samples, 97.5, axis=0)
        # 99%
        lower  = np.percentile(samples, 0.5, axis=0)
        upper  = np.percentile(samples, 99.5, axis=0)

        ci_df[f'{key}_lower'] = lower
        ci_df[f'{key}_mean']  = mean
        ci_df[f'{key}_upper'] = upper

    return inference_results, ci_df

# Define t_eval and y0 (you already have this in your script)
y0 = get_reactor_initial_values(top_dir)
reactor_volumes = [
    r1_V    # m3, Reactor 1 Volume
]
# Ode system
#t_eval = np.linspace(min(Data_Influent_states['Time']), max(Data_Influent_states['Time']), 1000) 
# t_eval = np.linspace(0, 14, 1000) # TODO: MANUAL
t_eval = np.linspace(0, 30, 1000) # TODO: MANUAL

t_span = (min(t_eval), max(t_eval))

# HIGHRES
inference_results_highres, ci_df_highres = compute_inference_results(
    theta_samples_highres_df,
    Data_Influent_states_HighRes,
    Data_Influent_states_HighRes,
    t_eval, y0,
    reactor_volumes,
)

# ACTIVE
inference_results_active, ci_df_active = compute_inference_results(
    theta_samples_active_df,
    Data_Influent_states_HighRes,
    Data_Influent_states_HighRes,
    t_eval, y0,
    reactor_volumes,
)

# Underlying true values
ode_system = lambda t, y, theta: ode_system_wrapper(
    t=t, y=y, theta=theta,
    influentData=Data_Influent_states_HighRes.to_numpy(),
    reactorVolumes=reactor_volumes
)
sol_true = solve_ivp(
    fun=ode_system,
    t_span=t_span,
    y0=y0, # Initial conditions
    t_eval=t_eval, # Time points to evaluate function at
    method='BDF',
    args=(true_theta_array,) # Arguments to pass to the function
)
true_flow = np.interp(t_eval, Data_Influent_states_HighRes['Time'], Data_Effluent_states_HighRes['Flowrate'])
true_COD = sol_true.y[1] + sol_true.y[2] + sol_true.y[7] + sol_true.y[8] + sol_true.y[10] + sol_true.y[9] + sol_true.y[11]
true_NH4 = sol_true.y[3]
true_NOx = sol_true.y[5]
true_TKN = sol_true.y[3] + sol_true.y[4]
true_Alkalinity = sol_true.y[6]
true_TSS = sol_true.y[12]
# Create a DataFrame for the true values
true_values_df = pd.DataFrame({
    'Time': t_eval,
    'Flowrate': true_flow,
    'COD': true_COD,
    'NH4+NH3': true_NH4,
    'NO3+NO2': true_NOx,
    'TKN': true_TKN,
    'Alkalinity': true_Alkalinity,
    'TSS': true_TSS
})


In [None]:
fontsize = 24
markersize = 8

compound_keys = ['COD', 'NH4+NH3', 'NO3+NO2', 'TKN', 'Alkalinity', 'TSS']
compound_format_names = {
    'COD': 'COD',
    'NH4+NH3': 'NH$_4^+$ + NH$_3$',
    'NO3+NO2': 'NO$_3^-$ + NO$_2^-$',
    'TKN': 'TKN',
    'Alkalinity': 'Alkalinity',
    'TSS': 'TSS'
}

# y_lims = {
#     'COD': (225, 255),
#     'NH4+NH3': (0, 10),
#     'NO3+NO2': (0, 10),
#     'TKN': (0, 10),
#     'Alkalinity': (0, 200),
#     'TSS': (0, 10)
# }

y_lims = {
    'COD': (225, 255),
    'NH4+NH3': (None, None),
    'NO3+NO2': (None, None),
    'TKN': (None, None),
    'Alkalinity': (None, None),
    'TSS': (None, None)
}

for key in compound_keys:
    compound_formatted_name = compound_format_names.get(key, key)

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 8), sharex=True)

    # ------------------------------------------------------------
    # HIGHRES subplot (Top)
    # ------------------------------------------------------------
    ax1.plot(
        Data_Effluent_compounds_HighRes['Time'],
        Data_Effluent_compounds_HighRes[key],
        linestyle='None',
        marker='.',
        markersize=markersize,
        color='black',
        label=f'{compound_formatted_name} Data'
    )
    ax1.plot(
        ci_df_highres['Time'],
        ci_df_highres[f'{key}_mean'],
        linestyle='-',
        linewidth=2,
        color='blue',
        label='Mean'
    )
    ax1.fill_between(
        ci_df_highres['Time'],
        ci_df_highres[f'{key}_lower'],
        ci_df_highres[f'{key}_upper'],
        alpha=0.2,
        color='blue',
        label='99% Credible Interval'
    )
    ax1.plot(
        true_values_df['Time'],
        true_values_df[key],
        linestyle='--',
        linewidth=2,
        color='red',
        label='True Value'
    )
    ax1.set_ylabel(f'{compound_formatted_name} (mg/L)', fontsize=fontsize)
    ax1.set_ylim(y_lims[key])
    ax1.set_xlim(0, 14)
    ax1.set_title(f'(a)', fontsize=fontsize)
    # ax1.legend(fontsize=fontsize - 4)
    ax1.tick_params(axis='both', labelsize=fontsize)

    # ------------------------------------------------------------
    # ACTIVE subplot (Bottom)
    # ------------------------------------------------------------
    ax2.plot(
        Data_Effluent_compounds_HighRes['Time'],
        Data_Effluent_compounds_HighRes[key],
        linestyle='None',
        marker='.',
        markersize=markersize,
        color='black',
        label=f'{compound_formatted_name} Data'
    )
    ax2.plot(
        ci_df_active['Time'],
        ci_df_active[f'{key}_mean'],
        linestyle='-',
        linewidth=2,
        color='blue',
        label='Mean'
    )
    ax2.fill_between(
        ci_df_active['Time'],
        ci_df_active[f'{key}_lower'],
        ci_df_active[f'{key}_upper'],
        alpha=0.2,
        color='blue',
        label='99% Credible Interval'
    )
    ax2.plot(
        true_values_df['Time'],
        true_values_df[key],
        linestyle='--',
        linewidth=2,
        color='red',
        label='True Value'
    )
    ax2.set_xlabel('Time (days)', fontsize=fontsize)
    ax2.set_ylabel(f'{compound_formatted_name} (mg/L)', fontsize=fontsize)
    ax2.set_ylim(y_lims[key])
    ax2.set_xlim(0, 14)
    ax2.set_title(f'(b)', fontsize=fontsize)
    # ax2.legend(fontsize=fontsize - 4)
    ax2.tick_params(axis='both', labelsize=fontsize)

    # Shared legend -- only show COD data, mean, 99% credible interval, and true value
    handles1, labels1 = ax1.get_legend_handles_labels()
    handles2, labels2 = ax2.get_legend_handles_labels()
    combined = dict(zip(labels1 + labels2, handles1 + handles2))
    fig.legend(
        combined.values(), combined.keys(),
        loc='lower center', ncol=2, fontsize=fontsize - 4, frameon=False, bbox_to_anchor=(0.5, -0.15)
    )

    plt.tight_layout()
    plt.show()


In [None]:
fontsize = 32
markersize = 14

compound_keys = ['COD', 'NH4+NH3', 'NO3+NO2', 'TKN', 'Alkalinity', 'TSS']
compound_format_names = {
    'COD': 'COD',
    'NH4+NH3': 'NH$_4^+$ + NH$_3$',
    'NO3+NO2': 'NO$_3^-$ + NO$_2^-$',
    'TKN': 'TKN',
    'Alkalinity': 'Alkalinity',
    'TSS': 'TSS'
}

y_lims = {
    'COD': (225, 255),
    'NH4+NH3': (None, None),
    'NO3+NO2': (None, None),
    'TKN': (None, None),
    'Alkalinity': (None, None),
    'TSS': (None, None)
}

# Dictionary to pair datasets with their labels and CI DataFrames
datasets = {
    "HighRes": (Data_Effluent_compounds_HighRes, ci_df_highres),
    "Active": (Data_Effluent_compounds_HighRes, ci_df_active)
}

for dataset_name, (data_df, ci_df) in datasets.items():
    for key in compound_keys:
        compound_formatted_name = compound_format_names.get(key, key)

        fig, ax = plt.subplots(figsize=(20, 10))

        # Data points
        ax.plot(
            data_df['Time'],
            data_df[key],
            linestyle='None',
            marker='.',
            markersize=markersize,
            color='black',
            label=f'{compound_formatted_name} Data'
        )

        # Mean
        ax.plot(
            ci_df['Time'],
            ci_df[f'{key}_mean'],
            linestyle='-',
            linewidth=2,
            color='blue',
            label='Mean'
        )

        # Credible Interval
        ax.fill_between(
            ci_df['Time'],
            ci_df[f'{key}_lower'],
            ci_df[f'{key}_upper'],
            alpha=0.2,
            color='blue',
            label='99% Credible Interval'
        )

        # True values
        ax.plot(
            true_values_df['Time'],
            true_values_df[key],
            linestyle='--',
            linewidth=2,
            color='red',
            label='True Value'
        )

        # Formatting
        ax.set_xlabel('Time (days)', fontsize=fontsize)
        ax.set_ylabel(f'{compound_formatted_name} (mg/L)', fontsize=fontsize)
        ax.set_ylim(y_lims[key])
        ax.set_xlim(0, 14)
        #ax.set_title(f'{compound_formatted_name} - {dataset_name}', fontsize=fontsize)

        # If highRes, make title "(a) COD - HF", if Active, make title "(b) COD - AC"
        if dataset_name == "HighRes":
            ax.set_title(f'(a) {compound_formatted_name} (HF)', fontsize=fontsize)
        else:
            ax.set_title(f'(b) {compound_formatted_name} (AC)', fontsize=fontsize)
        
        ax.tick_params(axis='both', labelsize=fontsize)
        ax.legend(fontsize=fontsize - 4)

        plt.tight_layout()
        plt.show()


In [None]:
fontsize = 32
markersize = 14

compound_keys = ['COD', 'NH4+NH3', 'NO3+NO2', 'TKN', 'Alkalinity', 'TSS']
compound_format_names = {
    'COD': 'COD',
    'NH4+NH3': 'NH$_4^+$ + NH$_3$',
    'NO3+NO2': 'NO$_3^-$ + NO$_2^-$',
    'TKN': 'TKN',
    'Alkalinity': 'Alkalinity',
    'TSS': 'TSS'
}

y_lims = {
    'COD': (225, 255),
    'NH4+NH3': (0, 2),
    'NO3+NO2': (46, 54),
    'TKN': (15, 28),
    'Alkalinity': (2, 5),
    'TSS': (350, 390)
}

# Dictionary to pair datasets with their labels and CI DataFrames
datasets = {
    "HighRes": (Data_Effluent_compounds_HighRes, ci_df_highres),
    "Active": (Data_Effluent_compounds_HighRes, ci_df_active)
}
for key in compound_keys:
    compound_formatted_name = compound_format_names.get(key, key)
    # Plot both datasets (HF and AC) side by side for each compound
    # Use datasets["HighRes"] for HF (left), datasets["Active"] for AC (right)
    data_df, ci_df = datasets["HighRes"]
    active_data_df, active_ci_df = datasets["Active"]
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8), sharey=True)
    # HF subplot (Left)
    ax1.plot(
        data_df['Time'],
        data_df[key],
        linestyle='None',
        marker='.',
        markersize=markersize,
        color='black',
        label=f'{compound_formatted_name} Data'
    )
    ax1.plot(
        ci_df['Time'],
        ci_df[f'{key}_mean'],
        linestyle='-',
        linewidth=2,
        color='blue',
        label='Mean'
    )
    ax1.fill_between(
        ci_df['Time'],
        ci_df[f'{key}_lower'],
        ci_df[f'{key}_upper'],
        alpha=0.2,
        color='blue',
        label='99% Credible Interval'
    )
    ax1.plot(
        true_values_df['Time'],
        true_values_df[key],
        linestyle='--',
        linewidth=2,
        color='red',
        label='True Value'
    )
    ax1.set_xlabel('Time (days)', fontsize=fontsize)
    ax1.set_ylabel(f'{compound_formatted_name} (mg/L)', fontsize=fontsize)
    ax1.set_ylim(y_lims[key])
    ax1.set_xlim(0, 14)
    ax1.set_title(f'(a) {compound_formatted_name} (HF)', fontsize=fontsize)
    ax1.tick_params(axis='both', labelsize=fontsize)
    # ax1.legend(fontsize=fontsize - 4)

    # AC subplot (Right)
    # Use the Active dataset for the right subplot
    active_data_df, active_ci_df = datasets["Active"]
    ax2.plot(
        active_data_df['Time'],
        active_data_df[key],
        linestyle='None',
        marker='.',
        markersize=markersize,
        color='black',
        label=f'{compound_formatted_name} Data'
    )
    ax2.plot(
        active_ci_df['Time'],
        active_ci_df[f'{key}_mean'],
        linestyle='-',
        linewidth=2,
        color='blue',
        label='Mean'
    )
    ax2.fill_between(
        active_ci_df['Time'],
        active_ci_df[f'{key}_lower'],
        active_ci_df[f'{key}_upper'],
        alpha=0.2,
        color='blue',
        label='99% Credible Interval'
    )
    ax2.plot(
        true_values_df['Time'],
        true_values_df[key],
        linestyle='--',
        linewidth=2,
        color='red',
        label='True Value'
    )
    ax2.set_xlabel('Time (days)', fontsize=fontsize)
    ax2.set_title(f'(b) {compound_formatted_name} (AC)', fontsize=fontsize)
    ax2.set_xlim(0, 14)
    ax2.tick_params(axis='both', labelsize=fontsize)
    #ax2.legend(fontsize=fontsize - 4)

    # Shared legend
    handles1, labels1 = ax1.get_legend_handles_labels()
    handles2, labels2 = ax2.get_legend_handles_labels()
    combined = dict(zip(labels1 + labels2, handles1 + handles2))
    fig.legend(
        combined.values(), combined.keys(),
        loc='lower center', ncol=2, fontsize=fontsize - 4, frameon=False, bbox_to_anchor=(0.5, -0.18)
    )

    plt.tight_layout()
    plt.show()

In [None]:
fontsize = 32
markersize = 14

compound_keys = ['COD', 'NH4+NH3', 'NO3+NO2', 'TKN', 'Alkalinity', 'TSS']
compound_format_names = {
    'COD': 'COD',
    'NH4+NH3': 'NH$_4^+$ + NH$_3$',
    'NO3+NO2': 'NO$_3^-$ + NO$_2^-$',
    'TKN': 'TKN',
    'Alkalinity': 'Alkalinity',
    'TSS': 'TSS'
}

y_lims = {
    'COD': (225, 270),
    'NH4+NH3': (None, None),
    'NO3+NO2': (None, None),
    'TKN': (None, None),
    'Alkalinity': (None, None),
    'TSS': (None, None)
}

# Correct dataset pairing
datasets = {
    "HF": (Data_Effluent_compounds_HighRes, ci_df_highres),
    "AC": (Data_Effluent_compounds_HighRes, ci_df_active)
}

styles = {
    "HF": {"color": "blue", "linestyle": "-"},
    "AC": {"color": "red", "linestyle": "--"}
}

hatches = {
    "HF": "//",   # diagonal forward
    "AC": "\\\\"  # diagonal backward
}

for key in compound_keys:
    compound_formatted_name = compound_format_names.get(key, key)

    fig, ax = plt.subplots(figsize=(20, 10))

    # Data points (black, same for both HF and AC effluent)
    ax.plot(
        Data_Effluent_compounds_HighRes['Time'],
        Data_Effluent_compounds_HighRes[key],
        linestyle='None',
        marker='.',
        markersize=markersize,
        color='black',
        label=f'Effluent {compound_formatted_name} Data'
    )

    # Loop over HF and AC
    for dataset_name, (data_df, ci_df) in datasets.items():
        style = styles[dataset_name]

        # Mean
        ax.plot(
            ci_df['Time'],
            ci_df[f'{key}_mean'],
            linestyle=style["linestyle"],
            linewidth=3,
            color=style["color"],
            label=f'{dataset_name} Mean'
        )

        # Credible Interval with hatching
        ax.fill_between(
            ci_df['Time'],
            ci_df[f'{key}_lower'],
            ci_df[f'{key}_upper'],
            facecolor=style["color"],
            alpha=0.15,
            hatch=hatches[dataset_name],
            edgecolor=style["color"],
            linewidth=0,
            label=f'{dataset_name} 99% Credible Interval'
        )

    # True values
    ax.plot(
        true_values_df['Time'],
        true_values_df[key],
        linestyle=':',
        linewidth=3,
        color='black',
        label='True Value'
    )

    # Formatting
    ax.set_xlabel('Time (days)', fontsize=fontsize)
    ax.set_ylabel(f'{compound_formatted_name} (mg/L)', fontsize=fontsize)
    ax.set_ylim(y_lims[key])
    ax.set_xlim(0, 30)
    ax.set_title(f'{compound_formatted_name} - HF vs AC', fontsize=fontsize)

    ax.tick_params(axis='both', labelsize=fontsize)
    ax.legend(fontsize=fontsize - 4)

    plt.tight_layout()
    plt.show()
