# World Model Core
*Sean Steinle, Kiya Aminfar*

This notebook walks through the core aspects of world models, developing crucial pieces of code sequentially. Not that this code isn't meant for scale -- instead, this is for a demonstration of how we developed the code that we did.

## Table of Contents
1. [Collecting Rollout Data](#Collecting-Rollout-Data)
2. [Training the VAE](#Training-the-VAE)
3. [Training the MDN-RNN](#Training-the-MDN-RNN)
    - [Prepping Rollout Data for the MDN-RNN](#Prepping-Rollout-Data-for-the-MDN-RNN)
    - [Core Training](#Core-Training)
4. [Training the Controller](#Training-the-Controller)
5. [Early Results](#Early-Results)

In [1]:
import gymnasium as gym
import matplotlib.pyplot as plt
import os
import numpy as np

## Collecting Rollout Data

In [2]:
#Let's begin by creating an instance of our humanoid environment and checking out what basic observations look like.
env = gym.make('Humanoid-v5', render_mode="rgb_array")
obs, info = env.reset()

In [3]:
obs.shape, obs

((348,),
 array([ 1.39268283e+00,  9.91937144e-01,  9.44525677e-03,  3.05778112e-03,
         4.25122819e-03,  3.89551230e-03, -6.64649138e-03, -3.33043362e-03,
         8.83578660e-03, -7.65481172e-03, -3.32984040e-03,  7.91978014e-03,
        -7.43426079e-03,  7.32167683e-03, -9.19272568e-03, -8.97806422e-04,
         2.29475980e-03, -2.74734731e-03, -6.14699041e-03,  1.34900528e-03,
         3.89232034e-03, -9.94205042e-04,  7.11374700e-03,  3.53623743e-03,
         9.81995706e-03,  3.03830396e-03, -9.44202346e-03,  9.03580781e-03,
        -1.21417030e-04,  5.73297512e-03, -9.94344368e-03,  3.64776655e-03,
         2.91604003e-03, -6.41789646e-03, -9.80838337e-03,  7.64872058e-03,
        -5.22639824e-03,  7.72532911e-03,  7.37694688e-03, -3.57145284e-03,
        -7.54092664e-03, -2.81269564e-03,  2.96084784e-03,  2.70164935e-03,
        -9.19551159e-03,  2.30438464e+00,  2.28720705e+00,  4.50605807e-02,
        -1.36298659e-03,  7.36661043e-02,  4.60364686e-02, -1.59461382e-01,
   

In [4]:
info

{'x_position': np.float64(0.005202448265680894),
 'y_position': np.float64(-0.0048011356635689765),
 'tendon_length': array([0.00829492, 0.01124962]),
 'tendon_velocity': array([-0.00034838, -0.00339049]),
 'distance_from_origin': np.float64(0.007079291745441775)}

As we can see, the humanoid environment gives us a TON of observations! We get dozens of variables representing various positions and velocities of body parts, the center of mass, and a lot of other variables I hardly understand. For an exhaustive list, see the [doc](https://gymnasium.farama.org/environments/mujoco/humanoid/#observation-space). The fact that there are so many variables here is what makes learning latent observations so obvious!

We also get some nice summary stats in info, but we aren't going to include them in our scrape.

In [5]:
def collect_rollout_data(env_name: str, out_dir: str, n_timesteps: int=10000, print_n_episodes: int=1000):
    """Simulates `n_timesteps` in the `env_name` environment, saving observations, rewards, and actions to a triplet of .npy files at `out_dir`."""
    env = gym.make(env_name, render_mode='rgb_array')
    obs, info = env.reset()
    observations, rewards, actions, done = [], [] , [], []
    episode_count = 0

    for timestep in range(n_timesteps):  # Run for n_timesteps or until the episode ends
        action = env.action_space.sample() #select random action
        obs, reward, terminated, truncated, info = env.step(action) #execute and get results
        observations.append(obs) #save observation
        rewards.append(reward) #save reward
        actions.append(action) #save action
        done.append(terminated or truncated) #save timestep of each episode's boundary
        if terminated or truncated: #check for game over, if so reset env
            episode_count+=1
            if episode_count % print_n_episodes == 0: print(f"finished {episode_count} episodes") #provide update on training
            observation, info = env.reset()
        env.close()
    np_obs, np_rewards, np_actions, np_done = np.array(observations), np.array(rewards), np.array(actions), np.array(done)
    print(f"observations has shape: {np_obs.shape}\trewards has shape: {np_rewards.shape}\tactions has shape: {np_actions.shape}\tdone has shape: {np_done.shape}")
    np.save(f'{out_dir}/{env_name}_{n_timesteps}_rollout_observations.npy', np_obs) #load with: new_obs = np.load("../data/processed/Humanoid-v5_10000_rollout_observations.npy")
    np.save(f'{out_dir}/{env_name}_{n_timesteps}_rollout_rewards.npy', np_rewards)
    np.save(f'{out_dir}/{env_name}_{n_timesteps}_rollout_actions.npy', np_actions)
    np.save(f'{out_dir}/{env_name}_{n_timesteps}_rollout_done.npy', np_done)
    return np_obs, np_rewards, np_actions, np_done

In [6]:
humanoid_obs, humanoid_rewards, humanoid_actions, humanoid_done = collect_rollout_data('Humanoid-v5', "../data/processed", 10000, 100)

finished 100 episodes
finished 200 episodes
finished 300 episodes
finished 400 episodes
observations has shape: (10000, 348)	rewards has shape: (10000,)	actions has shape: (10000, 17)	done has shape: (10000,)


In [7]:
humanoid_rewards

array([4.89290012, 4.92391502, 4.90372702, ..., 4.93088997, 4.99549046,
       4.98390821])

In [8]:
humanoid_actions

array([[ 0.23505302,  0.36002648, -0.29734674, ..., -0.05805098,
         0.24405803,  0.2011473 ],
       [ 0.38455474, -0.3011133 , -0.05101318, ..., -0.0141068 ,
        -0.16276172, -0.0362853 ],
       [ 0.29848835, -0.03167867,  0.21788624, ..., -0.36390272,
         0.21806365, -0.21161628],
       ...,
       [-0.04431577,  0.19833574, -0.29948762, ..., -0.01792738,
         0.31818205, -0.14132701],
       [-0.2908491 ,  0.21173128,  0.33720645, ...,  0.00166669,
         0.01706297, -0.27317977],
       [ 0.00698648,  0.09699604, -0.19836254, ...,  0.3916585 ,
         0.35469708, -0.10838389]], dtype=float32)

In [9]:
humanoid_done, humanoid_done.sum()

(array([False, False, False, ..., False, False, False]), np.int64(412))

## Training the VAE

Now that we have an easy function for gathering experiences in the environment, we need to train the VAE module of our world model which will compress the observation space into latent space with fewer dimensions.

Note that the original World Model implementation worked with tensorflow 1.18.0. This is incredibly outdated (worked with Python 3.5), so let's get a newer version (tensorflow 2.19.0). Additionally, we need to change the structure of the VAE from working with images to working with a vector of observation data! Luckily, ChatGPT is very good at updating code (or it will be very obvious if it is not!).

In [10]:
import tensorflow as tf
from tensorflow.keras import layers, Model, saving

@saving.register_keras_serializable()
class MLPVAE(Model):
    def __init__(self, input_dim=348, z_size=32, kl_tolerance=0.5):
        super(MLPVAE, self).__init__()
        self.z_size = z_size
        self.kl_tolerance = kl_tolerance

        # Encoder
        self.encoder = tf.keras.Sequential([
            layers.InputLayer(input_shape=(input_dim,)),
            layers.Dense(256, activation='relu'),
            layers.Dense(128, activation='relu'),
            layers.Dense(2 * z_size),  # output both mu and logvar
        ])

        # Decoder
        self.decoder = tf.keras.Sequential([
            layers.InputLayer(input_shape=(z_size,)),
            layers.Dense(128, activation='relu'),
            layers.Dense(256, activation='relu'),
            layers.Dense(input_dim, activation='linear'),  # output same shape as input
        ])

    def sample_z(self, mu, logvar):
        eps = tf.random.normal(shape=tf.shape(mu))
        sigma = tf.exp(0.5 * logvar)
        return mu + sigma * eps

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = tf.split(h, num_or_size_splits=2, axis=1)
        logvar = tf.clip_by_value(logvar, -10.0, 10.0)  # helps with exploding values
        return mu, logvar

    def decode(self, z):
        return self.decoder(z)

    def call(self, x):
        mu, logvar = self.encode(x)
        z = self.sample_z(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

    def compute_loss(self, x):
        x_recon, mu, logvar = self(x)
        recon_loss = tf.reduce_mean(tf.reduce_sum(tf.square(x - x_recon), axis=1))
        kl_loss = -0.5 * tf.reduce_sum(1 + logvar - tf.square(mu) - tf.exp(logvar), axis=1)
        kl_loss = tf.maximum(kl_loss, self.kl_tolerance * self.z_size)
        kl_loss = tf.reduce_mean(kl_loss)
        total_loss = recon_loss + kl_loss
        return total_loss, recon_loss, kl_loss

2025-05-10 10:06:02.310980: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-10 10:06:02.330789: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746885962.350748    6812 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746885962.355915    6812 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746885962.371254    6812 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

Let's also write a training function for convenience.

In [11]:
def create_dataset(x_train, batch_size=64, shuffle_buffer=10000):
    # Assuming x_train is a NumPy array of shape [n_samples, 348]
    dataset = tf.data.Dataset.from_tensor_slices(x_train.astype(np.float32))
    dataset = dataset.shuffle(shuffle_buffer).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

def train_vae(model, dataset, epochs=10, learning_rate=1e-4):
    optimizer = tf.keras.optimizers.Adam(learning_rate)

    for epoch in range(epochs):
        total_loss = 0.0
        total_batches = 0
        for x_batch in dataset:
            with tf.GradientTape() as tape:
                loss, recon_loss, kl_loss = model.compute_loss(x_batch)
            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            total_loss += loss.numpy()
            total_batches += 1

        avg_loss = total_loss / total_batches
        print(f"Epoch {epoch+1}: avg loss = {avg_loss:.4f}")


In [12]:
# x_train should be a NumPy array of shape (n_samples, 348)
x_train = humanoid_obs
x_train = (x_train - np.mean(x_train, axis=0)) / (np.std(x_train, axis=0) + 1e-6)

dataset = create_dataset(humanoid_obs, batch_size=64)
vae = MLPVAE(input_dim=348, z_size=32)
train_vae(vae, dataset, epochs=20)

2025-05-10 10:06:08.435728: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
2025-05-10 10:06:25.713385: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 1: avg loss = 263565.5625


2025-05-10 10:06:47.085948: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 2: avg loss = 125894.2422
Epoch 3: avg loss = 68511.4688


2025-05-10 10:07:18.918455: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 4: avg loss = 58257.0000
Epoch 5: avg loss = 52543.2148
Epoch 6: avg loss = 46821.2383
Epoch 7: avg loss = 39593.5703


2025-05-10 10:08:16.371272: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 8: avg loss = 33299.0352
Epoch 9: avg loss = 28569.9590
Epoch 10: avg loss = 24761.4082
Epoch 11: avg loss = 22122.9727
Epoch 12: avg loss = 20141.1992
Epoch 13: avg loss = 18348.6875
Epoch 14: avg loss = 16714.6973
Epoch 15: avg loss = 15301.8691


2025-05-10 10:10:20.412629: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 16: avg loss = 14170.6035
Epoch 17: avg loss = 13216.3174
Epoch 18: avg loss = 12370.6289
Epoch 19: avg loss = 11716.9619
Epoch 20: avg loss = 11140.0439


Loss is going down! At first I got a tons of NAN values, but it's because I wasn't normalizing the input data and I also needed to clip the logvar values we were getting as a result of the encoding process. If you get NANs again, a lower learning rate could help too. Onto saving the model!

In [13]:
vae.save_weights('../models/vae/humanoid_10000_vae_model.weights.h5') #save ONLY weights -- much simpler than serializing the entire object

In [14]:
new_vae = MLPVAE(input_dim=348, z_size=32) #instantiate new model object 
new_vae(tf.zeros((1, 348))) #invoke it to build its shape 
new_vae.load_weights('../models/vae/humanoid_10000_vae_model.weights.h5') #now load weights into empty vector

In [15]:
new_vae

<MLPVAE name=mlpvae_1, built=True>

In [16]:
vae

<MLPVAE name=mlpvae, built=True>

## Training the MDN-RNN

Now that we have a model which captures observations, we're theoretically ~1/3 done with the project! I say theoretically because this was probably the easiest part of the project. Now onto the meat of world models: capturing the transitions of our environment and training the MDN-RNN!

### Prepping Rollout Data for the MDN-RNN

To train the MDN-RNN, we first need to enhance our basic rollout dataset with predictions of `mu` and `logvar` for each experience. Then we'll feed this information to the MDN-RNN.

In [17]:
#We can use the dataset records still in memory. Note that we only need the observations for now!
humanoid_obs.shape, humanoid_rewards.shape, humanoid_actions.shape, humanoid_done.shape

((10000, 348), (10000,), (10000, 17), (10000,))

In [18]:
#We can also use the VAE still in memory!
vae

<MLPVAE name=mlpvae, built=True>

In [19]:
humanoid_obs[0]

array([ 1.40109648e+00,  9.99991324e-01, -2.09374503e-03, -2.97582050e-03,
        2.02814760e-03,  5.64103249e-02,  9.16979137e-03, -5.47811860e-03,
        1.02078401e-02, -8.49019545e-02, -1.71167293e-02, -4.61241440e-02,
       -1.58210413e-02,  3.47086999e-02, -3.93887371e-02, -9.62425597e-02,
        5.10217497e-04,  1.76552213e-02, -8.20650955e-03, -8.42943297e-03,
        1.19607831e-02,  2.47622933e-02, -1.12776856e-01, -1.95614432e-02,
       -1.53352677e-01,  6.11908684e-01, -2.56439194e-01, -1.31310136e+00,
        5.27343031e+00,  2.85599299e-01, -1.63061265e+00,  8.53757914e-01,
       -7.92272959e+00, -1.22734026e+00, -6.21301546e+00, -9.60282040e-01,
        3.39423908e+00, -4.65032510e+00, -1.02764365e+01,  1.05111100e+00,
        1.09217756e+00, -1.74632970e+00, -1.73952058e-02,  2.21747859e+00,
        1.67433277e+00,  2.30091924e+00,  2.28477017e+00,  4.42346651e-02,
        2.78891814e-04,  7.77510624e-02, -6.10290447e-03, -1.64634138e-01,
        1.14942495e-02,  

In [20]:
#That's it! Just predict for our entire observation set.
mu, logvar = vae.encode(humanoid_obs)
humanoid_z = vae.sample_z(mu, logvar).numpy()
humanoid_z.shape

(10000, 32)

In [21]:
np.save(f'../data/processed/Humanoid-v5_10000_rollout_z.npy', humanoid_z)

In [22]:
#We can use the dataset records still in memory. Note that we only need the observations for now!
humanoid_obs.shape, humanoid_rewards.shape, humanoid_actions.shape, humanoid_done.shape, humanoid_z.shape

((10000, 348), (10000,), (10000, 17), (10000,), (10000, 32))

### Core Training

Now that we have our mu and logvar arrays and also episode-wise aggregations of all of our data, we can train our MDN-RNN!

In [23]:
#CORE MODEL CLASS
class MDNRNN(tf.keras.Model):
    def __init__(self, latent_dim, action_dim, hidden_dim=256, num_mixtures=5):
        super().__init__()
        self.latent_dim = latent_dim
        self.action_dim = action_dim
        self.input_dim = latent_dim + action_dim
        self.hidden_dim = hidden_dim
        self.num_mixtures = num_mixtures

        # LSTM
        self.lstm = layers.LSTM(hidden_dim, return_sequences=True, return_state=True)

        # MDN output: means, stddevs, and mixture weights for latent prediction
        self.mdn_dense = layers.Dense(num_mixtures * (2 * latent_dim + 1))

        # Predict reward (scalar)
        self.reward_dense = layers.Dense(1)

        # Predict done (binary classification)
        self.done_dense = layers.Dense(1, activation="sigmoid")

    def call(self, inputs, initial_state=None, training=False):
        """
        inputs: (batch, seq_len, latent_dim + action_dim)
        """
        lstm_out, h, c = self.lstm(inputs, initial_state=initial_state, training=training)

        mdn_out = self.mdn_dense(lstm_out)
        reward_pred = self.reward_dense(lstm_out)
        done_pred = self.done_dense(lstm_out)

        return mdn_out, reward_pred, done_pred, [h, c]

    def get_mdn_params(self, mdn_out):
        """Split MDN output into pi, mu, sigma."""
        out = tf.reshape(mdn_out, [-1, self.num_mixtures, 2 * self.latent_dim + 1])
        pi = out[:, :, 0]
        mu = out[:, :, 1 : 1 + self.latent_dim]
        log_sigma = out[:, :, 1 + self.latent_dim :]
        sigma = tf.exp(log_sigma)

        pi = tf.nn.softmax(pi, axis=-1)  # mixture weights
        return pi, mu, sigma


In [24]:
#LOSS IMPLEMENTATION
import tensorflow_probability as tfp

def mdn_loss(z_target, pi, mu, sigma, eps=1e-8):
    """
    z_target: [batch * seq_len, latent_dim]
    pi: [batch * seq_len, num_mixtures]
    mu: [batch * seq_len, num_mixtures, latent_dim]
    sigma: [batch * seq_len, num_mixtures, latent_dim]
    """
    # Expand target for broadcasting: [batch, 1, latent_dim]
    z_expanded = tf.expand_dims(z_target, axis=1)

    # Create component Gaussians
    normal_dist = tfp.distributions.Normal(loc=mu, scale=sigma)
    log_probs = normal_dist.log_prob(z_expanded)  # shape: [batch, num_mixtures, latent_dim]

    # Sum over latent_dim: total log prob of each mixture component
    log_probs = tf.reduce_sum(log_probs, axis=-1)  # shape: [batch, num_mixtures]

    # Weight by mixture coefficients
    weighted_log_probs = log_probs + tf.math.log(pi + eps)  # log(pi * P)
    
    # LogSumExp over mixture components to marginalize
    log_likelihood = tf.reduce_logsumexp(weighted_log_probs, axis=-1)  # shape: [batch]

    # Negative log-likelihood
    return -tf.reduce_mean(log_likelihood)

def combined_loss(z_target, pi, mu, sigma, reward_target, reward_pred, done_target, done_pred,
                  reward_weight=1.0, done_weight=1.0):
    loss_mdn = mdn_loss(z_target, pi, mu, sigma)
    loss_reward = tf.reduce_mean(tf.square(reward_target - reward_pred))
    loss_done = tf.reduce_mean(tf.keras.losses.binary_crossentropy(done_target, done_pred))

    return loss_mdn + reward_weight * loss_reward + done_weight * loss_done


In [25]:
#TRAINING FUNCTION
def train_mdnrnn(model, dataset, epochs=10, learning_rate=1e-4, reward_weight=1.0, done_weight=1.0):
    optimizer = tf.keras.optimizers.Adam(learning_rate)
    
    for epoch in range(epochs):
        total_loss = 0.0
        total_mdn_loss = 0.0
        total_reward_loss = 0.0
        total_done_loss = 0.0
        total_batches = 0

        for z_action, (z_next, reward, done) in dataset:
            # Flatten z_action into shape (batch, seq_len, latent_dim + action_dim)
            with tf.GradientTape() as tape:
                # Forward pass through the model
                mdn_out, reward_pred, done_pred, _ = model(z_action, training=True)
                pi, mu, sigma = model.get_mdn_params(mdn_out)

                # Compute MDN, reward, and done losses
                loss_mdn = mdn_loss(tf.reshape(z_next, [-1, model.latent_dim]), pi, mu, sigma)
                loss_reward = tf.reduce_mean(tf.square(reward - reward_pred))
                loss_done = tf.reduce_mean(tf.keras.losses.binary_crossentropy(done, done_pred))

                # Total loss
                total_loss = loss_mdn + reward_weight * loss_reward + done_weight * loss_done

            # Compute gradients and apply
            grads = tape.gradient(total_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            # Accumulate the loss values
            total_loss += total_loss.numpy()
            total_mdn_loss += loss_mdn.numpy()
            total_reward_loss += loss_reward.numpy()
            total_done_loss += loss_done.numpy()

            total_batches += 1

        # Compute and print average losses for the epoch
        avg_loss = total_loss / total_batches
        avg_mdn_loss = total_mdn_loss / total_batches
        avg_reward_loss = total_reward_loss / total_batches
        avg_done_loss = total_done_loss / total_batches

        print(f"Epoch {epoch+1}:")
        print(f"  avg loss = {avg_loss:.4f}, mdn_loss = {avg_mdn_loss:.4f}, "
              f"reward_loss = {avg_reward_loss:.4f}, done_loss = {avg_done_loss:.4f}")


In [26]:
#DATA HANDLING
def preprocess_data(observations, actions, rewards, dones, sequence_length):
    """
    Yields tuples of the form:
    z_action: (sequence_length, latent_dim + action_dim)
    targets:  (z_next, reward, done), each of shape (sequence_length, ...)
    """
    for i in range(len(observations) - sequence_length):
        z_seq = observations[i:i+sequence_length]
        a_seq = actions[i:i+sequence_length]
        z_action = np.concatenate([z_seq, a_seq], axis=-1).astype(np.float32)

        z_next = observations[i+1:i+sequence_length+1].astype(np.float32)
        reward = rewards[i+1:i+sequence_length+1].astype(np.float32)
        done = dones[i+1:i+sequence_length+1].astype(np.float32)

        yield (
            z_action.astype(np.float32),
            (
                z_next.astype(np.float32),
                reward[:, None].astype(np.float32),  # <–– reshape from (T,) to (T,1)
                done[:, None].astype(np.float32)     # <–– reshape from (T,) to (T,1)
            )
        )


sequence_length = 10
latent_dim = 32
action_dim = 17
batch_size = 32

train_dataset = tf.data.Dataset.from_generator(
    lambda: preprocess_data(humanoid_z, humanoid_actions, humanoid_rewards, humanoid_done, sequence_length),
    output_signature=(
        tf.TensorSpec(shape=(sequence_length, latent_dim + action_dim), dtype=tf.float32),
        (
            tf.TensorSpec(shape=(sequence_length, latent_dim), dtype=tf.float32),
            tf.TensorSpec(shape=(sequence_length, 1), dtype=tf.float32),
            tf.TensorSpec(shape=(sequence_length, 1), dtype=tf.float32),
        )
    )
)

train_dataset = train_dataset.batch(batch_size).shuffle(1000)

In [27]:
mdnrnn = MDNRNN(latent_dim=32, action_dim=17)
train_mdnrnn(mdnrnn, train_dataset)

2025-05-10 10:11:47.810847: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:8: Filling up shuffle buffer (this may take a while): 241 of 1000
2025-05-10 10:11:50.953005: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


Epoch 1:
  avg loss = 1.7256, mdn_loss = 1028.9547, reward_loss = 16.5348, done_loss = 0.5429


2025-05-10 10:13:31.184300: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:8: Filling up shuffle buffer (this may take a while): 245 of 1000
2025-05-10 10:13:33.761952: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


Epoch 2:
  avg loss = 1.1448, mdn_loss = 197.1293, reward_loss = 3.5318, done_loss = 0.2056


2025-05-10 10:15:09.544208: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:8: Filling up shuffle buffer (this may take a while): 266 of 1000
2025-05-10 10:15:11.356743: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


Epoch 3:
  avg loss = 1.1590, mdn_loss = 159.1100, reward_loss = 2.4223, done_loss = 0.1875


2025-05-10 10:16:52.642470: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:8: Filling up shuffle buffer (this may take a while): 240 of 1000
2025-05-10 10:16:56.002466: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


Epoch 4:
  avg loss = 0.9280, mdn_loss = 146.8151, reward_loss = 2.0830, done_loss = 0.1833


2025-05-10 10:18:37.506752: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:8: Filling up shuffle buffer (this may take a while): 276 of 1000
2025-05-10 10:18:38.968064: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


Epoch 5:
  avg loss = 0.8064, mdn_loss = 137.5134, reward_loss = 1.8499, done_loss = 0.1804


2025-05-10 10:20:23.664826: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:8: Filling up shuffle buffer (this may take a while): 268 of 1000
2025-05-10 10:20:25.566168: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


Epoch 6:
  avg loss = 0.8245, mdn_loss = 124.8104, reward_loss = 1.7106, done_loss = 0.1792


2025-05-10 10:22:00.674388: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:8: Filling up shuffle buffer (this may take a while): 267 of 1000
2025-05-10 10:22:02.464169: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


Epoch 7:
  avg loss = 0.6862, mdn_loss = 120.5293, reward_loss = 1.6048, done_loss = 0.1779


2025-05-10 10:23:37.590231: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:8: Filling up shuffle buffer (this may take a while): 257 of 1000
2025-05-10 10:23:39.789504: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


Epoch 8:
  avg loss = 0.7655, mdn_loss = 118.5863, reward_loss = 1.5204, done_loss = 0.1766


2025-05-10 10:25:09.073064: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:8: Filling up shuffle buffer (this may take a while): 289 of 1000
2025-05-10 10:25:09.994345: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


Epoch 9:
  avg loss = 0.6971, mdn_loss = 117.2180, reward_loss = 1.4526, done_loss = 0.1755


2025-05-10 10:26:32.048261: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:8: Filling up shuffle buffer (this may take a while): 293 of 1000
2025-05-10 10:26:32.777614: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


Epoch 10:
  avg loss = 0.7036, mdn_loss = 116.1478, reward_loss = 1.4004, done_loss = 0.1745


## Training the Controller

Now that we have a VAE and an MDN-RNN that are (presumably) learning, let's write the final piece of the puzzle -- the MDN-RNN!



In [56]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model, saving
import gymnasium as gym

# --- Controller model ---
@saving.register_keras_serializable()
class LinearController(Model):
    def __init__(self, input_dim, action_dim):
        super(LinearController, self).__init__()
        self.linear = layers.Dense(action_dim, activation='tanh', use_bias=False)
        self.linear(tf.zeros((1, input_dim)))  # force build so weights are initialized

    def call(self, inputs):
        return self.linear(inputs)

    def get_weights_flat(self):
        return tf.reshape(self.linear.kernel, [-1])

    def set_weights_flat(self, flat_weights):
        new_weights = tf.reshape(flat_weights, self.linear.kernel.shape)
        self.linear.kernel.assign(new_weights)

# --- Evolution strategy ---
def evolve_controller(controller, vae, mdn_rnn, env, generations=10, pop_size=64, sigma=0.1, elite_frac=0.2):
    input_dim = controller.linear.kernel.shape[0]
    action_dim = controller.linear.kernel.shape[1]
    weight_dim = input_dim * action_dim

    base_weights = controller.get_weights_flat().numpy()
    elite_num = max(1, int(pop_size * elite_frac))

    for gen in range(generations):
        population = [base_weights + sigma * np.random.randn(weight_dim) for _ in range(pop_size)]
        scores = []

        for i, individual in enumerate(population):
            controller.set_weights_flat(individual)
            reward = evaluate_controller(controller, vae, mdn_rnn, env)
            scores.append((reward, individual))

        scores.sort(key=lambda x: -x[0])
        elites = [w for _, w in scores[:elite_num]]
        new_mean = np.mean(elites, axis=0)
        base_weights = new_mean

        print(f"Gen {gen+1}: Best score = {scores[0][0]:.2f}")

    controller.set_weights_flat(base_weights)
    return controller

In [67]:
def evaluate_controller(controller, vae, mdn_rnn, env, max_steps=1000):
    total_reward = 0
    obs, _ = env.reset()
    done = False
    h = tf.zeros((1, 256))  # RNN hidden state (adjust as needed)

    state = [h, h]  # Initialize state for MDN-RNN (h, c)
    for step in range(max_steps):
        x = tf.convert_to_tensor(obs[None, :], dtype=tf.float32)
        mu, _ = vae.encode(x)
        z = mu  # Use the mean (mu) as the latent vector

        zh = tf.concat([z, h], axis=1)  # Concatenate latent vector and hidden state
        action = controller(zh).numpy()[0]  # Get action from controller

        # Step in the environment
        obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

        # Update hidden state using the MDN-RNN
        rnn_input = tf.concat([z, action[None]], axis=1)  # (latent_dim + action_dim)
        rnn_input = tf.expand_dims(rnn_input, axis=1)  # Shape becomes (1, 1, 50)

        mdn_out, reward_pred, done_pred, state = mdn_rnn(rnn_input, initial_state=state)

        total_reward += reward

        # Optionally: You can use reward_pred and done_pred for monitoring

        if done:
            break

    return total_reward

In [None]:
z_dim = 32
h_dim = 256
action_dim = env.action_space.shape[0]

controller = LinearController(input_dim=z_dim + h_dim, action_dim=action_dim)

# Assume `vae` and `mdn_rnn` are your pretrained models
trained_controller = evolve_controller(controller, vae, mdnrnn, env)

So right now, controller with JUST VAE is working. Also, we're missing the CMA-ES piece! We just do random search optimization right now.