In [18]:
import torch
from IPython.display import HTML
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import gymnasium as gym
import numpy as np
from Models.VAE import VAE
from lightning.pytorch.loggers import WandbLogger
from pathlib import Path

In [19]:
def generate_trajectory(env_name, num_steps, repeat_action=1):
    # Initialize the environment
    env = gym.make(env_name)
    # Reset the environment to get the initial observation
    observation, _ = env.reset()
    
    # Initialize arrays to store the trajectory
    observations = np.zeros((num_steps,) + env.observation_space.shape, dtype=np.float32)
    actions = np.zeros((num_steps,) + env.action_space.shape, dtype=np.float32)
    rewards = np.zeros(num_steps, dtype=np.float32)
    dones = np.zeros(num_steps, dtype=bool)
    for t in range(num_steps):
        # Store the current observation
        observations[t, :] = observation

        if t % repeat_action == 0:
            # Choose a random action (for demonstration purposes)
            action = env.action_space.sample()
        
        # Store the chosen action
        actions[t] = action
        
        # Perform the action and observe the next state and reward
        observation, reward, done, _, _ = env.step(action)
        
        # Store the reward and done flag
        rewards[t] = reward
        dones[t] = done
        
        if done:
            # If the episode is done, break out of the loop
            break
    
    # Close the environment
    env.close()
    
    # Return the trajectory
    return observations[:t+1], actions[:t+1], rewards[:t+1], dones[:t+1]



In [32]:

%matplotlib inline
from matplotlib import pyplot as plt
from matplotlib import animation
from IPython.display import HTML

def display_trajectory_as_video(observations, fig_size=(8, 6)):
    # np array with shape (frames, height, width, channels)
    video = observations.astype(np.int16)

    fig = plt.figure(figsize=fig_size)
    plt.tight_layout()
    im = plt.imshow(video[0,:,:,:])

    plt.close() # this is required to not display the generated image

    def init():
        im.set_data(video[0,:,:,:])

    def animate(i):
        im.set_data(video[i,:,:,:])
        return im

    anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0],
                                interval=50)
    return HTML(anim.to_html5_video())


In [33]:
env_name = 'CarRacing-v2'
num_steps = 300
envs_n = 10
trajs = {'observations': [], 'actions': [], 'rewards': [], 'dones': []}

for _ in range(envs_n):
    observations, actions, rewards, dones = generate_trajectory(env_name, num_steps, repeat_action=100)

    trajs['observations'].append(observations)
    trajs['actions'].append(actions)
    trajs['rewards'].append(rewards)
    trajs['dones'].append(dones)


In [34]:

grid = np.concatenate([obs for obs in trajs['observations']], axis=2)

display_trajectory_as_video(grid)


In [None]:

wandb_logger = WandbLogger(log_model="all")
vae_checkpoint_reference = "team-good-models/model-registry/WorldModelVAE:v1"
vae_dir = wandb_logger.download_artifact(vae_checkpoint_reference, artifact_type="model")
encoding_model = VAE.load_from_checkpoint(Path(vae_dir) / "model.ckpt")


[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [None]:
from Utils.TransformerWrapper import TransformWrapper


trans_obs = [TransformWrapper.transform(obs) for obs in trajs['observations']]
recon_obs = [encoding_model(obs) for obs in trans_obs]

grid = np.concatenate([observations, recon_obs, ], 2)

In [None]:
import cv2

size = (96,96)

out = cv2.VideoWriter('project_brown.mp4',cv2.VideoWriter_fourcc(*'DIVX'),15, size)

for i in range(len(img_array)):
    rgb_img = cv2.cvtColor(img_array[i], cv2.COLOR_RGB2BGR)
    out.write(rgb_img)
out.release()