In [1]:
import math, random, time
import numpy as np
import torch
import torch.nn.functional as F
from torch.optim import Adam
from collections import deque, defaultdict
import matplotlib.pyplot as plt

from causal_gym import HighwayPCH
from imitation.imitate import parse_graph, find_sequential_pi_backdoor, collect_expert_trajectories
from imitation.gym_gail.core_net import DiscreteActor, Critic, Discriminator
from imitation.gym_gail.causal_gail import *

In [2]:
num_steps = 25
seed = 1
train_eps = 1000
device = 'cpu'

In [3]:
env = HighwayPCH(num_steps=num_steps, seed=seed, render_mode='rgb_array')
num_actions = env.env.action_space.n

In [4]:
G = parse_graph(env.get_graph)
X = {f'X{t}' for t in range(num_steps)}
Y = f'Y{num_steps}'
obs_prefix = env.env.observed_unobserved_vars[0]

In [5]:
Z_sets = find_sequential_pi_backdoor(G, X, Y, obs_prefix)
categorical_dims = calc_categorical_dims(env)
dummy_obs, _ = env.reset(seed=seed)

encode, z_dim, union_tokens, var_dims = build_z_encoder(Z_sets, dummy_obs, categorical_dims)
dummy_z = encode(dummy_obs, 0)
print('z_dim =', z_dim, ' | encode(dummy_obs, 0).size =', int(np.asarray(dummy_z).size))

z_dim = 370  | encode(dummy_obs, 0).size = 370


In [None]:
records = collect_expert_trajectories(
    env,
    num_episodes=train_eps,
    max_steps=num_steps,
    behavioral_policy=None,
    seed=seed
)

Starting episode 1/1000...
  Episode 1 ended at step 25 (terminated: False, truncated: True).
Starting episode 2/1000...
  Episode 2 ended at step 25 (terminated: False, truncated: True).
Starting episode 3/1000...
  Episode 3 ended at step 25 (terminated: False, truncated: True).
Starting episode 4/1000...
  Episode 4 ended at step 10 (terminated: True, truncated: False).
Starting episode 5/1000...
  Episode 5 ended at step 25 (terminated: False, truncated: True).
Starting episode 6/1000...
  Episode 6 ended at step 17 (terminated: True, truncated: False).
Starting episode 7/1000...


In [None]:
actor = DiscreteActor(z_dim, num_actions, hidden_size=128).to(device)
critic = Critic(z_dim, hidden_size=128).to(device)
discriminator = Discriminator(z_dim + num_actions, hidden_size=128, dropout=0.2).to(device)

actor_optim = Adam(actor.parameters(), lr=3e-4)
critic_optim = Adam(critic.parameters(), lr=1e-4)
discriminator_optim = Adam(discriminator.parameters(), lr=5e-5)

In [None]:
Z_e, A_e, X_e = make_expert_batch(records, encode, num_actions)
print(Z_e.shape, A_e.shape, X_e.shape)

torch.Size([5142, 445]) torch.Size([5142]) torch.Size([5142, 450])


In [None]:
stats = one_training_round(
    env,
    actor,
    critic,
    discriminator,
    actor_optim,
    critic_optim,
    discriminator_optim,
    encode,
    num_actions,
    X_e,
    expert_records=None,
    gamma=0.99,
    gae_lambda=0.95,
    ppo_clip=0.2,
    epochs=4,
    minibatch_size=256,
    entropy_coeff=2e-2,
    value_coeff=0.5,
    max_grad_norm=0.5,
    normalize_adv=True,
    loss_type='bce',
    gp_lambda=10.0,
    d_updates=2,
    d_minibatch_size=256,
    use_gp=False,
    instance_noise_std=0.05,
    label_smoothing=0.0,
    max_steps=num_steps,
    num_episodes=train_eps,
    seed=seed
)

stats

{'avg_env_return': 226.13310888900185,
 'avg_D_reward': 0.6900997161865234,
 'ppo_actor_loss': -0.010333863086998463,
 'ppo_critic_loss': 2.1499789357185364,
 'ppo_entropy': 1.6092643439769745,
 'ppo_approx_kl': -0.00042894088647489614,
 'ppo_clip_frac': 0.0,
 'D_loss': 10.738677501678467,
 'D_real_mean': -0.0033468197216279805,
 'D_fake_mean': -0.006157793221063912,
 'D_gp': 0.9353475421667099,
 'D_accuracy': 0.5634014457464218,
 'ep_lens': [10,
  4,
  4,
  11,
  8,
  22,
  16,
  6,
  20,
  29,
  30,
  13,
  3,
  3,
  4,
  6,
  17,
  14,
  12,
  5,
  6,
  12,
  3,
  30,
  22,
  29,
  30,
  7,
  23,
  17],
 'n_steps': 416,
 'n_episodes': 30}

In [None]:
epochs = 120
log_every = 5
ret_ma = deque(maxlen=20)
dret_ma = deque(maxlen=20)

for it in range(1, epochs + 1):
    stats = one_training_round(
        env,
        actor,
        critic,
        discriminator,
        actor_optim,
        critic_optim,
        discriminator_optim,
        encode,
        num_actions,
        X_e,
        expert_records=None,
        gamma=0.99,
        gae_lambda=0.95,
        ppo_clip=0.2,
        epochs=4,
        minibatch_size=256,
        entropy_coeff=2e-2,
        value_coeff=0.5,
        max_grad_norm=0.5,
        normalize_adv=True,
        loss_type='bce',
        gp_lambda=10.0,
        d_updates=2,
        d_minibatch_size=256,
        use_gp=False,
        instance_noise_std=0.05,
        label_smoothing=0.0,
        max_steps=num_steps,
        num_episodes=train_eps,
        seed=seed
    )

    ret_ma.append(stats['avg_env_return'])
    dret_ma.append(stats['avg_D_reward'])

    if it % log_every == 0:
        print(
            f"[{it:03d}] "
            f"R_env(m)={np.mean(ret_ma):.3f}  "
            f"R_D(m)={np.mean(dret_ma):.3f}  "
            f"pi: L_actor={stats['ppo_actor_loss']:.3f}  "
            f"V: L_critic={stats['ppo_critic_loss']:.3f}  "
            f"D: L={stats['D_loss']:.3f}  acc={stats['D_accuracy']:.3f}"
        )

[005] R_env(m)=218.446  R_D(m)=0.689  π: L_actor=-0.016  V: L_critic=1.963  D: L=9.965  acc=0.604
[010] R_env(m)=306.771  R_D(m)=0.691  π: L_actor=-0.021  V: L_critic=0.689  D: L=5.970  acc=0.614
[015] R_env(m)=374.548  R_D(m)=0.693  π: L_actor=-0.006  V: L_critic=0.417  D: L=1.481  acc=0.614
[020] R_env(m)=406.108  R_D(m)=0.669  π: L_actor=-0.008  V: L_critic=0.450  D: L=1.364  acc=0.814
[025] R_env(m)=478.097  R_D(m)=0.627  π: L_actor=-0.009  V: L_critic=0.260  D: L=1.306  acc=0.925
[030] R_env(m)=500.146  R_D(m)=0.577  π: L_actor=-0.018  V: L_critic=0.203  D: L=1.231  acc=0.952
[035] R_env(m)=503.076  R_D(m)=0.511  π: L_actor=-0.001  V: L_critic=0.293  D: L=1.137  acc=0.970
[040] R_env(m)=508.143  R_D(m)=0.452  π: L_actor=-0.005  V: L_critic=0.111  D: L=1.013  acc=0.981
[045] R_env(m)=508.491  R_D(m)=0.407  π: L_actor=-0.010  V: L_critic=0.184  D: L=0.967  acc=0.980
[050] R_env(m)=488.536  R_D(m)=0.365  π: L_actor=-0.023  V: L_critic=0.191  D: L=0.949  acc=0.980
[055] R_env(m)=462.3

In [None]:
@torch.no_grad()
def eval_policy(env, actor, encode, num_episodes=1000, max_steps=num_steps, seed=None):
    actor.eval()
    returns = []

    for e in range(num_episodes):
        rs = None if seed is None else seed + e + 1000
        obs, _ = env.reset(seed=rs)
        t, done, ret = 0, False, 0.0

        while not done and t < max_steps:
            z = torch.from_numpy(encode(obs, t)).float().unsqueeze(0).to(next(actor.parameters()).device)
            a, _, _ = actor.act(z, deterministic=True)
            obs, r, terminated, truncated, _ = env.do(lambda _: int(a.item()), show_reward=True)

            ret += r
            t += 1
            done = terminated or truncated

        returns.append(ret)

    return float(np.mean(returns)), returns

_, imitator_rewards = eval_policy(env, actor, encode)

In [None]:
expert_rewards = []
reward = 0.0
last_ep = records[0]['episode']

for r in records:
    if r['episode'] != last_ep:
        expert_rewards.append(reward)
        reward = 0.0
        last_ep = r['episode']

    reward += r['reward']

print(f'Expert rewards: mean={np.mean(expert_rewards):.3f}, std={np.std(expert_rewards):.3f}')
print(f'Imitator rewards: mean={np.mean(imitator_rewards):.3f}, std={np.std(imitator_rewards):.3f}')

Expert rewards: mean=676.280, std=205.085
Imitator rewards: mean=270.594, std=264.848
