# Automatic Reward Shaping from Confounded Offline Data
This notebook is based on our ICML 25 [paper](https://openreview.net/forum?id=Hu7hUjEMiW&referrer=%5BAuthor%20Console%5D(%2Fgroup%3Fid%3DICML.cc%2F2025%2FConference%2FAuthors%23your-submissions)).
Also see the Techical Report version [here](https://causalai.net/r123.pdf).

The task is to learn a potential function automatically from offline data to be used in Potential Based Reward Shaping (PBRS). The new reward function after reward shaping is defined to be,
$$
Y' = Y + \gamma\phi(s') - \phi(s)
$$
where $Y$ is the original reward signal and $Y'$ is the one after shaping. $\phi(\cdot)$ is the potential function we aim to learn automatically from offline datasets.

Intuitively, one can use the optimal state values as the potential function. And if the provided offline dataset is generated by a good enough policy, one can directly take the average cumulative return as the state value estimations. However, when the offline dataset is confounded or the data generating policy is sub-optimal, such naive estimations are highly biased and could mislead the policy training. See example 1&2 in the paper for more details.

In this work, we use causal bounds to estimate an upper bound on the optimal interventional
state values. Then, we take the estimated upper value bound as the potential function to train our online policy learner, Q-UCB.

## Environment Definition and Data Generation
In this notebook, we will replicate our experiment results in environment WindyLavaCross (hard), corresponding to Fig. 3(c) and Fig. 4(c).

In [1]:
from causal_gym.core import Task
from causal_rl.algo.reward_shaping.calculate_values import *

# Define behavioral policy
BEHAVIORAL = {
    'Nowind-Empty-8x8-v0': {
        'good': good_bpolicy_emptyworld,
        'bad': lambda s, w: good_bpolicy_emptyworld(s, w) if np.random.rand() > .5 else np.random.choice(5),
        'random': lambda s, w: np.random.choice(5),
    },
    'MiniGrid-Empty-8x8-v0': {
        'good': good_bpolicy_emptyworld,
        'bad': lambda s, w: good_bpolicy_emptyworld(s, w) if np.random.rand() > .5 else np.random.choice(5),
        'random': lambda s, w: np.random.choice(5),
    },
    'Custom-LavaCrossing-easy-v0': {
        'good': good_bpolicy_lavacross,
        'bad': bad_bpolicy_lavacross,
        'random': lambda s, w: np.random.choice(5)
    },
    'Custom-LavaCrossing-hard-v0': {
        'good': good_bpolicy_lavacross_hard,
        'bad': bad_bpolicy_lavacross_hard,
        'random': lambda s, w: np.random.choice(5)
    },
    'Custom-LavaCrossing-extreme-v0': {
        'good': good_bpolicy_lavacross_extreme,
        'bad': bad_bpolicy_lavacross_extreme,
        'random': lambda s, w: np.random.choice(5)
    },
    'Custom-LavaCrossing-maze-v0': {
        'good': good_bpolicy_lavacross_maze,
        'bad': bad_bpolicy_lavacross_maze,
        'bad2': bad_bpolicy_lavacross_maze2
    },
    'Custom-LavaCrossing-maze-complex-v0': {
        'good': better_bpolicy_lavacross_maze_complex,
        'bad': good_bpolicy_lavacross_maze,
        'bad2': bad_bpolicy_lavacross_maze_complex
    }
}

env_name = 'Custom-LavaCrossing-extreme-v0'
for SEED in SEEDS:
    print('\n=======================================\n')
    print(f'Env: {env_name} Seed: {SEED}')
    # Initialize the environment
    env = gym.make(
        env_name, 
        agent_pov=False, 
        render_mode='rgb_array', 
        highlight=False, 
        **KWARGS[env_name]
    )
    # We have to use 'cool' as the learning regime since we need do for online learning
    # and see for collecting offline data.
    windy_env = MiniGridActionRemapWrapper(WindyMiniGridPCH(
        env=env, 
        show_wind=True, 
        wind_dist=WIND_DIST[env_name],
        task=Task(learning_regime='cool'), 
    ))
    # Calculate optimal interventional policy space state value
    opt_values, opt_qvalues = value_iteration(windy_env)
    print(f'Opt state values of {env_name}')
    print(np.transpose(opt_values))
    save_values(opt_values, f'OPTV-{env_name}-{SEED}')
    save_values(opt_qvalues, f'OPTQ-{env_name}-{SEED}')
    print('------------------------')

    bounds = []
    mixed_dataset = []
    for policy_name, bpolicy in BEHAVIORAL[env_name].items():
        dataset, behavioral_values = gen_dataset(windy_env, bpolicy, seed=SEED)
        mixed_dataset.extend(dataset)
        print(f'{policy_name} behavioral policy values')
        print(np.transpose(behavioral_values))
        save_values(behavioral_values, f'BEV-{policy_name}-{env_name}-{SEED}')
        print('------------------------')
        bound, state_count = approx_opt_value_upper_bound(windy_env, dataset, windy_env.state_space, windy_env.action_space.n, horizon=KWARGS[env_name]['max_episode_steps'], reward_upper_bound=0)
        print(f'{policy_name} behavioral policy value bounds in {env_name}')
        print(np.transpose(bound))
        print('------------------------')
        bounds.append(bound)
        save_values(bound, f'BD-{policy_name}-{env_name}-{SEED}')

    with open(f'data/mixdata-{env_name}-{SEED}.json', 'w') as f:
        json.dump(mixed_dataset, f)
    final_bound = np.minimum.reduce(bounds)
    print(f'\nFinal Bound for {env_name}:')
    print(np.transpose(final_bound))
    save_values(final_bound, f'BD-FINAL-{env_name}-{SEED}')


ImportError: cannot import name 'StrEnum' from 'enum' (/Users/ml/anaconda3/envs/torchdev/lib/python3.9/enum.py)