In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from PIL import Image, ImageDraw
from IPython import display
from lib.viz import showarray

from stable_baselines import SAC
from stable_baselines.common.vec_env import VecNormalize
from stable_baselines.common.cmd_util import make_vec_env

import os, time, json
os.environ["MLFLOW_TRACKING_URI"] = "sqlite:///mlruns/db.sqlite"
import mlflow
mlflow_client = mlflow.tracking.MlflowClient()

from lib import eos
from lib.eos import EyeOnStickEnv
from lib.run import run_env_nsteps

In [None]:
NJ = 4
model_name, model_version = f"eos.{NJ}J", None
model, model_source = None, None

metrics = {}
while True:
    if model_version is not None:
        registered_model = mlflow_client.get_model_version(model_name, model_version)
    else:
        registered_model = mlflow_client.get_latest_versions(model_name, stages=["None"])[0]
        
    #if model_source is None or model_source != registered_model.source:
    model_source = registered_model.source
    actual_model_version = registered_model.version
    model = SAC.load(model_source)
    
    params_fname = f'{model_source}.json'
    with open(params_fname, 'r') as fp:
        params = json.load(fp)    
        
    env = make_vec_env(lambda: EyeOnStickEnv(NJ, params), n_envs=1)
    model.set_env(env)
    env.env_method('set_render_info', {'model_name': model_name, 'model_version': model_version, 'actual_model_version': actual_model_version})
    
    def displayfunc(img_array):
        display.clear_output(wait=True)

        dashboard_img = Image.new('RGB', (img_array.shape[1], img_array.shape[0]))
        dashboard_draw = ImageDraw.Draw(dashboard_img)
        
        if True:
            def draw_text(txt, vpos=0):
                vpos += 1
                dashboard_draw.text((10, 10*vpos), txt)
                return vpos
            
            vpos = 0
            for key, val in metrics.items():
                vpos = draw_text(f'{key:15s} {val:+.4f}', vpos=vpos)
            
        dashboard_img_array = np.asarray(dashboard_img)
        img_array = np.vstack((img_array, dashboard_img_array))

        showarray(img_array)
        
    metrics, data = run_env_nsteps(env, model, params['MAX_NSTEPS'], displayfunc=displayfunc, wait=0.05)