# World Model Core
*Sean Steinle, Kiya Aminfar*

This notebook walks through the core aspects of world models, developing crucial pieces of code sequentially. Not that this code isn't meant for scale -- instead, this is for a demonstration of how we developed the code that we did.

## Table of Contents
1. [Collecting Rollout Data](#Collecting-Rollout-Data)
2. [Training the VAE](#Training-the-VAE)
3. [Training the MDN-RNN](#Training-the-MDN-RNN)
    - [Prepping Rollout Data for the MDN-RNN](#Prepping-Rollout-Data-for-the-MDN-RNN)
    - [Core Training](#Core-Training)
4. [Training the Controller](#Training-the-Controller)
5. [Early Results](#Early-Results)

In [1]:
import gymnasium as gym
import matplotlib.pyplot as plt
import os
import numpy as np

## Collecting Rollout Data

In [2]:
#Let's begin by creating an instance of our humanoid environment and checking out what basic observations look like.
env = gym.make('Humanoid-v5', render_mode="rgb_array")
obs, info = env.reset()

In [3]:
obs.shape, obs

((348,),
 array([ 1.40602619e+00,  9.90106099e-01,  7.21400582e-03, -4.67594716e-03,
         7.54542565e-03,  9.36836000e-03, -8.31537166e-03,  4.05115331e-03,
         4.02881107e-04,  7.90516910e-03, -6.83311819e-03, -9.89276723e-03,
        -6.14061298e-03, -5.75308330e-03, -9.91729475e-03,  6.02789791e-03,
         3.52052936e-05, -4.63146434e-03,  8.84241028e-03, -8.99571457e-03,
         9.77278327e-03, -5.49930428e-03, -1.87761082e-03, -4.74033658e-03,
        -5.01742490e-03,  8.56199520e-05, -1.80819407e-03,  5.26268603e-03,
         5.30459695e-03, -5.30271840e-03, -3.63168874e-03,  8.65317797e-03,
        -2.00539408e-04, -7.97576613e-04, -2.13111953e-03, -6.47277394e-03,
         5.64883141e-03, -3.62262650e-03, -4.14701912e-03,  2.13638882e-03,
         6.35422719e-03, -2.24358261e-03, -5.46226625e-03,  6.86606297e-03,
        -6.97945906e-03,  2.30250734e+00,  2.28859078e+00,  4.79432613e-02,
        -1.77998742e-03,  1.10994596e-01,  4.11934410e-02, -2.31683185e-01,
   

In [4]:
info

{'x_position': np.float64(-0.0005459767665555986),
 'y_position': np.float64(0.0012958979631623128),
 'tendon_length': array([ 0.01594519, -0.00305965]),
 'tendon_velocity': array([-0.00052439, -0.00133354]),
 'distance_from_origin': np.float64(0.00140621554555009)}

As we can see, the humanoid environment gives us a TON of observations! We get dozens of variables representing various positions and velocities of body parts, the center of mass, and a lot of other variables I hardly understand. For an exhaustive list, see the [doc](https://gymnasium.farama.org/environments/mujoco/humanoid/#observation-space). The fact that there are so many variables here is what makes learning latent observations so obvious!

We also get some nice summary stats in info, but we aren't going to include them in our scrape.

In [7]:
def collect_rollout_data(env_name: str, out_dir: str, n_timesteps: int=10000, print_n_episodes: int=1000):
    """Simulates `n_timesteps` in the `env_name` environment, saving observations, rewards, and actions to a triplet of .npy files at `out_dir`."""
    env = gym.make(env_name, render_mode='rgb_array')
    obs, info = env.reset()
    observations, rewards, actions, done = [], [] , [], []
    episode_count = 0

    for timestep in range(n_timesteps):  # Run for n_timesteps or until the episode ends
        action = env.action_space.sample() #select random action
        obs, reward, terminated, truncated, info = env.step(action) #execute and get results
        observations.append(obs) #save observation
        rewards.append(reward) #save reward
        actions.append(action) #save action
        done.append(terminated or truncated) #save timestep of each episode's boundary
        if terminated or truncated: #check for game over, if so reset env
            episode_count+=1
            if episode_count % print_n_episodes == 0: print(f"finished {episode_count} episodes") #provide update on training
            observation, info = env.reset()
        env.close()
    np_obs, np_rewards, np_actions, np_done = np.array(observations), np.array(rewards), np.array(actions), np.array(done)
    print(f"observations has shape: {np_obs.shape}\trewards has shape: {np_rewards.shape}\tactions has shape: {np_actions.shape}\tdone has shape: {np_done.shape}")
    try:
        os.mkdir(f'{out_dir}/{env_name}_{n_timesteps}')
        np.save(f'{out_dir}/{env_name}_{n_timesteps}/observations.npy', np_obs) #load with: new_obs = np.load("../data/processed/Humanoid-v5_10000_rollout_observations.npy")
        np.save(f'{out_dir}/{env_name}_{n_timesteps}/rewards.npy', np_rewards)
        np.save(f'{out_dir}/{env_name}_{n_timesteps}/actions.npy', np_actions)
        np.save(f'{out_dir}/{env_name}_{n_timesteps}/done.npy', np_done)
    except:
        print(f"couldn't save files!")
    return np_obs, np_rewards, np_actions, np_done

In [8]:
humanoid_obs, humanoid_rewards, humanoid_actions, humanoid_done = collect_rollout_data('Humanoid-v5', "../data/processed", 10000, 100)

finished 100 episodes
finished 200 episodes
finished 300 episodes
finished 400 episodes
observations has shape: (10000, 348)	rewards has shape: (10000,)	actions has shape: (10000, 17)	done has shape: (10000,)
couldn't save files!


In [9]:
humanoid_rewards

array([4.89430505, 4.88611642, 4.88518135, ..., 4.40739756, 4.67624074,
       4.73252955])

In [10]:
humanoid_actions

array([[ 0.36019275, -0.09605586, -0.18576859, ..., -0.0539204 ,
         0.08269034,  0.19883963],
       [ 0.32511103, -0.02739318, -0.05949283, ...,  0.07207756,
         0.30314222, -0.13952677],
       [-0.2139204 , -0.32731584, -0.3209193 , ...,  0.31290817,
         0.20741826,  0.25650188],
       ...,
       [ 0.01935349, -0.37806827, -0.31496435, ..., -0.28754547,
        -0.30477956, -0.22121327],
       [ 0.1330297 , -0.22771981,  0.33968493, ...,  0.34366018,
         0.11724761, -0.19986708],
       [ 0.30846435, -0.17688535, -0.25955734, ...,  0.36348575,
         0.17780969,  0.09435472]], dtype=float32)

In [11]:
humanoid_done, humanoid_done.sum()

(array([False, False, False, ..., False, False, False]), np.int64(423))

## Training the VAE

Now that we have an easy function for gathering experiences in the environment, we need to train the VAE module of our world model which will compress the observation space into latent space with fewer dimensions.

Note that the original World Model implementation worked with tensorflow 1.18.0. This is incredibly outdated (worked with Python 3.5), so let's get a newer version (tensorflow 2.19.0). Additionally, we need to change the structure of the VAE from working with images to working with a vector of observation data! Luckily, ChatGPT is very good at updating code (or it will be very obvious if it is not!).

In [12]:
import tensorflow as tf
from tensorflow.keras import layers, Model, saving

@saving.register_keras_serializable()
class MLPVAE(Model):
    def __init__(self, input_dim=348, z_size=32, kl_tolerance=0.5):
        super(MLPVAE, self).__init__()
        self.z_size = z_size
        self.kl_tolerance = kl_tolerance

        # Encoder
        self.encoder = tf.keras.Sequential([
            layers.InputLayer(input_shape=(input_dim,)),
            layers.Dense(256, activation='relu'),
            layers.Dense(128, activation='relu'),
            layers.Dense(2 * z_size),  # output both mu and logvar
        ])

        # Decoder
        self.decoder = tf.keras.Sequential([
            layers.InputLayer(input_shape=(z_size,)),
            layers.Dense(128, activation='relu'),
            layers.Dense(256, activation='relu'),
            layers.Dense(input_dim, activation='linear'),  # output same shape as input
        ])

    def sample_z(self, mu, logvar):
        eps = tf.random.normal(shape=tf.shape(mu))
        sigma = tf.exp(0.5 * logvar)
        return mu + sigma * eps

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = tf.split(h, num_or_size_splits=2, axis=1)
        logvar = tf.clip_by_value(logvar, -10.0, 10.0)  # helps with exploding values
        return mu, logvar

    def decode(self, z):
        return self.decoder(z)

    def call(self, x):
        mu, logvar = self.encode(x)
        z = self.sample_z(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

    def compute_loss(self, x):
        x_recon, mu, logvar = self(x)
        recon_loss = tf.reduce_mean(tf.reduce_sum(tf.square(x - x_recon), axis=1))
        kl_loss = -0.5 * tf.reduce_sum(1 + logvar - tf.square(mu) - tf.exp(logvar), axis=1)
        kl_loss = tf.maximum(kl_loss, self.kl_tolerance * self.z_size)
        kl_loss = tf.reduce_mean(kl_loss)
        total_loss = recon_loss + kl_loss
        return total_loss, recon_loss, kl_loss

2025-05-11 10:53:58.394557: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-11 10:53:58.410230: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746975238.427844   26892 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746975238.432813   26892 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746975238.446158   26892 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

Let's also write a training function for convenience.

In [13]:
def create_dataset(x_train, batch_size=64, shuffle_buffer=10000):
    # Assuming x_train is a NumPy array of shape [n_samples, 348]
    dataset = tf.data.Dataset.from_tensor_slices(x_train.astype(np.float32))
    dataset = dataset.shuffle(shuffle_buffer).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

def train_vae(model, dataset, epochs=10, learning_rate=1e-4):
    optimizer = tf.keras.optimizers.Adam(learning_rate)

    for epoch in range(epochs):
        total_loss = 0.0
        total_batches = 0
        for x_batch in dataset:
            with tf.GradientTape() as tape:
                loss, recon_loss, kl_loss = model.compute_loss(x_batch)
            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            total_loss += loss.numpy()
            total_batches += 1

        avg_loss = total_loss / total_batches
        print(f"Epoch {epoch+1}: avg loss = {avg_loss:.4f}")


In [15]:
# x_train should be a NumPy array of shape (n_samples, 348)
x_train = humanoid_obs
x_train = (x_train - np.mean(x_train, axis=0)) / (np.std(x_train, axis=0) + 1e-6)

dataset = create_dataset(humanoid_obs, batch_size=64)
vae = MLPVAE(input_dim=348, z_size=32)
train_vae(vae, dataset, epochs=20)

Epoch 1: avg loss = 306568.8125


Loss is going down! At first I got a tons of NAN values, but it's because I wasn't normalizing the input data and I also needed to clip the logvar values we were getting as a result of the encoding process. If you get NANs again, a lower learning rate could help too. Onto saving the model!

In [16]:
vae.save_weights('../models/vae/humanoid_10000_vae_model.weights.h5') #save ONLY weights -- much simpler than serializing the entire object

In [17]:
new_vae = MLPVAE(input_dim=348, z_size=32) #instantiate new model object 
new_vae(tf.zeros((1, 348))) #invoke it to build its shape 
new_vae.load_weights('../models/vae/humanoid_10000_vae_model.weights.h5') #now load weights into empty vector

In [18]:
new_vae

<MLPVAE name=mlpvae_2, built=True>

In [19]:
vae

<MLPVAE name=mlpvae_1, built=True>

## Training the MDN-RNN

Now that we have a model which captures observations, we're theoretically ~1/3 done with the project! I say theoretically because this was probably the easiest part of the project. Now onto the meat of world models: capturing the transitions of our environment and training the MDN-RNN!

### Prepping Rollout Data for the MDN-RNN

To train the MDN-RNN, we first need to enhance our basic rollout dataset with predictions of `mu` and `logvar` for each experience. Then we'll feed this information to the MDN-RNN.

In [20]:
#We can use the dataset records still in memory. Note that we only need the observations for now!
humanoid_obs.shape, humanoid_rewards.shape, humanoid_actions.shape, humanoid_done.shape

((10000, 348), (10000,), (10000, 17), (10000,))

In [21]:
#We can also use the VAE still in memory!
vae

<MLPVAE name=mlpvae_1, built=True>

In [22]:
humanoid_obs[0]

array([ 1.38886569e+00,  9.99897978e-01, -3.86473454e-03, -1.37366724e-02,
        6.33536879e-04,  4.72869318e-03,  1.23726442e-01, -2.71861447e-02,
        1.90465276e-02, -3.42330410e-02, -1.58563616e-01, -1.29209048e-01,
       -2.46233564e-02,  4.49637519e-02, -1.58103099e-01, -1.09594646e-01,
        2.24837397e-02, -2.31870835e-02,  8.42925457e-03, -3.01925045e-02,
        2.37384980e-02,  2.42204090e-02,  3.73407888e-01,  2.99910596e-02,
       -3.61643667e-01,  3.86360986e-01, -3.00772425e+00,  2.65944794e-01,
       -6.64950148e-01,  1.12069123e+01, -2.56462351e+00,  3.19647750e+00,
       -2.39332040e+00, -1.57812598e+01, -1.39240095e+01, -3.36641598e+00,
        3.38142258e+00, -1.56526288e+01, -1.26775093e+01,  1.35818218e+00,
       -2.73749089e+00,  8.54538295e-01, -1.72602532e+00,  2.20293519e+00,
        1.85266074e+00,  2.29332593e+00,  2.27452413e+00,  4.26554563e-02,
        7.56150550e-04,  4.57639885e-02, -3.49123619e-02, -9.35642063e-02,
        6.97772305e-02,  

In [23]:
#That's it! Just predict for our entire observation set.
mu, logvar = vae.encode(humanoid_obs)
humanoid_z = vae.sample_z(mu, logvar).numpy()
humanoid_z.shape

(10000, 32)

In [24]:
np.save(f'../data/processed/Humanoid-v5_10000/z.npy', humanoid_z)

In [25]:
#We can use the dataset records still in memory. Note that we only need the observations for now!
humanoid_obs.shape, humanoid_rewards.shape, humanoid_actions.shape, humanoid_done.shape, humanoid_z.shape

((10000, 348), (10000,), (10000, 17), (10000,), (10000, 32))

### Core Training

Now that we have our mu and logvar arrays and also episode-wise aggregations of all of our data, we can train our MDN-RNN!

In [26]:
#CORE MODEL CLASS
class MDNRNN(tf.keras.Model):
    def __init__(self, latent_dim, action_dim, hidden_dim=256, num_mixtures=5):
        super().__init__()
        self.latent_dim = latent_dim
        self.action_dim = action_dim
        self.input_dim = latent_dim + action_dim
        self.hidden_dim = hidden_dim
        self.num_mixtures = num_mixtures

        # LSTM
        self.lstm = layers.LSTM(hidden_dim, return_sequences=True, return_state=True)

        # MDN output: means, stddevs, and mixture weights for latent prediction
        self.mdn_dense = layers.Dense(num_mixtures * (2 * latent_dim + 1))

        # Predict reward (scalar)
        self.reward_dense = layers.Dense(1)

        # Predict done (binary classification)
        self.done_dense = layers.Dense(1, activation="sigmoid")

    def call(self, inputs, initial_state=None, training=False):
        """
        inputs: (batch, seq_len, latent_dim + action_dim)
        """
        lstm_out, h, c = self.lstm(inputs, initial_state=initial_state, training=training)

        mdn_out = self.mdn_dense(lstm_out)
        reward_pred = self.reward_dense(lstm_out)
        done_pred = self.done_dense(lstm_out)

        return mdn_out, reward_pred, done_pred, [h, c]

    def get_mdn_params(self, mdn_out):
        """Split MDN output into pi, mu, sigma."""
        out = tf.reshape(mdn_out, [-1, self.num_mixtures, 2 * self.latent_dim + 1])
        pi = out[:, :, 0]
        mu = out[:, :, 1 : 1 + self.latent_dim]
        log_sigma = out[:, :, 1 + self.latent_dim :]
        sigma = tf.exp(log_sigma)

        pi = tf.nn.softmax(pi, axis=-1)  # mixture weights
        return pi, mu, sigma


In [27]:
#LOSS IMPLEMENTATION
import tensorflow_probability as tfp

def mdn_loss(z_target, pi, mu, sigma, eps=1e-8):
    """
    z_target: [batch * seq_len, latent_dim]
    pi: [batch * seq_len, num_mixtures]
    mu: [batch * seq_len, num_mixtures, latent_dim]
    sigma: [batch * seq_len, num_mixtures, latent_dim]
    """
    # Expand target for broadcasting: [batch, 1, latent_dim]
    z_expanded = tf.expand_dims(z_target, axis=1)

    # Create component Gaussians
    normal_dist = tfp.distributions.Normal(loc=mu, scale=sigma)
    log_probs = normal_dist.log_prob(z_expanded)  # shape: [batch, num_mixtures, latent_dim]

    # Sum over latent_dim: total log prob of each mixture component
    log_probs = tf.reduce_sum(log_probs, axis=-1)  # shape: [batch, num_mixtures]

    # Weight by mixture coefficients
    weighted_log_probs = log_probs + tf.math.log(pi + eps)  # log(pi * P)
    
    # LogSumExp over mixture components to marginalize
    log_likelihood = tf.reduce_logsumexp(weighted_log_probs, axis=-1)  # shape: [batch]

    # Negative log-likelihood
    return -tf.reduce_mean(log_likelihood)

def combined_loss(z_target, pi, mu, sigma, reward_target, reward_pred, done_target, done_pred,
                  reward_weight=1.0, done_weight=1.0):
    loss_mdn = mdn_loss(z_target, pi, mu, sigma)
    loss_reward = tf.reduce_mean(tf.square(reward_target - reward_pred))
    loss_done = tf.reduce_mean(tf.keras.losses.binary_crossentropy(done_target, done_pred))

    return loss_mdn + reward_weight * loss_reward + done_weight * loss_done


In [28]:
#TRAINING FUNCTION
def train_mdnrnn(model, dataset, epochs=10, learning_rate=1e-4, reward_weight=1.0, done_weight=1.0):
    optimizer = tf.keras.optimizers.Adam(learning_rate)
    
    for epoch in range(epochs):
        total_loss = 0.0
        total_mdn_loss = 0.0
        total_reward_loss = 0.0
        total_done_loss = 0.0
        total_batches = 0

        for z_action, (z_next, reward, done) in dataset:
            # Flatten z_action into shape (batch, seq_len, latent_dim + action_dim)
            with tf.GradientTape() as tape:
                # Forward pass through the model
                mdn_out, reward_pred, done_pred, _ = model(z_action, training=True)
                pi, mu, sigma = model.get_mdn_params(mdn_out)

                # Compute MDN, reward, and done losses
                loss_mdn = mdn_loss(tf.reshape(z_next, [-1, model.latent_dim]), pi, mu, sigma)
                loss_reward = tf.reduce_mean(tf.square(reward - reward_pred))
                loss_done = tf.reduce_mean(tf.keras.losses.binary_crossentropy(done, done_pred))

                # Total loss
                total_loss = loss_mdn + reward_weight * loss_reward + done_weight * loss_done

            # Compute gradients and apply
            grads = tape.gradient(total_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            # Accumulate the loss values
            total_loss += total_loss.numpy()
            total_mdn_loss += loss_mdn.numpy()
            total_reward_loss += loss_reward.numpy()
            total_done_loss += loss_done.numpy()

            total_batches += 1

        # Compute and print average losses for the epoch
        avg_loss = total_loss / total_batches
        avg_mdn_loss = total_mdn_loss / total_batches
        avg_reward_loss = total_reward_loss / total_batches
        avg_done_loss = total_done_loss / total_batches

        print(f"Epoch {epoch+1}:")
        print(f"  avg loss = {avg_loss:.4f}, mdn_loss = {avg_mdn_loss:.4f}, "
              f"reward_loss = {avg_reward_loss:.4f}, done_loss = {avg_done_loss:.4f}")


In [29]:
#DATA HANDLING
def preprocess_data(observations, actions, rewards, dones, sequence_length):
    """
    Yields tuples of the form:
    z_action: (sequence_length, latent_dim + action_dim)
    targets:  (z_next, reward, done), each of shape (sequence_length, ...)
    """
    for i in range(len(observations) - sequence_length):
        z_seq = observations[i:i+sequence_length]
        a_seq = actions[i:i+sequence_length]
        z_action = np.concatenate([z_seq, a_seq], axis=-1).astype(np.float32)

        z_next = observations[i+1:i+sequence_length+1].astype(np.float32)
        reward = rewards[i+1:i+sequence_length+1].astype(np.float32)
        done = dones[i+1:i+sequence_length+1].astype(np.float32)

        yield (
            z_action.astype(np.float32),
            (
                z_next.astype(np.float32),
                reward[:, None].astype(np.float32),  # <–– reshape from (T,) to (T,1)
                done[:, None].astype(np.float32)     # <–– reshape from (T,) to (T,1)
            )
        )


sequence_length = 10
latent_dim = 32
action_dim = 17
batch_size = 32

train_dataset = tf.data.Dataset.from_generator(
    lambda: preprocess_data(humanoid_z, humanoid_actions, humanoid_rewards, humanoid_done, sequence_length),
    output_signature=(
        tf.TensorSpec(shape=(sequence_length, latent_dim + action_dim), dtype=tf.float32),
        (
            tf.TensorSpec(shape=(sequence_length, latent_dim), dtype=tf.float32),
            tf.TensorSpec(shape=(sequence_length, 1), dtype=tf.float32),
            tf.TensorSpec(shape=(sequence_length, 1), dtype=tf.float32),
        )
    )
)

train_dataset = train_dataset.batch(batch_size).shuffle(1000)

In [30]:
mdnrnn = MDNRNN(latent_dim=32, action_dim=17)
train_mdnrnn(mdnrnn, train_dataset, 10)

2025-05-11 10:55:29.717311: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:13: Filling up shuffle buffer (this may take a while): 230 of 1000
2025-05-11 10:55:33.254045: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


Epoch 1:
  avg loss = 8.1963, mdn_loss = 4989.6450, reward_loss = 19.1286, done_loss = 0.4089


2025-05-11 10:56:57.988420: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


## Training the Controller

Now that we have a VAE and an MDN-RNN that are (presumably) learning, let's write the final piece of the puzzle -- the MDN-RNN!



In [31]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model, saving
import gymnasium as gym

# --- Controller model ---
@saving.register_keras_serializable()
class LinearController(Model):
    def __init__(self, input_dim, action_dim):
        super(LinearController, self).__init__()
        self.linear = layers.Dense(action_dim, activation='tanh', use_bias=False)
        self.linear(tf.zeros((1, input_dim)))  # force build so weights are initialized

    def call(self, inputs):
        return self.linear(inputs)

    def get_weights_flat(self):
        return tf.reshape(self.linear.kernel, [-1])

    def set_weights_flat(self, flat_weights):
        new_weights = tf.reshape(flat_weights, self.linear.kernel.shape)
        self.linear.kernel.assign(new_weights)

# --- Evolution strategy ---
def evolve_controller(controller, vae, mdn_rnn, env, generations=10, pop_size=64, sigma=0.1, elite_frac=0.2):
    input_dim = controller.linear.kernel.shape[0]
    action_dim = controller.linear.kernel.shape[1]
    weight_dim = input_dim * action_dim

    base_weights = controller.get_weights_flat().numpy()
    elite_num = max(1, int(pop_size * elite_frac))

    for gen in range(generations):
        population = [base_weights + sigma * np.random.randn(weight_dim) for _ in range(pop_size)]
        scores = []

        for i, individual in enumerate(population):
            controller.set_weights_flat(individual)
            reward = evaluate_controller(controller, vae, mdn_rnn, env)
            scores.append((reward, individual))

        scores.sort(key=lambda x: -x[0])
        elites = [w for _, w in scores[:elite_num]]
        new_mean = np.mean(elites, axis=0)
        base_weights = new_mean

        print(f"Gen {gen+1}: Best score = {scores[0][0]:.2f}")

    controller.set_weights_flat(base_weights)
    return controller

In [32]:
def evaluate_controller(controller, vae, mdn_rnn, env, max_steps=1000):
    total_reward = 0
    obs, _ = env.reset()
    done = False
    h = tf.zeros((1, 256))  # RNN hidden state (adjust as needed)

    state = [h, h]  # Initialize state for MDN-RNN (h, c)
    for step in range(max_steps):
        x = tf.convert_to_tensor(obs[None, :], dtype=tf.float32)
        mu, _ = vae.encode(x)
        z = mu  # Use the mean (mu) as the latent vector

        zh = tf.concat([z, h], axis=1)  # Concatenate latent vector and hidden state
        action = controller(zh).numpy()[0]  # Get action from controller

        # Step in the environment
        obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

        # Update hidden state using the MDN-RNN
        rnn_input = tf.concat([z, action[None]], axis=1)  # (latent_dim + action_dim)
        rnn_input = tf.expand_dims(rnn_input, axis=1)  # Shape becomes (1, 1, 50)

        mdn_out, reward_pred, done_pred, state = mdn_rnn(rnn_input, initial_state=state)

        total_reward += reward

        # Optionally: You can use reward_pred and done_pred for monitoring

        if done:
            break

    return total_reward

In [33]:
z_dim = 32
h_dim = 256
action_dim = env.action_space.shape[0]

controller = LinearController(input_dim=z_dim + h_dim, action_dim=action_dim)

# Assume `vae` and `mdn_rnn` are your pretrained models
trained_controller = evolve_controller(controller, vae, mdnrnn, env, generations=20)

Gen 1: Best score = 155.66


So now the controller is working. Also, we're missing the CMA-ES piece! We just do random search optimization right now.

## Early Results

With a very basic model trained, we can now compare the performance of our model with a random policy and a PPO policy. To do this, we'll have to wrap our model up in a nice interface that plays well with our `play_agent()` function from `./gymnasium_basics.ipynb`.

In [34]:
def make_world_model_policy(vae, mdn_rnn, controller):
    h = tf.zeros((1, 256))
    state = [h, h]

    def policy(obs):
        nonlocal state, h
        x = tf.convert_to_tensor(obs[None, :], dtype=tf.float32)
        mu, _ = vae.encode(x)
        z = mu

        zh = tf.concat([z, h], axis=1)
        action = controller(zh).numpy()[0]

        rnn_input = tf.concat([z, action[None]], axis=1)
        rnn_input = tf.expand_dims(rnn_input, axis=1)  # shape: (1, 1, 50)
        _, _, _, state = mdn_rnn(rnn_input, initial_state=state)
        h = state[0]

        return action

    return policy

In [35]:
#a slightly updated version of play_agent()
def play_agent(env_name, model=None, n_timesteps=50, render_mode="human"):
    """Plays games using a provided model or random policy."""
    env = gym.make(env_name, render_mode=render_mode)
    obs, info = env.reset()
    rewards = []
    episode_lengths = []
    steps = 0

    while steps < n_timesteps:
        env.render()
        if model is None: #random policy
            action = env.action_space.sample()
        elif hasattr(model, "predict"): #SB3 models use 'predict'
            action, _ = model.predict(obs)
        else: #our model uses `model(obs)`
            action = model(obs)  # assumes model is a callable function

        obs, reward, terminated, truncated, info = env.step(action)
        rewards.append(reward)
        steps += 1

        if terminated or truncated:
            episode_lengths.append(steps)
            obs, info = env.reset()

    env.close()
    return rewards, episode_lengths

In [36]:
wm_policy = make_world_model_policy(vae, mdnrnn, trained_controller)
rewards, lengths = play_agent(env_name="Humanoid-v5", model=wm_policy, n_timesteps=1000)
np.sum(rewards), np.mean(lengths) #avg reward, avg episode length of world model policy

KeyboardInterrupt: 

In [None]:
np.sum(rewards), np.mean(lengths)

In [None]:
rewards, lengths = play_agent(env_name="Humanoid-v5", model=None, n_timesteps=1000)
np.sum(rewards), np.mean(lengths) #avg reward, avg episode length of random policy

Okay, world models is still doing worse than random. Tough times! That said, we haven't done the magic of **scaling** yet! Let's break this bad boy up into a few scripts and hit the HPC!

### Epilogue: Saving MDN-RNNs and Controllers

In [None]:
mdnrnn.save_weights('../models/mdnrnn/humanoid_10000_mdnrnn_model.weights.h5') #save ONLY weights -- much simpler than serializing the entire object

In [None]:
new_mdnrnn = mdnrnn = MDNRNN(latent_dim=32, action_dim=17) #instantiate new model object 
new_mdnrnn(tf.zeros((1, 10, 49))) #invoke it to build its shape 
new_mdnrnn.load_weights('../models/mdnrnn/humanoid_10000_mdnrnn_model.weights.h5') #now load weights into empty vector

In [None]:
# Save the controller's weights
controller.save_weights('controller_weights.weights.h5')

In [None]:
# Create a new instance of the controller
new_controller = LinearController(input_dim=controller.linear.kernel.shape[0], action_dim=controller.linear.kernel.shape[1])

new_controller.build((1,controller.linear.kernel.shape[0]))

# Load the saved weights into the new controller instance
new_controller.load_weights('controller_weights.weights.h5')

In [None]:
wm_policy = make_world_model_policy(vae, mdnrnn, new_controller)
rewards, lengths = play_agent(env_name="Humanoid-v5", model=wm_policy, n_timesteps=1000)
np.sum(rewards), np.mean(lengths) #avg reward, avg episode length of world model policy