# Transfer of dimensionless policies on the Half Cheetah

In [1]:
from pathlib import Path
import pickle

import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
from tqdm.notebook import tqdm
import gymnasium as gym
from stable_baselines3 import SAC
from huggingface_sb3 import load_from_hub

from pipoli.core import Dimension, Context, DimensionalPolicy, ScaledPolicy
from pipoli.sources.sb3 import SB3Policy

from make_cheetah import make_cheetah_xml

## Original Half Cheetah

Follows the definition of the original Half Cheetah context.

In [2]:
base_dimensions = [
    M := Dimension([1, 0, 0]),
    L := Dimension([0, 1, 0]),
    T := Dimension([0, 0, 1]),
]
Unit = Dimension([0, 0, 0])

In [3]:
original_context = Context(
    base_dimensions,
    *zip(
        ("dt", T, 0.01),
        ("m", M, 14),
        ("g", L/T**2, 9.81),
        ("taumax", M*L**2/T**2, 1),
        ("L", L, 0.5),
        ("Lh", L, 0.15),
        ("l0", L, 0.145),
        ("l1", L, 0.15),
        ("l2", L, 0.094),
        ("l3", L, 0.133),
        ("l4", L, 0.106),
        ("l5", L, 0.07),
        ("k0", M*L**2/T**2, 240),
        ("k1", M*L**2/T**2, 180),
        ("k2", M*L**2/T**2, 120),
        ("k3", M*L**2/T**2, 180),
        ("k4", M*L**2/T**2, 120),
        ("k5", M*L**2/T**2, 60),
        ("b0", M*L**2/T, 6),
        ("b1", M*L**2/T, 4.5),
        ("b2", M*L**2/T, 3),
        ("b3", M*L**2/T, 4.5),
        ("b4", M*L**2/T, 3),
        ("b5", M*L**2/T, 1.5),
    )
)
original_cheetah_file = make_cheetah_xml(original_context, "original")

Original policy instantiation.

In [4]:
halfcheetah_v5_sac_expert =  load_from_hub(
    repo_id="farama-minari/HalfCheetah-v5-SAC-expert",
    filename="halfcheetah-v5-sac-expert.zip",
)
model = SAC.load(halfcheetah_v5_sac_expert)

sb3_policy = SB3Policy(
    model,
    model_obs_space=gym.spaces.Box(-np.inf, np.inf, (17,), np.float64),
    model_act_space=gym.spaces.Box(-1.0, 1.0, (6,), np.float32),
    predict_kwargs=dict(deterministic=True)
)

original_policy = DimensionalPolicy(
    sb3_policy,
    original_context,
    obs_dims=[L] + [Unit] * 7 + [L/T] * 2 + [1/T] * 7,
    act_dims=[M*L**2/T**2] * 6
)

In [5]:
# base = ["m", "L", "g"]
# m = original_context.value("m")
# L_ = original_context.value("L")
# g = original_context.value("g")

# m10_context = original_context.scale_to(base, [m * 10, L_, g])
# m10_xml = make_cheetah_xml(m10_context, "m10")

# m01_context = original_context.scale_to(base, [m * .98, L_, g])
# m01_xml = make_cheetah_xml(m01_context, "m01")

# L10_context = original_context.scale_to(base, [m, L_ * 10, g])
# L10_xml = make_cheetah_xml(L10_context, "L10")

# L01_context = original_context.scale_to(base, [m, L_ * .98, g])
# L01_xml = make_cheetah_xml(L01_context, "L01")

# m10_L10_context = original_context.scale_to(base, [m * 2, L_ * 2, g])
# m10_L10_xml = make_cheetah_xml(m10_L10_context, "m10_L10")

# m01_L10_context = original_context.scale_to(base, [m * .98, L_ * 10, g])
# m01_L10_xml = make_cheetah_xml(m01_L10_context, "m01_L10")

# m10_L01_context = original_context.scale_to(base, [m * 10, L_ * .98, g])
# m10_L01_xml = make_cheetah_xml(m10_L01_context, "m10_L01")

# m01_L01_context = original_context.scale_to(base, [m * .98, L_ * .98, g])
# m01_L01_xml = make_cheetah_xml(m01_L01_context, "m01_L01")

## Data generation

### Utilities

Hide this section for better legibility.

In [6]:
def make_context_sweep(context, base, range1, range2, range3, outdir="./output"):
    contexts = {}

    for b1 in range1:
        for b2 in range2:
            for b3 in range3:
                new_context = context.scale_to(base, [b1, b2, b3])
                xml_file = make_cheetah_xml(new_context, f"{b1:.2f}-{b2:.2f}-{b3:.2f}", outdir=outdir)

                contexts[xml_file] = new_context
    
    return contexts

In [7]:
def record_episode(env, policy, nb_steps, close=True):
    observations = np.zeros((nb_steps,) + env.observation_space.shape)
    actions = np.zeros((nb_steps,) + env.action_space.shape)
    infos = [None] * nb_steps
    rewards = np.zeros(nb_steps)

    obs, info = env.reset()
    act = policy.action(obs)

    for i in range(nb_steps):
        observations[i] = obs
        actions[i] = act

        obs, rew, _, trunc, info = env.step(act)
        act = policy.action(obs)

        rewards[i] = rew
        infos[i] = info

        if trunc:
            break
    
    if close:
        env.close()

    return observations, actions, rewards, infos, i, trunc

### Make contexts

In [8]:
base = ["m", "L", "g"]

num = 5
ms = np.logspace(-1, 1, num=num) * original_context.value("m")
Ls = np.logspace(-1, 1, num=num) * original_context.value("L")
gs = [original_context.value("g")]

all_contexts = make_context_sweep(original_context, base, ms, Ls, gs, outdir="output/xml_files")

### Generate data

In [9]:
nb_steps = 1000
df = pd.DataFrame(columns=["m", "L", "g", "context", "observations", "actions", "rewards", "infos", "last_step"])
df.attrs["max_steps"] = nb_steps
df.attrs["base"] = base

# original
m = original_context.value("m")
L_ = original_context.value("L")
g = original_context.value("g")
env = gym.make("HalfCheetah-v5", xml_file=original_cheetah_file, render_mode=None)
observations, actions, rewards, infos, last_step, trunc = record_episode(env, original_policy, nb_steps)
df.loc["original"] = [m, L_, g, original_context, observations, actions, rewards, infos, last_step]

# all others
for xml_file, context in all_contexts.items():
    m = context.value("m")
    L_ = context.value("L")
    g = context.value("g")

    env = gym.make("HalfCheetah-v5", xml_file=xml_file, render_mode=None)
    policy = original_policy.to_scaled(context=context, base=base)
    observations, actions, rewards, infos, last_step, trunc = record_episode(env, policy, nb_steps)

    index = xml_file.split("/")[-1].removesuffix(".xml")
    df.loc[index] = [m, L_, g, context, observations, actions, rewards, infos, last_step]

In [10]:
data_dir = Path("output/data")
data_dir.mkdir(exist_ok=True)

In [11]:

df.to_pickle(data_dir / "all_data.pkl.gz")

## Data analysis

In [12]:
df = pd.read_pickle(data_dir / "all_data.pkl.gz")
base = df.attrs["base"]

### Score vs distance analysis

In [21]:
def adim_reward_forward(context, infos):
    forward_reward_adim, _ = context.make_transforms([L / T], base)

    fwd_rews = np.array([forward_reward_adim(info["reward_forward"]) for info in infos])

    return fwd_rews

def adim_reward_ctrl(context, infos):
    control_cost_adim, _ = context.make_transforms([M*L**2/T**2], base)
    ctl_costs = np.array([control_cost_adim(info["reward_ctrl"]) for info in infos])

    return ctl_costs

Building the score dataframe.

In [51]:
score_df = df.sort_values(["m", "L"])[["m", "L", "context", "infos"]]

In [52]:
score_df["adimensional_distance_to_original"] = score_df["context"].map(lambda c: c.adimensional_distance(original_context, base))
score_df["cosine_similarity_to_original"] = score_df["context"].map(lambda c: c.cosine_similarity(original_context))
score_df["rewards_forward"] = score_df.apply(lambda row: adim_reward_forward(row["context"], row["infos"]), axis=1)
score_df["rewards_ctrl"] = score_df.apply(lambda row: adim_reward_ctrl(row["context"], row["infos"]), axis=1)
score_df["rewards_total_per_step"] = score_df.apply(lambda row: row["rewards_forward"] + row["rewards_ctrl"], axis=1)
score_df["reward_total"] = score_df["rewards_total_per_step"].map(np.sum)
score_df["mean_reward_per_step"] = score_df["rewards_total_per_step"].map(np.mean)
score_df["std_reward_per_step"] = score_df["rewards_total_per_step"].map(np.std)
score_df["norm_std_reward_per_step"] = score_df["rewards_total_per_step"].map(lambda r: np.std(r) / np.abs(np.mean(r)))

In [53]:
final_score_df = score_df.drop(columns=["context", "infos", "rewards_forward", "rewards_ctrl", "rewards_total_per_step"])
final_score_df

Unnamed: 0,m,L,adimensional_distance_to_original,cosine_similarity_to_original,reward_total,mean_reward_per_step,std_reward_per_step,norm_std_reward_per_step
cheetah-1.40-0.05-9.81,1.4,0.05,0.0,0.396211,2.012858,0.002013,0.052626,26.145055
cheetah-1.40-0.16-9.81,1.4,0.158114,5.737111e-16,0.799913,0.912959,0.000913,0.03665,40.144385
cheetah-1.40-0.50-9.81,1.4,0.5,6.938894000000001e-17,0.975981,26.611935,0.026612,0.14272,5.363016
cheetah-1.40-1.58-9.81,1.4,1.581139,7.850462000000001e-17,0.997983,1.603383,0.001603,0.038448,23.979423
cheetah-1.40-5.00-9.81,1.4,5.0,1.144392e-16,0.997945,14.692877,0.014693,0.10873,7.400165
cheetah-4.43-0.05-9.81,4.427189,0.05,1.779831e-15,0.779898,-1.6333,-0.001633,0.033404,20.451915
cheetah-4.43-0.16-9.81,4.427189,0.158114,8.470017e-16,0.973443,-0.315135,-0.000315,0.033862,107.452518
cheetah-4.43-0.50-9.81,4.427189,0.5,8.418701e-16,0.998553,2239.932448,2.239932,1.845512,0.823914
cheetah-4.43-1.58-9.81,4.427189,1.581139,1.716588e-16,0.999511,9.344592,0.009345,0.109107,11.67596
cheetah-4.43-5.00-9.81,4.427189,5.0,1.280147e-15,0.997887,24.625685,0.024626,0.149459,6.069242


## Study of the domain of different contexts

In [None]:
# data generation

nb_steps = 1000
observations = {}
actions = {}

for xml, context in tqdm(scaled_contexts.items()):
    observations[xml] = np.zeros((nb_steps, 17))
    actions[xml] = np.zeros((nb_steps, 6))

    policy = original_policy.to_scaled(context, base)

    env = gym.make("HalfCheetah-v5", xml_file=xml, render_mode=None)
    obs, _ = env.reset()
    trunc = False
    for i in range(1000):
        act = policy.action(obs)
        obs, _, _, trunc, _ = env.step(act)
        
        observations[xml][i] = obs
        actions[xml][i] = act

        if trunc:
            break

    env.close()
    
    observations[xml] = np.array(observations[xml])
    actions[xml] = np.array(actions[xml])

In [None]:
plt.figure()
plt.violinplot(observations["/home/fp/Dev/transfert_halfcheetah/output/cheetah-10-1.12.xml"])
plt.xticks([_+1 for _ in range(17)],
           labels=["$z$", "$\\phi$", "$\\theta_0$", "$\\theta_1$", "$\\theta_2$", "$\\theta_3$", "$\\theta_4$", "$\\theta_5$",
           "$\\dot{x}$", "$\\dot{z}$", "$\\dot\\phi$", "$\\dot\\theta_0$", "$\\dot\\theta_1$", "$\\dot\\theta_2$", "$\\dot\\theta_3$", "$\\dot\\theta_4$", "$\\dot\\theta_5$"
            ]);

In [None]:
plt.figure()
plt.violinplot(actions["/home/fp/Dev/transfert_halfcheetah/output/cheetah-10-1.12.xml"])
plt.xticks([_+1 for _ in range(6)],
           labels=["$\\tau_0$", "$\\tau_1$", "$\\tau_2$", "$\\tau_3$", "$\\tau_4$", "$\\tau_5$"]);

## Performance heatmap

In [86]:
# env = gym.make("HalfCheetah-v5", xml_file=original_cheetah_file, render_mode="human")
test_context = original_context
test_context_xml = original_cheetah_file

test_context = m10_L10_context
test_context_xml = m10_L10_xml


env = gym.make("HalfCheetah-v5", xml_file=test_context_xml)

obss, acts, rews, infos, _, _ = record_episode(env, original_policy, 500)
# obss, acts, rews, infos, _, _ = record_episode(env, original_policy.to_scaled(test_context, base), 500)

In [None]:
forward_reward_adim, _ = test_context.make_transforms([L / T], base)
control_cost_adim, _ = test_context.make_transforms([M*L**2/T**2], base)

fwd_rews = np.array([forward_reward_adim(info["reward_forward"]) for info in infos]).sum()
ctl_costs = np.array([control_cost_adim(info["reward_ctrl"]) for info in infos]).sum()

fwd_rews, ctl_costs, fwd_rews+ctl_costs