### Setup

In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import notebook_setup
from copy import deepcopy
import warnings, shutil, os, pickle
from types import SimpleNamespace
from tqdm.auto import tqdm, trange
import control
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
import pandas as pd
import gym
import torch
import torch.nn as nn
from torch.autograd.functional import jacobian, hessian
from commonml.helpers.logs import get_tensorboard_scalar_frame
from commonml.rl.ppo import returns
from commonml.stats.agg import mean_std

from systems.base import SystemEnv
from systems.plotting import (
    plot_env_response,
    multiple_response_plots
)
from systems.springmass import create_springmass, SpringMassEnv
from systems.lunarlander import LanderEnv

from rl import learn_rl, transform_rl_policy, evaluate_rl, load_agent, save_agent
from xform import (
    policy_transform, action_transform, get_transforms,
    pseudo_matrix, pseudo_matrix_from_data, ab_xform_from_pseudo_matrix,
    get_env_samples,
    dpolicy_dfa, dpolicy_dfb, err_inv
)

## System specification

### SpringMass

In [None]:
class NonLinearMatrix:
    def __init__(self, orig, func):
        self.orig = orig
        self.func = func
    def __matmul__(self, *args):
        x=args[0]
        new_matrix = self.func(self.orig.copy(), x)
        return new_matrix @ x
def make_xform(fraction, position):
    def xform(A, x):
        if abs(x[0]) <= position:
            A[1,0] *= fraction**(1-abs(x[0]-position)/position)
        return A
    return xform

In [None]:
sys_kwargs = dict(k=4, m=0.2, df=0.01)
learn_kwargs = dict(steps=50_000, seed=0, learning_rate=2e-3,
                    n_steps=2048, batch_size=64, n_epochs=10,
                    gamma=0.)
Q, R = np.asarray([[1,0], [0,1]]), np.asarray([[0.00001]])
angA, angB = np.pi/4, np.pi
scalarA, scalarB = 0.8, 0.5
xformA = lambda t: np.asarray([[np.cos(t*angA), -np.sin(t*angA)],
                               [np.sin(t*angA), np.cos(t*angA)]]).T \
                    @ ((1-t)*np.eye(2) + t*scalarA*np.eye(2))
xformB = lambda t: np.asarray([[np.cos(t*angB), -np.sin(t*angB)],
                               [np.sin(t*angB), np.cos(t*angB)]]).T \
                    @ ((1-t)*np.eye(2) + t*scalarB*np.eye(2))
x0 = np.asarray([-0.5, 0], np.float32)
make_env = lambda: SpringMassEnv(**sys_kwargs, q=Q, r=R, seed=0)
def make_xform_env(t):
    env = make_env()
    env.system.A = xformA(t) @ env.system.A
    env.system.B = xformB(t) @ env.system.B
    return env
def make_xform_env(t):
    env = make_env()
    # env.system.A = NonLinearMatrix(env.system.A,
    #                                make_xform(1-t, 1.))
    env.system.A = xformA(t) @ env.system.A
    env.system.B = xformB(t) @ env.system.B
    return env

env = make_env()
sys = create_springmass(**sys_kwargs)
interval = env.period * 25
env_spr = make_env()

In [None]:
make_xform_env(0.1).system.A

In [None]:
%matplotlib inline
plot_env_response(make_xform_env(0.9), np.asarray([0.4, 0]), agent)

### LanderEnv

In [None]:
def make_env():
    env = LanderEnv()
    env.reset()
    return env
def make_xform_env(t):
    min_power = 0.75
    max_power = 1.
    side = 1
    relative_power = np.ones(2, np.float32)
    relative_power[side] = min_power + (1-t) * (max_power - min_power)
    env = LanderEnv(relative_power=relative_power)
    env.reset()
    return env
learn_kwargs = dict(steps=500_000, seed=0, learning_rate=2e-3,
                    n_steps=4096, batch_size=256, n_epochs=20,
                    gamma=0.99)
interval = 30
x0 = None
N = 6
env_lander = make_env()

In [None]:
env = make_xform_env(0.)
plot_env_response(env, None, agent,
                 state_idx=(0,1,4), state_names='xya', max_steps=1000)

### LunarLander

In [None]:
from systems.lunarlander import LunarLanderEnv
learn_kwargs = dict(steps=100_000, seed=0, learning_rate=5e-3,
                    n_steps=1024, batch_size=256, n_epochs=5,
                    gamma=0.99)
def make_env():
    env = LunarLanderEnv(seed=0)
    env.reset()
    return env
def make_xform_env(t):
    env = make_env()
    min_power = 0.75
    max_power = 1.
    side = 1
    relative_power = np.ones(2, np.float32)
    env.relative_power[side] = min_power + (1-t) * (max_power - min_power)
    env.reset()
    return env
x0 = None
interval = 500
env_lunar = make_env()

In [None]:
env = make_xform_env(0.)
plot_env_response(env, None, agent,
                 state_idx=(0,1,4), state_names='xya')

# Conditional Policy Adaptation

## Functions

### Utils

In [None]:
def sort_by(arr, by):
    res = [arr[i] for i in np.argsort(by)]
    if isinstance(arr, np.ndarray):
        return np.asarray(res, dtype=arr.dtype)
    return res

def save_res(res, dest):
    assert isinstance(res, SimpleNamespace), 'Must be SimpleNamespace'
    try:
        agents = res.agents
        res.agents = []
        with open(dest, 'wb') as f:
            pickle.dump(res, f)
    finally:
        res.agents = agents
def load_res(dest):
    with open(dest, 'rb') as f:
        res = pickle.load(f)
    return res
def evaluate_trajectory(agent, env, x0=None, max_steps=np.inf):
    if x0 is not None:
        try:
            env.reset(x0)
        except:
            print('Couldnt reset to x0')
    else:
        x0 = env.reset()
    states, rewards, dones = [x0], [], [False]
    step = 0
    while not dones[-1] and step <= max_steps:
        step += 1
        action = agent.predict(states[-1])[0]
        state, reward, done, *_ = env.step(action)
        states.append(state)
        rewards.append(reward)
        dones.append(done)
    rewards.append(0.)
    states = np.asarray(states, np.float32)
    with torch.no_grad():
        states_ = torch.from_numpy(states).to(agent.policy.device)
        values = agent.policy.forward(states_)[1].detach().cpu().numpy().squeeze()
    ret = np.asarray(returns(rewards, dones, learn_kwargs['gamma'], truncate=False))
    return states.squeeze(), values, ret

In [None]:
def plot_value(agent, *lims, policy=False):
    # assert len(lims)==1
    X = []
    for lim in lims:
        if not isinstance(lim, np.ndarray):
            lim = np.linspace(*lim, num=25 if len(lim)<3 else lim[2])
        X.append(lim.squeeze())
    states = np.meshgrid(*X)
    ogshape = states[0].shape
    xins = [x.reshape(-1,1) for x in states]
    xin = np.hstack(xins)
    with torch.no_grad():
        xin = torch.from_numpy(xin).float()
        val = agent.policy.forward(xin)[1].detach().numpy()
    V = val.reshape(*ogshape)
    if len(lims)==1:
        plt.plot(x, V)
        plt.xlabel('State')
        plt.ylabel('Value')
        plt.grid(True, 'both', 'both')
    elif len(lims)==2:
        ax = plt.subplot(1,1,1, projection='3d')
        ax.plot_surface(states[0], states[1], v)

In [None]:
def reset_agent(agent, params):
    agent.policy.load_state_dict(params)
    agent.policy.state_xform *= 0
    agent.policy.action_xform = torch.eye(len(agent.policy.action_xform))

### Gradients

In [None]:
def dpolicy_df(
    A: np.ndarray, B: np.ndarray, F_A, F_B: np.ndarray,
    x: np.ndarray, u
) -> np.ndarray:
    F_BB_ = torch.from_numpy(np.linalg.pinv(F_B @ B)).float()
    A = torch.from_numpy(A).float()
    B = torch.from_numpy(B).float()
    F_A = torch.from_numpy(F_A).float()
    F_B = torch.from_numpy(F_B).float()
    x = torch.from_numpy(x).float()
    u = torch.from_numpy(u).float()
    I = torch.eye(len(A))
    def pi(F_A):
        return (F_BB_ @ ((I - F_A) @ A @ x.T + B @ u.T)).T
    dpidfa = jacobian(pi, F_A)
    # d2pidfa2 = hessian(pi, F_A)
    # dpidfa = torch.einsum('bnij->bij', dpidfa)
    def pi(F_B):
        F_BB_ = torch.linalg.pinv(F_B @ B)
        return (F_BB_ @ ((I - F_A) @ A @ x.T + B @ u.T)).T
    dpidfb = jacobian(pi, F_B)
    # d2pidfb2 = hessian(pi, F_B)
    # dpidfb = torch.einsum('bnij->bij', dpidfb)
    return dpidfa.numpy(), dpidfb.numpy()

In [None]:
def sh(**kwargs):
    for key, value in kwargs.items():
        print(key, value.shape)
def value_response(agent, A_s, B_s, F_A, F_B, x):
    # dudfa, dudfb,*_ = dpolicy_df(A_s, B_s, F_A, F_B, x,
    #                           agent.predict(x, deterministic=True)[0])
    # sh(dudfa=dudfa, dudfb=dudfb)
    dvdx, dudx = agent.policy.dvdpi_dobs(x, deterministic=True)
    dvdx, dudx = dvdx.numpy(), dudx.numpy()
    dxdu = np.stack([np.linalg.pinv(a) for a in dudx])
    dvdu = np.matmul(dvdx, dxdu)
    dvdut = np.matmul(dvdu,
                      np.linalg.pinv(np.linalg.pinv(F_B@B_s)@B_s))
    # dvdfa = (dvdu[:,None, :, :] * dudfa).squeeze(axis=1)
    # dvdfb = dvdu[:,None, :, :] * dudfb).squeeze(axis=1)
    dvdfa, dvdfb = None, None
    obs = torch.from_numpy(np.atleast_2d(x)).float().to(agent.policy.device)
    u_old = agent.policy.forward(obs, deterministic=True)[0].detach().cpu().numpy()
    u_new = action_transform(x, u_old, A_s, B_s, F_A, F_B)
    delta_u = u_new - u_old
    delta_v = dvdu * delta_u[:, None, :] # batch, project axis, action
    delta_vt = dvdut * delta_u[:, None, :] # batch, project axis, action
    return (dvdfa, dvdfb,
            dvdu, dvdut,
            delta_v, delta_vt,
            delta_u)

In [None]:
def dist_v_g(val, ret):
    val, ret = np.asarray(val), np.asarray(ret)
    return np.linalg.norm(val-ret)

def get_value_change(agent, continue_rl=False, N=6, interval=500):
    res = SimpleNamespace()
    res.dist, res.dist_xform, res.dist_new, res.dist_reused, res.agents = [], [], [], [], []
    res.dfs_new, res.dfs_reused = [], []
    res.infos = []
    res.delta_v, res.delta_vt, res.dvdu, res.dvdut = [], [], [], []
    res.delta_u = []
    
    for i in tqdm(np.arange(0, 1, 1/N), leave=False):
        
        _, val, ret = evaluate_trajectory(agent, make_xform_env(i), x0)
        res.dist.append(np.linalg.norm(val-ret))

        state_xform, action_xform, info = get_transforms(
            agent, make_env(), make_xform_env(i),
            n_episodes_or_steps='steps',
            buffer_episodes=interval,
            data_driven_source=True
        )
        agent_xform = transform_rl_policy(agent, state_xform, action_xform, copy=True)
        res.infos.append(info)
        
        dvdfa, dvdfb, dvdu, dvdut, delta_v, delta_vt, delta_u = \
            value_response(
                agent, info.A_s, info.B_s, info.F_A, info.F_B, info.x
            )
        res.delta_v.append(np.linalg.norm(delta_v, axis=2).squeeze())
        res.delta_vt.append(np.linalg.norm(delta_vt, axis=2).squeeze())
        res.dvdu.append(np.linalg.norm(dvdu, axis=2).squeeze())
        res.dvdut.append(np.linalg.norm(dvdut, axis=2).squeeze())
        
        _, val, ret = evaluate_trajectory(agent_xform, make_xform_env(i), x0)
        res.dist_xform.append(dist_v_g(val, ret))

        res.agents.append([agent, agent_xform])
        if continue_rl:
            kwargs = learn_kwargs.copy()
            kwargs['steps'] = continue_rl
            dirname = 'temp/' + str(os.getpid()) + '/'
            agent_new = learn_rl(make_xform_env(i),
                                 reuse_parameters_of=agent_xform,
                                 tensorboard_log=dirname+'new',
                                 **kwargs)
            res.dfs_new.append(get_tensorboard_scalar_frame('tensorboard/'+dirname+'new_1'))
            _, val, ret = evaluate_trajectory(agent_xform, make_xform_env(i), x0)
            res.dist_new.append(dist_v_g(val, ret))

            agent_reused = learn_rl(make_xform_env(i),
                                 reuse_parameters_of=agent,
                                 tensorboard_log=dirname+'reused',
                                 **kwargs)
            res.dfs_reused.append(get_tensorboard_scalar_frame('tensorboard/'+dirname+'reused_1'))
            _, val, ret = evaluate_trajectory(agent_reused, make_xform_env(i), x0)
            res.dist_reused.append(dist_v_g(val, ret))
            res.agents[-1].extend([agent_new, agent_reused])
            shutil.rmtree('tensorboard/'+dirname, ignore_errors=True)
    return res

## Experiments

In [None]:
# Load current_agent in case of notebook restart
from stable_baselines3 import PPO
agent = PPO.load('current_agent_'+make_env().name)
agent.policy.state_xform = agent.policy.state_xform.to(agent.policy.device)
agent.policy.action_xform = agent.policy.action_xform.to(agent.policy.device)

In [None]:
env = make_env()
agent = learn_rl(env, tensorboard_log=env.name+'/tuning', **learn_kwargs)
agent.save('current_agent_'+make_env().name)

In [None]:
res = get_value_change(agent, continue_rl=50_000, N=N, interval=500)
save_res(res, 'res_'+make_env().name)

In [None]:
res = load_res('res_'+make_env().name)

### $V-G$ with env change

In [None]:
%matplotlib inline
df = get_tensorboard_scalar_frame('tensorboard/'+make_env().name+'/tuning_18')
prev = df['rollout', 'ep_rew_mean'].to_numpy()
idx = df['rollout', 'ep_rew_mean'].index.to_numpy()

new, reused = [], []
for i in range(len(res.dfs_new)):
    if res.dist[i] > 0:
        new.append(res.dfs_new[i]['rollout', 'ep_rew_mean'].to_numpy())
    else:
        new.append(res.dfs_reused[i]['rollout', 'ep_rew_mean'].to_numpy())
    reused.append(res.dfs_reused[i]['rollout', 'ep_rew_mean'].to_numpy())
idx_new = res.dfs_new[0]['rollout', 'ep_rew_mean'].index.to_numpy()
idx_new = idx_new + idx[-1]
mn, sn = mean_std(new, axis=0)
mr, sr = mean_std(reused, axis=0)

plt.figure(figsize=(6,3))
plt.plot(idx, prev, label='Nominal RL')
plt.axvline(x=idx[-1], c='r', ls='--', label='Fault')
plt.plot(idx_new, mr, c='b', label='Tuned')
plt.fill_between(idx_new, mr+sr, mr-sr, color='b', alpha=0.4)
plt.plot(idx_new, mn, c='g', label='Transformed+Tuned')
plt.fill_between(idx_new, mn+sn, mn-sn, color='g', alpha=0.4)
plt.ylabel('Reward')
plt.xlabel('Training steps')
plt.legend()

In [None]:
%matplotlib inline
plt.figure(figsize=(6,3))
improvement = [d-dn for d, dn in zip(res.dist, res.dist_new)]
d_asc = sorted(res.dist)
plt.plot(d_asc, sort_by(res.dist, res.dist), c='b', ls=':', label='$||V_{\pi_s}-G_{\pi_s}||$')
plt.plot(d_asc, sort_by(res.dist_reused, res.dist), c='b', label='$||V_{\pi_s}-G_{\pi_s}||$ Tuned')
plt.plot(d_asc, sort_by(res.dist_xform, res.dist), c='g', ls=':', lw='3',
         label='$||V_{F_\pi(\pi_s)}-G_{F_\pi(\pi_s)}||$')
plt.plot(d_asc, sort_by(res.dist_new, res.dist), c='g', lw='3',
         label='$||V_{F_\pi(\pi_s)}-G_{F_\pi(\pi_s)}||$ Tuned')
plt.ylabel('$||V - G||$ on $T_t$')
plt.xlabel('$||V - G||$')
# plt.xticks(np.arange(len(res.dist)))
plt.legend()

### Quality of transformation

#### Sensitivity of value

In [None]:
plt.figure(figsize=(6,3))
mindv, maxdv = np.min(res.delta_v), np.max(res.delta_v)
idx = np.argsort(res.dist)
for i, dv in enumerate(res.delta_v):
    plt.hist(res.delta_v[idx[i]], density=True, bins=20, range=(mindv, maxdv),
             alpha=(1-i/N), label='$||V-G||=%.2f$' % res.dist[idx[i]])
plt.legend()
plt.ylabel('Density')
plt.xlabel('$\partial V /\partial u \cdot u$')

In [None]:
# Plot of gradient norms vs degrading system
x = np.atleast_2d(x0)
delta_vs, dvdus, dvduts = [], [], []
xs = []
for i in np.arange(0, 1, 1/N):
    # reset_agent(agent, policy_params)
    state_xform, action_xform, info = get_transforms(
            agent=agent, env=make_env(), env_=make_xform_env(i),
            buffer_episodes=500, n_episodes_or_steps='steps',
            data_driven_source=False
        )
    dvdfa, dvdfb, dvdu, delta_v, delta_u, dvdut = value_response(
        agent, info.A_s, info.B_s, info.F_A, info.F_B, info.x
    )
    dvdus.append(np.linalg.norm(dvdu, axis=2).squeeze())
    dvduts.append(np.linalg.norm(dvdut, axis=2).squeeze())
    delta_vs.append(delta_v.squeeze())
    xs.append(info.x.squeeze())

In [None]:
%matplotlib notebook
ax = plt.subplot(projection='3d' if len(res.infos[0].x[0])>=2 else None)
xs = np.asarray(xs)
idx = np.argsort(res.dist)
for i in idx:
    x, dv = res.infos[i].x, res.dvdut[i]
    ax.scatter(x[::5][:,0], x[::5][:,1], dv[::5],
               label='$||V-G||=%.2f$' % res.dist[i], alpha=0.6)
plt.legend()
ax.set_xlabel('$x_0$')
if len(res.infos[0].x[0])==2:
    ax.set_ylabel('$x_1$')
    ax.set_zlabel('$||\partial V / \partial u$||')
else:
    ax.set_ylabel('$||\partial V / \partial u$||')

#### Correlation between improvement, $V-G$, $\partial V / \partial u$

In [None]:
D, I, V, VT = [], [], [], []
for t in trange(5, leave=False):
    res_ = get_value_change(agent, continue_rl=False, N=10, interval=500)
    D.append(sort_by(res_.dist, res_.dist))
    I.append(sort_by([d-dn for d, dn in zip(res_.dist, res_.dist_xform)], res_.dist))
    V.append([np.mean(a) for a in sort_by(res_.dvdu, res_.dist)])
    VT.append([np.mean(a) for a in sort_by(res_.dvdut, res_.dist)])
mD, sd = mean_std(D)
mI, si = mean_std(I)
mVT, svt = mean_std(VT)
mV, sv = mean_std(V)

In [None]:
%matplotlib inline
plt.figure(figsize=(6,3))
plt.plot(mD, mI, c='c')
plt.fill_between(mD, mI+si, mI-si, alpha=0.2)
plt.ylabel('$\Delta \epsilon_G$', c='c')
plt.xlabel('$||V_s-G_{\pi_s,T_t}||$')
plt.twinx()
plt.plot(mD, mV, c='m', ls=':',
        label='$\partial V_s / \partial \pi_s$')
plt.fill_between(mD, mV+sv, mV-sv, alpha=0.2, color='m')
plt.plot(mD, mVT, c='m', ls='-',
        label='$\partial V_s / \partial \pi_t$')
plt.fill_between(mD, mVT+svt, mVT-svt, alpha=0.2, color='m')
plt.ylabel('$\partial V_s / \partial u$ on $T_t$', c='m')
# plt.yscale('log')
plt.legend(handles=plt.gca().lines)

#### Correlation between $\mathcal{D}$ and improvement

In [None]:
intervals = np.arange(50, 2000, 200)

D, I, V, VT = [], [], [], []
for t in tqdm(intervals, leave=False):
    res_ = get_value_change(agent, continue_rl=False, N=10, interval=t)
    D.append(sort_by(res_.dist, res_.dist))
    I.append(sort_by([d-dn for d, dn in zip(res_.dist, res_.dist_xform)], res_.dist))
    V.append([np.mean(a) for a in sort_by(res_.dvdu, res_.dist)])
    VT.append([np.mean(a) for a in sort_by(res_.dvdut, res_.dist)])

I = np.asarray(I)
mD, sd = mean_std(D, axis=0)

In [None]:
%matplotlib notebook
ax = plt.subplot(projection='3d')
ax.plot_surface(*np.meshgrid(intervals, mD, indexing='ij'),
                np.clip(I, a_min=-100, a_max=None), cmap='coolwarm',
               alpha=0.5)
ax.set_xlabel('$\mathcal{D}$')
ax.set_ylabel('$V_s - G_{\pi_s,T_t}$')
ax.set_zlabel('$\Delta \epsilon_G$')

#### Approximation error

In [None]:
D, I, E_s, E_t = [], [], [], []
intervals = np.arange(50, 1000, 200)
for t in tqdm(intervals, leave=False):
    res_ = get_value_change(agent, continue_rl=False, N=10, interval=t)
    D.append(sort_by(res_.dist, res_.dist))
    I.append(sort_by([d-dn for d, dn in zip(res_.dist, res_.dist_xform)], res_.dist))
    E_s.append(sort_by([i.err_s for i in res_.infos], res_.dist))
    E_t.append(sort_by([i.err_t for i in res_.infos], res_.dist))

In [None]:
mD, sd = mean_std(D)
I = np.asarray(I)
E = (np.asarray(E_s) + np.asarray(E_t)).mean(axis=0)

In [None]:
E_spr, E_lander, E_lunar = [], [], []
intervals = [10, 20, 40, 80, 160, 320, 640, 1280]
agent_spr = load_agent('current_agent_SpringMassEnv')
agent_lander = load_agent('current_agent_LanderEnv')
agent_lunar = load_agent('current_agent_LunarLanderEnv')
for trial in trange(5, leave=False):
    E_spr.append([])
    E_lander.append([])
    E_lunar.append([])
    for i in intervals:
        P, e, *_ = pseudo_matrix_from_data(env_spr, i, agent_spr, 'steps')
        _, B, *_ = ab_xform_from_pseudo_matrix(P, None, 0.01)
        E_spr[-1].append(e + err_inv(B))
        P, e, *_ = pseudo_matrix_from_data(env_lander, i, agent_lander, 'steps')
        _, B, *_ = ab_xform_from_pseudo_matrix(P, None, 0.1)
        E_lander[-1].append(e + err_inv(B))
        P, e, *_ = pseudo_matrix_from_data(env_lunar, i, agent_lunar, 'steps')
        _, B, *_ = ab_xform_from_pseudo_matrix(P, None, 0.04)
        E_lunar[-1].append(e + err_inv(B))

In [None]:
mE_spr, sE_spr = mean_std(E_spr)
mE_lander, sE_lander = mean_std(E_lander)
mE_lunar, sE_lunar = mean_std(E_lunar)
plt.figure(figsize=(6,3))
l, = plt.plot(intervals, mE_spr, label='Spring-Mass system')
plt.fill_between(intervals, mE_spr-sE_spr, mE_spr+sE_spr, color=l.get_color(), alpha=0.4)
l, = plt.plot(intervals, mE_lander, label='Lander system')
plt.fill_between(intervals, mE_lander-sE_lander, mE_lander+sE_lander, color=l.get_color(), alpha=0.4)
l, = plt.plot(intervals, mE_lunar, label='Lunar lander system')
plt.fill_between(intervals, mE_lunar-sE_lunar, mE_lunar+sE_lunar, color=l.get_color(), alpha=0.4)
plt.ylim(bottom=-0.01, top=0.5)
plt.legend(loc='upper right')
plt.ylabel('$\epsilon_D + \epsilon_i(B_t)$')
plt.xlabel('$\mathcal{D}$')

In [None]:
%matplotlib notebook
ax = plt.subplot(projection='3d')
ax.plot_surface(*np.meshgrid(intervals, E, indexing='ij'),
                np.clip(I, a_min=-100, a_max=None), cmap='coolwarm',
               alpha=0.5)
ax.set_xlabel('$\mathcal{D}$')
ax.set_ylabel('$\epsilon_A$')
ax.set_zlabel('$\Delta \epsilon_G$')

### Experiment

In [None]:
# experiment config

# Whether the knowledge of the source system is known,
# or approximated from sampled experiences
data_driven_source = True
# Whether to assume that the system transformations are known
# and not approximate
accurate_xfer = not data_driven_source
buffer_episodes = 5
interval = 5 * env.period * buffer_episodes
name = env.__class__.__name__
if data_driven_source and not accurate_xfer:
    name += 'StochasticAll'
elif data_driven_source:
    name += 'StochasticSource'
elif not accurate_xfer:
    name += 'StochasticXfer'

In [None]:
# train rl policy on original environment
agent = learn_rl(make_env(), tensorboard_log=name+'/Source',
                 **learn_kwargs)

In [None]:
plot_env_response(make_xform_env(), x0, agent_xform_tuned)

TODO change env_xforms so B!=0

TODO diagnose steps > kwargs['steps']

In [None]:
# fine-tune source policy on target environment
agent_new = learn_rl(make_xform_env(), tensorboard_log=name+'/Target',
                 **learn_kwargs)

In [None]:
# fine-tine source policy on target environment
agent_xform_tuned = concurrent_learn(agent, make_xform_env, interval,
                             xform_policy=True,
                             accurate_xfer=accurate_xfer,
                             tensorboard_log=name+'/XformedTuned',
                             **learn_kwargs)

In [None]:
# fine-tine the transformed policy, except xforms
agent_xform_tuned = learn_rl(
    make_xform_env(),
    reuse_parameters_of=agent_xform,
    learnable_transformation=False,
    tensorboard_log=name+'/XformedTuned', **learn_kwargs
)
print('state_xform', agent_xform_tuned.policy.state_xform)
print('action_xform', agent_xform_tuned.policy.action_xform)