In [17]:
from typing import Dict
import tempfile
from pathlib import Path
import numpy as np
from mlflow.tracking import MlflowClient

FPS = 10
# Use the mlruns folder at the repository root as the tracking store
# (absolute path ensures MlflowClient won't try to resolve a relative path under the notebook cwd)
TRACKING_URI = 'file:///home/henriquesouza/IA367/IA367-pydreamer/mlruns'

B, T = 5, 50

def download_artifact_npz(run_id, artifact_path) -> Dict[str, np.ndarray]:
    # Create the client with an explicit tracking URI pointing to the repo-root mlruns
    client = MlflowClient(tracking_uri=TRACKING_URI)
    with tempfile.TemporaryDirectory() as tmpdir:
        path = client.download_artifacts(run_id, artifact_path, tmpdir)
        with Path(path).open('rb') as f:
            data = np.load(f)
            return {k: data[k] for k in data.keys()}  # type: ignore

def encode_gif(frames, fps):
    # Copyright Danijar
    from subprocess import Popen, PIPE
    h, w, c = frames[0].shape
    pxfmt = {1: 'gray', 3: 'rgb24'}[c]
    cmd = ' '.join([
        'ffmpeg -y -f rawvideo -vcodec rawvideo',
        f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex',
        '[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse',
        f'-r {fps:.02f} -f gif -'])
    proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE)
    for image in frames:
        proc.stdin.write(image.tobytes())  # type: ignore
    out, err = proc.communicate()
    if proc.returncode:
        raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')]))
    del proc
    return out

def make_gif(env_name, run_id, step, fps=FPS, output_dir='myfigs'):
    # Criar pasta de destino se não existir
    Path(output_dir).mkdir(exist_ok=True)
    
    dest_path = f'{output_dir}/dream_{env_name}_{step}.gif'
    artifact = f'd2_wm_dream/{step}.npz'
    
    print(f"Baixando artifact: {artifact}")
    print(f"Run ID: {run_id}")
    
    data = download_artifact_npz(run_id, artifact)
    img = data['image_pred']
    print(f"Dimensões da imagem: {img.shape}")
    
    img = img[:B, :T].reshape((-1, 64, 64, 3))
    gif = encode_gif(img, fps)
    
    with Path(dest_path).open('wb') as f:
        f.write(gif)
    
    print(f"GIF salvo em: {dest_path}")

In [18]:
make_gif('pong', '58d45ccf5a334b369cc2c57fc834332a', '0258001') # último parametro está em d2_wm_dream

Baixando artifact: d2_wm_dream/0258001.npz
Run ID: 58d45ccf5a334b369cc2c57fc834332a


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Dimensões da imagem: (32, 48, 64, 64, 3)
GIF salvo em: myfigs/dream_pong_0258001.gif
