In [1]:
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from scipy.stats import gaussian_kde
import autograd.numpy as anp 
import sampyl as smp
from scipy.integrate import odeint
import pandas as pd
import sunode.wrappers.as_pytensor as sun

def mean_squared_error(y_true, y_pred):
    """
    Calculates the Mean Squared Error (MSE) between two arrays.
    Args:
    y_true (numpy array): The true values.
    y_pred (numpy array): The predicted values.

    Returns:
    float: The Mean Squared Error between the true and predicted values.
    """
    y_true = np.array(y_true)  # Ensure the input is a numpy array
    y_pred = np.array(y_pred)  # Ensure the input is a numpy array

    # Calculate the MSE
    mse = np.mean((y_true - y_pred) ** 2)
    return mse



def plot_kde_matrix(data, titles, color='orange', name='SIP method'):
    n, m = data.shape
    if len(titles) != m:
        raise ValueError("Length of titles vector must be equal to the number of columns in data.")
    
    fig, axes = plt.subplots(m, 1, figsize=(2.5, 1.5 * m), sharex='col', gridspec_kw={'hspace': 0})
    for i in range(m):
        kde = gaussian_kde(data[:, i])
        x = np.linspace(min(data[:, i]), max(data[:, i]), 1000)
        axes[i].plot(x, kde(x), color=color)
        axes[i].fill_between(x, kde(x), color=color, alpha=0.5)
        mean_val = np.mean(data[:, i])
        axes[i].axvline(mean_val, color='red', linestyle='--', linewidth=1, label=f'mean= {mean_val.round(2)}')
        axes[i].set_ylabel(titles[i], rotation=0, labelpad=40)
        # Position legend in the top right corner of each subplot
        axes[i].legend(loc='upper right', draggable = True)
        axes[i].yaxis.set_label_position("left")
        axes[i].spines['top'].set_visible(False)
        axes[i].spines['right'].set_visible(False)
        axes[i].spines['left'].set_visible(False)
        axes[i].spines['bottom'].set_visible(True)
        axes[i].get_xaxis().set_visible(True)
        axes[i].tick_params(axis='x', which='both', bottom=True)
        axes[i].tick_params(axis='y', which='both', left=False, right=False)  # Remove vertical ticks
        axes[i].set_yticklabels([])  # Remove vertical tick labels
        
    plt.subplots_adjust(left=0.3)  # Adjust this as necessary to fit titles
    plt.xlabel('Value')  # Set common xlabel for the horizontal axis
    plt.suptitle(name)
    plt.show()

# Required Modification to run NUTS, FInd_MAP and Hamiltonian monte carlo
from autograd import numpy as anp
from autograd.extend import primitive, defvjp

@primitive
def safe_sqrt(x):
    """ A safe version of sqrt that can handle ArrayBox types. """
    return x**0.5  # Using exponentiation as a substitute for sqrt.

def grad_safe_sqrt(ans, x):
    """ Gradient of sqrt at x, using the chain rule:
        d(sqrt(x))/dx = 0.5 / sqrt(x) = 0.5 * x**(-0.5) """
    return lambda g: g * 0.5 * x**(-0.5)

# Linking the gradient function to safe_sqrt
defvjp(safe_sqrt, grad_safe_sqrt)

def safe_std(x, axis=None, keepdims=False):
    mean = anp.mean(x, axis=axis, keepdims=True)
    var = anp.mean(anp.square(x - mean), axis=axis, keepdims=keepdims)
    std_dev = safe_sqrt(var)  # Using the custom safe_sqrt function.
    return std_dev

def generate_samples(mean_or_lower, std_or_upper, n_samples, distribution='normal'):
    """Generate samples from a specified distribution (normal or uniform) based on provided parameters."""    
    if distribution == 'normal':
        # Generate samples from a normal distribution
        if np.isscalar(mean_or_lower) and np.isscalar(std_or_upper):
            samples = np.random.normal(loc=mean_or_lower, scale=std_or_upper, size=n_samples)
        else:
            # Vectorized sampling for multiple sets of means and std deviations
            samples = np.random.normal(loc=mean_or_lower, scale=std_or_upper, size=(n_samples, len(mean_or_lower)))            
    elif distribution == 'uniform':
        # Generate samples from a uniform distribution
        if np.isscalar(mean_or_lower) and np.isscalar(std_or_upper):
            samples = np.random.uniform(low=mean_or_lower, high=std_or_upper, size=n_samples)
        else:
            # Vectorized sampling for multiple sets of bounds
            samples = np.random.uniform(low=mean_or_lower, high=std_or_upper, size=(n_samples, len(mean_or_lower)))            
    else:
        raise ValueError("Unsupported distribution type. Choose 'normal' or 'uniform'.")    
    return samples

def expand_2D_array(arr):
    """Expands a 2D array by pairing every element of each column with every element of all other columns."""
    # Number of columns
    n = arr.shape[1]    
    # List to hold the 1D arrays (columns)
    columns = [arr[:, i] for i in range(n)]    
    # Generate all combinations of pairs from these columns
    # itertools.product creates a Cartesian product, equivalent to a nested for-loop
    combinations = list(product(*columns))    
    # Convert list of tuples (combinations) back to a NumPy array
    expanded_array = np.array(combinations)    
    return expanded_array

def plot_paths(x_values, y_paths, show = False, color='blue', hl = 'r*', label = 'mean_path'):
    """Plots multiple y paths (each row in y_paths) against x_values on the same plot with a specified transparency. """
    # Calculate the mean path across all rows (mean for each column)
    mean_path = np.mean(y_paths, axis=0)
    
    # Set up the plot
    plt.figure(figsize=(10, 6))
    
    # Plot each path
    for path in y_paths:
        plt.plot(x_values, path, color=color, alpha=0.3)  # semi-transparent blue lines
    
    # Plot the mean path
    plt.plot(x_values, mean_path, hl , label=label)  # mean path in red
    
    # Adding title and labels
    plt.title('Multiple Y Paths with Mean Path')
    plt.xlabel('X Values')
    plt.ylabel('Y Values')
    plt.legend()
    if show:
        plt.show()