In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from dataloader import load_data
import GPVarInf as GPVI
from scipy.stats import norm
import pandas as pd

# Load the trained model
model_path = "model_latest.npz"
model_data = np.load(model_path, allow_pickle=True)
params = {k: jnp.array(v) for k, v in model_data['params'].item().items()}

# Load data
A, X_cov, condition_list = load_data()

# Recreate model
n_inducing = params['q_mu'].shape[0]
model = GPVI.VariationalGP(
    inducing_points=model_data['inducing_points'],
    condition_list=condition_list,
    num_inducing=n_inducing
)
model.set_params(params)

# Compute posterior mean and variance
def get_posterior_predictive(model, X):
    """Get posterior predictive distribution at test points"""
    K_mm = model.compute_Kmm(
        model.log_lengthscale,
        model.log_variance,
        model.log_scale
    )
    K_nm = model.compute_Knm(
        X,
        model.log_lengthscale,
        model.log_variance,
        model.log_scale
    )
    
    # Posterior mean
    f_mean = K_nm @ jnp.linalg.solve(K_mm, model.q_mu)
    
    # Posterior variance
    v1 = jnp.sum(K_nm * jnp.linalg.solve(K_mm, K_nm.T).T, axis=1)
    v2 = jnp.sum(K_nm @ jnp.linalg.solve(K_mm, model.q_sqrt) * \
                 K_nm @ jnp.linalg.solve(K_mm, model.q_sqrt), axis=1)
    f_var = v1 + v2
    
    return f_mean.flatten(), f_var

# Compute predictions
f_mean, f_var = get_posterior_predictive(model, X_cov)

# Plot posterior distributions for each condition
plt.figure(figsize=(15, 6))
unique_conditions = np.unique(X_cov[:, 0])
for condition in unique_conditions:
    mask = X_cov[:, 0] == condition
    plt.errorbar(
        X_cov[mask, 1],  # x-coordinate
        f_mean[mask],    # predicted mean
        yerr=2*np.sqrt(f_var[mask]),  # 2 standard deviations
        label=f'Condition {condition}',
        fmt='o',
        alpha=0.5
    )
plt.xlabel('X coordinate')
plt.ylabel('Predicted value')
plt.title('Posterior Predictions with 95% Confidence Intervals')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

# Analyze model uncertainty
uncertainty_df = pd.DataFrame({
    'Condition': X_cov[:, 0],
    'X': X_cov[:, 1],
    'Y': X_cov[:, 2],
    'Mean': f_mean,
    'Std': np.sqrt(f_var)
})

# Plot uncertainty heatmap
plt.figure(figsize=(12, 8))
pivot_table = uncertainty_df.pivot_table(
    values='Std',
    index='Y',
    columns='X',
    aggfunc='mean'
)
sns.heatmap(pivot_table, cmap='viridis')
plt.title('Model Uncertainty across Space')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.show()

# Compare predictions with actual data
plt.figure(figsize=(10, 6))
plt.scatter(A.flatten(), f_mean, alpha=0.5)
plt.plot([A.min(), A.max()], [A.min(), A.max()], 'r--')
plt.xlabel('True Values')
plt.ylabel('Predicted Values')
plt.title('Prediction vs Truth')
plt.show()

# Print model parameters
print("Model Parameters:")
print(f"Lengthscale: {np.exp(params['log_lengthscale'][0]):.3f}")
print(f"Variance: {np.exp(params['log_variance'][0]):.3f}")
print(f"Scale: {np.exp(params['log_scale'][0]):.3f}")