# What's a causal model?

Neural networks are complex and interesting systems -- yet we have full control over their internals! Interpretability seeks to understand how manipulations of model internals affect their output. The theory of causality, which formalizes causal effects in causal graphs, is fundamental to the field of interpretability.

But what exactly is a causal model? In this tutorial, we walk through a simple causal model with the `CausalAbstraction` library. Along the way, we'll discuss
* The components of causal graphs, and
* Operations we can perform on causal graphs.

## Let's set up a causal model

Imagine we trained a neural network to add three numbers, A, B, and C, together. How might it solve this problem?

We can formalize our **hypothesis** as a **causal graph**. The code below sets up a causal graph for an algorithm that adds the three numbers left to right. The causal model first adds A and B together, and then adds C to their sum.

In [1]:
# we use the CausalAbstraction library to set up a causal model
from causalab.causal.causal_model import CausalModel
from causalab.causal.trace import Mechanism, input_var
from causalab.causal.causal_viz import (
    display_structure,
    display_forward_pass,
    display_interchange,
)

DIGITS = list(range(0, 10))

######################
# 1. values          #
######################
# what values can our variables take on?
# in this case, let's imagine modular arithmetic, so all
# of our variables (except the raw input/output strings)
# can only take on a value from 0 to 9.
values = {
    "A": DIGITS,
    "B": DIGITS,
    "C": DIGITS,
    "A+B": DIGITS,
    "C'": DIGITS,
    "Y": DIGITS,
    "raw_input": None,
    "raw_output": None,
}

######################
# 2. mechanisms      #
######################
# HOW do the variables affect each other?
# A, B, C : input variables (no parents)
# A+B : computes A + B mod 10
# C' : copies over the value of C
# Y : computes A+B + C' mod 10 (NOTE: we're using the intermediate variables!)
# raw_input : represents the inputs A, B, C as a list
# raw_output : represents Y as an int
mechanisms = {
    # Input variables (no parents)
    "A": input_var(DIGITS),
    "B": input_var(DIGITS),
    "C": input_var(DIGITS),
    # Intermediate variables
    "A+B": Mechanism(parents=["A", "B"], compute=lambda t: (t["A"] + t["B"]) % 10),
    "C'": Mechanism(parents=["C"], compute=lambda t: t["C"]),
    "Y": Mechanism(parents=["A+B", "C'"], compute=lambda t: (t["A+B"] + t["C'"]) % 10),
    # Raw input/output
    "raw_input": Mechanism(
        parents=["A", "B", "C"], compute=lambda t: [t["A"], t["B"], t["C"]]
    ),
    "raw_output": Mechanism(parents=["Y"], compute=lambda t: t["Y"]),
}

# put it all together!
causal_model = CausalModel(mechanisms, values, id="Hierarchical addition")

Let's visualize our causal graph. The arrows show which variables have a direct causal effect on each other. If we can take a path from variable $\alpha$ to variable $\beta$, then we can say that $\alpha$ _has a causal effect on_ $\beta$. That is, changing the value of $\alpha$ may change the value of $\beta$.

In [2]:
display_structure(causal_model)

We can pass inputs into our causal model and **run it "forward"** to simulate the algorithm's final output - as well as its intermediate values.

In our visualization, we represent inputs as dark teal nodes, and the intermediate/output values as light teal nodes.

In [3]:
# play around with different inputs!
inputs = {"A": 1, "B": 2, "C": 3}

display_forward_pass(causal_model, inputs)

We can represent any causal process in a causal graph - including neural networks!

In [4]:
import torch

l1 = torch.nn.Linear(3, 3)
l2 = torch.nn.Linear(3, 3)
l3 = torch.nn.Linear(3, 1)

variables = [
    "A",
    "B",
    "C",
    *[f"h1{i + 1}" for i in range(3)],
    *[f"h2{i + 1}" for i in range(3)],
    "h31",
    "raw_input",
    "raw_output",
]

values = {var: None for var in variables}

# Define mechanisms using the new Mechanism API
mechanisms = {
    # Input variables
    "A": input_var(DIGITS),
    "B": input_var(DIGITS),
    "C": input_var(DIGITS),
    # First hidden layer (depends on A, B, C)
    "h11": Mechanism(
        parents=["A", "B", "C"],
        compute=lambda t: l1(torch.FloatTensor([t["A"], t["B"], t["C"]]))
        .round()[0]
        .item(),
    ),
    "h12": Mechanism(
        parents=["A", "B", "C"],
        compute=lambda t: l1(torch.FloatTensor([t["A"], t["B"], t["C"]]))
        .round()[1]
        .item(),
    ),
    "h13": Mechanism(
        parents=["A", "B", "C"],
        compute=lambda t: l1(torch.FloatTensor([t["A"], t["B"], t["C"]]))
        .round()[2]
        .item(),
    ),
    # Second hidden layer (depends on h11, h12, h13)
    "h21": Mechanism(
        parents=["h11", "h12", "h13"],
        compute=lambda t: l2(torch.FloatTensor([t["h11"], t["h12"], t["h13"]]))
        .round()[0]
        .item(),
    ),
    "h22": Mechanism(
        parents=["h11", "h12", "h13"],
        compute=lambda t: l2(torch.FloatTensor([t["h11"], t["h12"], t["h13"]]))
        .round()[1]
        .item(),
    ),
    "h23": Mechanism(
        parents=["h11", "h12", "h13"],
        compute=lambda t: l2(torch.FloatTensor([t["h11"], t["h12"], t["h13"]]))
        .round()[2]
        .item(),
    ),
    # Output layer
    "h31": Mechanism(
        parents=["h21", "h22", "h23"],
        compute=lambda t: l3(torch.FloatTensor([t["h21"], t["h22"], t["h23"]]))
        .round()
        .item(),
    ),
    # Raw input/output
    "raw_input": Mechanism(
        parents=["A", "B", "C"], compute=lambda t: [t["A"], t["B"], t["C"]]
    ),
    "raw_output": Mechanism(parents=["h31"], compute=lambda t: t["h31"]),
}

neural_network = CausalModel(mechanisms, values, id="Neural network")

In [5]:
display_structure(neural_network)

Keep this in mind - we can apply the same operations to neural networks that we apply to simple causal models! For now, let's see what we can actually do with a causal graph.

## Operations on causal graphs

Now that we know what a causal model looks like, let's try playing around with it. The key operation we can perform on a causal graph is an **intervention**. When we perform an intervention, we **change the value of a variable** - this might have downstream effects on the model's computation.

### Intervening on a causal graph

For example, what if the model thought that 1 + 2 was actually 4 instead of 3?

In [6]:
inputs = {"A": 1, "B": 2, "C": 3}
# let's make the model think that 1 + 2 is 4!
intervention = {"A+B": 4}

# we "run the model forward" again,
# but this time with an intervention changing A + B to 4
display_forward_pass(causal_model, inputs, intervention=intervention)

As you might have predicted, the causal model's new answer is 7 instead of 6. 

*Visualization note: We color the intervention in dark magenta and the values it affects in either light magenta (if the intervention directly affected the value) or violet (if the value was caused by a mix of intervened and base values).*

We can choose any variable to intervene on! For example, we can intervene directly on the output.

In [7]:
inputs = {"A": 1, "B": 2, "C": 3}
# let's make the model think that the answer is 10
intervention = {"Y": 10}

display_forward_pass(causal_model, inputs, intervention=intervention)

*Visualization note: since the `raw_output` only depends on the value of the intervened variable (and not the base inputs) we color it in magenta instead of violet.*

We can also intervene on parts of the input. For example, let's change just the value of the variable `C` in the input.

In [8]:
inputs = {"A": 1, "B": 2, "C": 3}
# let's replace the input C with 9
intervention = {"C": 9}

display_forward_pass(causal_model, inputs, intervention=intervention)

Of course, we could've simulated the model's output by just plugging in 9 for C. But the causal model still plays an important role here - it predicts what each intermediate value will turn out!

We can also intervene on multiple variables at the same time. For example, let's see what happens when we set `A + B` to 8 and `C'` to 4.

In [9]:
inputs = {"A": 1, "B": 2, "C": 3}
# let's intervene on multiple intermediate variables at once
intervention = {"A+B": 8, "C'": 4}

display_forward_pass(causal_model, inputs, intervention=intervention)

At this point, we overrode the effects of the inputs! Nevertheless, we have a prediction of what the final output will be, given our interventions.

### Interchange interventions

Interventions are powerful when we know the intermediate values of our causal model. However, in the case of complex systems such as neural networks, the intermediate values might not be predictable or interpretable. For example, it's hard to predict a reasonable value for an MLP activation.

A very powerful operation we can perform on causal models is an **interchange intervention**. In an interchange intervention, we **don't specify intermediate *values***. Instead, we choose a **set of inputs for the variable we want to change**, and let the causal model **infer** the intermediate value from those causal variables.

For example, let's change the model to think that `A + B` is 4 again. This time, we won't specify 4 up front. Instead, we'll **get the value for `A + B`** from running a **separate forward pass** on the inputs `A = 3` and `B = 1`. 

In [10]:
inputs = {"A": 1, "B": 2, "C": 3}

# let's change the value of A + B to be 3 + 1!
counterfactual_inputs = {
    "A+B": {"A": 3, "B": 1, "C": 3},
}

display_interchange(causal_model, inputs, counterfactual_inputs)

The result is the same as intervening on `A + B` directly and setting it to 4. But this time around, we didn't need to "magically" come up with the number 4 - instead, it came from the model's evaluation on a separate set of inputs. By **interchanging** the value of a variable across two separate forward passes of our causal model, we can estimate its **counterfactual behavior** when we edit the value of one of its variables. 

*Side note: just like with interventions, we can perform an interchange intervention on multiple variables! For each variable, we can specify a different source input that will specify its value.*

In [11]:
inputs = {"A": 1, "B": 2, "C": 3}

# let's change the value of A + B and C' using different sources
counterfactual_inputs = {
    "A+B": {"A": 3, "B": 1, "C": 3},  # make A + B = 3 + 1
    "C'": {"A": 1, "B": 2, "C": 8},  # make C' = 8
}

display_interchange(causal_model, inputs, counterfactual_inputs)