# Experiment 9.10 - Causal Behavioral Cloning

In [1]:
import numpy as np
import pandas as pd
from gymnasium import spaces

from causal_gym import Graph, SCM, PCH
from causal_rl.algo.imitation.imitate import *

  from pkg_resources import resource_stream, resource_exists


In [2]:
seed = 0

### Graph Definitions

In [3]:
# from Table 9.13
nodes = [{'name': n} for n in ['Z0', 'X0', 'X1', 'Y']]

g1_edges = [
    {'from_': 'Z0', 'to_': 'X0', 'type_': 'bidirected'},
    {'from_': 'Z0', 'to_': 'X1', 'type_': 'bidirected'},
    {'from_': 'Z0', 'to_': 'Y', 'type_': 'bidirected'},
    {'from_': 'X0', 'to_': 'X1', 'type_': 'directed'},
    {'from_': 'X1', 'to_': 'Y', 'type_': 'directed'}
]
g1_ordering = ['Z0', 'X0', 'X1', 'Y']
G1 = Graph(nodes=nodes, edges=g1_edges)

g2_edges = [
    {'from_': 'Z0', 'to_': 'X0', 'type_': 'bidirected'},
    {'from_': 'Z0', 'to_': 'Y', 'type_': 'directed'},
    {'from_': 'X0', 'to_': 'X1', 'type_': 'directed'},
    {'from_': 'X1', 'to_': 'Y', 'type_': 'directed'}
]
g2_ordering = ['Z0', 'X0', 'X1', 'Y']
G2 = Graph(nodes=nodes, edges=g2_edges)

g3_edges = [
    {'from_': 'Z0', 'to_': 'X0', 'type_': 'bidirected'},
    {'from_': 'Z0', 'to_': 'Y', 'type_': 'directed'},
    {'from_': 'X0', 'to_': 'X1', 'type_': 'directed'},
    {'from_': 'X1', 'to_': 'Y', 'type_': 'directed'}
]
g3_ordering = ['X0', 'Z0', 'X1', 'Y']
G3 = Graph(nodes=nodes, edges=g3_edges)

g4_edges = [
    {'from_': 'Z0', 'to_': 'X0', 'type_': 'bidirected'},
    {'from_': 'Z0', 'to_': 'Y', 'type_': 'bidirected'},
    {'from_': 'Z0', 'to_': 'X1', 'type_': 'directed'},
    {'from_': 'X0', 'to_': 'Y', 'type_': 'directed'},
    {'from_': 'X1', 'to_': 'Y', 'type_': 'directed'}
]
g4_ordering = ['X0', 'Z0', 'X1', 'Y']
G4 = Graph(nodes=nodes, edges=g4_edges)

### SCM/PCH Definitions

In [4]:
class G1SCM(SCM):
    def __init__(self, graph, u_dists, seed=None):
        super().__init__()
        self.rng = np.random.default_rng(seed)

        self.graph = graph
        self.u_dists = u_dists

        self._Uzx0_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))
        self._Uzx1_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))
        self._Uzy_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))
        self.Z_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))
        self.X0_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))
        self.X1_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))
        self.Y_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))

        self.Z_op_1 = self.rng.choice(['and', 'or', 'xor'])
        self.Z_op_2 = self.rng.choice(['and', 'or', 'xor'])
        self.X1_op = self.rng.choice(['and', 'or', 'xor'])
        self.Y_op = self.rng.choice(['and', 'or', 'xor'])

        self._Uzx0 = None
        self._Uzx1 = None
        self._Uzy = None
        self.Z = []
        self.X = []
        self._Y = None

        self.action_space = spaces.Discrete(2) # binary actions at each step
        self.observation_space = spaces.Dict({
            'Z': spaces.Sequence(spaces.Discrete(2)),
            'X': spaces.Sequence(spaces.Discrete(2))
        })

    def _sample_confounders(self):
        Uzx0 = self.rng.choice([0, 1], p=[1 - self.u_dists[0], self.u_dists[0]]) if not self._Uzx0_sign else 1 - self.rng.choice([0, 1], p=[1 - self.u_dists[0], self.u_dists[0]])
        Uzx1 = self.rng.choice([0, 1], p=[1 - self.u_dists[1], self.u_dists[1]]) if not self._Uzx1_sign else 1 - self.rng.choice([0, 1], p=[1 - self.u_dists[1], self.u_dists[1]])
        Uzy = self.rng.choice([0, 1], p=[1 - self.u_dists[2], self.u_dists[2]]) if not self._Uzy_sign else 1 - self.rng.choice([0, 1], p=[1 - self.u_dists[2], self.u_dists[2]])
        return Uzx0, Uzx1, Uzy

    def _sample_z(self):
        z = self._Uzx0

        if self.Z_op_1 == 'and':
            z &= self._Uzx1
        elif self.Z_op_1 == 'or':
            z |= self._Uzx1
        else: # xor
            z ^= self._Uzx1

        if self.Z_op_2 == 'and':
            z &= self._Uzy
        elif self.Z_op_2 == 'or':
            z |= self._Uzy
        else: # xor
            z ^= self._Uzy

        return z if self.Z_sign else 1 - z

    def _obs(self):
        return {'Z': self.Z, 'X': self.X}

    def reset(self, seed=None, options=None):
        if seed is not None:
            self.rng = np.random.default_rng(seed)

        self._Uzx0, self._Uzx1, self._Uzy = self._sample_confounders()
        self.X = []
        self.Z = [self._sample_z()]
        self._t = 0

        return self._obs(), {'Y': []}

    def action(self):
        if self._t == 0:
            return self._Uzx0 if self.X0_sign else 1 - self._Uzx0
        else:
            x1 = self.X[0]
            if self.X1_op == 'and':
                x1 &= self._Uzx1
            elif self.X1_op == 'or':
                x1 |= self._Uzx1
            else: # xor
                x1 ^= self._Uzx1

            return x1 if self.X1_sign else 1 - x1

    def _reward(self):
        y = self.X[1]
        if self.Y_op == 'and':
            y &= self._Uzy
        elif self.Y_op == 'or':
            y |= self._Uzy
        else: # xor
            y ^= self._Uzy

        return y if self.Y_sign else 1 - y

    def step(self, action):
        if self._t == 0:
            X0 = action
            self.X.append(X0)
            self._t = 1
            return self._obs(), 0.0, False, False, {'Y': []}

        X1 = action
        self.X.append(X1)

        y = self._reward()
        self._t = 2
        return self._obs(), float(y), True, False, {'Y': [y]}

    @property
    def get_graph(self):
        return self.graph

In [5]:
class G2SCM(SCM):
    def __init__(self, graph, u_dist, seed=None):
        super().__init__()
        self.rng = np.random.default_rng(seed)

        self.graph = graph
        self.u_dist = u_dist

        self._U = None
        self.Z = []
        self.X = []
        self._Y = None

        self._U_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))
        self.Z_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))
        self.X0_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))
        self.X1_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))
        self.Y_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))

        self.Y_op = self.rng.choice(['and', 'or', 'xor'])

        self.action_space = spaces.Discrete(2) # binary actions at each step
        self.observation_space = spaces.Dict({
            'Z': spaces.Sequence(spaces.Discrete(2)),
            'X': spaces.Sequence(spaces.Discrete(2))
        })

    def _sample_confounders(self):
        return self.rng.choice([0, 1], p=[1 - self.u_dist, self.u_dist]) if not self._U_sign else 1 - self.rng.choice([0, 1], p=[1 - self.u_dist, self.u_dist])
    
    def _sample_z(self):
        return self._U if not self.Z_sign else 1 - self._U

    def _obs(self):
        return {'Z': self.Z, 'X': self.X}

    def reset(self, seed=None, options=None):
        if seed is not None:
            self.rng = np.random.default_rng(seed)

        self._U = self._sample_confounders()
        self.X = []
        self.Z = [self._sample_z()]
        self._t = 0

        return self._obs(), {'Y': []}

    def action(self):
        if self._t == 0:
            return self._U if self.X0_sign else 1 - self._U
        else:
            return self.X[0] if self.X1_sign else 1 - self.X[0]
        
    def _reward(self):
        if not self.Y_sign:
            if self.Y_op == 'and':
                return 1 - (self.X[1] & self._U)
            elif self.Y_op == 'or':
                return 1 - (self.X[1] | self._U)
            else: # xor
                return 1 - (self.X[1] ^ self._U)
        else:
            if self.Y_op == 'and':
                return self.X[1] & self._U
            elif self.Y_op == 'or':
                return self.X[1] | self._U
            else: # xor
                return self.X[1] ^ self._U

    def step(self, action):
        if self._t == 0:
            X0 = action
            self.X.append(X0)
            self._t = 1
            return self._obs(), 0.0, False, False, {'Y': []}

        X1 = action
        self.X.append(X1)

        y = self._reward()
        self._t = 2
        return self._obs(), float(y), True, False, {'Y': [y]}

    @property
    def get_graph(self):
        return self.graph

In [6]:
class G3SCM(SCM):
    def __init__(self, graph, u_dist, seed=None):
        super().__init__()
        self.rng = np.random.default_rng(seed)

        self.graph = graph
        self.u_dist = u_dist

        self._U = None
        self.Z = []
        self.X = []
        self._Y = None

        self._U_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))
        self.Z_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))
        self.X0_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))
        self.X1_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))
        self.Y_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))

        self.Y_op = self.rng.choice(['and', 'or', 'xor'])

        self.action_space = spaces.Discrete(2) # binary actions at each step
        self.observation_space = spaces.Dict({
            'Z': spaces.Sequence(spaces.Discrete(2)),
            'X': spaces.Sequence(spaces.Discrete(2))
        })

    def _sample_confounders(self):
        return self.rng.choice([0, 1], p=[1 - self.u_dist, self.u_dist]) if not self._U_sign else 1 - self.rng.choice([0, 1], p=[1 - self.u_dist, self.u_dist])
    
    def _sample_z(self):
        return self._U if self.Z_sign else 1 - self._U

    def _obs(self):
        return {'Z': self.Z, 'X': self.X}

    def reset(self, seed=None, options=None):
        if seed is not None:
            self.rng = np.random.default_rng(seed)

        self._U = self._sample_confounders()
        self.X = []
        self.Z = []
        self._t = 0

        return self._obs(), {'Y': []}

    def action(self):
        if self._t == 0:
            return self._U if self.X0_sign else 1 - self._U
        else:
            return self.X[0] if self.X1_sign else 1 - self.X[0]
        
    def _reward(self):
        if not self.Y_sign:
            if self.Y_op == 'and':
                return 1 - (self.X[1] & self._U)
            elif self.Y_op == 'or':
                return 1 - (self.X[1] | self._U)
            else: # xor
                return 1 - (self.X[1] ^ self._U)
        else:
            if self.Y_op == 'and':
                return self.X[1] & self._U
            elif self.Y_op == 'or':
                return self.X[1] | self._U
            else: # xor
                return self.X[1] ^ self._U

    def step(self, action):
        if self._t == 0:
            X0 = action
            self.X.append(X0)
            self.Z = [self._sample_z()]
            self._t = 1
            return self._obs(), 0.0, False, False, {'Y': []}

        X1 = action
        self.X.append(X1)

        y = self._reward()
        self._t = 2
        return self._obs(), float(y), True, False, {'Y': [y]}

    @property
    def get_graph(self):
        return self.graph

In [7]:
class G4SCM(SCM):
    def __init__(self, graph, u_dists, seed=None):
        super().__init__()
        self.rng = np.random.default_rng(seed)

        self.graph = graph
        self.u_dists = u_dists

        self._Uzx0_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))
        self._Uzy_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))
        self.Z_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))
        self.X0_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))
        self.X1_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))
        self.Y_sign = bool(self.rng.choice([0, 1], p=[0.5, 0.5]))

        self.Z_op = self.rng.choice(['and', 'or', 'xor'])
        self.Y_op_1 = self.rng.choice(['and', 'or', 'xor'])
        self.Y_op_2 = self.rng.choice(['and', 'or', 'xor'])

        self._Uzx0 = None
        self._Uzy = None
        self.Z = []
        self.X = []
        self._Y = None

        self.action_space = spaces.Discrete(2) # binary actions at each step
        self.observation_space = spaces.Dict({
            'Z': spaces.Sequence(spaces.Discrete(2)),
            'X': spaces.Sequence(spaces.Discrete(2))
        })

    def _sample_confounders(self):
        Uzx0 = self.rng.choice([0, 1], p=[1 - self.u_dists[0], self.u_dists[0]]) if not self._Uzx0_sign else 1 - self.rng.choice([0, 1], p=[1 - self.u_dists[0], self.u_dists[0]])
        Uzy = self.rng.choice([0, 1], p=[1 - self.u_dists[1], self.u_dists[1]]) if not self._Uzy_sign else 1 - self.rng.choice([0, 1], p=[1 - self.u_dists[1], self.u_dists[1]])
        return Uzx0, Uzy
    
    def _sample_z(self):
        z = self._Uzx0

        if self.Z_op == 'and':
            z &= self._Uzy
        elif self.Z_op == 'or':
            z |= self._Uzy
        else: # xor
            z ^= self._Uzy

        return z if self.Z_sign else 1 - z

    def _obs(self):
        return {'Z': self.Z, 'X': self.X}

    def reset(self, seed=None, options=None):
        if seed is not None:
            self.rng = np.random.default_rng(seed)

        self._Uzx0, self._Uzy = self._sample_confounders()
        self.X = []
        self.Z = []
        self._t = 0

        return self._obs(), {'Y': []}

    def action(self):
        if self._t == 0:
            return self._Uzx0 if self.X0_sign else 1 - self._Uzx0
        else:
            return self.Z[0] if self.X1_sign else 1 - self.Z[0]
        
    def _reward(self):
        y = self.X[0]
        if self.Y_op_1 == 'and':
            y &= self.X[1]
        elif self.Y_op_1 == 'or':
            y |= self.X[1]
        else: # xor
            y ^= self.X[1]

        if self.Y_op_2 == 'and':
            y &= self._Uzy
        elif self.Y_op_2 == 'or':
            y |= self._Uzy
        else: # xor
            y ^= self._Uzy

        return y if self.Y_sign else 1 - y

    def step(self, action):
        if self._t == 0:
            X0 = action
            self.X.append(X0)
            self.Z = [self._sample_z()]
            self._t = 1
            return self._obs(), 0.0, False, False, {'Y': []}

        X1 = action
        self.X.append(X1)

        y = self._reward()
        self._t = 2
        return self._obs(), float(y), True, False, {'Y': [y]}

    @property
    def get_graph(self):
        return self.graph

In [8]:
class GXPCH(PCH):
    def __init__(self, x, graph, u_dists, seed=None):
        if x == 1:
            self.env = G1SCM(graph, u_dists, seed=seed)
        elif x == 2:
            self.env = G2SCM(graph, u_dists[0], seed=seed)
        elif x == 3:
            self.env = G3SCM(graph, u_dists[0], seed=seed)
        elif x == 4:
            self.env = G4SCM(graph, u_dists[:2], seed=seed)

        super().__init__()
        self.last_obs = None

    @property
    def get_graph(self):
        return self.env.get_graph

    def reset(self, seed=None, options=None):
        obs, info = self.env.reset(seed=seed, options=options)
        self.last_obs = obs
        return obs, info

    def see(self, behavioral_policy=None, show_reward=False):
        if self.last_obs is None:
            self.last_obs, _ = self.reset()

        if behavioral_policy is None:
            action = self.env.action()
        else:
            action = behavioral_policy(dict(self.last_obs))

        obs, reward, terminated, truncated, info = self.env.step(action)
        self.last_obs = obs
        info['natural_action'] = action
        return obs, reward, terminated, truncated, info

    def do(self, do_policy, show_reward=False):
        if self.last_obs is None:
            self.last_obs, _ = self.env.reset()

        action = do_policy(self.last_obs)
        obs, reward, terminated, truncated, info = self.env.step(action)
        self.last_obs = obs
        info['action'] = action
        return obs, reward, terminated, truncated, info

### BC Method Suite Setup

In [9]:
# Causal BC
# G1, G2, and G3 are imitable; G4 is not
g1_Z_sets = {'X0': set(), 'X1': {'X0'}} # X1 = set() can also work here
g2_Z_sets = {'X0': {'Z0'}, 'X1': {'Z0'}}
g3_Z_sets = {'X0': set(), 'X1': {'Z0'}}

In [10]:
# BC - Observed Parents
g1_obs_parents_sets = {'X0': set(), 'X1': {'X0'}}
g2_obs_parents_sets = {'X0': set(), 'X1': {'X0'}}
g3_obs_parents_sets = {'X0': set(), 'X1': {'X0'}}
g4_obs_parents_sets = {'X0': set(), 'X1': {'Z0'}}

In [11]:
# BC - All Observed
g1_all_obs_sets = {'X0': {'Z0'}, 'X1': {'Z0', 'X0'}}
g2_all_obs_sets = {'X0': {'Z0'}, 'X1': {'Z0', 'X0'}}
g3_all_obs_sets = {'X0': set(), 'X1': {'Z0', 'X0'}}
g4_all_obs_sets = {'X0': set(), 'X1': {'Z0', 'X0'}}

### SCM Generation and BC Execution

In [12]:
def evaluate(graph_number, graph, Z_sets, num_scms=100, num_trajs=100, seed=None):
    gaps = []

    for k in range(num_scms):
        s = seed + k if seed is not None else None
        rng = np.random.default_rng(s)

        u_dists = [rng.uniform(0.0, 1.0) for _ in range(3)]
        env = GXPCH(graph_number, graph, u_dists, seed=s)

        # measure expert performance
        records = collect_expert_trajectories(env, num_episodes=num_trajs, max_steps=2, seed=s, show_progress=False)
        expert_EY = np.mean([r['reward'] for r in records if r['terminated']])

        policy = train_policies(env, records, Z_sets, max_epochs=100, seed=s)

        # measure imitator performance
        rollout = eval_policy(env, policy, num_episodes=num_trajs, seed=s)
        rewards = [ep['Y'][-1] for ep in rollout]
        imitator_EY = np.mean(rewards)

        gaps.append(abs(imitator_EY - expert_EY))

    print('finished', graph_number, Z_sets)
    return np.mean(gaps), np.std(gaps)

In [13]:
causal = {
    'G1': evaluate(1, G1, g1_Z_sets),
    'G2': evaluate(2, G2, g2_Z_sets),
    'G3': evaluate(3, G3, g3_Z_sets)
}

obs_parents = {
    'G1': evaluate(1, G1, g1_obs_parents_sets),
    'G2': evaluate(2, G2, g2_obs_parents_sets),
    'G3': evaluate(3, G3, g3_obs_parents_sets),
    'G4': evaluate(4, G4, g4_obs_parents_sets)
}

all_obs = {
    'G1': evaluate(1, G1, g1_all_obs_sets),
    'G2': evaluate(2, G2, g2_all_obs_sets),
    'G3': evaluate(3, G3, g3_all_obs_sets),
    'G4': evaluate(4, G4, g4_all_obs_sets)
}

finished 1 {'X0': set(), 'X1': {'X0'}}
finished 2 {'X0': {'Z0'}, 'X1': {'Z0'}}
finished 3 {'X0': set(), 'X1': {'Z0'}}
finished 1 {'X0': set(), 'X1': {'X0'}}
finished 2 {'X0': set(), 'X1': {'X0'}}
finished 3 {'X0': set(), 'X1': {'X0'}}
finished 4 {'X0': set(), 'X1': {'Z0'}}
finished 1 {'X0': {'Z0'}, 'X1': {'Z0', 'X0'}}
finished 2 {'X0': {'Z0'}, 'X1': {'Z0', 'X0'}}
finished 3 {'X0': set(), 'X1': {'Z0', 'X0'}}
finished 4 {'X0': set(), 'X1': {'Z0', 'X0'}}


### Results

In [14]:
def format(gap):
    if gap == 'Not Imitable':
        return gap
    
    mean, std = gap
    return f'{mean:.2f} ± {std:.2f}'

table_data = []
for graph in ['G1', 'G2', 'G3', 'G4']:
    row = {
        'Graph': graph,
        'Causal BC': format(causal.get(graph, 'Not Imitable')),
        'Obs Parents': format(obs_parents.get(graph, 'Not Imitable')),
        'All Observed': format(all_obs.get(graph, 'Not Imitable'))
    }
    table_data.append(row)

df = pd.DataFrame(table_data)
df = df[['Graph', 'Causal BC', 'Obs Parents', 'All Observed']]
df

Unnamed: 0,Graph,Causal BC,Obs Parents,All Observed
0,G1,0.09 ± 0.10,0.07 ± 0.07,0.11 ± 0.11
1,G2,0.02 ± 0.03,0.21 ± 0.14,0.01 ± 0.03
2,G3,0.01 ± 0.03,0.23 ± 0.15,0.12 ± 0.13
3,G4,Not Imitable,0.11 ± 0.11,0.10 ± 0.11
