# Interchange Intervention Training: Equality learning tasks

In [1]:
__author__ = "Atticus Geiger"
__version__ = "CS224u, Stanford, Spring 2022"

## Contents

1. [Overview](#Overview)
1. [Set-up](#Set-up)
1. [The hierarchical equality task](#The-hierarchical-equality-task)
1. [The high-level causal model](#The-high-level-causal-model)
    1. [The algorithm with no intervention](#The-algorithm-with-no-intervention)
    1. [The algorithm with an intervention](#The-algorithm-with-an-intervention)
    1. [The algorithm with an interchange intervention](#The-algorithm-with-an-interchange-intervention)
1. [A fully-connected feed-forward neural network](#A-fully-connected-feed-forward-neural-network)
    1. [Basic intervention: zeroing out part of a hidden layer](#Basic-intervention:-zeroing-out-part-of-a-hidden-layer)
    1. [An interchange intervention](#An-interchange-intervention)
1. [Causal abstraction](#Causal-abstraction)
    1. [Alignment](#Alignment)
    1. [Interchange intervention](#Interchange-intervention)
    1. [Evaluation](#Evaluation)
1. [Interchange Intervention Training (IIT)](#Interchange-Intervention-Training-(IIT))
    1. [IIT on variable V1](#IIT-on-variable-V1)
    1. [IIT on variables V1 and V2](#IIT-on-variables-V1-and-V2)

## Overview

This notebook is a hands-on introduction to __causal abstraction analysis__ and __interchange intervention training__ with neural networks.

In causal abstraction analysis, we assess whether trained models conform to high-level causal models that we specify, not just in terms of their input–output behavior, but also in terms of their internal dynamics. 

The core technique is the __interchange intervention__, in which we actively manipulate internal states in the high-level causal model and in the neural network to see whether the two models show the same behavior in these counterfactual states.

In interchange intervention training, we go beyond passive analysis by actively training networks to conform to the high-level causal model.

To motivate and illustrate these concepts, we're going to focus on a challenging hierarchical equality task, building on work by [Geiger, Carstensen, Frank, and Potts (2020)](https://arxiv.org/abs/2006.07968).

## Set-up

In [4]:
import torch
import random
import copy
import itertools
import numpy as np
from sklearn.metrics import classification_report
import utils
from sklearn.metrics import classification_report
from LIM_deep_neural_classifier import LIMDeepNeuralClassifier
import dataset_equality

In [5]:
utils.fix_random_seeds()

## The hierarchical equality task

This section builds on results presented in [Geiger, Carstensen, Frank, and Potts (2020)](https://arxiv.org/abs/2006.07968). We will use a hierarchical equality task ([Premack 1983](https://www.cambridge.org/core/services/aop-cambridge-core/content/view/7DF6F2D22838F7546AF7279679F3571D/S0140525X00015077a.pdf/div-class-title-the-codes-of-man-and-beasts-div.pdf)) to present interchange intervention training (IIT). 

We define the hierarchical equality task as follows: The input is two pairs of objects and the output is **True** if both pairs contain the same object or if both pairs contain different objects and **False** otherwise.  For example, `AABB` and `ABCD` are both labeled **True**, while `ABCC` and `BBCD` are both labeled **False**. 

## The high-level causal model

Let $\mathcal{A}$ be the simple tree-structured algorithm that solves this task by applying a simple equality relation three times: Compute whether the first two inputs are equal, compute whether the second two inputs are equal, then compute whether the truth-valued outputs of these first two computations are equal. Here's a visual depiction of the algorithm:

<img src="fig/IIT/PremackFunctions.png" width="500"/>
<img src="fig/IIT/PremackGraph.png" width="500"/>

And here's a Python implementation of $\mathcal{A}$ that supports the interventions we'll want to do:

In [6]:
def compute_A(ex, intervention):
    graph = {}
    for i, obj in enumerate(ex):
        graph["input" + str(i+1)] = obj
    if "V1" in intervention:
        graph["V1"] = intervention["V1"]
    else:
        graph["V1"] = graph["input1"] == graph["input2"]
    if "V2" in intervention:
        graph["V2"] = intervention["V2"]
    else:
        graph["V2"] = graph["input3"] == graph["input4"]
    graph["output"] = graph["V1"] == graph["V2"]
    return graph

### The algorithm with no intervention

Let's first observe the behavior of the algorithm when we provide the input **(pentagon,pentagon, triangle, square)** with no interventions. Here is a visual depiction:

<img src="fig/IIT/PremackNoIntervention.png" width="500"/>

And here is the computation using `compute_A`:

In [7]:
compute_A(
    ("pentagon", "pentagon", "triangle", "square"), 
    intervention={})

{'input1': 'pentagon',
 'input2': 'pentagon',
 'input3': 'triangle',
 'input4': 'square',
 'V1': True,
 'V2': False,
 'output': False}

### The algorithm with an intervention

Let's now see the behavior of the algorithm when we provide the input **(square,pentagon,triangle, triangle)** with an intervention setting **V1** to **False**. First, a visual depiction:

<img src="fig/IIT/PremackIntervention.png" width="500"/>

And then the same computation with `compute_A`:

In [8]:
compute_A(
    ("square", "pentagon", "triangle", "triangle"), 
    intervention={"V1": True})

{'input1': 'square',
 'input2': 'pentagon',
 'input3': 'triangle',
 'input4': 'triangle',
 'V1': True,
 'V2': True,
 'output': True}

Notice that, in this example, even though the left two inputs are not the same, the intervention has changed the intermediate prediction for those two inputs from **False** to **True**, and thus the algorithm outputs **True**, since its output is determined by **V1** and **V2**.

### The algorithm with an interchange intervention

Finally, let's observe the behavior of the algorithm when we provide the base input **(square, pentagon, triangle, triangle)** with an intervention setting **V1** to be the value it would be for the source input **(pentagon, pentagon, triangle, square)**. Here's a diagram in which the dashed line indicates the interchange intervention:

<img src="fig/IIT/algorithmII.png" width="600"/>

And here is the corresponding interchange intervention in code:

In [9]:
def compute_interchange_A(base, source, variable):
    # Run the algorithm on `source`:
    src_output = compute_A(source, intervention={})
    # Get the source value for `variable`:
    val = src_output[variable]
    # Process `base` with the intervention setting `variable`
    # to the value it had in `source`:        
    return compute_A(base, intervention={variable: val})

In [10]:
compute_interchange_A(
    base=("pentagon", "pentagon", "triangle", "square"),    # base: T F ==> F
    source=("square", "pentagon", "triangle", "triangle"),  # source: F T ==> F
    variable="V1") # Will set base V1 to be source V1, leading to F F ==> T

{'input1': 'pentagon',
 'input2': 'pentagon',
 'input3': 'triangle',
 'input4': 'square',
 'V1': False,
 'V2': False,
 'output': True}

## A fully-connected feed-forward neural network

We've now seen how interventions work in our high-level causal model. We turn now to doing parallel work in our neural network, which will be a fully-connected feed-forward neural network with three hidden layers. The following code simply extends `TorchDeepNeuralClassifier` with a method `retrieve_activations` that supports interventions on PyTorch computation graphs:

In [11]:
class InterventionableTorchDeepNeuralClassifier(TorchDeepNeuralClassifier):
    def __init__(self, **base_kwargs):
        super().__init__(**base_kwargs)
        
    def make_hook(self, gets, sets, layer):
        def hook(model, input, output):
            layer_gets, layer_sets = [], []
            if gets is not None and layer in gets:
                layer_gets = gets[layer]
            if sets is not None and layer in sets:
                layer_sets = sets[layer]
            for set in layer_sets:
                output = torch.cat(
                    [output[:, :set["start"]], 
                     set["intervention"], 
                     output[:, set["end"]: ]], 
                    dim=1)
            for get in layer_gets:
                k = f'{get["layer"]}-{get["start"]}-{get["end"]}'
                self.activation[k] = output[:, get["start"]: get["end"] ]
            return output
        return hook

    def _gets_sets(self, gets=None, sets=None):
        handlers = []
        for layer in range(len(self.layers)):
            hook = self.make_hook(gets, sets, layer)
            both_handler = self.layers[layer].register_forward_hook(hook)
            handlers.append(both_handler)
        return handlers

    def retrieve_activations(self, X, get, sets):
        if sets is not None and "intervention" in sets:
            sets["intervention"] = sets["intervention"].type(torch.FloatTensor).to(self.device)
        X = X.type(torch.FloatTensor).to(self.device)
        self.activation = {}
        get_val = {get["layer"]: [get]} if get is not None else None
        set_val = {sets["layer"]: [sets]} if sets is not None else None
        handlers = self._gets_sets(get_val, set_val)
        logits = self.model(X)
        for handler in handlers:
            handler.remove()
        return self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}']

NameError: name 'TorchDeepNeuralClassifier' is not defined

The module `iit` provides some dataset functions for equality learning. Here we define a simple an equality dataset:

In [10]:
embedding_dim = 4

n_examples = 10000

X_train, X_test, y_train, y_test, test_dataset = iit.get_equality_dataset(
    embedding_dim, n_examples)

The examples in this dataset are 16-dimensional vectors: the concatenation of 4 4-dimensional vectors. Here's the first example with its label:

In [11]:
X_train[0], y_train[0]

(tensor([ 0.4365, -0.0289, -0.2255,  0.3735,  0.4365, -0.0289, -0.2255,  0.3735,
          0.3634,  0.1236,  0.0243, -0.1352,  0.4411, -0.2673, -0.4682, -0.3980],
        dtype=torch.float64),
 0)

The label for this example is determined by whether the equality value for the first two inputs matches the equality value for the second two inputs:

In [12]:
left = torch.equal(
    X_train[0][: embedding_dim],
    X_train[0][embedding_dim: embedding_dim*2])

left

True

In [13]:
right = torch.equal(
    X_train[0][embedding_dim*2: embedding_dim*3],
    X_train[0][embedding_dim*3: ])

right

False

In [14]:
int(left == right)

0

Let's see how our model does out-of-the-box on this task:

In [15]:
model = InterventionableTorchDeepNeuralClassifier(
    hidden_dim=embedding_dim * 4,
    hidden_activation=torch.nn.ReLU(), 
    num_layers=3)

_ = model.fit(X_train, y_train)

Stopping after epoch 822. Training loss did not improve more than tol=1e-05. Final error is 0.0004548490142042283.

This neural network achieves near perfect performance on its train set:

In [16]:
train_preds = model.predict(X_train)

print("Train Results")
print(classification_report(y_train, train_preds))

Train Results
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5000
           1       1.00      1.00      1.00      5000

    accuracy                           1.00     10000
   macro avg       1.00      1.00      1.00     10000
weighted avg       1.00      1.00      1.00     10000



And it generalizes perfectly to the test set:

In [17]:
print("Test Results")

test_preds = model.predict(X_test)

print(classification_report(y_test, test_preds))

Test Results
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5000
           1       1.00      1.00      1.00      5000

    accuracy                           1.00     10000
   macro avg       1.00      1.00      1.00     10000
weighted avg       1.00      1.00      1.00     10000



Does it implement our high-level model of the problem, though?

### Basic intervention: zeroing out part of a hidden layer

To begin to build towards the full interchange intervention, let's consider a simpler intervention, where we zero out the first `embedding_dim` neurons in the first hidden layer.

Our basic inputs are random vectors:

In [18]:
a = X_train[0][: embedding_dim]
b = X_train[1][: embedding_dim]
c = X_train[2][: embedding_dim]

X_same_different = torch.cat((a, a, b, c)).unsqueeze(0)

X_different_same = torch.cat((a, b, c, c)).unsqueeze(0)

And here we define two different inputs for use in later examples. We'll use training examples so that we are sure to see the full logic of these interventions; the next section will consider test examples in the context of a full abstraction analysis:

For the intervention, we first specify that we want it target layer 1 (the outer key), and then we give a list of specifications for interventions at that layer (here we will do just one). So that we can study the full layer before and after the intervention, we specify the entire layer:

In [19]:
zeroing_get_coord = {
    "layer": 1,
    "start": 0, 
    "end": 
    embedding_dim*4
}

Next, we specify the intervention itself: in layer 1, the first `embedding_layer` inputs will be turned into 0s:

In [20]:
zeroing_intervention = {
    "layer": 1,
    "start": 0,  
    "end": embedding_dim, 
    "intervention": torch.zeros((1,embedding_dim))
}

For the `X_same_different` input, the network computes the following values at our intervention site, without any intervention:

In [21]:
model.retrieve_activations(X_same_different, zeroing_get_coord, None)

tensor([[0.0000, 0.8953, 0.0000, 2.0436, 0.3650, 0.6414, 2.0196, 1.5436, 0.9906,
         0.0000, 0.3958, 0.8037, 0.1184, 0.5085, 0.5110, 1.3489]],
       device='cuda:0', grad_fn=<SliceBackward0>)

And here are the values computed with the intervention:

In [22]:
model.retrieve_activations(X_same_different, zeroing_get_coord, zeroing_intervention)

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.3650, 0.6414, 2.0196, 1.5436, 0.9906,
         0.0000, 0.3958, 0.8037, 0.1184, 0.5085, 0.5110, 1.3489]],
       device='cuda:0', grad_fn=<SliceBackward0>)

We can also see how the intervention affects outputs. To that, we specify the final layer (the two logits) as the coordinate:

In [23]:
zeroing_output_coord = {
    "layer": 3, 
    "start": 0, 
    "end": 2}

Here are the outputs without an intervention:

In [24]:
model.retrieve_activations(X_same_different, zeroing_output_coord, sets=None)

tensor([[ 7.6271, -2.5820]], device='cuda:0', grad_fn=<SliceBackward0>)

And with the intervention we specified above:

In [25]:
model.retrieve_activations(X_same_different, zeroing_output_coord, zeroing_intervention)

tensor([[ 19.4618, -15.0455]], device='cuda:0', grad_fn=<SliceBackward0>)

### An interchange intervention

We're now ready to do a full intervention. The only change from the above is that, instead of simply zeroing out some neurons, we'll replace them with the corresponding values determined by a distinct input.

We'll again target the first `embedding_dim` units in the first hidden layer:

In [26]:
ii_coord = {"layer": 1, "start": 0, "end": embedding_dim}

For our **source** input, we'll use `X_different_same`. The first step is to get the activations for this input at our coordinate:

In [27]:
intervention_get = model.retrieve_activations(X_different_same, ii_coord, None)

intervention_get

tensor([[0.3911, 0.2510, 0.0000, 1.2667]], device='cuda:0',
       grad_fn=<SliceBackward0>)

Then we define the intervention using these values:

In [28]:
ii_set = {
    "layer": 1, 
    "start": 0, 
    "end": embedding_dim, 
    "intervention": intervention_get}

We now turn to our __base__ input, which will be `X_same_different`. With no intervention, this has the following values at our intervention site:

In [29]:
model.retrieve_activations(X_same_different, ii_coord, None)

tensor([[0.0000, 0.8953, 0.0000, 2.0436]], device='cuda:0',
       grad_fn=<SliceBackward0>)

And then we can verify that the intervention works as we intended it to; these values should be the same as `intervention_get` above:

In [30]:
model.retrieve_activations(X_same_different, ii_coord, ii_set)

tensor([[0.3911, 0.2510, 0.0000, 1.2667]], device='cuda:0',
       grad_fn=<SliceBackward0>)

Finally, we can see what the intervention does to the network's predictions. We specify the coordinates of the output logits:

In [31]:
ii_output_coord = {"layer": 3, "start": 0, "end": 2}

With no intervention, the input `X_same_different` delivers:

In [32]:
model.retrieve_activations(X_same_different, ii_output_coord, None)

tensor([[ 7.6271, -2.5820]], device='cuda:0', grad_fn=<SliceBackward0>)

With the intervention, that same input delivers:

In [33]:
model.retrieve_activations(X_same_different, ii_output_coord, ii_set)

tensor([[ 22.5734, -17.7018]], device='cuda:0', grad_fn=<SliceBackward0>)

If our target coordinates for the intervention were a modular encoding of the value for the first two inputs, then this intervention would have change the network's prediction from `0` to `1`, since we would have effectively created a **(different, different)** input. It's unlikely that this happened, suggesting that our hypothesis about where this information is encoded is false. A full-fledged causal abstraction analysis will allow us to assess this more comprehensively.

## Causal abstraction

To recap:

1. We defined a **high-level causal model** (a tree-structured algorithm) that solves the hierarchical equality task.

1. We trained a **low-level fully connected neural network** that seeks to solv the hierarchical equality task.

1. We peformed illustrative interventions on both these networks to begin to get a feel for whether the high-level model is an abstraction of the lower-level neural one.

The formal theory of **causal abstraction** describes the conditions that must hold for the high-level tree structured algorithm to be a **simplified and faithful description** of the neural network. 

In essence: an high-level model is a causal abstraction of a neural network if and only if for all base and source inputs, the algorithm and network provides the same output, for some alignment between these two models.

Below, we define an alignment between the neural network and the algorithm and a function to compute the **interchange intervention accuracy** (II accuracy) for a high-level variable: the percentage of aligned interchange interventions that the network and algorithm produce the same output on. When the II accuracy is 100%, the causal abstraction relation holds between the network and a simplified version of the algorithm where only one high-level variable exists.

### Alignment

The first step is to specify an alignment:

In [34]:
alignment = {
    "V1": {"layer": 1, "start": 0, "end": embedding_dim}, 
    "V2": {"layer": 1, "start": embedding_dim, "end": embedding_dim*2}}

In essence, this reflects a hypothesis that we will find the equality label for the first two inputs in the first four neurons in layer 1, and that we'll find the equality label for the second two inputs in the next four neurons in layer 1. This is of course just one of a great many hypotheses we could state. 

### Interchange intervention

The function `interchange_intervention` packages up the multi-step process we walked through above:

In [35]:
def interchange_intervention(model, base, source, get_coord, output_coord):
    intervention = model.retrieve_activations(source, get_coord, None)
    get_coord["intervention"] = intervention
    return model.retrieve_activations(base, output_coord, get_coord)

In [36]:
output_coord = {"layer": 3, "start": 0, "end": 2}

Example: 

In [37]:
interchange_intervention(
    model, 
    base=X_same_different, 
    source=X_different_same, 
    get_coord=ii_coord, 
    output_coord=output_coord)

tensor([[ 22.5734, -17.7018]], device='cuda:0', grad_fn=<SliceBackward0>)

So that we can run out high-level model on our vector examples, we define a helper function to parse them into their component inputs:

In [38]:
def convert_input(tensor, embedding_dim):
    return [tuple(tensor[0, embedding_dim*k:embedding_dim*(k+1)].flatten().tolist()) 
            for k in range(4)]

Illustration:

In [39]:
compute_A(convert_input(X_same_different, embedding_dim), {})['output']

False

In [40]:
compute_A(convert_input(X_different_same, embedding_dim), {})['output']

False

Now if we perform an intervention on the **V1** variable with `X_different_same` as the base and `X_same_different` as the source, we effectively create a **(False, False)** example:

In [41]:
compute_interchange_A(
    convert_input(X_different_same, embedding_dim),
    convert_input(X_same_different, embedding_dim),
    variable="V1")['output']    

True

### Evaluation

The function `ii_evaluation` puts these pieces together in the context of a full evaluation on a set of examples:

In [42]:
def ii_evaluation(X_assess, model, variable, output_coord):
    labels = []
    predictions = []
    for base, source in itertools.product(X_assess, repeat=2):
        base = base.unsqueeze(0)
        source = source.unsqueeze(0)
        # Run the high-level model with the intervention:
        algorithm_output = compute_interchange_A(
            convert_input(base, embedding_dim), 
            convert_input(source, embedding_dim), 
            variable)
        # Get the high-level model's label:
        labels.append(int(algorithm_output["output"]))
        # Run the neural model with the intervention:
        network_output = interchange_intervention(
            model, 
            base,
            source,
            alignment[variable],
            output_coord)
        # Get the neural model's prediction with the intervention:
        pred = network_output.argmax(axis=1)
        predictions.append(int(pred))
    return labels, predictions

First, let's assess the hypothesis that **V1** is encoded at our chosen site, using a sample of test cases for efficiency:

In [43]:
print(classification_report(*ii_evaluation(X_test[: 100], model, "V1", output_coord)))

              precision    recall  f1-score   support

           0       0.53      0.41      0.46      5016
           1       0.52      0.63      0.57      4984

    accuracy                           0.52     10000
   macro avg       0.52      0.52      0.52     10000
weighted avg       0.52      0.52      0.52     10000



And then the corresponding assessment for **V2**:

In [44]:
print(classification_report(*ii_evaluation(X_test[: 100], model, "V2", output_coord)))

              precision    recall  f1-score   support

           0       0.55      0.71      0.62      5016
           1       0.58      0.41      0.48      4984

    accuracy                           0.56     10000
   macro avg       0.56      0.56      0.55     10000
weighted avg       0.56      0.56      0.55     10000



We have low accuracy for both **V1** and **V2**, meaning that under this alignment the neural network does not compute either variable. In other words, we have no evidence that this network computes simple equality relations to solve this hierarchical equality task. The goal of interchange intervention training is to change this. We turn to that method next.

## Interchange Intervention Training (IIT)

Interchange Intervention Training (IIT) is a method for training a neural network to conform to the causal structure of a high-level algorithm. Conceptually, it is a direct extension of the causal abstraction analysis we just performed, except instead of **evaluating** whether the neural network and algorithm produce the same outputs under aligned interchange interventions, we are now **training** the neural network to produce the output of the algorithm under aligned interchange interventions.

IIT was developed by [Geiger\*, Wu\*, Lu\*, Rozner, Kreiss, Icard, Goodman, and Potts (2021)](https://arxiv.org/abs/2112.00826), and it is used for model distillation [ Wu\*, Geiger\*, Rozner, Kreiss, Lu, Icard, Goodman, and Potts (2021)](https://arxiv.org/abs/2112.02505).

### IIT on variable V1

Our first intervention will target the first four dimensions of layer 1 in the network, but now we will be training the network in part to play the role of the **V1** variable:

In [45]:
V1 = 0

V1_id_to_coords = {
    V1: [{"layer": 1, "start": 0, "end": embedding_dim}]    
}

Next we create an equality dataset that includes examples for IIT training:

In [46]:
data_size = 10000

iit_equality_dataset = iit.get_IIT_equality_dataset("V1", embedding_dim, data_size)

X_base_train, X_sources_train, y_base_train, y_IIT_train, interventions = iit_equality_dataset

This dataset has the following components:

* `X_base_train`: a regular set of train examples
* `y_base_train`: a regular set of train labels
* `X_sources_train`: a list additional train sets (here, a singleton list of them) for counterfactuals
* `y_IIT_train`: a list of labels for the examples in `X_sources_train`.
* `interventions`: a list of intervention sites (here, all `0` corresponding to our key for "V1")

Our model is a deep classifier like the one we used above, but now one that can do IIT:

In [47]:
iit_model = TorchDeepNeuralClassifierIIT(
    hidden_dim=embedding_dim*4, 
    hidden_activation=torch.nn.ReLU(), 
    num_layers=3,
    id_to_coords=V1_id_to_coords)

The model is fit using our IIT dataset:

In [48]:
_ = iit_model.fit(
    X_base_train, 
    X_sources_train, 
    y_base_train, 
    y_IIT_train, 
    interventions)

Finished epoch 1000 of 1000; error is 0.002717384573770687

To evaluate this model, we create a fresh IIT equality dataset consisting of 100 examples:

In [49]:
X_base_test, X_sources_test, y_base_test, y_IIT_test, interventions = iit.get_IIT_equality_dataset(
    "V1", embedding_dim, 100)

In [50]:
IIT_preds, base_preds = iit_model.iit_predict(X_base_test, X_sources_test, interventions)

This IIT-trained model does well in terms of a standard behavioral tests:

In [51]:
print(classification_report(y_base_test, base_preds))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        48
           1       1.00      1.00      1.00        48

    accuracy                           1.00        96
   macro avg       1.00      1.00      1.00        96
weighted avg       1.00      1.00      1.00        96



Importantly, it _also_ performs perfectly on counterfactual examples – certainly a marked improvement over the model we studied above that did no IIT:

In [52]:
print(classification_report(y_IIT_test, IIT_preds))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        48
           1       1.00      1.00      1.00        48

    accuracy                           1.00        96
   macro avg       1.00      1.00      1.00        96
weighted avg       1.00      1.00      1.00        96



Of course, we did only one kind of IIT: we pushed the first `embedding_dim` neurons in layer 1 to conform to **V1** in the causal model. As a results, we still have low counterfactual accuracy for **V2**, meaning that, under this alignment, the neural network doesn't compute whether the second pair of inputs are equal:

In [53]:
X_base_test_V2, X_sources_test_V2, y_base_test_V2, y_IIT_test_V2, interventions_V2 = iit.get_IIT_equality_dataset(
    "V2", embedding_dim,data_size)

IIT_preds_V2, base_preds_V2 = iit_model.iit_predict(X_base_test_V2, X_sources_test_V2, interventions_V2)

print(classification_report(y_IIT_test_V2, IIT_preds_V2))

              precision    recall  f1-score   support

           0       0.50      0.50      0.50      5000
           1       0.50      0.50      0.50      5000

    accuracy                           0.50     10000
   macro avg       0.50      0.50      0.50     10000
weighted avg       0.50      0.50      0.50     10000



### IIT on variables V1 and V2

To address this, we can simply train the network to compute both **V1** and **V2**.

In [54]:
V2 = 1

id_to_coords = {
    V1: [{"layer": 1, "start": 0, "end": embedding_dim}], 
    V2: [{"layer": 1, "start": embedding_dim, "end": embedding_dim*2}]    
}

In [55]:
both_model = TorchDeepNeuralClassifierIIT(
    hidden_dim=embedding_dim*4, 
    hidden_activation=torch.nn.ReLU(), 
    num_layers=3, 
    id_to_coords=id_to_coords)

In [56]:
v1data = iit.get_IIT_equality_dataset("V1", embedding_dim, data_size)

v2data = iit.get_IIT_equality_dataset("V2", embedding_dim, data_size)

In [57]:
X_base_train_both = torch.cat([v1data[0], v2data[0]], dim=0)

X_sources_train_both = [torch.cat([v1data[1][i], v2data[1][i]], dim=0) 
                        for i in range(len(v1data[1]))] 

y_base_train_both = torch.cat([v1data[2], v2data[2]])
y_IIT_train_both = torch.cat([v1data[3], v2data[3]])

interventions_both = torch.cat([v1data[4], v2data[4]])

_ = both_model.fit(
    X_base_train_both, 
    X_sources_train_both, 
    y_base_train_both, 
    y_IIT_train_both, 
    interventions_both)

Finished epoch 1000 of 1000; error is 0.5991211635991931

In [58]:
IIT_preds, base_preds = both_model.iit_predict(
    X_base_test, X_sources_test, interventions)

print("Standard evaluation")
print(classification_report(y_base_test, base_preds))
      
print("V1 counterfactual evaluation")
print(classification_report(y_IIT_test, IIT_preds))

Standard evaluation
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        48
           1       1.00      1.00      1.00        48

    accuracy                           1.00        96
   macro avg       1.00      1.00      1.00        96
weighted avg       1.00      1.00      1.00        96

V1 counterfactual evaluation
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        48
           1       1.00      1.00      1.00        48

    accuracy                           1.00        96
   macro avg       1.00      1.00      1.00        96
weighted avg       1.00      1.00      1.00        96



In [59]:
IIT_preds_V2, base_preds_V2 = both_model.iit_predict(
    X_base_test_V2, X_sources_test_V2, interventions_V2)

print("V2 counterfactual evaluation")
print(classification_report(y_IIT_test_V2, IIT_preds_V2))

V2 counterfactual evaluation
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5000
           1       1.00      1.00      1.00      5000

    accuracy                           1.00     10000
   macro avg       1.00      1.00      1.00     10000
weighted avg       1.00      1.00      1.00     10000

