- - -
## Data Loading & Augmentation
- - -

In [None]:
import os
from io_agent.runner.iterative import augment_mujoco_dataset, registered_envs
from io_agent.plant.mujoco import Walker2dEnv, HopperEnv, HalfCheetahEnv


for name, env_class in registered_envs.items():
    save_dir = f"./{name}_data/dataset"
    file_name = "rich_augmented"

    if not os.path.exists(os.path.join(save_dir, file_name)):
        env = env_class()
        augment_mujoco_dataset(
            env=env,
            save_dir=save_dir,
            file_name=file_name,
        )


- - -
## Iterative IO Controller Training
- - -

In [None]:
import numpy as np
import torch
import multiprocessing

from io_agent.runner.iterative import run_iterative_io, IterativeIOArgs


n_cpu = multiprocessing.cpu_count()
n_trials = 20
general_seed = 42
seed_rng = np.random.default_rng(general_seed)
*trial_seeds, train_seed = seed_rng.integers(0, 2**30, n_trials + 1)
device = "cuda" if torch.cuda.is_available() else "cpu"


experiment_args = {
    # "Walker-IO-1e4": IterativeIOArgs(
    #     lr_exp_decay=0.9975,
    #     learning_rate=5e-2,
    #     n_batch=64,
    #     data_size=int(1e4),
    #     eval_epochs=tuple(range(0, 2601, 100)),
    #     env_name="walker",
    #     work_dir="./walker_data"),
    "Hopper-IO-5e3": IterativeIOArgs(
        lr_exp_decay=0.995,
        learning_rate=5e-2,
        n_batch=64,
        data_size=int(5e3),
        env_name="hopper",
        eval_epochs=tuple(range(0, 1001, 100)),
        work_dir="./hopper_data"),
    # "Cheetah-IO-1e4": IterativeIOArgs(
    #     lr_exp_decay=0.9925,
    #     learning_rate=5e-2,
    #     n_batch=64,
    #     data_size=int(1e4),
    #     eval_epochs=tuple(range(0, 2601, 100)),
    #     env_name="cheetah",
    #     work_dir="./cheetah_data"),
}

results = {}
for key, args in experiment_args.items():
    costs, epoch_losses, step_losses, iterative_io_agent = run_iterative_io(
        args=args,
        seed=train_seed,
        trial_seeds=trial_seeds,
        name=key,
        device=device,
        verbose=True)
    results[key] = (costs, epoch_losses, step_losses)

- - -
## Visualizing the Training
- - -

In [None]:
# Smoothing with last n steps
from collections import deque
from itertools import chain

from io_agent.plant.mujoco import HopperEnv


queue = deque(maxlen=5)
env = HopperEnv()

smooth_scores = {}
for name, scores in results.items():
    smooth_scores[name] = {}
    for key, values in scores[0].items():
        queue.append(values)
        smooth_scores[name][key] = list(map(lambda x: env.env.get_normalized_score(x) * 100, chain(*queue)))


In [None]:
import numpy as np
from collections import defaultdict

from io_agent.plotter import tube_figure_plt


fig, axes = tube_figure_plt(
    cost_data=smooth_scores,
    title=f"",
    log_xaxis=False,
    log_yaxis=False,
    x_label="epoch",
    y_label="episodic score (%)",
    percentiles=(20, 80)
)

fig, axes = tube_figure_plt(
    cost_data={key: {index + 1: value for index,
                     value in enumerate(value[1])} for key, value in results.items()},
    title=f"",
    log_xaxis=True,
    log_yaxis=True,
    x_label="epoch",
    y_label="sub loss",
    percentiles=(20, 80)
)

fig, axes = tube_figure_plt(
    cost_data={key: {index + 1: value for index,
                     value in enumerate(value[2])} for key, value in results.items()},
    title=f"",
    log_xaxis=True,
    log_yaxis=True,
    x_label="gradient step",
    y_label="batch sub loss",
    percentiles=(20, 80)
)

- - -
## Offline RL Comparison
- - -

In [41]:
from typing import Tuple, Dict, Any
import os
import pandas as pd
import numpy as np
from collections import defaultdict

from io_agent.plotter import tube_figure_plt


log_dir = "offline_rl/logs/comparison/iql/hopper"
log_dir = "offline_rl/logs/lr_decay_0_9625/io/hopper"

# def read_train_logs(log_dir: str) -> Any:


dataframes = defaultdict(dict)
for exp_path_name in os.listdir(log_dir):
    exp_path = os.path.join(log_dir, exp_path_name)
    for seed_name in os.listdir(exp_path):
        date_times = os.listdir(os.path.join(exp_path, seed_name))
        last_exp_time = sorted(date_times)[-1]
        dataframe = pd.read_csv(os.path.join(exp_path, seed_name, last_exp_time, "record/policy_training_progress.csv"))
        dataframes[int(exp_path_name[5:])][int(seed_name[5:])] = dataframe




In [42]:

reward_arrays = [np.array(dataframes[int(5e3)][seed]["eval/normalized_episode_reward"]) for seed in (61,)]
size = min(array.size for array in reward_arrays)
indices = np.argsort(np.stack([array[:size] for array in reward_arrays]).mean(0))

std_arrays = [np.array(dataframes[int(5e3)][seed]["eval/normalized_episode_std"]) for seed in (61,)]
size = min(array.size for array in std_arrays)
std_arrays = np.stack([array[:size] for array in std_arrays])

np.stack([array[:size] for array in reward_arrays]).mean(0)[indices[-50:]].mean()
# std_arrays[:, indices[-50:]].mean()




78.16546329383819

In [29]:
from typing import Any, Dict, List
from itertools import chain
import json
from IPython.display import display


def prepare_progress_data(file_names: List[str]) -> List[Dict[str, Any]]:
    progress_dicts = []
    for index, file_name in enumerate(file_names):
        with open(file_name, "r") as fobj:
            progress_dicts += [{"seed": f"{index+1}", **
                                json.loads(line)} for line in fobj.readlines()]
    return progress_dicts

def vega_multi_seed_experiment(progress_dicts,
                               title: str = "Training Statistics",
                               x_key: str = "time/steps"
                               ) -> Dict[str, Any]:
    names = list(progress_dicts[0].keys())
    return {
        "$schema": "https://vega.github.io/schema/vega-lite/v5.json",
        "description": "Multi seed experiments figure",
        "width": 600,
        "height": 300,
        "data": {"values": progress_dicts},
        "params": [
            {
                "name": "selected_legend",
                "value": f"x_key",
                "bind": {"input": "select", "options": names, "name": "Choose y axis:  "}
            },
        ],
        "title": {
            "text": title,
            "fontSize": 20,
            "fontWeight": 300,
        },
        "transform": [
            {"fold": names, "as": ["legend_name", "legend_value"]},
            {"filter": {"field": "legend_name", "equal": {"expr": "selected_legend"}}},
        ],
        "encoding": {
            "x": {"field": x_key,
                  "type": "quantitative",
                  "axis": {
                          "offset": 10,
                          "titleFontSize": 14,
                          "titleFontWeight": 500,
                          "title": x_key,
                  }
                  },
        },
        "layer": [
            {
                "mark": {"type": "errorband",
                         "opacity": 0.5,
                         "interpolate": "basis",
                         "extent": "iqr",
                         "borders": {
                             "opacity": 0.2,
                             "strokeDash": [1, 1],
                             "color": "gray"
                         }},
                "encoding": {
                    "y": {"field": "legend_value",
                          "type": "quantitative",
                          "axis": {
                              "title": "values",
                              "offset": 10,
                              "titleFontSize": 14,
                              "titleFontWeight": 500,
                          }
                          },
                }
            },
            {
                "mark": {"type": "line", "opacity": 1.0, "interpolate": "basis"},
                "encoding": {
                    "y": {
                        "aggregate": "mean",
                        "field": "legend_value",
                    },
                    "color": {"field": "legend_name", "type": "nominal", "title": "-*-*-*-Traces-*-*-*-"}
                }
            },
            # {
            #     "mark": {"type": "line", "opacity": 0.5, "color": "gray", "strokeDash": [2, 3], "interpolate": "basis"},
            #     "encoding": {
            #         "y": {
            #             "aggregate": "max",
            #             "field": "legend_value",
            #         },
            #         "color": {"field": "legend_value", "type": "nominal", "axis": {"title": "Traces"}}
            #     }
            # },
            # {
            #     "mark": {"type": "line", "opacity": 0.5, "color": "gray", "strokeDash": [2, 3], "interpolate": "basis"},
            #     "encoding": {
            #         "y": {
            #             "aggregate": "min",
            #             "field": "legend_value",
            #         },
            #         "color": {"field": "legend_name", "type": "nominal", "axis": {"title": "Traces"}}
            #     }
            # },
        ]
    }


def vega_notebook_render(schema: Dict[str, Any]) -> None:
    display(
        {"application/vnd.vegalite.v5+json": schema},
        raw=True
    )

data = [{"seed": seed, **item}
        for seed in (61, 62, 63, 64) 
            for item in dataframes[int(1e6)][seed].to_dict(orient="records")]
vega_notebook_render(vega_multi_seed_experiment(
    data,
    x_key="timestep")
)




In [12]:

arrays = [np.array(dataframes[int(5e3)][seed]["eval/normalized_episode_reward"]) for seed in (61, 62, 63, 64)]
size = min(array.size for array in arrays)
np.quantile(np.stack([array[:size] for array in arrays]).mean(0), 1.0)

66.89749678685159

In [13]:
[array.size for array in arrays]

[1000, 1000, 1000, 1000]

In [None]:
!pip install jaxopt
!pip install jaxtyping
!pip install optax


In [None]:
import jax
import jaxopt
import jax.random as jrd
import jax.numpy as jnp


key = jrd.PRNGKey(42)

jnp.linalg.norm(jrd.orthogonal(key, 3), ord=2, axis=0)
# jrd.orthogonal(key, 3)

x = jrd.orthogonal(key, 3) @ jnp.diag(jrd.uniform(key, (3,)) + 1.0001) @ jrd.orthogonal(key, 3).T
jnp.linalg.eigh(x), jrd.uniform(key, (3,)) + 1.0001


In [None]:
jnp.ones(5) @ jnp.ones((5, 3))

In [None]:
from typing import NamedTuple
import jax
import jax.numpy as jnp
import jax.random as jrd
import jaxopt
from jaxtyping import Array, Float


class IOParams(NamedTuple):
    theta_uu: Float[Array, "A A"]
    theta_su: Float[Array, "S A"]


def init_params(key: jrd.KeyArray, action_size: int, state_size: int) -> IOParams:
    eig_val_key, eig_vec_key, su_key = jrd.split(key, 3)

    eig_vecs = jrd.orthogonal(eig_vec_key, action_size)
    eig_vals = jax.nn.softplus(jrd.normal(
        eig_val_key, (action_size,))) + 1
    init_psd_matrix = eig_vecs @ jnp.diag(eig_vals) @ eig_vecs.T

    return IOParams(
        theta_uu=init_psd_matrix,
        theta_su=jrd.normal(
            su_key, (state_size, action_size))
    )

@jax.jit
def minimizer_action(param: IOParams,
                     state: Float[Array, "S"],
                     box_low: float = -1.0,
                     box_high: float = 1.0,
                     ) -> Float[Array, "A"]:

    qp = jaxopt.BoxCDQP(jit=True)
    action_size = param.theta_uu.shape[0]
    init_action = jnp.zeros(action_size)

    Q_matrix = param.theta_uu
    c_vector = 2 * state @ param.theta_su

    sol = qp.run(init_action,
                 params_obj=(Q_matrix, c_vector),
                 params_ineq=(jnp.ones(action_size) * box_low,
                              jnp.ones(action_size) * box_high)
                 ).params

    return sol


key = jrd.PRNGKey(44)

param = init_params(key, 3, 500)
batch_minimizer = (jax.vmap(minimizer_action, (None, 0)))
batch_minimizer(param, jrd.normal(key, (20, 500,)))
None

In [None]:
jnp.arange(4) @ jnp.ones((4, 5)) @ jnp.arange(5)

In [None]:
%%timeit
batch_minimizer(param, jrd.normal(key, (100, 500)))

In [None]:

state = jnp.ones(4)
action = jnp.ones(5)
theta_su = jnp.ones((4, 5))
state @ (theta_su @ action)

In [None]:
def loss(w, x, y):
    return jnp.linalg.norm(x @ w - y)

param = jnp.zeros(4)
x = jrd.normal(jrd.PRNGKey(42), (100, 4))
y = jrd.normal(jrd.PRNGKey(42), (100,))

opt = jaxopt.PolyakSGD(loss, maxiter=1, max_stepsize=1e-3)
# opt.update(param, )
state = opt.init_state(param, x, y)
param, state = opt.update(param, state, x, y)

In [None]:
param, state = opt.update(param, state, x, y)

In [None]:
jax.vmap(lambda x: x.sum(), 0, 0)(jnp.ones((5, 4)))

In [None]:
key = jrd.PRNGKey(42)

eig_vecs = jrd.orthogonal(key, 4)
eig_vals = jax.nn.softplus(jrd.normal(key, (4,))) + 1
jnp.linalg.eigh(eig_vecs @ jnp.diag(eig_vals) @ eig_vecs.T)[0], eig_vals

In [None]:
from typing import NamedTuple
import jax
import jax.numpy as jnp
import jax.random as jrd
import jaxopt
from jaxtyping import Array, Float

from io_agent.control.jax_io import minimizer_action, init_params, q_fn, IOParams, batch_loss_fn, project_theta_uu


key = jrd.PRNGKey(42)
state_key, action_key, param_key = jrd.split(key, 3)
states = jrd.normal(state_key, (100, 277))
actions = jrd.uniform(action_key, (10000, 100, 3), minval=-1.0, maxval=1.0)

param = init_params(param_key, 3, 277)
param = IOParams(
    theta_uu=param.theta_uu,
    theta_su=param.theta_su * 1
)

min_act = jax.vmap(minimizer_action, (None, 0))(param, states)

analy_min_q = jax.vmap(q_fn, (None, 0, 0))(param, states, min_act)

empiric_min_q = jax.vmap(jax.vmap(q_fn, (None, 0, 0)), (None, None, 0))(param, states, actions).min(0)

jnp.all(analy_min_q <= empiric_min_q)

In [None]:
from io_agent.control.iterative_io import IterativeIOController
from io_agent.utils import load_experiment
import numpy as np
import os
import torch


from io_agent.plant.mujoco import Walker2dEnv, HalfCheetahEnv, HopperEnv


env = HopperEnv()

rng = np.random.default_rng(42)
trial_seeds = rng.integers(0, 2**30, 10)

task_name = env.__class__.__name__.lower()[:-3]
walker_data = load_experiment(os.path.join(
    "./offline_rl/data", task_name, "rich_augmented"))
augmented_dataset = walker_data["augmented_dataset"]
feature_handler = walker_data["feature_handler"]

agent = IterativeIOController(
    constraints=feature_handler.params.constraints,
    feature_handler=feature_handler,
    learning_rate=1.0,
    include_constraints=True,
    action_constraints_flag=True,
    state_constraints_flag=False,
    lr_exp_decay=1.0
)


In [None]:
param = IOParams(
    theta_uu=jnp.array(agent.th_theta_uu.detach().numpy()),
    theta_su=jnp.array(agent.th_theta_su.detach().numpy())
)

In [None]:
from io_agent.plant.mujoco import Walker2dEnv, HalfCheetahEnv, HopperEnv


env = HopperEnv()

env.task_name

In [None]:
import jaxopt

optimizer = jaxopt.GradientDescent(
            # fun=jax.jit(batch_loss_fn),
            fun=jax.jit(batch_loss_fn),
            # opt=optax.adam(learning_rate=self.learning_rate)
            stepsize=1,
            acceleration=False
        )
opt_state = optimizer.init_state(
            jax.tree_util.tree_map(jnp.zeros_like, param),
            jnp.zeros_like(states[:1]),
            jnp.zeros_like(actions[:1]))

In [None]:


th_loss = agent.loss(
    torch.from_numpy(np.array(states)),
    torch.from_numpy(np.array(actions[0])),
)
th_loss.mean().backward()

In [None]:
agent.train_optimizer.step()
agent.project_theta_uu()

In [None]:
jnp.array(agent.th_theta_uu.detach().numpy()) - project_theta_uu(new_param).theta_uu

In [None]:
project_theta_uu(new_param).theta_uu

In [None]:
new_param, opt_state = optimizer.update(
                    param,
                    opt_state,
                    states,
                    actions[0])

In [None]:
new_param.theta_uu, jnp.array(agent.th_theta_uu.detach().numpy())

In [None]:
theta_param = project_theta_uu(param)

In [2]:
from optax import exponential_decay
import jaxopt
import jax.numpy as jnp


schedular = exponential_decay(1e-1, 1, decay_rate=0.99, end_value=1e-5)



In [14]:
jaxopt.AndersonAcceleration

ProxGradState(iter_num=Array(0, dtype=int32, weak_type=True), stepsize=Array(1., dtype=float32), error=Array(inf, dtype=float32), aux=None, velocity=Array([0., 0., 0., 0., 0.], dtype=float32), t=Array(1., dtype=float32, weak_type=True))

In [74]:

def loss(x):
    return (x).mean() 


gd = jaxopt.GradientDescent(
            fun=loss,
            stepsize=schedular
        )
state = gd.init_state(jnp.zeros(5))
_, state = gd.update(jnp.zeros(5), state)
state



ProxGradState(iter_num=Array(1, dtype=int32, weak_type=True), stepsize=Array(0.1, dtype=float32), error=Array(0.44721365, dtype=float32), aux=None, velocity=Array([-0.02, -0.02, -0.02, -0.02, -0.02], dtype=float32), t=Array(1.618034, dtype=float32, weak_type=True))

In [77]:
_, state = gd.update(jnp.zeros(5), state)
state


ProxGradState(iter_num=Array(4, dtype=int32, weak_type=True), stepsize=Array(0.0970299, dtype=float32), error=Array(2.7809048, dtype=float32), aux=None, velocity=Array([-0.18475662, -0.18475662, -0.18475662, -0.18475662, -0.18475662],      dtype=float32), t=Array(3.2948797, dtype=float32, weak_type=True))

In [None]:
sbatch batch.sh 61 5000 cql comparison && \
sbatch batch.sh 62 5000 cql comparison && \
sbatch batch.sh 63 5000 cql comparison && \
sbatch batch.sh 64 5000 cql comparison && \
sbatch batch.sh 61 1000000 cql comparison && \
sbatch batch.sh 62 1000000 cql comparison && \
sbatch batch.sh 63 1000000 cql comparison && \
sbatch batch.sh 64 1000000 cql comparison && \
sbatch batch.sh 61 5000 combo comparison && \
sbatch batch.sh 62 5000 combo comparison && \
sbatch batch.sh 63 5000 combo comparison && \
sbatch batch.sh 64 5000 combo comparison && \
sbatch batch.sh 61 1000000 combo comparison && \
sbatch batch.sh 62 1000000 combo comparison && \
sbatch batch.sh 63 1000000 combo comparison && \
sbatch batch.sh 64 1000000 combo comparison

sbatch batch.sh 61 5000 io lr_decay_0_9625 && \
sbatch batch.sh 62 5000 io lr_decay_0_9625 && \
sbatch batch.sh 63 5000 io lr_decay_0_9625 && \
sbatch batch.sh 64 5000 io lr_decay_0_9625


