In [1]:
from typing import Callable

import gc
import os
import glob
import shutil
import subprocess

import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
from matplotlib.pyplot import cm

In [2]:
end_padding_time = 2                # Number of seconds to repeat for the last frame
fps = 35                            # Frame rate, 35 for Doom's default settings
cmap = cm.jet                       # Colour map for semantic segmentation
save_path = "logs/media"            # Directory to save videos
tmp_path = f"{save_path}/tmp"       # Directory to save temporary images

In [3]:
def load_merge(path_: str, load_obs: bool=True, load_pos: bool=True, obs_transform: Callable=lambda x: x):
    obs, pos, ep_ends, weapon = [], [], [], []
    ep_offset = 0
    matching_files = glob.glob(path_)
    if len(matching_files) < 1:
        raise FileNotFoundError(f"No match for glob expression: {path_}")
    for f in tqdm(matching_files, desc="loading saved data"):
        x = np.load(f)
        if load_obs:
            obs.append(x["obs"])
            if "ep_ends" in x:
                ep_ends.append(np.asarray(x["ep_ends"], dtype=np.uint64) + ep_offset)
                ep_offset += len(obs[-1])
        if load_pos:
            pos.append(obs_transform(x["feats"]))
        if "weapon" in x:
            weapon.append(x["weapon"])
    return (np.concatenate(obs, axis=0) if load_obs else [], 
            np.concatenate(pos, axis=0) if load_pos else [],
            np.concatenate(ep_ends, axis=0) if ep_ends else [],
            np.concatenate(weapon, axis=0) if weapon else [])

In [4]:
def render_as_gif(obs: np.ndarray, ep_ends: np.ndarray, save_name: str, scale: int=0, ep_num_: int=None, format: str=None) -> list[Image.Image]:
    last_ep_end = 0

    for ep_num, ep in (pbar := tqdm(enumerate(ep_ends.tolist(), 1), total=ep_ends.shape[0], 
                                    desc="rendering ep 1")): #, ncols=100)):
        if ep_num_ is None or ep_num == ep_num_:
            frames = []
            for i in range(last_ep_end, ep):
                mapped = np.array(cmap(obs[i, 3, :, :])[:, :, :3] * 255, dtype=np.uint8)
                if format is None:
                    img = np.concatenate([obs[i, :3, :, :].transpose(1, 2, 0), mapped])
                elif format == "rgb":
                    img = obs[i, :3, :, :].transpose(1, 2, 0)
                elif format == "ss":
                    img = mapped
                if scale:
                    img = np.repeat(np.repeat(img, repeats=scale, axis=0), repeats=scale, axis=1)
                frames.append(Image.fromarray(img))
            frames += [frames[-1]] * (fps * end_padding_time).__ceil__()
            pbar.set_description(f"saving ep {ep_num}")

            gif: Image.Image = frames[0]
            gif.save(f"{save_path}/{save_name}_ep{ep_num}{'_' + format if format else ''}.gif", save_all=True, append_images=frames[1:], optimize=True, duration=1000/30, loop=0)
        last_ep_end = ep + 1

    return frames

In [5]:
save_name = "rtss_map1"
obs, _, ep_ends, _ = load_merge(f"logs/back3/{save_name}_small_ss_rgb_1e-3/record_0.npz", load_pos=False)
gc.collect()

loading saved data:   0%|          | 0/1 [00:00<?, ?it/s]

82

In [6]:
# render_as_gif(obs, ep_ends, save_name, ep_num_=5, format='rgb');

In [7]:
render_as_gif(obs, ep_ends, save_name, ep_num_=5, format='ss');

rendering ep 1:   0%|          | 0/5 [00:00<?, ?it/s]