# Learn a Reward Function using Maximum Conditional Entropy Inverse Reinforcement Learning

MCE IRL only supports tabular environments.

In [None]:
from imitation.algorithms.mce_irl import MCEIRL
import gym
import imitation.envs.examples.model_envs
from imitation.algorithms import base

from imitation.data import rollout
from imitation.envs import resettable_env
from stable_baselines3.common.vec_env import DummyVecEnv
from imitation.rewards import reward_nets


env_name = "imitation/CliffWorld15x6-v0"
env = gym.make(env_name)
state_venv = resettable_env.DictExtractWrapper(
    DummyVecEnv([lambda: gym.make(env_name)] * 4), "state"
)
obs_venv = resettable_env.DictExtractWrapper(
    DummyVecEnv([lambda: gym.make(env_name)] * 4), "obs"
)

trajs = rollout.generate_trajectories(
    policy=None,
    venv=state_venv,
    sample_until=rollout.make_min_timesteps(10000),
)

reward_net = reward_nets.BasicRewardNet(
    env.pomdp_observation_space,
    env.action_space,
    use_action=False,
    use_next_state=False,
    use_done=False,
    hid_sizes=[],
)

mce_irl = MCEIRL(trajs, env, reward_net, linf_eps=1e-3)

mce_irl.train(max_iter=5000)

In [None]:
from imitation.rewards.reward_wrapper import RewardVecEnvWrapper

obs_env_with_learned_reward = RewardVecEnvWrapper(obs_venv, reward_net.predict)

In [None]:
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3.common.evaluation import evaluate_policy


learner = PPO(
    policy=MlpPolicy,
    env=obs_env_with_learned_reward,
    seed=0,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
    n_steps=64,
)

learner_rewards_before_training, _ = evaluate_policy(
    learner, obs_venv, 100, return_episode_rewards=True
)
learner.learn(1000)  # 100000
learner_rewards_after_training, _ = evaluate_policy(
    learner, obs_venv, 100, return_episode_rewards=True
)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

print(np.mean(learner_rewards_after_training))
print(np.mean(learner_rewards_before_training))

plt.hist(
    [learner_rewards_before_training, learner_rewards_after_training],
    label=["untrained", "trained"],
)
plt.legend()
plt.show()