In [1]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad, random
from functools import partial
from dataclasses import dataclass

from bayes.posterior import FlowBasedPosterior, PRNGKeyManager
from sinterp.interpolants import OneSidedLinear

# jax.config.update('jax_disable_jit', True)
jax.config.update("jax_debug_nans", True)

# 1. Define the Ground-Truth RL Environment
N_STATES = 4
N_ACTIONS = 2
REWARD_NOISE_STD = 0.1 # Fixed standard deviation for reward generation

# Use a key for reproducible ground-truth generation
key = random.PRNGKey(0)
key, T_key, R_key = random.split(key, 3)

# True Transition Matrix (S x A x S')
# Random logits, then softmax over the last dimension (s') to ensure probabilities sum to 1.
true_transition_logits = random.normal(T_key, (N_STATES, N_ACTIONS, N_STATES))
true_transition_matrix = jax.nn.softmax(true_transition_logits, axis=-1)

# True Reward Matrix (S x A)
# Rewards are centered around these values.
true_reward_matrix = random.uniform(R_key, (N_STATES, N_ACTIONS), minval=-5, maxval=5)


# 2. Define Parameter Space and Likelihood Function
REWARD_BOUND = 10.0
# The total number of parameters to learn for the forward model
PARAM_DIM = (N_STATES * N_ACTIONS * N_STATES) + (N_STATES * N_ACTIONS)
# (transition params)      +      (reward params)

@dataclass
class RLTheta:
    """A structured container for the RL forward model parameters."""
    log_transition_matrix: jnp.ndarray # Shape: (S, A, S')
    reward_matrix: jnp.ndarray         # Shape: (S, A)

def vec_to_theta_rl(h: jnp.ndarray) -> RLTheta:
    """
    Converts a flat vector `h` into a structured RLTheta object.
    Ensures the transition matrix is properly normalized.
    """
    # 1. Slice the flat vector `h` into its constituent parts.
    c = 0
    transition_params = h[c:c + N_STATES * N_ACTIONS * N_STATES].reshape(N_STATES, N_ACTIONS, N_STATES)
    c += N_STATES * N_ACTIONS * N_STATES
    
    reward_params = h[c:].reshape(N_STATES, N_ACTIONS)

    # 2. Ensure the transition matrix is valid using log_softmax.
    # This is numerically stable and guarantees that for each (s, a),
    # the probabilities over s' sum to 1.
    log_transition_matrix = jax.nn.log_softmax(transition_params, axis=-1)
    
    # 3. The reward matrix
    reward_matrix = REWARD_BOUND * jnp.tanh(reward_params)

    return RLTheta(log_transition_matrix, reward_matrix)

def forward_model_log_likelihood(theta: RLTheta, observation: jnp.ndarray):
    """
    Calculates log p(s', r | s, a, theta) for a single observation.
    observation is a vector [s, a, r, s'].
    """
    s, a, r, s_prime = observation[0], observation[1], observation[2], observation[3]
    
    # Cast state/action to integers to use as indices
    s, a, s_prime = s.astype(int), a.astype(int), s_prime.astype(int)

    # Log-likelihood of the transition p(s' | s, a, theta)
    log_p_transition = theta.log_transition_matrix[s, a, s_prime]
    
    # Log-likelihood of the reward p(r | s, a, theta)
    # Assumes reward is Gaussian with mean R(s,a) and fixed variance.
    expected_reward = theta.reward_matrix[s, a]
    log_p_reward = jax.scipy.stats.norm.logpdf(r, loc=expected_reward, scale=REWARD_NOISE_STD)
    
    # Total log-likelihood is the sum (since log(A*B) = log(A) + log(B))
    return log_p_transition + log_p_reward

def build_total_log_likelihood_and_grad_rl(observations):
    """
    Builds the function that computes the gradient of the total log-likelihood
    with respect to the FLATTENED parameter vector `h`.
    """
    y_data = jnp.array(observations)

    def total_log_likelihood(h, y_data_batch):
        """Calculates sum of log p(y_i | h) for a batch of observations."""
        theta = vec_to_theta_rl(h)
        
        # Use vmap to efficiently calculate likelihood over the whole batch
        log_likelihoods = vmap(partial(forward_model_log_likelihood, theta))(y_data_batch)
        
        return jnp.sum(log_likelihoods)

    # Differentiate with respect to the flattened parameter vector `h`
    total_log_likelihood_grad_fn = grad(total_log_likelihood, argnums=0)
    
    return total_log_likelihood_grad_fn, y_data

# 3. Initialize the Posterior Model
key_manager = PRNGKeyManager(seed=1)
# interpolator = get_interp('OneSidedLinear') # Assuming this is available

# Initialize the posterior over the PARAMETER space of the RL model
posterior = FlowBasedPosterior(
    build_total_log_likelihood_and_grad=build_total_log_likelihood_and_grad_rl,
    dim=PARAM_DIM,
    key_manager=key_manager,
    interpolator=OneSidedLinear(),
    distillation_threshold=100
)

# 4. Simulate and add observations to the model
print("--- Generating and Adding RL Observations ---")
num_observations = 200
for _ in range(num_observations):
    # Sample a random state and action
    key, s_key, a_key, t_key, r_key = random.split(key, 5)
    s = random.randint(s_key, (), 0, N_STATES)
    a = random.randint(a_key, (), 0, N_ACTIONS)
    
    # Sample next state from the true transition model
    s_prime = random.choice(t_key, N_STATES, p=true_transition_matrix[s, a])
    
    # Sample reward from the true reward model
    r = random.normal(r_key) * REWARD_NOISE_STD + true_reward_matrix[s, a]
    
    # Add the observation tuple (as a jax array)
    posterior.add_observation(jnp.array([s, a, r, s_prime]))

# # 5. Find the MAP estimate of the parameters
# h_map, final_log_prob = find_map_with_overparameterization(
#     posterior=posterior,
#     key_manager=key_manager,
#     num_steps=5000,
#     learning_rate=1e-3
# )


h_map = posterior.sample(key, (1,))[0]

# 6. Verify the result against the true RL model parameters
theta_map_rl = vec_to_theta_rl(h_map)

print("\n--- Verification of Learned RL Model ---")
# print(f"Final Average Log-Likelihood: {final_log_prob:.4f}")

# To compare transition matrices, we exponentiate the learned log-probabilities
found_transition_matrix = jnp.exp(theta_map_rl.log_transition_matrix)
found_reward_matrix = theta_map_rl.reward_matrix

print("\n--- Comparing True vs. Found Transition Matrix ---")
print("Note: Values are p(s' | s, a). Should be close if data is sufficient.")
for s in range(N_STATES):
    for a in range(N_ACTIONS):
        print(f"\nState={s}, Action={a}")
        print(f"  True:    {true_transition_matrix[s, a]}")
        print(f"  Found:   {found_transition_matrix[s, a]}")


print("\n--- Comparing True vs. Found Reward Matrix ---")
print("Note: Values are E[r | s, a]. Should be close.")
print("True:\n", true_reward_matrix)
print("Found:\n", found_reward_matrix)

2025-08-05 14:56:03.540229: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-08-05 14:56:03.540275: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-08-05 14:56:03.541060: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Initialized with a one-sided framework and OneSidedLinear interpolator.
--- Generating and Adding RL Observations ---

--- Distilling Likelihood (Unified Stochastic Score-Matching) ---
Distillation Step 0, Unified Loss: 2.5303


KeyboardInterrupt: 


--- Verification of Learned RL Model ---

--- Comparing True vs. Found Transition Matrix ---
Note: Values are p(s' | s, a). Should be close if data is sufficient.

State=0, Action=0
  True:    [0.04047961 0.06079901 0.5718052  0.3269162 ]
  Found:   [0.22874711 0.09776904 0.12090193 0.55258197]

State=0, Action=1
  True:    [0.19146848 0.12623502 0.13012268 0.5521739 ]
  Found:   [0.18394393 0.12047469 0.12903614 0.56654525]

State=1, Action=0
  True:    [0.02222157 0.693535   0.06817152 0.21607192]
  Found:   [0.20910555 0.01902924 0.71104014 0.06082499]

State=1, Action=1
  True:    [0.09101926 0.1581395  0.23025165 0.5205896 ]
  Found:   [0.20085382 0.10000684 0.61061186 0.08852748]

State=2, Action=0
  True:    [0.17254084 0.5293649  0.21808279 0.08001144]
  Found:   [0.05014689 0.40201223 0.4023899  0.14545095]

State=2, Action=1
  True:    [0.04666347 0.10220193 0.5334075  0.31772703]
  Found:   [0.2986915  0.08569892 0.09943836 0.5161712 ]

State=3, Action=0
  True:    [0.03750