In [1]:
from typing import Dict
import tempfile
from pathlib import Path
import numpy as np
from mlflow.tracking import MlflowClient

FPS = 10
B, T = 5, 50

def download_artifact_npz(run_id, artifact_path) -> Dict[str, np.ndarray]:
    client = MlflowClient()
    with tempfile.TemporaryDirectory() as tmpdir:
        path = client.download_artifacts(run_id, artifact_path, tmpdir)
        with Path(path).open('rb') as f:
            data = np.load(f)
            return {k: data[k] for k in data.keys()}  # type: ignore

def encode_gif(frames, fps):
    # Copyright Danijar
    from subprocess import Popen, PIPE
    h, w, c = frames[0].shape
    pxfmt = {1: 'gray', 3: 'rgb24', 4: 'rgba24'}[c]
    cmd = ' '.join([
        'ffmpeg -y -f rawvideo -vcodec rawvideo',
        f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex',
        '[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse',
        f'-r {fps:.02f} -f gif -'])
    proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE)
    for image in frames:
        proc.stdin.write(image.tobytes())  # type: ignore
    out, err = proc.communicate()
    if proc.returncode:
        raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')]))
    del proc
    return out

def make_gif(env_name, run_id, step, fps=FPS):
    dest_path = f'figures/dream_{env_name}_{step}.gif'
    artifact = f'd2_wm_dream/{step}.npz'
    data = download_artifact_npz(run_id, artifact)
    img = data['image_pred']
    print(img.shape)
    img = img[:B, :T].reshape((-1, 64, 64, 3))[:,:,:,:3]
    gif = encode_gif(img, fps)
    with Path(dest_path).open('wb') as f:
        f.write(gif)
        
def make_gif_onehot(env_name, run_id, step, fps=FPS):
    dest_path = f'figures/dream_{env_name}_{step}.gif'
    artifact = f'd2_wm_dream/{step}.npz'
    data = download_artifact_npz(run_id, artifact)
    img = data['image_pred']
    print(img.shape, type(img))
    
    print(img.sum(axis=-1))
    img = img.argmax(axis=-1)
    print(img.shape, img[0,0,:20,:20])
    img = np.repeat(np.expand_dims(img * 17, axis=-1), 3, axis=-1)
    print(img.shape, img.max())
    
    
    img = img[:B, :T].reshape((-1, 256, 256, 1))[:,:,:,:3]
    gif = encode_gif(img, fps)
    with Path(dest_path).open('wb') as f:
        f.write(gif)
        

In [2]:
import mlflow

mlflow.get_tracking_uri()

'file:///saivvy/pydreamer/results/atari/mlruns'

In [3]:
mlflow.set_tracking_uri('file:///saivvy/pydreamer/mlruns')

In [None]:
# Montezuma

make_gif('montezuma', '599e69d178ca4f65a10423d272f9f45d', '0500001')

In [None]:
# Breakout

make_gif('breakout', '83e5def4975242ccbf16a3ca8f62a674', '0500001')

In [None]:
# Space invaders

make_gif('invaders', '6d57d49ab844475cbb83b606816b01fe', '0500001')

In [None]:
# DMC quadruped

make_gif('quadruped', 'ff6cb24c04de4e6b821bb811c855d207', '0300001')

In [None]:
# DMLab goals small

make_gif('dmlab', '6f78cce067464e8aa4bcb6f35a1a4386', '0161001', fps=8)

In [None]:
# MiniWorld ScavengerHunt

make_gif('scavenger', '123b575400874f5db75ac7887f4e61c0', '0900001')

In [47]:
make_gif('pong', '3f452afca7204e5a882f68f8b19570eb', '0030001')

(8, 8, 256, 256, 3)


In [None]:
make_gif('miniworld', '7960374dccea44e99f1c574b8d4d3011', '0001001')

In [18]:
make_gif('carla', 'f0f03946308b4699979fee03a08f7e04', '0015001')

(2, 2, 64, 64, 3)


In [19]:
make_gif_onehot('carla', 'e4e9a316956245bdbafe432183e260b1', '0010001')

(2, 2, 128, 128, 13) <class 'numpy.ndarray'>
[[[[0.99999994 1.         1.0000002  ... 0.99999994 0.9999998
    1.        ]
   [1.0000001  0.99999994 1.0000001  ... 1.0000002  1.
    0.99999994]
   [1.0000001  1.         0.9999999  ... 1.         1.
    0.9999999 ]
   ...
   [1.0000001  1.         0.99999994 ... 1.         1.
    1.        ]
   [0.9999999  0.9999999  1.0000001  ... 0.9999999  0.9999999
    1.        ]
   [1.0000001  0.99999994 1.         ... 0.99999994 0.99999994
    1.        ]]

  [[1.0000001  1.0000001  0.99999994 ... 0.9999999  0.9999998
    1.0000001 ]
   [1.0000001  1.         0.9999999  ... 1.0000001  0.99999994
    0.9999999 ]
   [1.0000002  1.         0.99999994 ... 1.         0.9999999
    1.        ]
   ...
   [1.0000001  1.         0.9999999  ... 0.99999994 0.9999998
    1.        ]
   [0.99999994 1.0000001  0.99999994 ... 1.0000001  1.
    1.0000001 ]
   [0.9999998  1.0000002  0.99999994 ... 1.0000001  1.
    0.9999999 ]]]


 [[[0.9999999  0.9999999  1.0000