In [1]:
import random
import pickle
import torch

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import numpy as np

from collections import defaultdict

from causal_gym import AntMazePCH
from causal_rl.algo.imitation.imitate import *

  from pkg_resources import resource_stream, resource_exists


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
num_steps = 100
seed = 0
hidden_dims = {}

random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7989c8194430>

In [4]:
env = AntMazePCH(num_steps=num_steps, hidden_dims=hidden_dims, seed=seed)
train_eps = env.expert.num_eps
train_eps

1101

In [5]:
G = parse_graph(env.get_graph)
X = {f'X{t}' for t in range(num_steps)}
Y = f'Y{num_steps}'
obs_prefix = env.env.observed_unobserved_vars[0]

In [6]:
Z_sets = find_sequential_pi_backdoor(G, X, Y, obs_prefix)
# Z_sets

In [7]:
records = collect_expert_trajectories(
    env,
    num_episodes=train_eps,
    max_steps=num_steps,
    behavioral_policy=None,
    seed=seed
)

In [8]:
naive_Z_sets = {}
for Xi in X:
    i = int(Xi[1:])
    cond = set()

    for j in range(i+1):
        cond.update({f'{o}{j}' for o in list(set(obs_prefix) - {'X'})})

    for j in range(i):
        cond.add(f'X{j}')
    naive_Z_sets[Xi] = cond

# naive_Z_sets

In [9]:
# trim Z-sets to only include lookback of 5 steps
def trim_Z_sets(Z_sets, lookback=5):
    trimmed_Z_sets = {}
    for Xi, cond_vars in Z_sets.items():
        i = int(Xi[1:])
        
        if i < lookback:
            trimmed_Z_sets[Xi] = cond_vars.copy()
            continue

        min_t = i - lookback
        keep_vars = set()

        for var in cond_vars:
            step = int(var[1:])
            if step >= min_t:
                keep_vars.add(var)
        trimmed_Z_sets[Xi] = keep_vars
    return trimmed_Z_sets

Z_sets = trim_Z_sets(Z_sets, lookback=5)
naive_Z_sets = trim_Z_sets(naive_Z_sets, lookback=5)

In [None]:
hidden_size = 256
causal_policies = train_policies(env, records, Z_sets, lr=3e-4, batch_size=512, seed=seed, patience=20, continuous=True, hidden_dim=hidden_size, device=device)
naive_policies = train_policies(env, records, naive_Z_sets, lr=3e-4, batch_size=512, seed=seed, patience=20, continuous=True, hidden_dim=hidden_size, device=device)

In [None]:
episode_rewards = defaultdict(float)
for rec in records:
    ep = rec['episode']
    episode_rewards[ep] = rec['info']['Y'][-1]

num_eps = len(episode_rewards)

expert_rewards = [episode_rewards[e] for e in range(num_eps)]

causal_returns = eval_policy(env, causal_policies, num_episodes=num_eps, seed=seed, device=device)
naive_returns = eval_policy(env, naive_policies, num_episodes=num_eps, seed=seed, device=device)

causal_rewards = [ep['Y'][-1] for ep in causal_returns]
naive_rewards = [ep['Y'][-1] for ep in naive_returns]

plt.figure(figsize=(8,5))
plt.plot(expert_rewards, label='Expert (behavioral)')
plt.plot(causal_rewards, label='Causal imitation')
plt.plot(naive_rewards, label='Naive behavior control')
plt.xlabel('Episode')
plt.ylabel('Final Cumulative Reward')
plt.title('Comparison of Expert vs. Causal vs. Naive Returns')
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(8,5))
bins = 20  # number of histogram bins

expert_mean = np.mean(expert_rewards)
causal_mean = sum(causal_rewards) / len(causal_rewards)
naive_mean = sum(naive_rewards) / len(naive_rewards)

# plot histograms for causal and naive
plt.hist(causal_rewards, bins=bins, alpha=0.6, label=f'Causal Imitation ({causal_mean:.1f})')
plt.hist(naive_rewards,  bins=bins, alpha=0.6, label=f'Naive Baseline ({naive_mean:.1f})')

# compute and plot expert mean as a vertical line
plt.axvline(expert_mean, color='black', linestyle='--', linewidth=2, label=f'Expert Mean ({expert_mean:.1f})')

plt.xlabel('Final Cumulative Reward')
plt.ylabel('Number of Episodes')
plt.title('Episode Return Distributions')
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
def policy_accuracy(records, policies):
    total, correct = 0, 0
    per_step = defaultdict(lambda: {'corr':0, 'total':0})

    for r in records:
        t = r['step']
        key = f'X{t}'
        if key not in policies:
            # no policy for this step—skip or count as incorrect
            continue
        pi_t = policies[key]
        pred = pi_t(r['obs'])
        true = r['action']

        per_step[t]['total']  += 1
        per_step[t]['corr']   += int(pred == true)
        total   += 1
        correct += int(pred == true)

    overall_acc = correct / total if total else float('nan')
    print(f"Overall accuracy: {overall_acc*100:.2f}% ({correct}/{total})")

    print("Per-step accuracy:")
    for t in sorted(per_step):
        ts = per_step[t]
        acc = ts['corr']/ts['total']
        print(f"  step {t}: {acc*100:.2f}% ({ts['corr']}/{ts['total']})")

    print()
    return overall_acc, per_step

# Example usage:
ci_acc, ci_step_acc = policy_accuracy(records, causal_policies)
bc_acc, bc_step_acc = policy_accuracy(records, naive_policies)

In [None]:
# Prepare label and display names
n_actions = env.env.action_space.n
labels = list(range(n_actions))
action_names = list(env.env.unwrapped.action_type.actions.keys())

# Gather true actions once
y_true = []
for r in records:
    y_true.append(r['action'])

# Gather predictions for causal and naive
y_pred_causal = []
y_pred_naive  = []
for r in records:
    t = r['step']
    key = f'X{t}'
    # causal
    if key in causal_policies:
        y_pred_causal.append(causal_policies[key](r['obs']))
    else:
        y_pred_causal.append(-1)  # or some placeholder
    # naive
    if key in naive_policies:
        y_pred_naive.append(naive_policies[key](r['obs']))
    else:
        y_pred_naive.append(-1)

# Compute confusion matrices (we’ll ignore the placeholder label -1)
cm_causal = confusion_matrix(y_true, y_pred_causal, labels=labels)
cm_naive  = confusion_matrix(y_true, y_pred_naive,  labels=labels)

# Plot side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

action_names = [env.env._meta_actions.ACTIONS_ALL[a] for a in action_names]

disp1 = ConfusionMatrixDisplay(cm_causal, display_labels=action_names)
disp1.plot(ax=ax1, cmap='Blues', xticks_rotation='vertical')
ax1.set_title("Causal Imitator")

disp2 = ConfusionMatrixDisplay(cm_naive, display_labels=action_names)
disp2.plot(ax=ax2, cmap='Greens', xticks_rotation='vertical')
ax2.set_title("Naive Baseline")

plt.tight_layout()
plt.show()

In [None]:
causal_records = collect_imitator_trajectories(env, causal_policies, train_eps, num_steps, seed)
naive_records = collect_imitator_trajectories(env, naive_policies, train_eps, num_steps, seed)

# Prepare label and display names
n_actions = env.env.action_space.n
labels = list(range(n_actions))
action_names = list(env.env.unwrapped.action_type.actions.keys())

# Gather true actions once
y_true = []
for r in records:
    y_true.append(r['action'])

# Gather predictions for causal and naive
y_pred_causal = []
y_pred_naive  = []
for r in records:
    t = r['step']
    key = f'X{t}'
    # causal
    if key in causal_policies:
        y_pred_causal.append(causal_policies[key](r['obs']))
    else:
        y_pred_causal.append(-1)  # or some placeholder
    # naive
    if key in naive_policies:
        y_pred_naive.append(naive_policies[key](r['obs']))
    else:
        y_pred_naive.append(-1)

# Compute confusion matrices (we’ll ignore the placeholder label -1)
cm_causal = confusion_matrix(y_true, y_pred_causal, labels=labels)
cm_naive  = confusion_matrix(y_true, y_pred_naive,  labels=labels)

# Plot side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

action_names = [env.env._meta_actions.ACTIONS_ALL[a] for a in action_names]

disp1 = ConfusionMatrixDisplay(cm_causal, display_labels=action_names)
disp1.plot(ax=ax1, cmap='Blues', xticks_rotation='vertical')
ax1.set_title("Causal Imitator")

disp2 = ConfusionMatrixDisplay(cm_naive, display_labels=action_names)
disp2.plot(ax=ax2, cmap='Greens', xticks_rotation='vertical')
ax2.set_title("Naive Baseline")

plt.tight_layout()
plt.show()

In [None]:
from collections import defaultdict

def avg_steps(records):
    steps = defaultdict(int)
    for r in records:
        ep = r['episode']
        steps[ep] += 1

    if not steps:
        return 0.0
    
    return sum(steps.values()) / len(steps)

print(avg_steps(records))
print(avg_steps(causal_records))
print(avg_steps(naive_records))

In [None]:
from collections import Counter

def filtered(records, state: List[Tuple[str, int]]):
    filtered_records = []
    for r in records:
        for var, val in state:
            index = -1 if var in ('X', 'Y') else -2 # X and Y have one less entry in the info
            loc = 'obs' if var in env.env.observed_unobserved_vars[0] else 'info'

            if -index <= len(r[loc][var]) and r[loc][var][index] == val:
                filtered_records.append(r)

    return filtered_records

# Prepare data for plotting
labels = list(range(env.env.action_space.n))
action_names = list(env.env.unwrapped.action_type.actions.keys())
action_names = [env.env._meta_actions.ACTIONS_ALL[a] for a in action_names]

state1 = [('W', 1)]
state0 = [('W', 0)]

expert_filtered1 = filtered(records, state1)
expert_filtered0 = filtered(records, state0)
causal_filtered1 = filtered(causal_records, state1)
causal_filtered0 = filtered(causal_records, state0)
naive_filtered1 = filtered(naive_records, state1)
naive_filtered0 = filtered(naive_records, state0)

# Count actions for each filtered result
expert_counts1 = Counter(r['action'] for r in expert_filtered1)
expert_counts0 = Counter(r['action'] for r in expert_filtered0)
causal_counts1 = Counter(r['action'] for r in causal_filtered1)
causal_counts0 = Counter(r['action'] for r in causal_filtered0)
naive_counts1 = Counter(r['action'] for r in naive_filtered1)
naive_counts0 = Counter(r['action'] for r in naive_filtered0)

# Normalize counts
total_expert1 = sum(expert_counts1.values())
total_expert0 = sum(expert_counts0.values())
total_causal1 = sum(causal_counts1.values())
total_causal0 = sum(causal_counts0.values())
total_naive1 = sum(naive_counts1.values())
total_naive0 = sum(naive_counts0.values())

expert_counts1 = {k: v / total_expert1 for k, v in expert_counts1.items()}
expert_counts0 = {k: v / total_expert0 for k, v in expert_counts0.items()}
causal_counts1 = {k: v / total_causal1 for k, v in causal_counts1.items()}
causal_counts0 = {k: v / total_causal0 for k, v in causal_counts0.items()}
naive_counts1 = {k: v / total_naive1 for k, v in naive_counts1.items()}
naive_counts0 = {k: v / total_naive0 for k, v in naive_counts0.items()}

print(f'Expert Filtered ({state1}):', expert_counts1)
print(f'Expert Filtered ({state0}):', expert_counts0)
print(f'Causal Filtered ({state1}):', causal_counts1)
print(f'Causal Filtered ({state0}):', causal_counts0)
print(f'Naive Filtered ({state1}):', naive_counts1)
print(f'Naive Filtered ({state0}):', naive_counts0)

# Create subplots
fig, axes = plt.subplots(3, 2, figsize=(8, 8), sharey=True)

# Plot each distribution
axes[0, 0].bar(action_names, [expert_counts1.get(a, 0) for a in labels])
axes[0, 0].set_title(f'Expert Filtered ({state1}):')
axes[0, 1].bar(action_names, [expert_counts0.get(a, 0) for a in labels])
axes[0, 1].set_title(f'Expert Filtered ({state0}):')

axes[1, 0].bar(action_names, [causal_counts1.get(a, 0) for a in labels])
axes[1, 0].set_title(f'Causal Filtered ({state1}):')
axes[1, 1].bar(action_names, [causal_counts0.get(a, 0) for a in labels])
axes[1, 1].set_title(f'Causal Filtered ({state0}):')

axes[2, 0].bar(action_names, [naive_counts1.get(a, 0) for a in labels])
axes[2, 0].set_title(f'Naive Filtered ({state1}):')
axes[2, 1].bar(action_names, [naive_counts0.get(a, 0) for a in labels])
axes[2, 1].set_title(f'Naive Filtered ({state0}):')

# Adjust layout
for ax in axes.flat:
    ax.set_xlabel("Actions")
    ax.set_ylabel("Frequency")
    ax.set_xticks(action_names)
    ax.set_xticklabels(action_names, rotation=45, ha="right")

plt.tight_layout()
plt.show()

In [None]:
def calculate_variation(counts1, counts0, labels):
    variation = {}
    for label in labels:
        freq1 = counts1.get(label, 0)
        freq0 = counts0.get(label, 0)
        variation[label] = abs(freq1 - freq0)
    return variation

# Compute variations
expert_variation = calculate_variation(expert_counts1, expert_counts0, labels)
causal_variation = calculate_variation(causal_counts1, causal_counts0, labels)
naive_variation = calculate_variation(naive_counts1, naive_counts0, labels)

# Print variations
print("Expert Variation:", expert_variation)
print("Causal Variation:", causal_variation)
print("Naive Variation:", naive_variation)

# Plot variations
plt.figure(figsize=(10, 6))
x_labels = [action_names[label] for label in labels]

bar_width = 0.25  # width of each bar
x_indices = np.arange(len(x_labels))  # positions for the groups

plt.bar(x_indices - bar_width, [expert_variation[label] for label in labels], bar_width, label='Expert Variation')
plt.bar(x_indices, [causal_variation[label] for label in labels], bar_width, label='Causal Variation')
plt.bar(x_indices + bar_width, [naive_variation[label] for label in labels], bar_width, label='Naive Variation')

plt.xticks(x_indices, x_labels)

plt.xlabel('Actions')
plt.ylabel('Variation')
plt.title(f'Action Frequency Variation Across States ({state1} vs {state0})')
plt.xticks(rotation=45, ha="right")
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
causal_correct = sum(cm_causal[i, i] for i in range(len(labels)))
naive_correct  = sum(cm_naive[i, i] for i in range(len(labels)))

print(f"Causal Correct Predictions: {causal_correct}")
print(f"Naive Correct Predictions: {naive_correct}")

causal_incorrect = sum(cm_causal[i, j] for i in range(len(labels)) for j in range(len(labels)) if i != j)
naive_incorrect  = sum(cm_naive[i, j] for i in range(len(labels)) for j in range(len(labels)) if i != j)

print(f"Causal Incorrect Predictions: {causal_incorrect}")
print(f"Naive Incorrect Predictions: {naive_incorrect}")

causal_correct / naive_correct, causal_incorrect / naive_incorrect

In [None]:
t = 100
testenv = HighwayPCH(num_steps=t, render_mode='human')

In [None]:
obs, _ = testenv.reset()

for step in range(t):
    action = causal_policies[f'X{step}'](obs)
    # action = naive_policies[f'X{step}'](obs)
    obs, reward, terminated, truncated, info = testenv.do(action, show_reward=True)
    # action, obs, reward, terminated, truncated, info = testenv.see(show_reward=True)
    testenv.render()

    if terminated or truncated:
        testenv.env._env.close()
        break