In [1]:
import torch
from IPython.display import HTML
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import gymnasium as gym
import numpy as np
from Models.VAE import VAE
from lightning.pytorch.loggers import WandbLogger
from pathlib import Path

In [2]:
def generate_trajectory(env_name, num_steps, repeat_action=1):
    # Initialize the environment
    env = gym.make(env_name)
    # Reset the environment to get the initial observation
    observation, _ = env.reset()
    
    # Initialize arrays to store the trajectory
    observations = np.zeros((num_steps,) + env.observation_space.shape, dtype=np.float32)
    actions = np.zeros((num_steps,) + env.action_space.shape, dtype=np.float32)
    rewards = np.zeros(num_steps, dtype=np.float32)
    dones = np.zeros(num_steps, dtype=bool)
    for t in range(num_steps):
        # Store the current observation
        observations[t, :] = observation

        if t % repeat_action == 0:
            # Choose a random action (for demonstration purposes)
            action = env.action_space.sample()
        
        # Store the chosen action
        actions[t] = action
        
        # Perform the action and observe the next state and reward
        observation, reward, done, _, _ = env.step(action)
        
        # Store the reward and done flag
        rewards[t] = reward
        dones[t] = done
        
        if done:
            # If the episode is done, break out of the loop
            break
    
    # Close the environment
    env.close()
    
    # Return the trajectory
    return observations[:t+1], actions[:t+1], rewards[:t+1], dones[:t+1]



In [3]:

%matplotlib inline
from matplotlib import pyplot as plt
from matplotlib import animation
from IPython.display import HTML
%matplotlib inline
from matplotlib import pyplot as plt
from matplotlib import animation
from IPython.display import HTML

def display_trajectory_as_video(video, fig_size=(8, 6), margin=0.01, interval=50):
    # np array with shape (frames, height, width, channels)

    fig = plt.figure(figsize=fig_size)
    plt.subplots_adjust(left=margin, right=(1-margin), top=(1-margin), bottom=margin)
    im = plt.imshow(video[0,:,:,:])

    plt.close() # this is required to not display the generated image

    def init():
        im.set_data(video[0,:,:,:])

    def animate(i):
        im.set_data(video[i,:,:,:])
        return im

    anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0],
                                interval=interval)
    return HTML(anim.to_html5_video())


In [9]:
env_name = 'CarRacing-v2'
num_steps = 1200
envs_n = 10
trajs = {'observations': [], 'actions': [], 'rewards': [], 'dones': []}
skip_frame = 4
repeat_action = 200
start_frame = 15

for _ in range(envs_n):
    observations, actions, rewards, dones = generate_trajectory(env_name, num_steps, repeat_action=repeat_action)

    trajs['observations'].append(observations[start_frame::skip_frame])
    trajs['actions'].append(actions[start_frame::skip_frame])
    trajs['rewards'].append(rewards[start_frame::skip_frame])
    trajs['dones'].append(dones[start_frame::skip_frame])


In [16]:

observations, actions, rewards, dones = generate_trajectory(env_name, num_steps, repeat_action=repeat_action)

trajs['observations'][3] = observations[start_frame::skip_frame]
trajs['actions'][3] = actions[start_frame::skip_frame]
trajs['rewards'][3] = rewards[start_frame::skip_frame]
trajs['dones'][3] = dones[start_frame::skip_frame]


In [17]:
trajs['observations'][3].shape

(297, 96, 96, 3)

In [18]:

grid = np.concatenate([obs for obs in trajs['observations']], axis=2)

grid = grid.astype(np.int16)
display_trajectory_as_video(grid, fig_size=(12, 3), margin=0.05)


In [29]:

wandb_logger = WandbLogger(log_model="all")
vae_checkpoint_reference = "team-good-models/lightning_logs/model-pbf8ubmb:v17"
vae_dir = wandb_logger.download_artifact(vae_checkpoint_reference, artifact_type="model")
encoding_model = VAE.load_from_checkpoint(Path(vae_dir) / "model.ckpt")


[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [30]:
trajs['observations'][4].shape


(297, 96, 96, 3)

In [31]:
from Utils.TransformerWrapper import Crop, TransformWrapper
from torchvision import transforms

transform = transforms.Compose([
        transforms.ToTensor(), 
        transforms.Grayscale(), 
        transforms.Normalize((0.5,), (0.5,)), 
        Crop(bottom=-15),
        transforms.Resize((64, 64), antialias=True),
        ])
trans_obs = []
for obs in trajs['observations']:
    trans_x = torch.stack([transform(x/255.) for x in obs])
    trans_obs.append(trans_x)
recon_obs = torch.concat([encoding_model(obs)[0] for obs in trans_obs], dim=3)
trans_obs = torch.concat(trans_obs, dim=3)

grid = torch.concat([trans_obs, recon_obs,], dim=-2).detach().numpy()
grid = np.moveaxis(grid, -3, -1)

display_trajectory_as_video(grid, fig_size=(12, 3), margin=0.05)


In [None]:
from Models.MDNRNN import MDNRNN


mdnrnn_checkpoint_reference = "team-good-models/model-registry/WorldModelMDNRNN:latest"
mdnrnn_dir = wandb_logger.download_artifact(mdnrnn_checkpoint_reference, artifact_type="model")
mdnrnn = MDNRNN.load_from_checkpoint(Path(mdnrnn_dir) / "model.ckpt", strict=False)


[34m[1mwandb[0m:   1 of 1 files downloaded.  
/home/jukebox/miniconda3/envs/skrl/lib/python3.11/site-packages/lightning/pytorch/core/saving.py:188: Found keys that are not in the model state dict but in the checkpoint: ['encoding.conv1.weight', 'encoding.conv1.bias', 'encoding.conv2.weight', 'encoding.conv2.bias', 'encoding.conv3.weight', 'encoding.conv3.bias', 'encoding.conv4.weight', 'encoding.conv4.bias', 'encoding.fc_mu.weight', 'encoding.fc_mu.bias', 'encoding.fc_logsigma.weight', 'encoding.fc_logsigma.bias']


In [None]:
from Utils.TransformerWrapper import Crop, TransformWrapper
from torchvision import transforms

def predict_future(batch):
    _, _, latents = encoding_model.encoder(batch)
    hidden = mdnrnn.initial_state(batch_size=1)

    future_latent = []
    for action, latent in zip(actions, latents):            
        action = torch.Tensor(action).unsqueeze(dim=0)
        latent = latent.unsqueeze(dim=0)
        mu, sigma, logpi, r, d, hidden = mdnrnn.cell(action, latent, hidden)
        future_latent.append(mdnrnn.cell.sample(mu, sigma, logpi))
    
    future_latent = torch.concat(future_latent, dim=0)

    return encoding_model.decoder(future_latent)


trans_obs = []
for obs in trajs['observations']:
    trans_x = torch.stack([transform(x/255.) for x in obs])
    trans_obs.append(trans_x)
recon_obs = torch.concat([encoding_model(obs)[0] for obs in trans_obs], dim=3)
future_states = torch.concat([predict_future(obs) for obs in trans_obs], dim=3)
trans_obs = torch.concat(trans_obs, dim=3)


grid = torch.concat([trans_obs, recon_obs, future_states], dim=-2).detach().numpy()
grid = np.moveaxis(grid, -3, -1)

display_trajectory_as_video(grid, fig_size=(12, 6), margin=0.05, interval=100)
