- - -
## Data Loading & Augmentation
- - -

In [None]:
import os
from io_agent.runner.iterative import augment_mujoco_dataset
from io_agent.plant.mujoco import Walker2dEnv

save_dir = "./walker_data/dataset"
file_name = "walker_rich_augmented"

if not os.path.exists(os.path.join(save_dir, file_name)):
    env = Walker2dEnv()
    augment_mujoco_dataset(
        env=env,
        save_dir=save_dir,
        file_name=file_name,
    )


- - -
## Iterative IO Controller Training
- - -

In [None]:
import datetime
import numpy as np
import torch
import multiprocessing

from io_agent.plant.mujoco import Walker2dEnv

from io_agent.utils import load_experiment
from io_agent.runner.basic import run_agent
from io_agent.runner.iterative import run_iterative_io, IterativeIOArgs


walker_data = load_experiment("./walker_data/dataset/walker_rich_augmented")
augmented_dataset = walker_data["augmented_dataset"]
feature_handler = walker_data["feature_handler"]

env = Walker2dEnv()

n_cpu = multiprocessing.cpu_count()
n_trials = 10
general_seed = 44
seed_rng = np.random.default_rng(general_seed)
trial_seeds = seed_rng.integers(0, 2**30, n_trials)
device = "cuda" if torch.cuda.is_available() else "cpu"
timestamp = datetime.datetime.now().strftime("%y-%m%d-%H%M%S")
model_rng = np.random.default_rng(seed_rng.integers(0, 2*30))


experiment_args = {
    "Iterative-IO-1e4": IterativeIOArgs(
        lr_exp_decay=0.9975,
        learning_rate=9e-3,
        n_batch=64,
        data_size=int(1e4),
        eval_epochs=tuple(range(0, 1601, 40))),
}

results = {}
for key, args in experiment_args.items():
    costs, epoch_losses, step_losses, iterative_io_agent = run_iterative_io(
        args=args,
        feature_handler=feature_handler,
        augmented_dataset=augmented_dataset,
        env=env,
        rng=model_rng,
        trial_seeds=trial_seeds,
        name=key,
        save_dir=f"/mnt/DEPO/tok/sl-to-rl/walker/models/ablation/{timestamp}",
        log_dir=f"./walker_data/logs_{timestamp}",
        device=device,
        verbose=True)
    results[key] = (costs, epoch_losses, step_losses)

In [None]:
import numpy as np
import matplotlib.ticker as tck

from utils import steady_state_cost, load_experiment
from collections import defaultdict
from plotter import histogram_figure, histogram_figure_plt, tube_figure_plt


fig, axes = tube_figure_plt(
    cost_data={key: {epoch: [env.env.get_normalized_score(item) * 100 for item in scores]
                     for epoch, scores in value[0].items()}
               for key, value in results.items()},
    title=f"",
    log_xaxis=True,
    log_yaxis=False,
    x_label="epoch",
    y_label="episodic score (%)",
    percentiles=(20, 80)
)

fig, axes = tube_figure_plt(
    cost_data={key: {index + 1: value for index,
                     value in enumerate(value[1])} for key, value in results.items()},
    title=f"",
    log_xaxis=True,
    log_yaxis=True,
    x_label="epoch",
    y_label="sub loss",
    percentiles=(20, 80)
)

fig, axes = tube_figure_plt(
    cost_data={key: {index + 1: value for index,
                     value in enumerate(value[2])} for key, value in results.items()},
    title=f"",
    log_xaxis=True,
    log_yaxis=True,
    x_label="gradient step",
    y_label="batch sub loss",
    percentiles=(20, 80)
)