In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import functools
import jax
import os
import pandas as pd
import mlflow
import pickle

from datetime import datetime
from jax import numpy as jnp
import matplotlib.pyplot as plt

from IPython.display import HTML, clear_output
import mediapy

import brax

import flax
from brax import envs
from brax.io import model
from brax.io import json
from brax.io import html

# other envs
from task_aware_skill_composition.brax.envs.point import Point
from task_aware_skill_composition.brax.envs.car import Car
from task_aware_skill_composition.brax.envs.drone import Drone
from task_aware_skill_composition.brax.envs.point import Point
from task_aware_skill_composition.brax.envs.doggo import Doggo

# tasks
from task_aware_skill_composition.brax.tasks import get_task

In [3]:
mlflow.set_tracking_uri(f"file:///home/tassos/.local/share/mlflow")

In [4]:
backend = 'mjx'

# env = Car(backend=backend)
# env = Drone(backend=backend)
# env = Point(backend=backend)
# env = Doggo(backend=backend)

# env = envs.get_environment(env_name="hopper", backend=backend)
# env = AntMaze(backend=backend)

# env_tag = type(env).__name__

task = get_task("point", "two_goals")
env = task.env

jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

2024-11-05 18:41:58.863064: W external/xla/xla/service/gpu/nvptx_compiler.cc:930] The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version 12.6.77. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


# Visualizing Policy

In [5]:
from brax.training.agents.ppo import networks as ppo_networks
from jaxgcrl import networks as crl_networks
from brax.training.acme import running_statistics

In [6]:
training_run_id = "9a1bae66199d4e4aba59f14b333f2e5a"
logged_model_path = f'runs:/{training_run_id}/policy_params'
real_path = mlflow.artifacts.download_artifacts(logged_model_path)
params = model.load_params(real_path)
# normalizer_params, policy_params = model.load_params(real_path)
# normalizer_params, policy_params, crl_critic_params = model.load_params(real_path)

In [7]:
run = mlflow.get_run(run_id=training_run_id)
if run.data.params["normalize_observations"] == "True":
    normalize = running_statistics.normalize
else:
    normalize = lambda x, y: x

In [8]:
# reset the environment
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)

  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


In [9]:
# Making the network
ppo_network = ppo_networks.make_ppo_networks(
      state.obs.shape[0],
      env.action_size,
      preprocess_observations_fn=normalize
)
make_policy = ppo_networks.make_inference_fn(ppo_network)

# # Making the network
# crl_network = crl_networks.make_crl_networks(
#     env=env,
#     observation_size=env.observation_size,
#     action_size=env.action_size,
#     repr_dim=run.data.params["repr_dim"],
#     preprocess_observations_fn=normalize,
#     hidden_layer_sizes=[int(run.data.params["h_dim"])] * int(run.data.params["n_hidden"]),
#     use_ln=bool(run.data.params["use_ln"]),
# )
# make_policy = crl_networks.make_inference_fn(crl_network)

# make_policy = make_inference_fn

In [10]:
inference_fn = make_policy(params)
jit_inference_fn = jax.jit(inference_fn)

In [11]:
# grab a trajectory acting according to the policy function
rollout = [state]
n_steps = 50
render_every = 1

for i in range(n_steps):
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_step(state, ctrl)

    if state.done:
        break

    # print(type(state.pipeline_state))
    
    rollout.append(state)

  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


In [14]:
mediapy.show_video(
    env.render(
        [s.pipeline_state for s in rollout],
        camera='overview'
    ), fps=1.0 / env.dt
)

0
This browser does not support the video tag.


In [13]:
HTML(html.render(env.sys, [rollout[1].pipeline_state]))