In [1]:
from typing import Callable

import gc
import os
import glob
import time
import shutil
import subprocess
import concurrent.futures

import numpy as np
from PIL import Image
from tqdm import tqdm
from matplotlib.pyplot import cm

from stats import DiscordWebhook
from parse_data.helper_func import load_merged_records

In [2]:
end_padding_time = 2                # Number of seconds to repeat for the last frame
fps = 35                            # Frame rate, 35 for Doom's default settings
cmap = cm.jet                       # Colour map for semantic segmentation
save_path = "captures/media"        # Directory to save videos
tmp_path = f"{save_path}/tmp"       # Directory to save temporary images

In [3]:
# NVENC is faster & on GPU but often less optimized in file size
conversion_command = f"ffmpeg -framerate {fps} -i %s -c:v libx264 -crf 21 -preset fast -f mp4 -pix_fmt yuv420p %s"
# conversion_command = f"ffmpeg -framerate {fps} -hwaccel_output_format cuda -i %s -c:v h264_nvenc -crf 21 -preset fast -f mp4 -pix_fmt yuv420p %s"

In [4]:
def create_pillow_frames(last_ep_end: int, ep: int, obs: np.ndarray, img_type: str="rs", 
                         scale: int=0) -> list[Image.Image]:
    frames = []
    for i in range(last_ep_end, ep):
        img = []
        for c in img_type:
            if c == 'r':
                img.append(obs[i, :3, :, :].transpose(1, 2, 0))
            elif c == 's':
                img.append(np.array(cmap(obs[i, 3, :, :])[:, :, :3] * 255, dtype=np.uint8))
        if len(img) == 1:
            img = img[0]
        else:
            img = np.concatenate(img)
        if scale:
            img = np.repeat(np.repeat(img, repeats=scale, axis=0), repeats=scale, axis=1)
        frames.append(Image.fromarray(img))
    frames += [frames[-1]] * (fps * end_padding_time).__ceil__()
    return frames

In [5]:
def t_subprocess_call(cmd: str) -> subprocess.CompletedProcess:
    return subprocess.run(cmd, shell=True)

def t_render(last_ep_end: int, ep: int, obs: np.ndarray, score: int, save_name: str, 
             ep_num: int, tmp_n_digits: int|str, img_type: str="rs", scale: int=0) -> subprocess.CompletedProcess:
    frames = create_pillow_frames(last_ep_end, ep, obs, img_type, scale)

    tmp_template = f"{tmp_path}/{ep_num}_%0{tmp_n_digits}d.png"

    for i, img in enumerate(frames):
        img.save(tmp_template %i)
    
    mp4_path = f"{save_name}_ep{ep_num:03d}_{score}.mp4"
    if os.path.exists(mp4_path):
        os.remove(mp4_path)
    
    del img
    gc.collect()

    cmd = conversion_command % (tmp_template, mp4_path)
    return subprocess.run(cmd, shell=True)

In [6]:
def render_as_mp4(pool: concurrent.futures.ThreadPoolExecutor, obs: np.ndarray, ep_ends: np.ndarray, 
                  scores: np.ndarray, save_name: str, img_type: str="rs", parallelize: str="pillow", 
                  scale: int=0) -> dict[concurrent.futures.Future, int]:
    # We only know how many digits to reserve for temporary images at run time
    tmp_n_digits = str(len(str(obs.shape[0])))

    last_ep_end = 0
    worker_jobs = {}

    try:
        if os.path.exists(tmp_path):
            shutil.rmtree(tmp_path)
        os.makedirs(tmp_path)
    except:
        print(f"Failed attempt at cleaning tmp directory, consider manual cleaning later")
    
    num_eps = ep_ends.shape[0]

    first_pass = parallelize in {"ffmpeg", "pillow"}

    for ep_num, ep in (pbar := tqdm(enumerate(ep_ends.tolist(), 1), total=num_eps, ncols=100, leave=True, 
                                    desc="rendering ep 1" if first_pass else "submitting jobs")):
        score = int(scores[ep_num-1]) if len(scores) else ''
        pbar.update(0)
        match parallelize:
            case "ffmpeg":
                frames = create_pillow_frames(last_ep_end, ep, obs, img_type, scale)

                pbar.set_description(f"saving ep {ep_num}")
                tmp_template = f"{tmp_path}/{ep_num}_%0{tmp_n_digits}d.png"
                for i, img in enumerate(frames):
                    img.save(tmp_template %i)
                
                pbar.set_description(f"converting ep {ep_num}")
                mp4_path = f"{save_name}_ep{ep_num:03d}_{score}.mp4"
                if os.path.exists(mp4_path):
                    os.remove(mp4_path)

                cmd = conversion_command % (tmp_template, mp4_path)
                worker_jobs[pool.submit(t_subprocess_call, cmd)] = ep_num
            case "pillow":
                frames = create_pillow_frames(last_ep_end, ep, obs, img_type, scale)
                
                pbar.set_description(f"saving ep {ep_num}")
                tmp_template = f"{tmp_path}/{ep_num}_%0{tmp_n_digits}d.png"

                current_jobs = {
                    pool.submit(img.save, tmp_template %i) : (ep_num, i) for i, img in enumerate(frames)
                }
                worker_jobs.update(current_jobs)
                
                pbar.set_description(f"converting ep {ep_num}")
                mp4_path = f"{save_name}_ep{ep_num:03d}_{score}.mp4"
                if os.path.exists(mp4_path):
                    os.remove(mp4_path)

                cmd = conversion_command % (tmp_template, mp4_path)

                concurrent.futures.wait(current_jobs)
                worker_jobs[pool.submit(t_subprocess_call, cmd)] = ep_num
            case "all":
                worker_jobs[pool.submit(t_render, last_ep_end, ep, obs, score, save_name, ep_num, tmp_n_digits, img_type, scale)] = ep_num
            case _:
                raise NotImplementedError(f"Unknown parallelization option: \"{parallelize}\"")

        last_ep_end = ep + 1
        
        if ep_num != num_eps:
            pbar.set_description(f"rendering ep {ep_num+1}")
        else:
            pbar.set_description("rendering done")
    
    for job in tqdm(concurrent.futures.as_completed(worker_jobs), ncols=100,
                    total=len(worker_jobs), desc="finishing up"):
        job.result()
    
    try:
        if os.path.exists(tmp_path):
            shutil.rmtree(tmp_path)
    except:
        print(f"Failed attempt at cleaning tmp directory, consider manual cleaning later")

    return worker_jobs

In [7]:
map_names = []
for i in range(1, 3+1):
    base_name = "map" + str(i)
    map_variants = [base_name, "rtss_" + base_name]
    for v in "saw":
        map_variants.append(base_name + v)
        map_variants.append(f"rtss_{base_name}{v}")
    map_names.extend(map_variants)

In [8]:
def render_all_mp4s_of_model(pool: concurrent.futures.ThreadPoolExecutor, model_name: str, model_dir: str, scale: int=0,
                             map_names: list[str]=map_names, img_type: str="rs", parallelize: str="pillow") -> list[str]:
    success = []
    for map_name in map_names:
        record_path = f"captures/{map_name}/{model_dir}/record_*.npz"
        if glob.glob(record_path):
            print(f"Loading {model_name}/{map_name}...", flush=True, end='\r')
            try:
                map_name_save = map_name
                sub_dir = '' if img_type == "rs" else f"/{img_type.replace('r', "rgb_").replace('s', "ss_").strip('_')}"
                if map_name.startswith("rtss_"):
                    map_name_save = map_name[5:]
                    save_name = f"{save_path}/{model_name.replace('ss', 'rtss')}{sub_dir}/{map_name_save}/"
                else:
                    save_name = f"{save_path}/{model_name}{sub_dir}/{map_name_save}/"
                if os.path.isdir(save_name) and os.listdir(save_name):
                    print(f"Already done: {model_name}/{map_name}", flush=True, end='\n')
                    continue
                if not os.path.isdir(save_name):
                    os.makedirs(save_name)
                save_name += f"{model_name}_{map_name_save}"
                obs, _, ep_ends, *_, scores = load_merged_records(record_path, load_pos=False, no_tqdm=True)
                gc.collect()
                t = time.time()
                DiscordWebhook.send_msg_no_instance(f"mp4 creation job of {model_name}/{map_name} ({img_type}) starts")
                print(f"Rendering {model_name}/{map_name}...", flush=True, end='\r')
                render_as_mp4(pool, obs, ep_ends, scores, save_name, img_type, parallelize, scale)
                t = round(time.time() - t)
                DiscordWebhook.send_msg_no_instance(f"mp4 creation job of {model_name}/{map_name} ({img_type}) done, took {t//60} min {t%60} sec")
                success.append(map_name)
            except Exception as e:
                print(f"Loading {model_name}/{map_name} failed", flush=True, end='\n')
                DiscordWebhook.send_msg_no_instance(f"mp4 creation job of {model_name}/{map_name} ({img_type}) failed")
                DiscordWebhook.send_error_no_instance(e)
            finally:
                gc.collect()
    DiscordWebhook.send_msg_no_instance(f"All mp4 creation jobs of {model_name} ({img_type}) done!")
    return success

In [9]:
# Pillow's png saving is the bottleneck here, parallelizing it is significantly faster 
# (8x on my device)
# parallelize, max_workers = "all", None
parallelize, max_workers = "pillow", None

with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
    render_all_mp4s_of_model(pool, "ppo_ss_rgb", "ppo_ss_rgb_1e-3_best", img_type="rs", parallelize=parallelize)
    render_all_mp4s_of_model(pool, "ppo_ss_4", "s4_ppo_ss_1e-3_final", img_type="s", parallelize=parallelize)
    render_all_mp4s_of_model(pool, "ppo_rgb", "rgb_9e-4_", img_type="r", parallelize=parallelize)
    render_all_mp4s_of_model(pool, "ppo_ss_4", "s4_ppo_ss_1e-3_final", img_type="rs", parallelize=parallelize)

Already done: ppo_ss_rgb/map1
Already done: ppo_ss_rgb/rtss_map1
Already done: ppo_ss_rgb/rtss_map1a
Already done: ppo_ss_rgb/rtss_map1w
Already done: ppo_ss_rgb/map2s
Already done: ppo_ss_rgb/rtss_map2s
Rendering ppo_ss_4/map1...

rendering done: 100%|███████████████████████████████████████████████| 20/20 [02:39<00:00,  7.95s/it]
finishing up: 100%|███████████████████████████████████████▉| 98286/98288 [00:07<00:00, 10637.97it/s]

In [None]:
if os.path.exists(tmp_path):
    shutil.rmtree(tmp_path)