In [1]:
from imitation.imitate_single_step import *
from imitation.data import ExpertDataset
from imitation.gan import *
from causal_gym import HighwaySingleStepPCH

In [2]:
pch = HighwaySingleStepPCH()
nodes, base, conf = pch.env.get_graph
G = parse_graph(nodes, base, conf)

In [3]:
print(G.de)
print(G.be)
print(G.v)

[('X', 'Y'), ('Z', 'X'), ('Z', 'Y'), ('L', 'X'), ('L', 'W')]
[('W', 'Y')]
['L', 'W', 'Z', 'X', 'Y']


In [4]:
G_x_bar = G.subgraph(G.set_v, set(), {'X'})
d_separated(G_x_bar, {'X'}, {'Y'}, {'Z'})

True

In [5]:
G_bar_x = G.subgraph(G.set_v, {'X'}, set())
identify({'X'}, {'Y'}, G_bar_x).__str__()

'sum{Z}[[P(X,Z,W,L) / P(X,W,L)]sum{W}[P(L,W)P(Y,Z,X,W,L) / P(L)P(Z,X,L,W)]]'

In [6]:
observed, _ = pch.env.observed_unobserved_vars
Pa_pi = conditioning_set(G, 'X', 'Y', observed)
Pa_pi

{'W', 'Z'}

In [7]:
find_pi_backdoor(G, 'X', 'Y', Pa_pi)

{'Z'}

In [8]:
records = collect_expert_trajectories(pch, num_episodes=128)
records[0]

{'episode': 0,
 'step': 0,
 'obs': {'x': 29.1455588268693, 'z': 30.0, 'w': 0},
 'action': 2,
 'reward': True,
 'terminated': True,
 'truncated': True,
 'info': {'u': 1, 'l': 0, 'y': True}}

In [9]:
dataset_ci = ExpertDataset(records, cond_vars=['z'], action_var='action')
loader_ci = torch.utils.data.DataLoader(dataset_ci, batch_size=64, shuffle=True)

dataset_bc = ExpertDataset(records, cond_vars=['z', 'w'], action_var='action')
loader_bc = torch.utils.data.DataLoader(dataset_bc, batch_size=64, shuffle=True)

generator_ci = ConditionalGenerator(cond_dim=1, num_actions=pch.action_space.n, hidden_dim=32)
discriminator_ci = Discriminator(cond_dim=1, num_actions=pch.action_space.n, hidden_dim=32)
generator_ci = train_gan(generator_ci, discriminator_ci, loader_ci, epochs=60)
policy_ci = GANPolicy(generator_ci)

generator_bc = ConditionalGenerator(cond_dim=2, num_actions=pch.action_space.n, hidden_dim=32)
discriminator_bc = Discriminator(cond_dim=2, num_actions=pch.action_space.n, hidden_dim=32)
generator_bc = train_gan(generator_bc, discriminator_bc, loader_bc, epochs=60)
policy_bc = GANPolicy(generator_bc)

policy_ci_fn = lambda obs: policy_ci([obs['z']])
policy_bc_fn = lambda obs: policy_bc([obs['z'], obs['w']])

expert_rewards = [r['reward'] for r in records]
expert_dist = compute_distribution(expert_rewards, bins=[-0.5, 0.5, 1.5])

ci_rewards = rollout_policy(pch, policy_ci_fn, num_episodes=128)
ci_dist = compute_distribution(ci_rewards, bins=[-0.5, 0.5, 1.5])

bc_rewards = rollout_policy(pch, policy_bc_fn, num_episodes=128)
bc_dist = compute_distribution(bc_rewards, bins=[-0.5, 0.5, 1.5])

print('Causal GAN L1:\t', l1_distance(expert_dist, ci_dist))
print('Naive GAN l1:\t', l1_distance(expert_dist, bc_dist))

Causal GAN L1:	 0.0
Naive GAN l1:	 0.09375
