# Setup Folders

In [1]:
import os

project_folder = "./"
data_folder = os.path.join(project_folder, "data")
output_folder = os.path.join(project_folder, "weights/vae")
os.makedirs(output_folder, exist_ok=True)

local_data_folder = None

# Training Settings

In [2]:
EPOCHS = 100
LR = 1e-3
TRAIN_RATIO = 0.8
BATCH_SIZE = 4096
NUM_PRELOAD_FILES = 10
NUM_DATASET_WORKERS = 8

IMAGE_CHANNELS = 3
OBSERVATION_DIM = 64
HIDDEN_DIM = 1024
REPRESENTATION_DIM = 32

TRAIN_IMAGE_LOG_INTERVAL = 100

WANDB_PROJECT = "world-models-paper"
WANDB_RUN_NAME = "vae"

LOG_LEVEL = "INFO"

# Setup

In [3]:
%pip install -q -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [4]:
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader

from src.datasets.observations_dataset import ObservationsDataset
from src.models.vae import ConvVAE
from src.training.early_stopping import EarlyStopping
from src.training.vae import ConvVaeTrainer
from src.utils.torch import get_device
from src.utils.logging import get_logger
from src.utils.secrets import get_secret

In [5]:
logger = get_logger(LOG_LEVEL)

[32m2025-12-03 13:46:19[0m [1;30m[INFO][0m Logger initialized.


In [6]:
DEVICE = get_device(logger)

[32m2025-12-03 13:46:19[0m [1;30m[INFO][0m Using device: mps:0


# Load Dataset

In [7]:
train_dataset, test_dataset = ObservationsDataset.train_test_split(data_folder,
                                                                   local_data_folder=local_data_folder,
                                                                   num_preloaded_files=NUM_PRELOAD_FILES,
                                                                   num_workers=NUM_DATASET_WORKERS,
                                                                   train_ratio=TRAIN_RATIO,
                                                                   shuffle_files=True,
                                                                   shuffle_file_samples=True,
                                                                   logger=logger)

In [8]:
train_size = len(train_dataset)
test_size = len(test_dataset)
logger.info(f"Train: {train_size}")
logger.info(f"Test: {test_size}")

[32m2025-12-03 13:46:19[0m [1;30m[INFO][0m Train: 1239760
[32m2025-12-03 13:46:19[0m [1;30m[INFO][0m Test: 304384


In [9]:
example_observation = next(train_dataset)
logger.info(example_observation.shape)

[32m2025-12-03 13:46:19[0m [1;30m[INFO][0m torch.Size([3, 64, 64])


In [10]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

In [11]:
train_batches = len(train_dataloader)
test_batches = len(test_dataloader)
logger.info(f"Train batches: {train_batches}")
logger.info(f"Test batches: {test_batches}")

[32m2025-12-03 13:46:19[0m [1;30m[INFO][0m Train batches: 303
[32m2025-12-03 13:46:19[0m [1;30m[INFO][0m Test batches: 75


# Train

In [None]:
model = ConvVAE(image_channels=IMAGE_CHANNELS, h_dim=HIDDEN_DIM, z_dim=REPRESENTATION_DIM).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
early_stopping = EarlyStopping(tolerance=5, min_delta=0.01)

In [13]:
wandb_setup = {
    "api_key": get_secret('wandbApiKey'),
    "project": WANDB_PROJECT,
    "run_name": WANDB_RUN_NAME,
    "config": {
        "epochs": EPOCHS,
        "batch_size_loader": BATCH_SIZE,
        "learning_rate": LR,
        "train_ratio": TRAIN_RATIO,
        "hidden_dim": HIDDEN_DIM,
        "representation_dim": REPRESENTATION_DIM,
        "architecture": "CONV-VAE",
        "train_dataset_size": train_size,
        "test_dataset_size": test_size,
        "train_batches": train_batches,
        "test_batches": test_batches,
        "preload_files": NUM_PRELOAD_FILES,
        "num_dataset_workers": NUM_DATASET_WORKERS
    }
}

In [14]:
trainer = ConvVaeTrainer(model=model,
                         weights_folder=output_folder,
                         train_dataloader=train_dataloader,
                         optimizer=optimizer,
                         num_epochs=EPOCHS,
                         batch_size=BATCH_SIZE,
                         load_checkpoint=True,
                         max_norm=0.1,
                         device=DEVICE,
                         test_dataloader=test_dataloader,
                         early_stopper=early_stopping,
                         wandb_setup=wandb_setup,
                         logger=logger,
                         train_image_log_interval=TRAIN_IMAGE_LOG_INTERVAL)

[32m2025-12-03 13:46:20[0m [1;30m[INFO][0m Resuming training from: ./weights/vae/epoch_1.pth
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/henriqueschmitz/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mschhenrique[0m ([33mschhenrique-columbia-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
trainer.train()

Epoch:   0%|          | 0/100 [00:00<?, ?epoch/s]

Train Epoch 2:   0%|          | 0/303 [00:00<?, ?batch/s]

Test Epoch 2:   0%|          | 0/75 [00:00<?, ?batch/s]

[32m2025-12-03 14:07:54[0m [1;30m[INFO][0m Epoch 2 Loss: 17.5037
[32m2025-12-03 14:10:36[0m [1;30m[INFO][0m Epoch 2 Loss: 17.0334
[32m2025-12-03 14:30:07[0m [1;30m[INFO][0m Epoch 3 Loss: 17.2877
[32m2025-12-03 14:32:50[0m [1;30m[INFO][0m Epoch 3 Loss: 17.5538
[32m2025-12-03 14:56:54[0m [1;30m[INFO][0m Epoch 4 Loss: 17.1905
[32m2025-12-03 14:59:38[0m [1;30m[INFO][0m Epoch 4 Loss: 16.9457
[32m2025-12-03 15:23:18[0m [1;30m[INFO][0m Epoch 5 Loss: 17.1288
[32m2025-12-03 15:26:11[0m [1;30m[INFO][0m Epoch 5 Loss: 16.9387
[32m2025-12-03 15:49:21[0m [1;30m[INFO][0m Epoch 6 Loss: 17.2421
[32m2025-12-03 15:52:23[0m [1;30m[INFO][0m Epoch 6 Loss: 17.2847
[32m2025-12-03 16:16:57[0m [1;30m[INFO][0m Epoch 7 Loss: 17.0444
[32m2025-12-03 16:19:55[0m [1;30m[INFO][0m Epoch 7 Loss: 17.0579
[32m2025-12-03 16:42:28[0m [1;30m[INFO][0m Epoch 8 Loss: 17.0337
[32m2025-12-03 16:45:09[0m [1;30m[INFO][0m Epoch 8 Loss: 16.8271
[32m2025-12-03 17:08:05[0m [1;3

# Testing

In [None]:
def test_observation(model, observation):
    model.eval()
    with torch.no_grad():
        z, _, _ = model.encode(observation.unsqueeze(0))
        decoded = model.decode(z)
    _, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(observation.permute(1, 2, 0).cpu().numpy())
    axes[0].set_title('Original Observation')
    axes[0].axis('off')
    axes[1].imshow(decoded.squeeze(0).permute(1, 2, 0).cpu().numpy())
    axes[1].set_title('Decoded Observation')
    axes[1].axis('off')
    plt.show()

In [None]:
i = 0
for example_observation in iter(test_dataset):
    test_observation(model, example_observation)
    i += 1
    if i == 10:
        break

RuntimeError: slow_conv2d_forward_mps: input(device='cpu') and weight(device=mps:0')  must be on the same device