In [None]:
import jax
import jax.numpy as jnp
import liesel.model as lsl
import liesel.goose as gs
import liesel.contrib.splines as splines
import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
import numpy as np

In [None]:
from liesel.contrib.splines import equidistant_knots, basis_matrix
import tensorflow_probability.substrates.jax.bijectors as tfb
from liesel.distributions.mvn_degen import MultivariateNormalDegenerate
from sklearn.metrics import mean_squared_error
import numpy as np
import jax.numpy as jnp

In [None]:
# Generate Data for task
def generate_gaussian_data(n, seed, M=3, c_u=0.0):
    """
    Generate synthetic Gaussian-distributed data with replicates and a response variable.

    Args:
        n (int): The number of samples.
        M (int, optional): The number of replicates per sample. Default is 3.
        c_u (float, optional): The covariance factor between replicates. Default is 0.0.
        seed (int, optional): Random seed for reproducibility. Default is None.

    Returns:
        tuple: (y, x, replicates, sigma_matrices) as JAX arrays.
    """

    # Set random seed for reproducibility

    np.random.seed(seed)

    # Sample values for true covariate
    x = np.random.normal(loc=10, scale=5, size=n)

    # Generate covariance matrices
    def create_sigma_me(dim_sigma_me, c_u):
        Sigma_me = np.zeros((dim_sigma_me, dim_sigma_me))
        for i in range(dim_sigma_me):
            for j in range(dim_sigma_me):
                if i == j:
                    Sigma_me[i, j] = 1
                else:
                    Sigma_me[i, j] = c_u
        return Sigma_me

    # Scale the Sigma_me matrices by individual variance factor
    sigma_matrices = []
    for i in range(n):
        sigma_sq_ui = 1 if i < n // 2 else 2  # First half scaled by 1, second half by 2
        Sigma_me = create_sigma_me(M, c_u)
        scaled_matrix = sigma_sq_ui * Sigma_me
        sigma_matrices.append(scaled_matrix)

    sigma_matrices_array = jnp.array(np.stack(sigma_matrices))

    # Create M replicates of true variable x
    replicates = []
    for i in range(n):
        mean_vector = np.repeat(x[i], M)
        Sigma_me = sigma_matrices[i]
        samples = np.random.multivariate_normal(mean_vector, Sigma_me)
        replicates.append(samples)

    # Generate response variable y
    variances = np.random.choice([0.3, 0.5], size=n)
    y_true = np.sin(x)

    y = np.random.normal(loc=np.sin(x), scale=np.sqrt(variances))

    # Convert to JAX arrays before returning
    x = jnp.array(x)
    replicates = jnp.array(np.array(replicates))
    y_noise = jnp.array(y)
    sigma_matrices = jnp.array(np.stack(sigma_matrices))

    return y_noise, x, replicates, sigma_matrices, y_true

In [None]:
def create_model(x, y, n_param_splines, sigma, x_tilde = None, sample_x = False):
  if sample_x:
    # Define hyperparameters for variance of x
    a_x = lsl.Var.new_param(0.001, name = "a_x")
    b_x = lsl.Var.new_param(0.001, name = "b_x")

    # Define prior for tau2_x using an Inverse Gamma distribution
    tau2_x_prior = lsl.Dist(tfd.InverseGamma, concentration = a_x, scale = b_x)
    tau2_x = lsl.Var.new_param(10.0, distribution = tau2_x_prior, name = "tau2_x")

    # Define hyperparameters for mu_x (mean of x)
    location_mu_x_prior = lsl.Var.new_param(0.0, name = "location_mu_x_prior")
    scale_mu_x_prior = lsl.Var.new_param(1000.0, name= "scale_mu_x_prior")

    # Define prior for mu_x using a Normal distribution
    mu_x_prior = lsl.Dist(tfd.Normal, loc = location_mu_x_prior, scale = scale_mu_x_prior)

    # Define mu_x as a parameter with the prior distribution
    mu_x= lsl.Var.new_param(0.0, distribution = mu_x_prior, name = "mu_x")

    # Define prior distribution for x
    x_prior_dist = lsl.Dist(tfd.Normal, loc = mu_x, scale = tau2_x)

    # Estimate x using the mean of replicates and assign a prior distribution
    x_estimated = lsl.Var.new_param(jnp.expand_dims(x_tilde.mean(axis=1), -1), # initial estimation is the mean of the replicates
                                    distribution = x_prior_dist,
                                    name="x_estimated")

    # Define likelihood model for measurement error
    measurement_dist = lsl.Dist(
      tfd.MultivariateNormalFullCovariance,
      loc=x_estimated,
      covariance_matrix= sigma
      )

  # Define x_tilde (observed replicates) as observed data in the probabilistic model
    x_tilde_var = lsl.Var.new_obs(
        value = x_tilde,
        distribution=measurement_dist,
        name="x_tilde"
    )
    # Generate equidistant knots for spline basis functions based on the mean of replicates (x_tilde)
    knots = equidistant_knots(jnp.expand_dims(x_tilde.mean(axis=1), -1), n_param=n_param_splines, order=3)

    # Compute spline basis matrix for modeling the mean function
    basis_matrix_var_mu = lsl.Var.new_calc(
      lambda x: splines.basis_matrix(
          x.squeeze(),
          knots=knots,    # Use precomputed knots
          order=3,        # Cubic spline basis
          outer_ok=True),
        x = x_estimated,
        name="basis_matrix_mu"
        )

    # Compute spline basis matrix for modeling the scale
    basis_matrix_var_scale = lsl.Var.new_calc(
      lambda x: splines.basis_matrix(
          x.squeeze(),
          knots=knots,
          order=3,
          outer_ok=True),
        x = x_estimated,
        name="basis_matrix_scale"
        )
  else:
    # Generate equidistant knots for spline basis functions based on the mean of replicates (x_tilde)
    knots = equidistant_knots(jnp.expand_dims(x, -1), n_param=n_param_splines, order=3)

    # Compute spline basis matrix for modeling the mean function
    basis_matrix_var_mu = lsl.Var.new_calc(
      lambda x: splines.basis_matrix(
          x.squeeze(),
          knots=knots,    # Use precomputed knots
          order=3,        # Cubic spline basis
          outer_ok=True),
        x = x,
        name="basis_matrix_mu"
        )

    # Compute spline basis matrix for modeling the scale
    basis_matrix_var_scale = lsl.Var.new_calc(
      lambda x: splines.basis_matrix(
          x.squeeze(),
          knots=knots,
          order=3,
          outer_ok=True),
        x = x,
        name="basis_matrix_scale"
        )

  # Define intercept parameters for the mean and scale functions in the spline model
  b0_mu = lsl.Var.new_param(0.0, name="b0_mu")
  b0_scale = lsl.Var.new_param(0.0, name="b0_scale")

  # Define hyperparameters for the variance of the mean function
  a_var_mu = lsl.Var(0.001, name = "a_mu")
  b_var_mu = lsl.Var(0.001, name = "b_mu")

  # Define hyperparameters for the variance of the scale function
  a_var_scale = lsl.Var(0.001, name = "a_scale")
  b_var_scale = lsl.Var(0.001, name = "b_scale")

  # Define prior tau2_mu distributions using Inverse Gamma
  prior_tau2_mu = lsl.Dist(tfd.InverseGamma, concentration=a_var_mu, scale=b_var_mu)
  tau2_mu = lsl.Var.new_param(10.0, distribution = prior_tau2_mu, name= "tau2_mu")

  # Define prior tau2_scale distributions using Inverse Gamma
  prior_tau2_scale = lsl.Dist(tfd.InverseGamma, concentration = a_var_scale, scale = b_var_scale)
  tau2_scale = lsl.Var.new_param(10.0, distribution = prior_tau2_scale, name= "tau_scale")

  # Compute P-spline penalty matrix (2nd-order difference for smoothness)
  penalty = splines.pspline_penalty(d=n_param_splines,diff=2)

  # Define penalty matrices for scale and mean functions
  penalty_scale = lsl.Var(penalty, name= "penalty_scale")
  penalty_mu = lsl.Var(penalty, name= "penalty_mu")

  # Compute eigenvalues of the penalty matrix
  evals = jax.numpy.linalg.eigvalsh(penalty)

  # Compute rank of the penalty matrix (number of positive eigenvalues)
  rank_scale = lsl.Value(jnp.sum(evals > 0.0), _name= "rank_scale")
  rank_mu = lsl.Value(jnp.sum(evals > 0.0), _name= "rank_mu")

  # Compute log determinant of the penalty matrix (ignoring zero eigenvalues)
  log_pdet = jnp.log(jnp.where(evals > 0.0, evals, 1.0)).sum()
  log_pdet_mu = lsl.Value(log_pdet, _name= "log_pdet_mu")
  log_pdet_scale = lsl.Value(log_pdet, _name= "log_pdet_scale")

  # Define prior distribution for spline coefficients (scale function)
  prior_coef_scale  = lsl.Dist(
      MultivariateNormalDegenerate.from_penalty,
      loc= jnp.zeros(shape=(n_param_splines,)),
      var= tau2_scale,
      pen= penalty_scale,
      rank = rank_scale,
      log_pdet=log_pdet_scale
      )

  # Initialize spline coefficients for scale function
  start_value_scale = np.zeros(np.shape(penalty)[-1], np.float32)
  coef_scale = lsl.Var.new_param(start_value_scale, distribution= prior_coef_scale, name= "coef_scale")

  # Define prior distribution for spline coefficients (mean function)
  prior_coef_mu  = lsl.Dist(
      MultivariateNormalDegenerate.from_penalty,
      loc= jnp.zeros(shape=(n_param_splines,)),
      var=tau2_mu,
      pen= penalty_mu,
      rank = rank_mu,
      log_pdet=log_pdet_mu
      )

  # Initialize spline coefficients for mean function
  start_value_mu = np.zeros(np.shape(penalty)[-1], np.float32)
  coef_mu = lsl.Var.new_param(start_value_mu, distribution= prior_coef_mu, name= "coef_mu")

  def pred_fn(beta0, spline_coef, basis_matrix):
    return beta0 + jnp.dot(basis_matrix, spline_coef)

    # Compute the scale (standard deviation) of y using the spline model
  scale_of_y = lsl.Var.new_calc(
      pred_fn,
      beta0=b0_scale,                       # Intercept for scale function
      spline_coef=coef_scale,               # Spline coefficients for scale
      basis_matrix=basis_matrix_var_scale,  # Basis matrix for scale
      name="scale_of_y"
  )

  # Transform scale_of_y to ensure positivity (exponential transformation)
  scale_of_y_transformed = lsl.Var.new_calc(jnp.exp, scale_of_y, name = "scale_of_y_transformed")

  # Compute the mean (mu) of y using the spline model
  mu_of_y = lsl.Var.new_calc(
      pred_fn,
      beta0= b0_mu,
      spline_coef = coef_mu,
      basis_matrix = basis_matrix_var_mu,
      name="mu_of_y"
  )

  # Define the likelihood distribution of y (Normal with estimated mean and scale)
  y_dist = lsl.Dist(
      tfd.Normal,
      loc=mu_of_y,
      scale= scale_of_y_transformed
    )

  # Define y as an observed variable with the specified distribution
  y_var = lsl.Var.new_obs(
      value=y,
      distribution=y_dist,
      name="y"
  )

  if sample_x:
    return lsl.Model([y_var, x_tilde_var])

  else:
    return lsl.Model([y_var])


In [None]:
def engine_builder(model, x_sample = False):

  def transition_tau_mu(prng_key, model_state):
    """
    Sample tau2_mu from its posterior distribution using Gibbs sampling.

    Args:
        prng_key: The random number generator key for sampling.
        model_state: A dictionary containing the model parameters and state.

    Returns:
        dict: A dictionary containing the sampled tau2_mu.
    """
    # Extract relevant parameters from model state
    pos = interface.extract_position(
        position_keys=["a_mu", "b_mu", "rank_mu", "penalty_mu", "coef_mu"],
        model_state=model_state
    )
    # Extract values from position
    a_prior = pos["a_mu"]
    b_prior = pos["b_mu"]
    rank = pos["rank_mu"]
    K = pos["penalty_mu"]
    beta = pos["coef_mu"]

    # Compute the Gibbs sampling parameters
    a_gibbs = jnp.squeeze(a_prior + 0.5 * rank)
    b_gibbs = jnp.squeeze(b_prior + 0.5 * (beta @ K @ beta))

    # Draw a sample from the gamma distribution
    draw = b_gibbs / jax.random.gamma(prng_key, a_gibbs)

    return {"tau2_mu": draw}


  def transition_tau_scale(prng_key, model_state):
    """
    Sample tau_scale from its posterior distribution using Gibbs sampling.

    Args:
        prng_key: The random number generator key for sampling.
        model_state: A dictionary containing the model parameters and state.

    Returns:
        dict: A dictionary containing the sampled tau_scale.
    """
    # Extract relevant parameters from model state
    pos = interface.extract_position(
        position_keys=["a_scale", "b_scale", "rank_scale", "penalty_scale", "coef_scale"],
        model_state=model_state
    )
    # Extract values from position
    a_prior = pos["a_scale"]
    b_prior = pos["b_scale"]
    rank = pos["rank_scale"]
    K = pos["penalty_scale"]
    beta = pos["coef_scale"]

    # Compute the Gibbs sampling parameters
    a_gibbs = jnp.squeeze(a_prior + 0.5 * rank)
    b_gibbs = jnp.squeeze(b_prior + 0.5 * (beta @ K @ beta))

    # Draw a sample from the gamma distribution
    draw = b_gibbs / jax.random.gamma(prng_key, a_gibbs)

    return {"tau_scale" : draw}

  def transition_mu_x(prng_key, model_state):
    """
    Sample mu_x from its posterior distribution conditioned on the data.

    Args:
        prng_key: The random number generator key for sampling.
        model_state: A dictionary containing the model parameters and state.

    Returns:
        dict: A dictionary containing the sampled mu_x.
    """
    # Extract relevant parameters from model state
    pos = interface.extract_position(
        position_keys=["x_estimated", "tau2_mu", "tau2_x", "a_x", "b_x"],
        model_state=model_state
    )
    x = pos["x_estimated"]
    n = len(x)
    tau2_mu = pos["tau2_mu"]
    tau2_x = pos["tau2_x"]
    a_x = pos["a_x"]
    b_x = pos["b_x"]

    # Compute the posterior mean and standard deviation for mu_x
    normal_sample = jax.random.normal(prng_key, (1,))
    mu_mean = (n * jnp.mean(x) * tau2_mu) / (n * tau2_mu + tau2_x)
    mu_std = jnp.sqrt(tau2_x * tau2_mu / (n * tau2_mu + tau2_x))

    # Sample mu_x from a normal distribution
    mu_x = jnp.squeeze(mu_mean + mu_std * normal_sample)

    return {"mu_x": mu_x}


  def transition_tau2_x(prng_key, model_state):
    """
    Sample tau2_x from its posterior distribution using the inverse gamma distribution.

    Args:
        prng_key: The random number generator key for sampling.
        model_state: A dictionary containing the model parameters and state.

    Returns:
        dict: A dictionary containing the sampled tau2_x.
    """
    # Extract relevant parameters from model state
    pos = interface.extract_position(
        position_keys=["a_x", "b_x", "x_estimated", "mu_x", "b_x"],
        model_state=model_state
    )
    a_x = pos["a_x"]
    b_x = pos["b_x"]
    x = pos["x_estimated"]
    n = len(x)
    mu_x = pos["mu_x"]

    # Compute the new alpha and beta for the inverse gamma distribution
    alpha_new = a_x + n / 2
    beta_new = b_x + ((x - mu_x)**2).sum() / 2

    # Sample tau2_x from the inverse gamma distribution
    tau2_x = jnp.squeeze(tfd.InverseGamma(concentration=alpha_new, scale=beta_new).sample(seed=prng_key))

    return {"tau2_x" : tau2_x}


  def x_proposal(key, model_state, step_size):
    """
    Propose a new value for x using a Metropolis-Hastings proposal distribution.

    Args:
        key: The random number generator key for sampling.
        model_state: A dictionary containing the model parameters and state.
        step_size: A scaling factor for the proposal distribution.

    Returns:
        gs.MHProposal: A Metropolis-Hastings proposal object containing the proposed x values and the log correction.
    """
    # Extract current values of x_estimated and x_tilde from model state
    pos = interface.extract_position(
        position_keys=["x_estimated", "x_tilde"],
        model_state=model_state
    )
    x_current = pos["x_estimated"].squeeze()  # Turn from (n, 1) into (n,)
    n, m = pos["x_tilde"].shape

    # Initialize the step size scaling factor and covariance matrix
    g = 1.0
    M = m
    eye_matrices = jnp.eye(m, dtype=jnp.float32)

    # Create a stack of identity matrices and draw normal samples
    sigma = jnp.stack([eye_matrices for _ in range(n)])  # (n, m, m)
    normal_samples = jax.random.normal(key, (n,))  # (n,)

    # Compute the scale factor for the proposal
    trace_values = jnp.trace(sigma, axis1=1, axis2=2)  # (n,)
    scale_factor = (g * trace_values / (M**2))  # (n,)

    # Propose new x values
    x_proposed = x_current + scale_factor * normal_samples * step_size  # (n,)
    pos = {"x_estimated": jnp.expand_dims(x_proposed, -1)}  # Turn back into (n, 1)

    return gs.MHProposal(pos, log_correction=0.0)

  #add kernels and return engine
  interface = gs.LieselInterface(model)
  eb_sample = gs.EngineBuilder(seed = 2 , num_chains=4)
  eb_sample.set_model(gs.LieselInterface(model))
  eb_sample.set_initial_values(model.state)

  eb_sample.add_kernel(gs.IWLSKernel(["coef_scale"]))
  eb_sample.add_kernel(gs.IWLSKernel(["coef_mu"]))
  eb_sample.add_kernel(gs.GibbsKernel(["tau2_mu"], transition_tau_mu))
  eb_sample.add_kernel(gs.GibbsKernel(["tau_scale"], transition_tau_scale))

  if x_sample:
    eb_sample.add_kernel(gs.GibbsKernel(["mu_x"], transition_mu_x))
    eb_sample.add_kernel(gs.RWKernel(["x_estimated"]))
    eb_sample.add_kernel(gs.GibbsKernel(["tau2_x"], transition_tau2_x))

  eb_sample.set_duration(warmup_duration = 1000, posterior_duration = 5000, thinning_posterior=10)

  eb_sample.positions_included = ["mu_of_y", "scale_of_y_transformed"]

  engine = eb_sample.build()

  return engine

In [None]:
def test_sampling(num_sim, n, M, c_u, n_knots, x_sampling=False, x_naive=False):
    """
    Perform multiple simulations to test spline-based modeling on Gaussian data.

    Parameters:
    - num_sim: Integer, number of simulations to run.
    - n: Integer, number of data points per simulation.
    - M: Some model parameter (purpose depends on generate_gaussian_data).
    - c_u: covariance parameter (used in generate_gaussian_data).
    - n_knots: Integer, number of knots in the spline basis.
    - x_sampling: Boolean, if True, enables sampling of x values.
    - x_naive: Boolean, if True, uses the mean of x_tilde instead of actual x values.

    Returns:
    - results: Dictionary containing various computed quantities across simulations.
    """

    np.random.seed(42)  # Set random seed for reproducibility

    # Generate an array of random seeds for each simulation
    seeds_array = np.random.randint(low=2, high=1000, size=num_sim)

    # Initialize result storage for different computed parameters
    mu_coef_results = np.zeros((n_knots, num_sim))  # Mean coefficients for mu
    scale_coef_results = np.zeros((n_knots, num_sim))  # Mean coefficients for scale
    mu_mean_results = np.zeros((n, num_sim))  # Mean of mu values
    sigma_mean_results = np.zeros((n, num_sim))  # Mean of sigma values

    tau2_mu_results = []  # List to store tau^2 for mu
    tau_scale_results = []  # List to store tau scale values
    x_values_generated = np.zeros((n, num_sim))  # Store generated x values

    # If x_sampling is enabled, initialize additional result storage
    if x_sampling:
        x_values = np.zeros((n, num_sim))  # Store estimated x values
        mu_x_results = []  # Store mean mu_x values
        tau2_x_results = []  # Store tau^2_x values

    # Loop over simulations
    for i in range(num_sim):
        # Generate Gaussian data with given parameters and seed
        y, x, x_tilde, sigma, y_true = generate_gaussian_data(n=n, seed=seeds_array[i], M=M, c_u=c_u)

        # If x_naive is True, use the mean of x_tilde as input to the model
        if x_naive:
            model = create_model(
                x=jnp.expand_dims(x_tilde.mean(axis=1), -1),
                y=y,
                sigma=sigma,
                n_param_splines=n_knots,
                sample_x=x_sampling
            )
            x_values_generated[:, i] = x_tilde.mean(axis=1)  # Store mean x_tilde values
        else:
            # Otherwise, use actual x values
            model = create_model(
                x=x,
                y=y,
                sigma=sigma,
                n_param_splines=n_knots,
                x_tilde=x_tilde,
                sample_x=x_sampling
            )
            x_values_generated[:, i] = x_tilde.mean(axis=1)  # Store mean x_tilde values

        # Build engine and run sampling
        engine = engine_builder(model, x_sample=x_sampling)
        engine.sample_all_epochs()

        # Extract results from the engine
        results = engine.get_results()
        samples = results.get_posterior_samples()
        summary = gs.Summary(results)

        # Store spline coefficient results
        mu_coef_results[:, i] = summary.quantities["mean"]["coef_mu"]
        scale_coef_results[:, i] = summary.quantities["mean"]["coef_scale"]

        # Store posterior mean estimates
        mu_mean_results[:, i] = samples["mu_of_y"].mean(axis=(0, 1))
        sigma_mean_results[:, i] = samples["scale_of_y_transformed"].mean(axis=(0, 1))

        # Store tau2_mu and tau_scale results
        tau2_mu_results.append(summary.quantities["mean"]["tau2_mu"])
        tau_scale_results.append(summary.quantities["mean"]["tau_scale"])

        # If x_sampling is enabled, store additional results
        if x_sampling:
            mu_x_results.append(summary.quantities["mean"]["mu_x"])
            tau2_x_results.append(summary.quantities["mean"]["tau2_x"])
            x_values[:, i] = summary.quantities["mean"]["x_estimated"].squeeze()

    # Compile results into a dictionary
    results = {
        "mu_coef_results": mu_coef_results,
        "scale_coef_results": scale_coef_results,
        "tau2_mu_results": tau2_mu_results,
        "tau_scale_results": tau_scale_results,
        "mu_mean": mu_mean_results,
        "sigma_mean": sigma_mean_results,
        "x_values_generated": x_values_generated
    }

    # If x_sampling is enabled, add x-related results
    if x_sampling:
        x_results = {
            "tau2_x_results": tau2_x_results,
            "mu_x_results": mu_x_results,
            "sampled_x_values": x_values
        }
        results = results | x_results  # Merge dictionaries

    return results  # Return the compiled results
