In [1]:
import math, random, time
import numpy as np
import torch
import torch.nn.functional as F
from torch.optim import Adam
from collections import deque

from causal_gym import HighwayPCH
from imitation.imitate import parse_graph, find_sequential_pi_backdoor, collect_expert_trajectories
from imitation.gym_gail.core_net import DiscreteActor, Critic, Discriminator
from imitation.gym_gail.causal_gail import *

In [None]:
num_steps = 30
seed = 1
expert_eps = 200
train_eps = 30
device = 'cpu'

In [3]:
env = HighwayPCH(num_steps=num_steps, seed=seed, render_mode='rgb_array')
num_actions = env.env.action_space.n

In [4]:
G = parse_graph(env.get_graph)
X = {f'X{t}' for t in range(num_steps)}
Y = f'Y{num_steps}'
obs_prefix = env.env.observed_unobserved_vars[0]

In [5]:
Z_sets = find_sequential_pi_backdoor(G, X, Y, obs_prefix)
categorical_dims = calc_categorical_dims(env)
dummy_obs, _ = env.reset(seed=seed)

encode, z_dim, union_tokens, var_dims = build_z_encoder(Z_sets, dummy_obs, categorical_dims)
dummy_z = encode(dummy_obs, 0)
print('z_dim =', z_dim, ' | encode(dummy_obs, 0).size =', int(np.asarray(dummy_z).size))

z_dim = 145  | encode(dummy_obs, 0).size = 145


In [None]:
records = collect_expert_trajectories(
    env,
    num_episodes=expert_eps,
    max_steps=num_steps,
    behavioral_policy=None,
    seed=seed
)

Starting episode 1/10...
  Episode 1 ended at step 10 (terminated: False, truncated: True).
Starting episode 2/10...
  Episode 2 ended at step 10 (terminated: False, truncated: True).
Starting episode 3/10...
  Episode 3 ended at step 10 (terminated: False, truncated: True).
Starting episode 4/10...
  Episode 4 ended at step 10 (terminated: True, truncated: True).
Starting episode 5/10...
  Episode 5 ended at step 10 (terminated: False, truncated: True).
Starting episode 6/10...
  Episode 6 ended at step 10 (terminated: False, truncated: True).
Starting episode 7/10...
  Episode 7 ended at step 10 (terminated: False, truncated: True).
Starting episode 8/10...
  Episode 8 ended at step 10 (terminated: False, truncated: True).
Starting episode 9/10...
  Episode 9 ended at step 10 (terminated: False, truncated: True).
Starting episode 10/10...
  Episode 10 ended at step 7 (terminated: True, truncated: False).
Finished collecting expert trajectories.


In [None]:
actor = DiscreteActor(z_dim, num_actions, hidden_size=128).to(device)
critic = Critic(z_dim, hidden_size=128).to(device)
discriminator = Discriminator(z_dim + num_actions, hidden_size=128, dropout=0.2).to(device)

actor_optim = Adam(actor.parameters(), lr=3e-4)
critic_optim = Adam(critic.parameters(), lr=1e-3)
discriminator_optim = Adam(discriminator.parameters(), lr=1e-4)

In [8]:
Z_e, A_e, X_e = make_expert_batch(records, encode, num_actions)
print(Z_e.shape, A_e.shape, X_e.shape)

torch.Size([97, 145]) torch.Size([97]) torch.Size([97, 150])


In [None]:
stats = one_training_round(
    env,
    actor,
    critic,
    discriminator,
    actor_optim,
    critic_optim,
    discriminator_optim,
    encode,
    num_actions,
    X_e,
    expert_records=None,
    gamma=0.99,
    gae_lambda=0.95,
    ppo_clip=0.2,
    epochs=4,
    minibatch_size=512,
    entropy_coeff=5e-3,
    value_coeff=0.5,
    max_grad_norm=0.5,
    normalize_adv=True,
    loss_type='bce',
    gp_lambda=10.0,
    d_updates=4,
    d_minibatch_size=512,
    use_gp=True,
    instance_noise_std=0.0,
    label_smoothing=0.1,
    max_steps=num_steps,
    num_episodes=train_eps,
    seed=seed
)

stats

Epoch 1/5 completed.
Epoch 2/5 completed.
Epoch 3/5 completed.
Epoch 4/5 completed.
Epoch 5/5 completed.


{'avg_env_return': 166.325410816414,
 'avg_D_reward': 0.6946890950202942,
 'ppo_actor_loss': -0.026710760965943336,
 'ppo_critic_loss': 1.1620875358581544,
 'ppo_entropy': 1.6088080167770387,
 'ppo_approx_kl': 0.004577630899079588,
 'ppo_clip_frac': 0.0,
 'D_loss': 10.692280451456705,
 'D_real_mean': 0.007911172385017077,
 'D_fake_mean': 0.0035715773701667786,
 'D_gp': 0.9308118224143982,
 'D_accuracy': 0.6078431407610575,
 'ep_lens': [10, 10, 10, 7, 5, 10, 3, 10, 10, 10],
 'n_steps': 85,
 'n_episodes': 10}

In [None]:
epochs = 120
log_every = 5
ret_ma = deque(maxlen=20)
dret_ma = deque(maxlen=20)

for it in range(1, epochs + 1):
    stats = one_training_round(
        env=env,
        actor=actor, critic=critic, discriminator=discriminator,
        actor_optim=actor_optim, critic_optim=critic_optim, discriminator_optim=discriminator_optim,
        encode=encode, num_actions=num_actions,
        X_e=X_e,
        gamma=0.99, gae_lambda=0.95,
        ppo_clip=0.2, epochs=4, minibatch_size=512,
        entropy_coeff=5e-3, value_coeff=0.5, max_grad_norm=0.5,
        normalize_adv=True,
        loss_type='bce',
        gp_lambda=10.0, d_updates=4, d_minibatch_size=512, use_gp=True,
        instance_noise_std=0.0, label_smoothing=0.1,
        max_steps=num_steps, num_episodes=train_eps,
        seed=seed + it
    )

    ret_ma.append(stats['avg_env_return'])
    dret_ma.append(stats['avg_D_reward'])

    if it % log_every == 0:
        print(
            f"[{it:03d}] "
            f"R_env(m)={np.mean(ret_ma):.3f}  "
            f"R_D(m)={np.mean(dret_ma):.3f}  "
            f"π: L_actor={stats['ppo_actor_loss']:.3f}  "
            f"V: L_critic={stats['ppo_critic_loss']:.3f}  "
            f"D: L={stats['D_loss']:.3f}  acc={stats['D_accuracy']:.3f}"
        )

Epoch 1/4 completed.
Epoch 2/4 completed.
Epoch 3/4 completed.
Epoch 4/4 completed.
Epoch 1/4 completed.
Epoch 2/4 completed.
Epoch 3/4 completed.
Epoch 4/4 completed.
Epoch 1/4 completed.
Epoch 2/4 completed.
Epoch 3/4 completed.
Epoch 4/4 completed.
Epoch 1/4 completed.
Epoch 2/4 completed.
Epoch 3/4 completed.
Epoch 4/4 completed.
Epoch 1/4 completed.
Epoch 2/4 completed.
Epoch 3/4 completed.
Epoch 4/4 completed.
[005] R_env(m)=121.413  R_D(m)=0.699  π: L_actor=-0.020  V: L_critic=0.506  D: L=9.887  acc=0.512
Epoch 1/4 completed.
Epoch 2/4 completed.
Epoch 3/4 completed.
Epoch 4/4 completed.
Epoch 1/4 completed.
Epoch 2/4 completed.
Epoch 3/4 completed.
Epoch 4/4 completed.
Epoch 1/4 completed.
Epoch 2/4 completed.
Epoch 3/4 completed.
Epoch 4/4 completed.
Epoch 1/4 completed.
Epoch 2/4 completed.
Epoch 3/4 completed.
Epoch 4/4 completed.
Epoch 1/4 completed.
Epoch 2/4 completed.
Epoch 3/4 completed.
Epoch 4/4 completed.
[010] R_env(m)=124.381  R_D(m)=0.706  π: L_actor=-0.017  V: L_

KeyboardInterrupt: 

In [None]:
@torch.no_grad()
def eval_policy(env, actor, encode, num_episodes=6, max_steps=num_steps, seed=None):
    actor.eval()
    returns = []

    for e in range(num_episodes):
        obs, _ = env.reset(seed=seed + e)
        t, done, ret = 0, False, 0.0

        while not done and t < max_steps:
            z = torch.from_numpy(encode(obs, t)).float().unsqueeze(0).to(next(actor.parameters()).device)
            a, _, _ = actor.act(z, deterministic=True)
            obs, r, terminated, truncated, _ = env.do(lambda _: int(a.item()), show_reward=True)

            ret += r
            t += 1
            done = terminated or truncated

        returns.append(ret)

    return float(np.mean(returns)), returns

eval_policy(env, actor, encode)