In [28]:
import time
from tensordict import TensorDict
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.envs.transforms import ToTensorImage, TransformedEnv
from matplotlib import pyplot as plt
from pathlib import Path
from tqdm import tqdm
from imageio.v3 import imread
from imageio import get_writer

In [29]:
def render_env(env_info: TensorDict, show_fig: bool = True, save_fig: bool = False, save_dir: Path = Path("outputs")):
    env_pixels = env_info.get("pixels").permute(1,2,0).numpy()
    reward = env_info.get('reward').item()
    
    fig, ax = plt.subplots()
    
    ax.set_title(f"Reward: {reward:.2f}")
    ax.axis('off')
    ax.imshow(env_pixels)
    
    if save_fig:
        save_dir.mkdir(parents=True, exist_ok=True)
        
        fig.set_size_inches(env_pixels.shape[1] / 100.0, env_pixels.shape[0] / 100.0)
        
        fig.savefig(save_dir / f"{time.time()}.png", bbox_inches='tight', pad_inches=0)
    
    if show_fig:
        plt.show()
    
    plt.close(fig)

In [30]:
def create_mp4(save_dir: Path = Path("outputs"), output_path: Path = Path("outputs"), fps=2):
    image_paths = save_dir.glob("*.png")
    images = [imread(str(path)) for path in image_paths]
    
    if len(images) == 0:
        print(f"No images found in {save_dir}, skipping mp4 creation!")
        return
    
    output_path = output_path / f"{time.time()}.mp4"

    writer = get_writer(output_path, fps=fps)

    for img in images:
        writer.append_data(img)

    writer.close()

In [34]:
env = DMControlEnv(env_name="manipulator",
                       task_name="insert_ball",
                       from_pixels=True,
                       pixels_only=True)
env = TransformedEnv(env)
env.append_transform(ToTensorImage())

env_info = env.reset()

for i in tqdm(range(10)):
    env_info = env.rand_step()
    render_env(env_info['next'], show_fig=False, save_fig=True)

#create_mp4()

env.close()

100%|██████████| 10/10 [00:02<00:00,  3.82it/s]
