# The transfer study

## Init

In [None]:
from itertools import product
from pathlib import Path

import numpy as np
import pandas as pd
import gymnasium as gym
from sb3_contrib import TQC
from huggingface_sb3 import load_from_hub

from matplotlib import pyplot as plt
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats("pdf", "svg")

from pipoli.core import DimensionalPolicy, Dimension, Context, Policy
from pipoli.sources.sb3 import SB3Policy

from make_cheetah import make_cheetah, make_cheetah_xml

In [2]:
ROOT = Path() / "output"
XML_FILES = ROOT / "xml_files"
DATA = ROOT / "data"

In [9]:
class NaiveObsScaledActPolicy(Policy):

    def __init__(self, dim_pol, context, base):
        self.dim_pol = dim_pol
        self.context = context
        self.base = base
        act_dims = dim_pol.act_dims
        self.orig_to_adim, _ = dim_pol.context.make_transforms(act_dims, base)
        _, self.adim_to_scale = context.make_transforms(act_dims, base)
    
    def action(self, obs):
        """Scale only the action from a naive observation.
        
        Usually, the actuators of a system are sized accordingly. It would be
        unfair to not allow the policy to use the full range or allow it to use to much.
        """
        orig_act = self.dim_pol.action(obs)
        adim_act = self.orig_to_adim(orig_act)
        scaled_act = self.adim_to_scale(adim_act)
        return scaled_act

In [21]:
def evaluate_policy(context, xml_file, policy, relax_steps=100, max_steps=1000):
    forward_weight = context.value("forward_reward_weight")
    ctrl_weight = context.value("ctrl_cost_weight")
    env = gym.make(
        "HalfCheetah-v5",
        xml_file=xml_file,
        forward_reward_weight=forward_weight,
        ctrl_cost_weight=ctrl_weight,
        reset_noise_scale=0,
    )

    rewards = np.zeros(max_steps)
    fwd_rewards = np.zeros(max_steps)
    ctrl_rewards = np.zeros(max_steps)

    trunc = False
    step = 0

    obs, info = env.reset()

    while (not trunc) and step < max_steps:

        if step < relax_steps:
            act = np.zeros(6)
        else:
            act = policy.action(obs)

        obs, rew, _, trunc, info = env.step(act)

        rewards[step] = rew
        fwd_rewards[step] = info["reward_forward"]
        ctrl_rewards[step] = info["reward_ctrl"]

        step += 1

    env.close()

    is_flipped = abs(obs[1]) > np.pi / 1.8

    return rewards.sum(), fwd_rewards.sum(), ctrl_rewards.sum(), int(is_flipped)

In [None]:
def heatmap(df, x, y, C, /, title=None, xlabel=None, ylabel=None, zlabel=None, xscale="linear", yscale="linear", **kwargs):
    xs = np.array(df[x])
    ys = np.array(df[y])
    Cs = np.array(df[C])

    Nx = np.sqrt(xs.size).astype(int)
    Ny = np.sqrt(ys.size).astype(int)
    X = xs.reshape((Nx, Ny))
    Y = ys.reshape((Nx, Ny))
    Z = Cs.reshape((Nx, Ny))

    plt.figure()
    plt.pcolormesh(X, Y, Z, **kwargs)
    plt.xscale(xscale)
    plt.yscale(yscale)
    plt.title(title or "")
    plt.xlabel(xlabel or x)
    plt.ylabel(ylabel or y)
    plt.colorbar(label=zlabel or C)

## Parameters

In [None]:
BASE_DIMENSIONS = [
    M := Dimension([1, 0, 0]),
    L := Dimension([0, 1, 0]),
    T := Dimension([0, 0, 1]),
]
Unit = Dimension([0, 0, 0])

In [None]:
original_context = Context.from_quantities(
    dt = 0.01                   | T,
    m = 14                      | M,
    g = 9.81                    | L/T**2,
    taumax = 1                  | M*L**2/T**2,
    d = 0.046                   | L,
    L = 0.5                     | L,
    Lh = 0.15                   | L,
    l0 = 0.145                  | L,
    l1 = 0.15                   | L,
    l2 = 0.094                  | L,
    l3 = 0.133                  | L,
    l4 = 0.106                  | L,
    l5 = 0.07                   | L,
    k0 = 240                    | M*L**2/T**2,
    k1 = 180                    | M*L**2/T**2,
    k2 = 120                    | M*L**2/T**2,
    k3 = 180                    | M*L**2/T**2,
    k4 = 120                    | M*L**2/T**2,
    k5 = 60                     | M*L**2/T**2,
    b0 = 6                      | M*L**2/T,
    b1 = 4.5                    | M*L**2/T,
    b2 = 3                      | M*L**2/T,
    b3 = 4.5                    | M*L**2/T,
    b4 = 3                      | M*L**2/T,
    b5 = 1.5                    | M*L**2/T,
    armature = 0.1              | M*L**2,
    damping = 0.01              | M*L**2/T,
    stiffness = 8               | M*L**2/T**2,
    forward_reward_weight = 1   | T/L,
    ctrl_cost_weight = 0.1      | T**4/M**2/L**4,
)

In [None]:
halfcheetah_v5_tqc_expert =  load_from_hub(
    repo_id="farama-minari/HalfCheetah-v5-TQC-expert",
    filename="halfcheetah-v5-TQC-expert.zip",
)
model = TQC.load(halfcheetah_v5_tqc_expert, device="cpu")

sb3_policy = SB3Policy(
    model,
    model_obs_space=gym.spaces.Box(-np.inf, np.inf, (17,), np.float64),
    model_act_space=gym.spaces.Box(-1.0, 1.0, (6,), np.float32),
    predict_kwargs=dict(deterministic=True)
)

original_policy = DimensionalPolicy(
    sb3_policy,
    original_context,
    obs_dims=[L] + [Unit] * 7 + [L/T] * 2 + [1/T] * 7,
    act_dims=[M*L**2/T**2] * 6
)

Exception: code expected at most 16 arguments, got 18


In [None]:
base = ["m", "L", "g"]

max_scale = 4
min_scale = 1 / 4
num_1 = 5
num_2 = 5
num_3 = 5

# ranges for similar contexts
b1_range = np.linspace(min_scale, max_scale, num=num_1) * original_context.value(base[0])
b2_range = np.linspace(min_scale, max_scale, num=num_2) * original_context.value(base[1])
b3_range = np.linspace(min_scale, max_scale, num=num_3) * original_context.value(base[2])

base_orig = np.array([original_context.value(b) for b in base])

# find the maximum value for each value scaled up and down
twice = original_context.scale_to(base, base_orig * max_scale).values
half = original_context.scale_to(base, base_orig * min_scale).values

# compute scale factor
upper = twice / original_context.values
upper[np.isnan(upper)] = 0
lower = half / original_context.values
lower[np.isnan(lower)] = 0

# make sure all the biggest are upper and the smallest lower
to_switch = upper < lower
to_lower = upper[to_switch]
to_upper = lower[to_switch]

# bounds factors for random contexts around the original context
upper[to_switch] = to_upper
lower[to_switch] = to_lower

In [17]:
common_metadata = dict(
    base=base,
    original_context=original_context,
    original_policy=dict(
        repo_id="farama-minari/HalfCheetah-v5-TQC-expert",
        filename="halfcheetah-v5-TQC-expert.zip",
    ),
    env="HalfCheetah-v5",
)

## Similar transfer

### Data generation

In [46]:
all_similar_contexts = []

for b1, b2, b3 in product(b1_range, b2_range, b3_range):
    context = original_context.scale_to(base, [b1, b2, b3])
    xml = make_cheetah_xml(context, f"{b1}-{b2}-{b3}", outdir=XML_FILES)
    all_similar_contexts.append((context, xml))

In [47]:
similar_df = pd.DataFrame(
    columns=[
        "context", "xml", "distance",
        "scaled_reward", "scaled_fwd_reward", "scaled_ctrl_reward", "scaled_is_flipped",
        "naive_reward", "naive_fwd_reward", "naive_ctrl_reward", "naive_is_flipped",
        "semi_scaled_reward", "semi_scaled_fwd_reward", "semi_scaled_ctrl_reward", "semi_scaled_is_flipped"
    ]
)

for i, (context, xml) in enumerate(all_similar_contexts):
    distance = context.adimensional_distance(original_context, base)

    scaled_policy = original_policy.to_scaled(context, base)
    scaled_data = evaluate_policy(context, xml, scaled_policy)

    naive_policy = original_policy
    naive_data = evaluate_policy(context, xml, naive_policy)

    semi_scaled_policy = NaiveObsScaledActPolicy(original_policy, context, base)
    semi_scaled_data = evaluate_policy(context, xml, semi_scaled_policy)

    similar_df.loc[i] = (context, Path(xml).read_text(), distance) + scaled_data + naive_data + semi_scaled_data


In [48]:
similar_df

Unnamed: 0,context,xml,distance,scaled_reward,scaled_fwd_reward,scaled_ctrl_reward,scaled_is_flipped,naive_reward,naive_fwd_reward,naive_ctrl_reward,naive_is_flipped,semi_scaled_reward,semi_scaled_fwd_reward,semi_scaled_ctrl_reward,semi_scaled_is_flipped
0,"((L, Dimension([0 1 0]), 0.125), (Lh, Dimensio...",<!-- Generated Cheetah Model\n\n The state ...,0.000000e+00,14589.881049,14913.034644,-323.153595,0,-84730.601743,56.502847,-84787.104590,1,181.719942,207.286515,-25.566573,1
1,"((L, Dimension([0 1 0]), 0.125), (Lh, Dimensio...",<!-- Generated Cheetah Model\n\n The state ...,5.684342e-14,14589.881293,14913.034887,-323.153594,0,-67176.086256,776.984597,-67953.070853,0,594.833498,753.765985,-158.932487,1
2,"((L, Dimension([0 1 0]), 0.125), (Lh, Dimensio...",<!-- Generated Cheetah Model\n\n The state ...,1.398579e-17,14589.884067,14913.037650,-323.153583,0,-13367.430151,341.522363,-13708.952514,1,55.930318,124.647835,-68.717517,1
3,"((L, Dimension([0 1 0]), 0.125), (Lh, Dimensio...",<!-- Generated Cheetah Model\n\n The state ...,1.705303e-13,14589.884080,14913.037670,-323.153590,0,-10784.314681,255.702187,-11040.016868,0,129.918681,507.424494,-377.505813,0
4,"((L, Dimension([0 1 0]), 0.125), (Lh, Dimensio...",<!-- Generated Cheetah Model\n\n The state ...,0.000000e+00,14589.881049,14913.034644,-323.153595,0,-4246.847854,109.965864,-4356.813718,1,273.313027,650.054725,-376.741699,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
120,"((L, Dimension([0 1 0]), 2.0), (Lh, Dimension(...",<!-- Generated Cheetah Model\n\n The state ...,0.000000e+00,14586.999672,14910.166369,-323.166698,0,-28.189621,-1.188842,-27.000778,0,-420.800878,-4.472122,-416.328757,0
121,"((L, Dimension([0 1 0]), 2.0), (Lh, Dimension(...",<!-- Generated Cheetah Model\n\n The state ...,5.684342e-14,14586.996221,14910.162939,-323.166717,0,-1.578389,-0.387727,-1.190662,0,-519.296550,1.266116,-520.562666,0
122,"((L, Dimension([0 1 0]), 2.0), (Lh, Dimension(...",<!-- Generated Cheetah Model\n\n The state ...,1.398579e-17,14587.002104,14910.168789,-323.166685,0,-0.601586,-0.228192,-0.373394,0,-7.224159,443.170146,-450.394305,0
123,"((L, Dimension([0 1 0]), 2.0), (Lh, Dimension(...",<!-- Generated Cheetah Model\n\n The state ...,1.705303e-13,14586.996637,14910.163342,-323.166705,0,-0.346575,-0.166509,-0.180067,0,-148.139042,272.930877,-421.069919,0


### Analysis

## Quasi similar transfer

### Data generation

In [49]:
all_quasi_similar_contexts = []

for b1, b2, b3 in product(b1_range, b2_range, b3_range):
    context = original_context.scale_to(base, [b1, b2, b3]).change(L=0.5)
    xml = make_cheetah_xml(context, f"{b1}-{b2}-{b3}", outdir=XML_FILES)
    all_quasi_similar_contexts.append((context, xml))

In [50]:
quasi_similar_df = pd.DataFrame(
    columns=[
        "context", "xml", "distance",
        "scaled_reward", "scaled_fwd_reward", "scaled_ctrl_reward", "scaled_is_flipped",
        "naive_reward", "naive_fwd_reward", "naive_ctrl_reward", "naive_is_flipped",
        "semi_scaled_reward", "semi_scaled_fwd_reward", "semi_scaled_ctrl_reward", "semi_scaled_is_flipped"
    ]
)

for i, (context, xml) in enumerate(all_quasi_similar_contexts):
    distance = context.adimensional_distance(original_context, base)

    scaled_policy = original_policy.to_scaled(context, base)
    scaled_data = evaluate_policy(context, xml, scaled_policy)

    naive_policy = original_policy
    naive_data = evaluate_policy(context, xml, naive_policy)

    semi_scaled_policy = NaiveObsScaledActPolicy(original_policy, context, base)
    semi_scaled_data = evaluate_policy(context, xml, semi_scaled_policy)

    quasi_similar_df.loc[i] = (context, Path(xml).read_text(), distance) + scaled_data + naive_data + semi_scaled_data



In [51]:
quasi_similar_df

Unnamed: 0,context,xml,distance,scaled_reward,scaled_fwd_reward,scaled_ctrl_reward,scaled_is_flipped,naive_reward,naive_fwd_reward,naive_ctrl_reward,naive_is_flipped,semi_scaled_reward,semi_scaled_fwd_reward,semi_scaled_ctrl_reward,semi_scaled_is_flipped
0,"((L, Dimension([0 1 0]), 0.5), (Lh, Dimension(...",<!-- Generated Cheetah Model\n\n The state ...,7073.355042,-6150.438703,-204.854643,-5945.584060,0,-1.634219e+06,-128.264667,-1.634091e+06,0,-6501.999393,-91.533427,-6410.465966,0
1,"((L, Dimension([0 1 0]), 0.5), (Lh, Dimension(...",<!-- Generated Cheetah Model\n\n The state ...,7073.355042,-6117.987314,-154.559388,-5963.427926,0,-6.637537e+04,-241.829110,-6.613355e+04,0,-6114.592791,-212.325991,-5902.266800,0
2,"((L, Dimension([0 1 0]), 0.5), (Lh, Dimension(...",<!-- Generated Cheetah Model\n\n The state ...,7073.355042,-6111.088569,-223.753511,-5887.335058,0,-2.068140e+04,-248.540526,-2.043286e+04,0,-5915.071952,-237.336258,-5677.735695,0
3,"((L, Dimension([0 1 0]), 0.5), (Lh, Dimension(...",<!-- Generated Cheetah Model\n\n The state ...,7073.355042,-6151.565478,-233.679372,-5917.886107,0,-1.011164e+04,-190.281621,-9.921354e+03,0,-6119.116144,-409.465410,-5709.650734,0
4,"((L, Dimension([0 1 0]), 0.5), (Lh, Dimension(...",<!-- Generated Cheetah Model\n\n The state ...,7073.355042,-6150.438703,-204.854643,-5945.584060,0,-6.703667e+03,-669.251766,-6.034416e+03,0,-6343.923049,-371.149278,-5972.773771,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
120,"((L, Dimension([0 1 0]), 0.5), (Lh, Dimension(...",<!-- Generated Cheetah Model\n\n The state ...,442.447145,26.403781,27.218977,-0.815196,1,2.640368e+01,27.219007,-8.153282e-01,1,26.403679,27.219007,-0.815328,1
121,"((L, Dimension([0 1 0]), 0.5), (Lh, Dimension(...",<!-- Generated Cheetah Model\n\n The state ...,442.447145,26.403781,27.218977,-0.815196,1,2.724775e+01,27.282139,-3.439216e-02,1,26.403800,27.218972,-0.815172,1
122,"((L, Dimension([0 1 0]), 0.5), (Lh, Dimension(...",<!-- Generated Cheetah Model\n\n The state ...,442.447145,26.403781,27.218977,-0.815196,1,2.727863e+01,27.289306,-1.067594e-02,1,26.403871,27.218948,-0.815077,1
123,"((L, Dimension([0 1 0]), 0.5), (Lh, Dimension(...",<!-- Generated Cheetah Model\n\n The state ...,442.447145,26.403781,27.218977,-0.815196,1,2.728694e+01,27.292064,-5.128261e-03,1,26.403930,27.218934,-0.815004,1


## Non similar transfer

### Data generation

In [52]:
all_non_similar_contexts = []

for _ in range(125):
    context = original_context.sample_around(lower, upper)
    b1 = context.value(base[0])
    b2 = context.value(base[1])
    b3 = context.value(base[2])
    xml = make_cheetah_xml(context, f"{b1}-{b2}-{b3}", outdir=XML_FILES)
    all_non_similar_contexts.append((context, xml))

In [53]:
non_similar_df = pd.DataFrame(
    columns=[
        "context", "xml", "distance",
        "scaled_reward", "scaled_fwd_reward", "scaled_ctrl_reward", "scaled_is_flipped",
        "naive_reward", "naive_fwd_reward", "naive_ctrl_reward", "naive_is_flipped",
        "semi_scaled_reward", "semi_scaled_fwd_reward", "semi_scaled_ctrl_reward", "semi_scaled_is_flipped"
    ]
)

for i, (context, xml) in enumerate(all_non_similar_contexts):
    distance = context.adimensional_distance(original_context, base)

    scaled_policy = original_policy.to_scaled(context, base)
    scaled_data = evaluate_policy(context, xml, scaled_policy)

    naive_policy = original_policy
    naive_data = evaluate_policy(context, xml, naive_policy)

    semi_scaled_policy = NaiveObsScaledActPolicy(original_policy, context, base)
    semi_scaled_data = evaluate_policy(context, xml, semi_scaled_policy)

    non_similar_df.loc[i] = (context, Path(xml).read_text(), distance) + scaled_data + naive_data + semi_scaled_data



In [54]:
non_similar_df

Unnamed: 0,context,xml,distance,scaled_reward,scaled_fwd_reward,scaled_ctrl_reward,scaled_is_flipped,naive_reward,naive_fwd_reward,naive_ctrl_reward,naive_is_flipped,semi_scaled_reward,semi_scaled_fwd_reward,semi_scaled_ctrl_reward,semi_scaled_is_flipped
0,"((L, Dimension([0 1 0]), 1.7949332136854248), ...",<!-- Generated Cheetah Model\n\n The state ...,4.621409e+08,-3.082366e+08,63.298626,-3.082367e+08,0,-2.095418e+06,-6.361503,-2.095412e+06,0,-5.274399e+08,23.232175,-5.274399e+08,0
1,"((L, Dimension([0 1 0]), 1.8123922567171347), ...",<!-- Generated Cheetah Model\n\n The state ...,1.638706e+06,-1.224613e+06,10.077807,-1.224623e+06,0,-4.749346e+05,9.138269,-4.749438e+05,0,-1.865862e+06,8.314358,-1.865870e+06,0
2,"((L, Dimension([0 1 0]), 1.4208666614039984), ...",<!-- Generated Cheetah Model\n\n The state ...,9.114128e+07,-6.633835e+07,-776.381586,-6.633758e+07,0,-8.583364e+05,-0.993675,-8.583354e+05,0,-1.042933e+08,-4.229377,-1.042933e+08,0
3,"((L, Dimension([0 1 0]), 0.8691978851703743), ...",<!-- Generated Cheetah Model\n\n The state ...,1.040158e+08,-6.817916e+07,229.626464,-6.817939e+07,0,-1.378372e+06,-1.480340,-1.378371e+06,0,-9.907907e+07,11.429484,-9.907908e+07,0
4,"((L, Dimension([0 1 0]), 1.2185748718984508), ...",<!-- Generated Cheetah Model\n\n The state ...,4.132902e+07,-1.612057e+07,-788.092874,-1.611978e+07,0,-6.234018e+05,60.123096,-6.234619e+05,0,-3.507829e+07,21.629150,-3.507831e+07,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
120,"((L, Dimension([0 1 0]), 0.712413554490486), (...",<!-- Generated Cheetah Model\n\n The state ...,3.614963e+08,-2.337059e+08,184.102454,-2.337061e+08,0,-1.085611e+06,40.036565,-1.085651e+06,0,-2.494381e+08,93.124519,-2.494382e+08,0
121,"((L, Dimension([0 1 0]), 0.569451612883271), (...",<!-- Generated Cheetah Model\n\n The state ...,7.799201e+07,-8.931261e+07,-80.796020,-8.931253e+07,1,-4.859437e+05,-80.794002,-4.858629e+05,1,-8.931261e+07,-80.796020,-8.931253e+07,1
122,"((L, Dimension([0 1 0]), 1.5049231621309906), ...",<!-- Generated Cheetah Model\n\n The state ...,6.761024e+08,-4.224404e+08,65.605292,-4.224404e+08,0,-1.678870e+06,-2.001587,-1.678868e+06,0,-6.879003e+08,-4.418256,-6.879003e+08,0
123,"((L, Dimension([0 1 0]), 1.1671443152026642), ...",<!-- Generated Cheetah Model\n\n The state ...,8.461367e+07,-6.751008e+07,-2.369061,-6.751008e+07,0,-9.509550e+05,-2.841510,-9.509521e+05,0,-8.430335e+07,11.050075,-8.430336e+07,0
