# Experiment 9.11 - Causal GAIL

In [None]:
from gymnasium import spaces
from causal_gym import Graph, SCM, PCH
from imitation import *
from imitation.gym_gail.core_net import DiscreteActor, Critic, Discriminator
from imitation.gym_gail.causal_gail import *

In [None]:
# define causal graph
nodes = [{'name': n} for n in ['Z0', 'X0', 'Z1', 'X1', 'Y']]
edges = [
    {'from_': 'Z0', 'to_': 'Z1', 'type_': 'bidirected'},
    {'from_': 'Z0', 'to_': 'Y', 'type_': 'bidirected'},
    {'from_': 'Z1', 'to_': 'X1', 'type_': 'bidirected'},
    {'from_': 'Z1', 'to_': 'X1', 'type_': 'directed'},
    {'from_': 'X0', 'to_': 'Y', 'type_': 'directed'},
    {'from_': 'X1', 'to_': 'Y', 'type_': 'directed'},
    {'from_': 'Z1', 'to_': 'Y', 'type_': 'directed'},
]
G = Graph(nodes=nodes, edges=edges)

In [None]:
class E911SCM(SCM):
    def __init__(self, graph, seed=None):
        super().__init__()
        self.rng = np.random.default_rng(seed)

        self.graph = graph

        self._U = [] # U1, U2, U3, U4
        self.Z = []
        self.X = []
        self._Y = [] # 5-element vector w/ xors

        self.action_space = spaces.Discrete(2) # binary actions at each step
        self.observation_space = spaces.Dict({
            'Z': spaces.Sequence(spaces.Discrete(2)),
            'X': spaces.Sequence(spaces.Discrete(2))
        })

    def _sample_confounders(self):
        U1 = self.rng.choice([0, 1], p=[0.2, 0.8])
        U2 = self.rng.choice([0, 1], p=[0.2, 0.8])
        U3 = self.rng.choice([0, 1], p=[0.8, 0.2])
        U4 = self.rng.choice([0, 1], p=[0.9, 0.1])
        return [U1, U2, U3, U4]

    def _sample_z(self):
        return [self._U[0] ^ self._U[2], self._U[0] ^ self._U[1] ^ self._U[3]]

    def _obs(self):
        return {'Z': self.Z, 'X': self.X}

    def reset(self, seed=None, options=None):
        if seed is not None:
            self.rng = np.random.default_rng(seed)

        self._U = self._sample_confounders()
        self.X = []
        self.Z = self._sample_z()
        self._t = 0

        return self._obs(), {'Y': []}

    def action(self):
        if self._t == 0:
            return self.rng.choice([0, 1], p=[0.32, 0.68])
        else:
            return self._U[1] ^ self.Z[1]

    def _reward(self):
        return self.X[0] ^ self.X[1] ^ self.Z[0] ^ self.Z[1] ^ self._U[3]

    def step(self, action):
        if self._t == 0:
            X0 = action
            self.X.append(X0)
            self._t = 1
            return self._obs(), 0.0, False, False, {'Y': []}

        X1 = action
        self.X.append(X1)

        y = self._reward()
        self._t = 2
        return self._obs(), float(y), True, False, {'Y': [y]}

    @property
    def get_graph(self):
        return self.graph