# World Model Core
*Sean Steinle, Kiya Aminfar*

This notebook walks through the core aspects of world models, developing crucial pieces of code sequentially. Not that this code isn't meant for scale -- instead, this is for a demonstration of how we developed the code that we did.

## Table of Contents
1. [Collecting Rollout Data](#Collecting-Rollout-Data)
2. [Training the VAE](#Training-the-VAE)
3. [Training the MDN-RNN](#Training-the-MDN-RNN)
    - [Prepping Rollout Data for the MDN-RNN](#Prepping-Rollout-Data-for-the-MDN-RNN)
    - [Core Training](#Core-Training)
4. [Training the Controller](#Training-the-Controller)
5. [Early Results](#Early-Results)

In [1]:
import gymnasium as gym
import matplotlib.pyplot as plt
import os
import numpy as np

## Collecting Rollout Data

In [2]:
#Let's begin by creating an instance of our humanoid environment and checking out what basic observations look like.
env = gym.make('Humanoid-v5', render_mode="rgb_array")
obs, info = env.reset()

In [3]:
obs.shape, obs

((348,),
 array([ 1.39300131e+00,  9.95082371e-01,  3.87763039e-03, -2.46595386e-03,
         1.68996662e-03, -9.42309819e-03, -6.06253604e-03, -4.55935851e-03,
         1.81377671e-03,  9.27322084e-03,  1.20513263e-03,  9.15671483e-03,
        -5.34892928e-03, -2.89313763e-03, -4.67409161e-03,  4.80980385e-03,
         3.31908782e-04,  3.80452899e-03,  7.95388306e-03, -5.40410374e-03,
        -5.11927945e-03,  6.37732803e-03, -1.33475781e-03,  4.30260742e-03,
         1.09733892e-03, -1.90648726e-03, -1.09611434e-03, -9.08495783e-03,
         7.41654441e-03,  9.75929267e-03,  9.96725691e-03,  7.37010286e-03,
         4.84701979e-03, -1.81094195e-03, -1.26663294e-03,  5.95643963e-03,
        -7.36869602e-03,  8.03223366e-03,  1.65738149e-03,  6.56168602e-03,
         1.49662101e-03, -1.40926562e-03, -9.51879738e-03,  4.21087831e-03,
        -6.92943150e-03,  2.30151498e+00,  2.28678620e+00,  4.58295022e-02,
        -5.77727533e-04,  9.55042464e-02,  1.52667775e-02, -2.01205557e-01,
   

In [4]:
info

{'x_position': np.float64(-0.0016523157076843691),
 'y_position': np.float64(-0.007591421744894626),
 'tendon_length': array([0.0094839 , 0.00795158]),
 'tendon_velocity': array([-0.00637485,  0.00054431]),
 'distance_from_origin': np.float64(0.007769158983231034)}

As we can see, the humanoid environment gives us a TON of observations! We get dozens of variables representing various positions and velocities of body parts, the center of mass, and a lot of other variables I hardly understand. For an exhaustive list, see the [doc](https://gymnasium.farama.org/environments/mujoco/humanoid/#observation-space). The fact that there are so many variables here is what makes learning latent observations so obvious!

We also get some nice summary stats in info, but we aren't going to include them in our scrape.

In [11]:
def collect_rollout_data(env_name: str, out_dir: str, n_timesteps: int=10000, print_n_episodes: int=1000):
    """Simulates `n_timesteps` in the `env_name` environment, saving observations, rewards, and actions to a triplet of .npy files at `out_dir`."""
    env = gym.make(env_name, render_mode='rgb_array')
    obs, info = env.reset()
    observations, rewards, actions, episode_boundaries = [], [] , [], []
    episode_count = 0

    for timestep in range(n_timesteps):  # Run for n_timesteps or until the episode ends
        action = env.action_space.sample() #select random action
        obs, reward, terminated, truncated, info = env.step(action) #execute and get results
        observations.append(obs) #save observation
        rewards.append(reward) #save reward
        actions.append(action) #save action
        if terminated or truncated: #check for game over, if so reset env
            episode_boundaries.append(timestep) #save timestep of each episode's boundary
            episode_count+=1
            if episode_count % print_n_episodes == 0: print(f"finished {episode_count} episodes") #provide update on training
            observation, info = env.reset()
        env.close()
    np_obs, np_rewards, np_actions, np_episode_boundaries = np.array(observations), np.array(rewards), np.array(actions), np.array(episode_boundaries)
    print(f"observations has shape: {np_obs.shape}\trewards has shape: {np_rewards.shape}\tactions has shape: {np_actions.shape}\tboundaries has shape: {episode_boundaries}")
    np.save(f'{out_dir}/{env_name}_{n_timesteps}_rollout_observations.npy', np_obs) #load with: new_obs = np.load("../data/processed/Humanoid-v5_10000_rollout_observations.npy")
    np.save(f'{out_dir}/{env_name}_{n_timesteps}_rollout_rewards.npy', np_rewards)
    np.save(f'{out_dir}/{env_name}_{n_timesteps}_rollout_actions.npy', np_actions)
    np.save(f'{out_dir}/{env_name}_{n_timesteps}_rollout_boundaries.npy', np_episode_boundaries)
    return np_obs, np_rewards, np_actions, np_episode_boundaries

In [10]:
humanoid_obs, humanoid_rewards, humanoid_actions, humanoid_boundaries = collect_rollout_data('Humanoid-v5', "../data/processed", 10000, 100)

finished 100 episodes
finished 200 episodes
finished 300 episodes
finished 400 episodes
observations has shape: (10000, 348)	rewards has shape: (10000,)	actions has shape: (10000, 17)


In [14]:
humanoid_rewards

array([4.9154845 , 4.91313937, 4.94133224, ..., 2.56266492, 3.89328895,
       5.00908928])

In [13]:
humanoid_actions

array([[-0.08674568,  0.18042727,  0.24072039, ...,  0.14242245,
        -0.16913189, -0.29231656],
       [-0.29700246, -0.08621656, -0.36000764, ...,  0.3989401 ,
         0.25781026, -0.16689824],
       [ 0.11839256,  0.27461025,  0.07927635, ...,  0.08514802,
         0.13366902, -0.23283356],
       ...,
       [-0.3481346 , -0.1078826 , -0.2464141 , ...,  0.327872  ,
         0.1691019 , -0.13490571],
       [ 0.22734375,  0.12003687, -0.34475392, ..., -0.2912965 ,
        -0.1930455 , -0.18568286],
       [ 0.38920617,  0.13925335, -0.09383011, ...,  0.02945893,
        -0.26485687,  0.13945284]], dtype=float32)

In [12]:
humanoid_boundaries

array([  24,   44,   63,   93,  117,  138,  182,  209,  226,  268,  298,
        315,  351,  371,  403,  422,  443,  463,  488,  517,  535,  557,
        576,  595,  614,  642,  662,  689,  707,  727,  745,  766,  787,
        807,  845,  864,  890,  916,  940,  958,  986, 1007, 1033, 1062,
       1089, 1110, 1131, 1153, 1174, 1193, 1213, 1238, 1263, 1288, 1324,
       1362, 1381, 1410, 1429, 1446, 1463, 1482, 1524, 1553, 1572, 1589,
       1622, 1642, 1662, 1679, 1704, 1736, 1756, 1786, 1805, 1826, 1854,
       1873, 1915, 1952, 1975, 1997, 2039, 2059, 2086, 2121, 2148, 2183,
       2210, 2239, 2258, 2309, 2355, 2374, 2400, 2430, 2449, 2483, 2521,
       2544, 2570, 2606, 2632, 2663, 2693, 2712, 2740, 2770, 2791, 2822,
       2845, 2866, 2889, 2909, 2928, 2959, 2977, 2996, 3016, 3037, 3059,
       3099, 3117, 3159, 3178, 3198, 3228, 3245, 3264, 3282, 3302, 3321,
       3338, 3366, 3387, 3428, 3446, 3473, 3503, 3535, 3556, 3578, 3596,
       3620, 3639, 3665, 3684, 3719, 3742, 3760, 37

## Training the VAE

Now that we have an easy function for gathering experiences in the environment, we need to train the VAE module of our world model which will compress the observation space into latent space with fewer dimensions.

Note that the original World Model implementation worked with tensorflow 1.18.0. This is incredibly outdated (worked with Python 3.5), so let's get a newer version (tensorflow 2.19.0). Additionally, we need to change the structure of the VAE from working with images to working with a vector of observation data! Luckily, ChatGPT is very good at updating code (or it will be very obvious if it is not!).

In [15]:
import tensorflow as tf
from tensorflow.keras import layers, Model, saving

@saving.register_keras_serializable()
class MLPVAE(Model):
    def __init__(self, input_dim=348, z_size=32, kl_tolerance=0.5):
        super(MLPVAE, self).__init__()
        self.z_size = z_size
        self.kl_tolerance = kl_tolerance

        # Encoder
        self.encoder = tf.keras.Sequential([
            layers.InputLayer(input_shape=(input_dim,)),
            layers.Dense(256, activation='relu'),
            layers.Dense(128, activation='relu'),
            layers.Dense(2 * z_size),  # output both mu and logvar
        ])

        # Decoder
        self.decoder = tf.keras.Sequential([
            layers.InputLayer(input_shape=(z_size,)),
            layers.Dense(128, activation='relu'),
            layers.Dense(256, activation='relu'),
            layers.Dense(input_dim, activation='linear'),  # output same shape as input
        ])

    def sample_z(self, mu, logvar):
        eps = tf.random.normal(shape=tf.shape(mu))
        sigma = tf.exp(0.5 * logvar)
        return mu + sigma * eps

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = tf.split(h, num_or_size_splits=2, axis=1)
        logvar = tf.clip_by_value(logvar, -10.0, 10.0)  # helps with exploding values
        return mu, logvar

    def decode(self, z):
        return self.decoder(z)

    def call(self, x):
        mu, logvar = self.encode(x)
        z = self.sample_z(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

    def compute_loss(self, x):
        x_recon, mu, logvar = self(x)
        recon_loss = tf.reduce_mean(tf.reduce_sum(tf.square(x - x_recon), axis=1))
        kl_loss = -0.5 * tf.reduce_sum(1 + logvar - tf.square(mu) - tf.exp(logvar), axis=1)
        kl_loss = tf.maximum(kl_loss, self.kl_tolerance * self.z_size)
        kl_loss = tf.reduce_mean(kl_loss)
        total_loss = recon_loss + kl_loss
        return total_loss, recon_loss, kl_loss

2025-05-09 10:27:00.743408: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-09 10:27:00.770285: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746800820.791959   31576 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746800820.798624   31576 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746800820.816920   31576 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

Let's also write a training function for convenience.

In [16]:
def create_dataset(x_train, batch_size=64, shuffle_buffer=10000):
    # Assuming x_train is a NumPy array of shape [n_samples, 348]
    dataset = tf.data.Dataset.from_tensor_slices(x_train.astype(np.float32))
    dataset = dataset.shuffle(shuffle_buffer).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

def train_vae(model, dataset, epochs=10, learning_rate=1e-4):
    optimizer = tf.keras.optimizers.Adam(learning_rate)

    for epoch in range(epochs):
        total_loss = 0.0
        total_batches = 0
        for x_batch in dataset:
            with tf.GradientTape() as tape:
                loss, recon_loss, kl_loss = model.compute_loss(x_batch)
            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            total_loss += loss.numpy()
            total_batches += 1

        avg_loss = total_loss / total_batches
        print(f"Epoch {epoch+1}: avg loss = {avg_loss:.4f}")


In [17]:
# x_train should be a NumPy array of shape (n_samples, 348)
x_train = humanoid_obs
x_train = (x_train - np.mean(x_train, axis=0)) / (np.std(x_train, axis=0) + 1e-6)

dataset = create_dataset(humanoid_obs, batch_size=64)
vae = MLPVAE(input_dim=348, z_size=32)
train_vae(vae, dataset, epochs=20)

2025-05-09 10:27:07.306695: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
2025-05-09 10:27:24.075387: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 1: avg loss = 289024.7188


2025-05-09 10:27:40.264828: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 2: avg loss = 123538.1016
Epoch 3: avg loss = 67821.8672


2025-05-09 10:28:09.838940: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 4: avg loss = 58037.1836
Epoch 5: avg loss = 52092.7578
Epoch 6: avg loss = 46350.1094
Epoch 7: avg loss = 39106.0586


2025-05-09 10:29:02.727875: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 8: avg loss = 33068.6367
Epoch 9: avg loss = 28321.2461
Epoch 10: avg loss = 24051.6816
Epoch 11: avg loss = 20940.8828
Epoch 12: avg loss = 18792.0293
Epoch 13: avg loss = 16970.7285
Epoch 14: avg loss = 15694.1816
Epoch 15: avg loss = 14791.6787


2025-05-09 10:30:57.359811: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 16: avg loss = 13650.3232
Epoch 17: avg loss = 12947.7314
Epoch 18: avg loss = 12331.4980
Epoch 19: avg loss = 11848.9102
Epoch 20: avg loss = 11423.7061


Loss is going down! At first I got a tons of NAN values, but it's because I wasn't normalizing the input data and I also needed to clip the logvar values we were getting as a result of the encoding process. If you get NANs again, a lower learning rate could help too. Onto saving the model!

In [18]:
vae.save_weights('../models/vae/humanoid_10000_vae_model.weights.h5') #save ONLY weights -- much simpler than serializing the entire object

In [19]:
new_vae = MLPVAE(input_dim=348, z_size=32) #instantiate new model object 
new_vae(tf.zeros((1, 348))) #invoke it to build its shape 
new_vae.load_weights('../models/vae/humanoid_10000_vae_model.weights.h5') #now load weights into empty vector

In [20]:
new_vae

<MLPVAE name=mlpvae_1, built=True>

In [21]:
vae

<MLPVAE name=mlpvae, built=True>

## Training the MDN-RNN

Now that we have a model which captures observations, we're theoretically ~1/3 done with the project! I say theoretically because this was probably the easiest part of the project. Now onto the meat of world models: capturing the transitions of our environment and training the MDN-RNN!

### Prepping Rollout Data for the MDN-RNN

To train the MDN-RNN, we first need to enhance our basic rollout dataset with predictions of `mu` and `logvar` for each experience. Then we'll feed this information to the MDN-RNN.

In [22]:
#We can use the dataset records still in memory. Note that we only need the observations for now!
humanoid_obs.shape, humanoid_rewards.shape, humanoid_actions.shape, humanoid_boundaries.shape

((10000, 348), (10000,), (10000, 17), (405,))

In [23]:
#We can also use the VAE still in memory!
vae

<MLPVAE name=mlpvae, built=True>

In [24]:
humanoid_obs[0]

array([ 1.40830562e+00,  9.99876874e-01, -9.51343451e-03,  1.18707539e-02,
       -3.84923369e-03,  1.15938890e-02, -3.00163036e-02, -4.25435428e-03,
        9.58761690e-04, -6.35710950e-02,  2.01602117e-02, -4.33043558e-02,
        5.35531591e-03,  8.45089699e-02,  2.45519732e-02, -2.40603192e-02,
       -1.01506478e-02,  1.46174317e-02, -8.65743713e-03,  1.21055954e-02,
       -1.53897781e-02, -2.20527455e-02, -2.35265818e-01,  3.60738168e-02,
       -1.82763673e-01, -3.03734843e-01,  1.64679568e+00, -4.48322248e-01,
        9.44238863e-01, -4.28256359e+00,  2.59820444e-01, -6.46090300e-02,
       -5.44796398e+00,  8.57359267e-01, -5.35796373e+00,  3.66024060e-01,
        6.90749771e+00,  1.91271460e+00, -3.74138601e+00, -1.69903784e+00,
        2.16734825e+00, -1.25681037e+00,  1.86683659e+00, -1.85226147e+00,
       -2.09367945e+00,  2.30252275e+00,  2.28360708e+00,  4.34351997e-02,
        7.03459228e-04,  3.90673226e-02, -4.73535097e-02, -9.34772193e-02,
        9.25679374e-02,  

In [39]:
#That's it! Just predict for our entire observation set.
mu, logvar = vae.encode(humanoid_obs)
mu, logvar = mu.numpy(), logvar.numpy() #cast to numpy (it's an EagerTensor)
mu.shape, logvar.shape #makes sense -- a mean and stdev for each dimen of latent space for each sample

((10000, 32), (10000, 32))

In [38]:
type(humanoid_obs), type(mu.numpy())

(numpy.ndarray, numpy.ndarray)

In [26]:
np.save(f'../data/processed/Humanoid-v5_10000_rollout_mu.npy', mu)
np.save(f'../data/processed/Humanoid-v5_10000_rollout_logvar.npy', logvar)

In [32]:
#So we created the mu and logvar predictions, but we also need to chunk our data into episodes which we can train our model on.
#This is where we can use our boundaries data!
humanoid_obs.shape, humanoid_rewards.shape, humanoid_actions.shape, mu.shape, logvar.shape, humanoid_boundaries.shape

((10000, 348),
 (10000,),
 (10000, 17),
 TensorShape([10000, 32]),
 TensorShape([10000, 32]),
 (405,))

In [43]:
#create list of episode dict which contain each datatype of interest
episodes = []
start_index = 0
for end_index in humanoid_boundaries:
    edict = {}
    edict['obs'] = humanoid_obs[start_index:end_index]
    edict['rewards'] = humanoid_rewards[start_index:end_index]
    edict['actions'] = humanoid_actions[start_index:end_index]
    edict['mu'] = mu[start_index:end_index]
    edict['logvar'] = logvar[start_index:end_index]
    episodes.append(edict)
    start_index = end_index #move window forward

In [44]:
#great, first dimension matches for all data!
episodes[0]['obs'].shape, episodes[0]['rewards'].shape, episodes[0]['actions'].shape, episodes[0]['mu'].shape, episodes[0]['logvar'].shape

((24, 348), (24,), (24, 17), (24, 32), (24, 32))

### Core Training

Now that we have our mu and logvar arrays and also episode-wise aggregations of all of our data, we can train our MDN-RNN!

In [78]:
class MDNRNN(tf.keras.Model):
    def __init__(self, latent_size, action_size, rnn_units=256, num_mixtures=5):
        super(MDNRNN, self).__init__()
        self.latent_size = latent_size
        self.action_size = action_size
        self.rnn_units = rnn_units
        self.num_mixtures = num_mixtures

        # Inputs: [z_t, a_t] concatenated
        self.input_layer = layers.Dense(rnn_units, activation='relu')

        # RNN core
        self.lstm = layers.LSTM(rnn_units, return_sequences=True, return_state=True)

        # Outputs: MDN parameters and optional done prediction
        output_dim = num_mixtures * (2 * latent_size + 1)  # means, stds, logmix
        self.mdn_output = layers.Dense(output_dim)
        self.done_output = layers.Dense(1, activation='sigmoid')  # optional 'done' signal

    def call(self, inputs, states=None, training=False):
        """
        inputs: (z_t, a_t) of shape (batch, seq_len, latent_size + action_size)
        states: initial LSTM states (optional)
        returns: ((pi, mu, log_sigma), done_pred, final_states)
        """
        x = self.input_layer(inputs)
        x, final_h, final_c = self.lstm(x, initial_state=states, training=training)
    
        mdn_params = self.mdn_output(x)  # shape: (B, T, M * (2D + 1))
        done_pred = self.done_output(x)
    
        # Reshape and split MDN params
        B, T = tf.shape(mdn_params)[0], tf.shape(mdn_params)[1]
        M = self.num_mixtures
        D = self.latent_size
    
        # Reshape to (B, T, M, 2D + 1)
        mdn_params = tf.reshape(mdn_params, [B, T, M, 2 * D + 1])
    
        # Split into pi, mu, log_sigma
        pi = mdn_params[..., 0]                 # shape (B, T, M)
        mu = mdn_params[..., 1:D+1]             # shape (B, T, M, D)
        log_sigma = mdn_params[..., D+1:]       # shape (B, T, M, D)
    
        return (pi, mu, log_sigma), done_pred, [final_h, final_c]

In [81]:
import tensorflow_probability as tfp
tfd = tfp.distributions

def compute_mdn_loss(targets, mdn_params):
    """
    targets: shape (B, T, D)
    mdn_params: tuple (pi, mu, log_sigma)
        - pi: (B, T, M)
        - mu, log_sigma: (B, T, M, D)
    Returns scalar loss (average NLL over batch and time)
    """
    pi, mu, log_sigma = mdn_params  # already split in model
    B, T, D = tf.shape(targets)[0], tf.shape(targets)[1], tf.shape(targets)[2]
    M = tf.shape(pi)[-1]

    # Expand targets to shape (B, T, 1, D) to broadcast against (B, T, M, D)
    targets_expanded = tf.expand_dims(targets, axis=2)

    # Compute component-wise log-probabilities
    dist = tfd.Normal(loc=mu, scale=tf.exp(log_sigma))  # shape: (B, T, M, D)
    log_prob = dist.log_prob(targets_expanded)  # (B, T, M, D)
    log_prob = tf.reduce_sum(log_prob, axis=-1)  # sum over D → shape: (B, T, M)

    # Apply log(pi) with log-sum-exp trick
    log_pi = tf.math.log_softmax(pi, axis=-1)  # shape: (B, T, M)
    log_mix = log_pi + log_prob               # (B, T, M)

    # LogSumExp over mixtures → log p(x)
    log_likelihood = tf.reduce_logsumexp(log_mix, axis=-1)  # shape: (B, T)

    # Negative log likelihood loss
    nll = -tf.reduce_mean(log_likelihood)  # scalar

    return nll

In [52]:
def random_batch(dataset, batch_size, seq_len):
    """
    dataset: list of episodes, each is a dict with keys 'z' and 'a'
    batch_size: number of sequences per batch
    seq_len: number of time steps (excluding the +1 for teacher forcing)
    
    Returns:
        z_batch: (B, T+1, latent_size)
        a_batch: (B, T+1, action_size)
    """
    z_batch = []
    a_batch = []

    for _ in range(batch_size):
        # Randomly pick an episode
        episode = np.random.choice(dataset)
        z_seq = episode['mu']
        a_seq = episode['actions']

        # Make sure it's long enough
        assert len(z_seq) >= seq_len + 1

        # Randomly choose start index
        start = np.random.randint(0, len(z_seq) - seq_len - 1)

        # Slice out a window of length seq_len + 1
        z_chunk = z_seq[start : start + seq_len + 1]
        a_chunk = a_seq[start : start + seq_len + 1]

        z_batch.append(z_chunk)
        a_batch.append(a_chunk)

    return np.array(z_batch), np.array(a_batch)

In [82]:
# Hyperparameters
learning_rate = 0.001
decay_rate = 0.9999
min_learning_rate = 0.0001
num_steps = 5000
z_dim = 32
action_dim = humanoid_actions.shape[1]
batch_size = 32
seq_len = 10 #has to be smaller than the shortest episode!

optimizer = tf.keras.optimizers.Adam(learning_rate)

# Create model
mdnrnn = MDNRNN(latent_size=z_dim, action_size=action_dim)

# Custom training step
@tf.function
def train_step(inputs, targets):
    with tf.GradientTape() as tape:
        mdn_params, done_preds, _ = mdnrnn(inputs, training=True)
        mdn_loss = compute_mdn_loss(targets, mdn_params)
        done_loss = tf.keras.losses.binary_crossentropy(done_targets, done_preds)
        loss = mdn_loss + tf.reduce_mean(done_loss)
    
    gradients = tape.gradient(loss, mdnrnn.trainable_variables)
    optimizer.apply_gradients(zip(gradients, mdnrnn.trainable_variables))
    return loss

# Training loop
for step in range(num_steps):
    # Simulate or load a batch of data
    raw_z, raw_a = random_batch(episodes, batch_size, seq_len)

    # inputs: [z_t, a_t] — everything except last time step
    # targets: [z_{t+1}] — prediction target is next latent
    inputs = tf.concat([raw_z[:, :-1, :], raw_a[:, :-1, :]], axis=-1) #returns tensor type
    targets = tf.convert_to_tensor(raw_z[:, 1:, :], dtype=tf.float32) #teacher forcing + convert to tensor type

    # Training step
    loss = train_step(inputs, targets)

    # Learning rate decay (optional)
    lr = max(min_learning_rate, learning_rate * (decay_rate ** step))
    optimizer.learning_rate = lr

    # Logging
    if step % 20 == 0:
        print(f"Step {step}, Loss: {loss:.4f}, LR: {lr:.6f}")

NameError: in user code:

    File "/tmp/ipykernel_31576/3262445165.py", line 22, in train_step  *
        done_loss = tf.keras.losses.binary_crossentropy(done_targets, done_preds)

    NameError: name 'done_targets' is not defined


In [None]:
#NEXT STEP: need to have one-hot episode boundary dataset which we can use in random_batch.
#then update with GPT's code