# Setup Folders

In [1]:
import os

project_folder = "./"
data_folder = os.path.join(project_folder, "data")
vae_folder = os.path.join(project_folder, "weights/vae")
output_folder = os.path.join(project_folder, "weights/worldmodel")
os.makedirs(output_folder, exist_ok=True)

local_data_folder = None

# Training Settings

In [2]:
EPOCHS = 100
LR = 1e-3
TRAIN_RATIO = 0.8
BATCH_SIZE = 2
SEQ_LEN = 64
NUM_PRELOAD_FILES = 10
NUM_DATASET_WORKERS = 8


IMAGE_CHANNELS = 3
OBSERVATION_DIM = 64
OBSERVATION_REPRESENTATION_DIM = 32
VAE_HIDDEN_DIM = 1024
INPUT_STATE_DIM = 4
HIDDEN_DIM = 256
OUTPUT_STATE_DIM = 1
RNN_INPUT_DIM = OBSERVATION_REPRESENTATION_DIM + INPUT_STATE_DIM
RNN_OUTPUT_DIM = OBSERVATION_REPRESENTATION_DIM + OUTPUT_STATE_DIM

WANDB_PROJECT = "world-models-paper"
WANDB_RUN_NAME = "worldmodel"

LOG_LEVEL = "INFO"

# Setup

In [3]:
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from src.datasets.simulation_steps_dataset import SimulationStepsDataset
from src.models.vae import ConvVAE
from src.models.worldmodel import MdnRnn
from src.training.early_stopping import EarlyStopping
from src.training.worldmodel import WorldModelTrainer
from src.utils.torch import get_device
from src.utils.logging import get_logger
from src.utils.secrets import get_secret

In [4]:
logger = get_logger()

[32m2025-12-03 18:44:44[0m [1;30m[INFO][0m Logger initialized.


In [5]:
DEVICE = get_device()

[32m2025-12-03 18:44:44[0m [1;30m[INFO][0m Logger initialized.
[32m2025-12-03 18:44:44[0m [1;30m[INFO][0m Using device: mps:0


# Load Dataset

In [6]:
train_dataset, test_dataset = SimulationStepsDataset.train_test_split(data_folder,
                                                                      local_data_folder=local_data_folder,
                                                                      num_preloaded_files=NUM_PRELOAD_FILES,
                                                                      num_workers=NUM_DATASET_WORKERS,
                                                                      train_ratio=TRAIN_RATIO,
                                                                      shuffle_files=True,
                                                                      kwargs={"sequence_length": SEQ_LEN})

In [7]:
train_size = len(train_dataset)
test_size = len(test_dataset)
logger.info(f"Train: {train_size}")
logger.info(f"Test: {test_size}")

[32m2025-12-03 18:44:45[0m [1;30m[INFO][0m Train: 18878
[32m2025-12-03 18:44:45[0m [1;30m[INFO][0m Test: 4491


In [8]:
example = next(train_dataset)
example_observation = example[0][0].unflatten(0, (3, OBSERVATION_DIM, OBSERVATION_DIM))
example_input_state = example[0][1]
example_output_state = example[0][2]
logger.info(example_observation.shape)
logger.info(example_input_state.shape)
logger.info(example_output_state.shape)

[32m2025-12-03 18:44:45[0m [1;30m[INFO][0m torch.Size([3, 64, 64])
[32m2025-12-03 18:44:45[0m [1;30m[INFO][0m torch.Size([12288])
[32m2025-12-03 18:44:45[0m [1;30m[INFO][0m torch.Size([12288])


In [9]:
def custom_collate_fn(batch):
    observations = torch.stack([item[0] for item in batch])
    input_states = torch.stack([item[1] for item in batch])
    output_states = torch.stack([item[2] for item in batch])
    return observations, input_states, output_states

In [10]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=custom_collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=custom_collate_fn)

In [11]:
train_batches = len(train_dataloader)
test_batches = len(test_dataloader)
logger.info(f"Train batches: {train_batches}")
logger.info(f"Test batches: {test_batches}")

[32m2025-12-03 18:44:45[0m [1;30m[INFO][0m Train batches: 9439
[32m2025-12-03 18:44:45[0m [1;30m[INFO][0m Test batches: 2246


# Train

In [12]:
vae = ConvVAE(image_channels=IMAGE_CHANNELS, h_dim=VAE_HIDDEN_DIM, z_dim=OBSERVATION_REPRESENTATION_DIM).to(DEVICE)
vae.load_state_dict(torch.load(os.path.join(vae_folder, f"model.pth"), map_location=DEVICE))
vae.eval()
vae.requires_grad_(False) # We want to train only the world model
for param in vae.parameters():
    param.requires_grad = False

In [None]:
model = MdnRnn(input_size=RNN_INPUT_DIM, hidden_size=HIDDEN_DIM, output_size=RNN_OUTPUT_DIM).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
early_stopping = EarlyStopping(tolerance=5, min_delta=0.01)

In [14]:
wandb_setup = {
    "api_key": get_secret('wandbApiKey'),
    "project": WANDB_PROJECT,
    "run_name": WANDB_RUN_NAME,
    "config": {
        "epochs": EPOCHS,
        "batch_size_loader": BATCH_SIZE,
        "learning_rate": LR,
        "train_ratio": TRAIN_RATIO,
        "hidden_dim": HIDDEN_DIM,
        "observation_representation_dim": OBSERVATION_REPRESENTATION_DIM,
        "architecture": "MDN-RNN",
        "train_dataset_size": train_size,
        "test_dataset_size": test_size,
        "train_batches": train_batches,
        "test_batches": test_batches,
        "preload_files": NUM_PRELOAD_FILES,
        "num_dataset_workers": NUM_DATASET_WORKERS
    }
}

In [15]:
trainer = WorldModelTrainer(model=model,
                            weights_folder=output_folder,
                            train_dataloader=train_dataloader,
                            optimizer=optimizer,
                            num_epochs=EPOCHS,
                            batch_size=BATCH_SIZE,
                            load_checkpoint=True,
                            max_norm=0.1,
                            device=DEVICE,
                            test_dataloader=test_dataloader,
                            early_stopper=early_stopping,
                            wandb_setup=wandb_setup,
                            logger=logger,
                            vae=vae,
                            observation_dim=OBSERVATION_DIM,
                            seq_len=SEQ_LEN)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/henriqueschmitz/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mschhenrique[0m ([33mschhenrique-columbia-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
trainer.train()

Epoch:   0%|          | 0/100 [00:00<?, ?epoch/s]

Train Epoch 1:   0%|          | 0/9439 [00:00<?, ?batch/s]

Test Epoch 1:   0%|          | 0/2246 [00:00<?, ?batch/s]

[32m2025-12-03 18:49:06[0m [1;30m[INFO][0m Epoch 1 Loss: 19.1988
[32m2025-12-03 18:49:53[0m [1;30m[INFO][0m Epoch 1 Loss: 19.0299
[32m2025-12-03 18:54:12[0m [1;30m[INFO][0m Epoch 2 Loss: 18.3039
[32m2025-12-03 18:54:57[0m [1;30m[INFO][0m Epoch 2 Loss: 18.0628
[32m2025-12-03 18:59:13[0m [1;30m[INFO][0m Epoch 3 Loss: 18.0085
[32m2025-12-03 18:59:54[0m [1;30m[INFO][0m Epoch 3 Loss: 17.6738
[32m2025-12-03 19:04:09[0m [1;30m[INFO][0m Epoch 4 Loss: 17.8440
[32m2025-12-03 19:04:54[0m [1;30m[INFO][0m Epoch 4 Loss: 17.9259
[32m2025-12-03 19:09:11[0m [1;30m[INFO][0m Epoch 5 Loss: 17.7236
[32m2025-12-03 19:09:57[0m [1;30m[INFO][0m Epoch 5 Loss: 17.3269
[32m2025-12-03 19:14:15[0m [1;30m[INFO][0m Epoch 6 Loss: 17.6148
[32m2025-12-03 19:15:00[0m [1;30m[INFO][0m Epoch 6 Loss: 18.0420
[32m2025-12-03 19:19:18[0m [1;30m[INFO][0m Epoch 7 Loss: 17.5254
[32m2025-12-03 19:20:04[0m [1;30m[INFO][0m Epoch 7 Loss: 17.4629
[32m2025-12-03 19:24:20[0m [1;3

# Testing

In [None]:
def predict_next_observation(model, vae, observation, input_state, hidden):
    model.eval()
    with torch.no_grad():
        observation_encoded, _, _ = vae.encode(observation.unsqueeze(0))
        input = torch.cat((observation_encoded, input_state.unsqueeze(0)), dim=1)
        pi, sigma, mu, new_hidden = model(input.unsqueeze(0), hidden)
        categorical = torch.distributions.Categorical(pi)
        k = categorical.sample().item()
        chosen_mu = mu[:, :, k, :]
        chosen_sigma = sigma[:, :, k, :]
        predicted_state_encoded = torch.distributions.Normal(chosen_mu, chosen_sigma).sample()
        predicted_observation_encoded = predicted_state_encoded[:, :, :OBSERVATION_REPRESENTATION_DIM]
        predicted_reward = predicted_state_encoded[:, :, -1]
        predicted_observation = vae.decode(predicted_observation_encoded.squeeze(1))
        return predicted_observation.squeeze(0), predicted_reward, new_hidden

In [None]:
def show_observation_comparison(actual, predicted):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(actual.permute(1, 2, 0).cpu().numpy())
    axes[0].set_title('Actual Observation')
    axes[0].axis('off')
    axes[1].imshow(predicted.permute(1, 2, 0).cpu().numpy())
    axes[1].set_title('Predicted Observation')
    axes[1].axis('off')
    plt.show()

In [None]:
with torch.no_grad():
    h0 = torch.zeros(1, 1, 256).to(DEVICE)
    c0 = torch.zeros(1, 1, 256).to(DEVICE)
    hidden = (h0, c0)
    predicted_observation = torch.zeros_like(example_observation)
    predicted_reward = 0
    observation_for_prediction = None

    example_batch = next(iter(test_dataloader))
    example_observation_sequences, example_input_state_sequences, example_output_state_sequences = example_batch
    example_observations = example_observation_sequences[0].to(DEVICE)
    example_input_states = example_input_state_sequences[0].to(DEVICE)

    for step in range(len(example_observations)):
        observation = example_observations[step]
        input_state = example_input_states[step]
        input_state[3] = predicted_reward
        if observation_for_prediction is None:
            observation_for_prediction = observation
        show_observation_comparison(observation, predicted_observation)
        time.sleep(0.1)
        predicted_observation, predicted_reward, hidden = predict_next_observation(model, vae, observation_for_prediction, input_state, hidden)
        observation_for_prediction = predicted_observation

KeyboardInterrupt: 

In [None]:
# import matplotlib.pyplot as plt
# from matplotlib import animation
# from IPython.display import HTML

# # Setup for prediction
# h0 = torch.zeros(1, 1, 256).to(DEVICE)
# c0 = torch.zeros(1, 1, 256).to(DEVICE)
# hidden = (h0, c0)
# predicted_observation = torch.zeros_like(example_observation)
# observation_for_prediction = None


# frames = []

# fig, axes = plt.subplots(1, 2, figsize=(10, 5))
# plt.close()

# def draw_frame(step_data):
#     step, data = step_data
#     global observation_for_prediction, predicted_observation, hidden # Access globals or pass them in class
#     im1 = axes[0].imshow(observation.permute(1, 2, 0).cpu().numpy())
#     axes[0].set_title('Actual Observation')
#     axes[0].axis('off')
#     im2 = axes[1].imshow(predicted_observation.permute(1, 2, 0).cpu().numpy())
#     axes[1].set_title('Predicted Observation')
#     axes[1].axis('off')
#     return [im1, im2]

# # Let's use ArtistAnimation, it's often simpler for pre-computed loops or simpler logic.
# # Actually, capturing the plot as an image array for `imageio` might be robust but requires `imageio`.
# # Let's stick to `matplotlib.animation.ArtistAnimation`.

# # Re-thinking: The loop needs to run to generate predictions sequentially.
# # So, run the loop, collect the two images (actual, predicted) at each step as a list of artists [im1, im2], then animate.

# frames = []
# fig, axes = plt.subplots(1, 2, figsize=(10, 5))


# for step, data in enumerate(iter(test_dataset)):
#     if step >= SEQ_LEN:
#         break
#     observation, input_state, output_state = data
#     observation = observation.unsqueeze(0).unflatten(1, (3, OBSERVATION_DIM, OBSERVATION_DIM)).squeeze(0)
#     observation = observation.to(DEVICE)
#     input_state = input_state.to(DEVICE)

#     if observation_for_prediction is None:
#         observation_for_prediction = observation
#     im_actual = axes[0].imshow(observation.permute(1, 2, 0).cpu().numpy(), animated=True)
#     im_pred = axes[1].imshow(predicted_observation.permute(1, 2, 0).cpu().numpy(), animated=True)
#     frames.append([im_actual, im_pred])
#     predicted_observation, hidden = predict_next_observation(model, vae, observation_for_prediction, input_state, hidden)
#     observation_for_prediction = predicted_observation



In [None]:
# from IPython.display import HTML

# axes[0].set_title('Actual Observation')
# axes[0].axis('off')
# axes[1].set_title('Predicted Observation')
# axes[1].axis('off')
# anim = animation.ArtistAnimation(fig, frames, interval=100, blit=True, repeat_delay=1000)
# HTML(anim.to_jshtml())