In [None]:
import torch
from torch import nn
import torchvision.transforms as T
from IPython.display import display, HTML
from PIL import Image
import numpy as np
from pathlib import Path

from rollout_dataset import RolloutDataset,RolloutDataloader,Episode
from latent_dataset import LatentDataset,LatentDataloader,LatentEpisode

from vision import ConvVAE,VisionTrainer
from memory import MDN_RNN,MemoryTrainer
from controller import Controller, ControllerTrainer


Initial solution size: 867
(5_w,11)-aCMA-ES (mu_w=3.4,w_1=42%) in dimension 867 (seed=645381, Fri Dec  6 17:02:22 2024)
Iteration: 0


In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def create_dataset_gif(
    episode: Episode,
    save_path=Path("media/rollout_dataset.gif"),
):
    observations = episode.observations.unsqueeze(0).to(DEVICE)
    scale_factor = 1
    spacing = 1
    img_width, img_height = 64 * scale_factor, 64 * scale_factor
    total_width = img_width 
    total_height = img_height
    images = []
    for t in range(observations.shape[1]):
        original_img = T.Resize((img_height, img_width))(
            T.ToPILImage()(observations[0, t].cpu())
        )
        combined_img = Image.new("RGB", (total_width, total_height), (0, 0, 0))
        combined_img.paste(original_img, (0, 0))
        images.append(combined_img)
    save_path.parent.mkdir(parents=True, exist_ok=True)
    # Save as GIF
    images[0].save(
        save_path,
        save_all=True,
        append_images=images[1:],
        duration=200,  # Increase duration for slower playback
        loop=0,
    )
    print(f"Dataset GIF saved to {save_path}")
def create_vision_gif(
    episode: Episode,
    vision: ConvVAE,
    save_path=Path("media/vae_reconstruction.gif"),
):
    observations = episode.observations.unsqueeze(0).to(DEVICE)
    latents = vision.get_batched_latents(observations)
    vae_reconstructions = vision.decoder(latents.squeeze(0))
    scale_factor = 1
    spacing = 1
    img_width, img_height = 64 * scale_factor, 64 * scale_factor
    total_width = img_width * 2 + spacing * 2
    total_height = img_height

    images = []
    for t in range(vae_reconstructions.shape[0]):
        original_img = T.Resize((img_height, img_width))(
            T.ToPILImage()(observations[0, t].cpu())
        )
        vae_img = T.Resize((img_height, img_width))(
            T.ToPILImage()(vae_reconstructions[t].cpu())
        )
        combined_img = Image.new("RGB", (total_width, total_height), (0, 0, 0))
        combined_img.paste(original_img, (0, 0))
        combined_img.paste(vae_img, (img_width + spacing, 0))
        images.append(combined_img)

    save_path.parent.mkdir(parents=True, exist_ok=True)
    # Save as GIF
    images[0].save(
        save_path,
        save_all=True,
        append_images=images[1:],
        duration=60,  # Increase duration for slower playback
        loop=0,
    )
    print(f"Vae reconstruction GIF saved to {save_path}")
def create_memory_gif(
    episode: Episode,
    vision: ConvVAE,
    memory: MDN_RNN,
    save_path=Path("media/vision_memory_reconstruction.gif"),
    display_gif_in_notebook=False,
):
    observations = episode.observations.unsqueeze(0).to(DEVICE)
    actions = episode.actions.unsqueeze(0).to(DEVICE)

    # Get latent representations from VAE
    latents = vision.get_batched_latents(observations)

    # Initialize RNN hidden state
    hidden_state, cell_state = memory.init_hidden()
    hidden_state = hidden_state.to(DEVICE)
    cell_state = cell_state.to(DEVICE)

    # Generate predictions using MDN-RNN
    predicted_latents = []
    for t in range(latents.shape[1] - 1):
        pi, mu, sigma, hidden_state, cell_state = memory(
            latents[:, t, :], actions[:, t, :], None, None
        )
        predicted_latent = memory.sample_latent(pi, mu, sigma)
        predicted_latents.append(predicted_latent)

    predicted_latents = torch.stack(predicted_latents, dim=1)

    # Decode the latents
    vae_reconstructions = vision.decoder(latents.squeeze(0))
    memory_reconstructions = vision.decoder(predicted_latents.squeeze(0))

    # Set up visualization parameters
    scale_factor = 1
    spacing = 1
    img_width, img_height = 64 * scale_factor, 64 * scale_factor
    total_width = img_width * 3 + spacing * 3
    total_height = img_height

    images = []

    for t in range(vae_reconstructions.shape[0] - 1):
        original_img = T.Resize((img_height, img_width))(
            T.ToPILImage()(observations[0, t].cpu())
        )
        vision_img = T.Resize((img_height, img_width))(
            T.ToPILImage()(vae_reconstructions[t].cpu())
        )
        memory_img = T.Resize((img_height, img_width))(
            T.ToPILImage()(memory_reconstructions[t].cpu())
        )

        combined_img = Image.new("RGB", (total_width, total_height), (0, 0, 0))
        combined_img.paste(original_img, (0, 0))
        combined_img.paste(vision_img, (img_width + spacing, 0))
        combined_img.paste(memory_img, (2 * (img_width + spacing), 0))
        images.append(combined_img)

    save_path.parent.mkdir(parents=True, exist_ok=True)

    # Save as GIF
    images[0].save(
        save_path,
        save_all=True,
        append_images=images[1:],
        duration=60,
        loop=0,
    )
    print(f"VAE and Memory reconstruction GIF saved to {save_path}")

In [1]:
rollout_dataset = RolloutDataset(
        "create",
        num_rollouts=5000,
        max_steps=500,
    )
(
    train_episodes,
    test_episodes,
    val_episodes,
) = torch.utils.data.random_split(rollout_dataset, [0.5, 0.3, 0.2])
training_dataset = RolloutDataset(
    "from",
    episodes=[rollout_dataset.episodes_paths[idx] for idx in train_episodes.indices],
)
test_dataset = RolloutDataset(
    "from",
    episodes=[rollout_dataset.episodes_paths[idx] for idx in test_episodes.indices],
)
val_dataset = RolloutDataset(
    "from",
    episodes=[rollout_dataset.episodes_paths[idx] for idx in val_episodes.indices],
)
train_dataloader = RolloutDataloader(training_dataset, 64)
test_dataloader = RolloutDataloader(test_dataset, 64)
val_dataloader = RolloutDataloader(val_dataset, 64)
episode = rollout_dataset[0]
create_dataset_gif(episode)


NameError: name 'RolloutDataset' is not defined

In [5]:
vision = ConvVAE().to(DEVICE)
vision_trainer = VisionTrainer()
vision_trainer.train(
    vision,
    train_dataloader,
    test_dataloader,
    torch.optim.Adam(vision.parameters()),
    epochs=1,
)
episode = rollout_dataset[0]
create_vision_gif(episode, vision)

Epoch 1/1 | Train Loss: 1.8611 | Test Loss: 3.0302
Model saved to models/vision.pt
Vae reconstruction GIF saved to vae_reconstruction.gif


In [None]:
latent_training_set = LatentDataset(
        training_dataset,
        vision,
        "create",
    )
latent_test_set = LatentDataset(
    test_dataset,
    vision,
    "create",
    )
latent_val_set = LatentDataset(
    val_dataset,
    vision,
    "create",
)

train_dataloader = LatentDataloader(latent_training_set, 64)
test_dataloader = LatentDataloader(latent_test_set, 64)
test_dataloader = LatentDataloader(latent_val_set, 64)
memory = MDN_RNN().to(DEVICE)
memory_trainer = MemoryTrainer()
memory_trainer.train(
    memory,
    train_dataloader,
    test_dataloader,
    torch.optim.Adam(memory.parameters()),
    save_path=Path("models/memory_continuos.pt"),
)
create_memory_gif(episode, vision,memory)

In [None]:
vision = ConvVAE.from_pretrained().to("cpu")
memory = MDN_RNN.from_pretrained().to("cpu")
controller = Controller().to("cpu")
controller_trainer = ControllerTrainer(controller, vision, memory, population_size=11)
controller_trainer.train(3)