# Interchange Intervention Training: Equality learning tasks

In [1]:
__author__ = "Atticus Geiger"
__version__ = "CS224u, Stanford, Spring 2022"

## Contents

1. [Overview](#Overview)
1. [Set-up](#Set-up)
1. [The hierarchical equality task](#The-hierarchical-equality-task)
1. [The high-level causal model](#The-high-level-causal-model)
    1. [The algorithm with no intervention](#The-algorithm-with-no-intervention)
    1. [The algorithm with an intervention](#The-algorithm-with-an-intervention)
    1. [The algorithm with an interchange intervention](#The-algorithm-with-an-interchange-intervention)
1. [A fully-connected feed-forward neural network](#A-fully-connected-feed-forward-neural-network)
    1. [Basic intervention: zeroing out part of a hidden layer](#Basic-intervention:-zeroing-out-part-of-a-hidden-layer)
    1. [An interchange intervention](#An-interchange-intervention)
1. [Causal abstraction](#Causal-abstraction)
1. [Interchange Intervention Training (IIT)](#Interchange-Intervention-Training-(IIT))
1. [Multisource IIT](#Multisource-IIT)

## Overview

This notebook is a hands-on introduction to __causal abstraction analysis__ and __interchange intervention training__ with neural networks.

In causal abstraction analysis, we assess whether trained models conform to high-level causal models that we specify, not just in terms of their input–output behavior, but also in terms of their internal dynamics. 

The core technique is the __interchange intervention__, in which we actively manipulate internal states in the high-level causal model and in the neural network to see whether the two models show the same behavior in these counterfactual states.

In interchange intervention training, we go beyond analysis by actively training networks to conform to the high-level causal model.

To motivate and illustrate these concepts, we're going to focus on a challenging hierarchical equality task, building on work by [Geiger, Carstensen, Frank, and Potts (2020)](https://arxiv.org/abs/2006.07968).

## Set-up

In [2]:
import torch
import random
import copy
import itertools
import numpy as np
from sklearn.metrics import classification_report
from torch_deep_neural_classifier import TorchDeepNeuralClassifier
from torch_deep_neural_classifier_iit import TorchDeepNeuralClassifierIIT
import iit
import utils

In [3]:
utils.fix_random_seeds(44)

## The hierarchical equality task

This section builds on results presented in [Geiger, Carstensen, Frank, and Potts (2020)](https://arxiv.org/abs/2006.07968). We will use a hierarchical equality task to present interchange intervention training (IIT). 

We define the hierarchical equality task as follows: The input is two pairs of objects and the output is **True** if both pairs contain the same object or if both pairs contain different objects and **False** otherwise. 

For example, `AABB` and `ABCD` are both labeled **True**, while `ABCC` and `BBCD` are both labeled **false**. 

## The high-level causal model

Let $\mathcal{A}$ be the simple tree-structured algorithm that solves this task by applying a simple equality relation three times: Compute whether the first two inputs are equal, compute whether the second two inputs are equal, then compute whether the truth-valued outputs of these first two computations are equal. Here's a visual depiction of the algorithm:

<img src="fig/IIT/PremackFunctions.png" width="500"/>
<img src="fig/IIT/PremackGraph.png" width="500"/>

And here's a Python implementation of $\mathcal{A}$ that supports the interventions we'll want to do:

In [4]:
def compute_A(ex, intervention):
    graph = {}
    for i, obj in enumerate(ex):
        graph["input" + str(i+1)] = obj
    if "V1" in intervention:
        graph["V1"] = intervention["V1"]
    else:
        graph["V1"] = graph["input1"] == graph["input2"]
    if "V2" in intervention:
        graph["V2"] = intervention["V2"]
    else:
        graph["V2"] = graph["input3"] == graph["input4"]
    graph["output"] = graph["V1"] == graph["V2"]
    return graph

### The algorithm with no intervention

Let's first observe the behavior of the algorithm when we provide the input **(pentagon,pentagon, triangle, square)** with no interventions. Here is a visual depiction:

<img src="fig/IIT/PremackNoIntervention.png" width="500"/>

And here is the computation using `compute_A`:

In [5]:
compute_A(("pentagon", "pentagon", "triangle", "square"), intervention={})

{'input1': 'pentagon',
 'input2': 'pentagon',
 'input3': 'triangle',
 'input4': 'square',
 'V1': True,
 'V2': False,
 'output': False}

### The algorithm with an intervention

Let's now see the behavior of the algorithm when we provide the input **(square,pentagon, triangle, triangle)** with an intervention setting **V1** to **False**. First, a visual depiction:

<img src="fig/IIT/PremackIntervention.png" width="500"/>

And then the same computation with `compute_A`:

In [6]:
compute_A(
    ("square", "pentagon", "triangle", "triangle"), 
    intervention={"V1": True})

{'input1': 'square',
 'input2': 'pentagon',
 'input3': 'triangle',
 'input4': 'triangle',
 'V1': True,
 'V2': True,
 'output': True}

Notice that, in this example, even though the left two inputs are not the same, the intervention has changed the intermediate prediction for those two inputs from **False** to **True**, and thus the algorithm outputs **True**, since its output is determined by **V1** and **V2**.

### The algorithm with an interchange intervention

Finally, let's observe the behavior of the algorithm when we provide the base input **(square,pentagon, triangle, triangle)** with an intervention setting **V1** to be the value it would be for the source input **(pentagon,pentagon, triangle, square)**. Here's a diagram in which the dashed line indicates the interchange intervention:

<img src="fig/IIT/algorithmII.png" width="600"/>

And here is the corresponding interchange intervention in code:

In [7]:
def compute_interchange_A(base, source, variable):
    # Run the algorithm on `source`:
    src_output = compute_A(source, intervention={})
    # Get the source value for `variable`:
    val = src_output[variable]
    # Process `base` with the intervention setting `variable`
    # to the value it had in `source`:        
    return compute_A(base, intervention={variable: val})

In [8]:
compute_interchange_A(
    base=("pentagon", "pentagon", "triangle", "square"),    # base: T F ==> F
    source=("square", "pentagon", "triangle", "triangle"),  # source: F T ==> F
    variable="V1") # Will set base V1 to be source V1, leading to F F ==> T

{'input1': 'pentagon',
 'input2': 'pentagon',
 'input3': 'triangle',
 'input4': 'square',
 'V1': False,
 'V2': False,
 'output': True}

## A fully-connected feed-forward neural network

We've now seen how interventions work in our high-level causal model. We turn now to doing parallel work in our neural network, which will be a fully-connected feed-forward neural network with three hidden layers. The following code simply extends `TorchDeepNeuralClassifier` with a method `retrieve_activations` that supports interventions on PyTorch computation graphs:

The module `iit` provides some dataset functions for equality learning. Here we define a simple 

In [9]:
class InterventionableTorchDeepNeuralClassifier(TorchDeepNeuralClassifier):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def make_hook(self, gets, sets, layer):
        def hook(model, input, output):
            layer_gets, layer_sets = [], []
            if gets is not None and layer in gets:
                layer_gets = gets[layer]
            if sets is not None and layer in sets:
                layer_sets = sets[layer]
            for set in layer_sets:
                output = torch.cat([output[:,:set["start"]], set["intervention"], output[:,set["end"]:]], dim = 1)
            for get in layer_gets:
                k = f'{get["layer"]}-{get["start"]}-{get["end"]}'
                self.activation[k] = output[:, get["start"]: get["end"] ]
            return output
        return hook

    def _gets_sets(self, gets=None, sets=None):
        handlers = []
        for layer in range(len(self.layers)):
            hook = self.make_hook(gets, sets, layer)
            both_handler = self.layers[layer].register_forward_hook(hook)
            handlers.append(both_handler)
        return handlers

    def retrieve_activations(self, X, get, sets):
        if sets is not None and "intervention" in sets:
            sets["intervention"] = sets["intervention"].type(torch.FloatTensor).to(self.device)
        X = X.type(torch.FloatTensor).to(self.device)
        self.activation = {}
        get_val = {get["layer"]: [get]} if get is not None else None
        set_val = {sets["layer"]: [sets]} if sets is not None else None
        handlers = self._gets_sets(get_val, set_val)
        logits = self.model(X)
        for handler in handlers:
            handler.remove()
        return self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}']

In [10]:
embedding_dim = 4

n_examples = 10000

X_train, X_test, y_train, y_test, test_dataset = iit.get_equality_dataset(embedding_dim, n_examples)

The examples in this dataset are 16-dimensional vectors: the concatenation of 4 4-dimensional vectors. Here's the first example with its label:

In [11]:
X_train[0], y_train[0]

(tensor([-0.2531, -0.1757,  0.0881,  0.4889, -0.3318, -0.3059,  0.3298,  0.1829,
          0.0713,  0.0626,  0.4520,  0.0696, -0.0203,  0.4501,  0.4336, -0.0723],
        dtype=torch.float64),
 1)

The label for this example is determined by whether the equality value for the first two inputs matches the equality value for the second two inputs:

In [12]:
left = torch.equal(
    X_train[0][: embedding_dim],
    X_train[0][embedding_dim: embedding_dim*2])

left

False

In [13]:
right = torch.equal(
    X_train[0][embedding_dim*2: embedding_dim*3],
    X_train[0][embedding_dim*3: ])

right

False

In [14]:
int(left == right)

1

Let's see how our model does out-of-the-box on this task:

In [15]:
hidden_dim = 4 * embedding_dim  # 4 inputs.

model = InterventionableTorchDeepNeuralClassifier(
    hidden_dim=hidden_dim, 
    hidden_activation=torch.nn.ReLU(), 
    num_layers=3)

_ = model.fit(X_train,y_train)

Stopping after epoch 727. Training loss did not improve more than tol=1e-05. Final error is 0.000661362395476317.

In [16]:
model.layers

[ActivationLayer(
   (linear): Linear(in_features=16, out_features=16, bias=True)
   (activation): ReLU()
 ),
 ActivationLayer(
   (linear): Linear(in_features=16, out_features=16, bias=True)
   (activation): ReLU()
 ),
 ActivationLayer(
   (linear): Linear(in_features=16, out_features=16, bias=True)
   (activation): ReLU()
 ),
 Linear(in_features=16, out_features=2, bias=True)]

This neural network achieves near perfect performance on its train set:

In [17]:

print("Train Results")

preds = model.predict(X_train)

print(classification_report(y_train, preds))

Train Results
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5000
           1       1.00      1.00      1.00      5000

    accuracy                           1.00     10000
   macro avg       1.00      1.00      1.00     10000
weighted avg       1.00      1.00      1.00     10000



And generalizes perfectly to the test set!

In [18]:
print("Test Results")

preds = model.predict(X_test)

print(classification_report(y_test, preds))

Test Results
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5000
           1       1.00      1.00      1.00      5000

    accuracy                           1.00     10000
   macro avg       1.00      1.00      1.00     10000
weighted avg       1.00      1.00      1.00     10000



### Basic intervention: zeroing out part of a hidden layer

To begin to build towards the full interchange intervention, let's consider a simpler intervention, where we zero out the first `embedding_dim` neurons in the first hidden layer.

Our basic inputs are random vectors:

And here we define two different inputs for use in later examples. We'll use training examples so that we are sure to see the full logic of these interventions; the next section will consider test examples in the context of a full abstraction analysis:

In [19]:
a = X_train[0][: embedding_dim]
b = X_train[1][: embedding_dim]
c = X_train[2][: embedding_dim]

X_same_different = torch.cat((a, a, b, c)).unsqueeze(0)

X_different_same = torch.cat((a, b, c, c)).unsqueeze(0)

For the intervention, we first specify that we want it target layer 1. So that we can study the full layer before and after the intervention, we specify the entire layer:

In [20]:
zeroing_get_coord = {
    "layer": 1, 
    "start": 0, 
    "end": embedding_dim*4}

Next, we specify the intervention itself: in layer 1, the first `embedding_layer` inputs will be turned into 0s:

In [21]:
zeroing_intervention = {
    "layer": 1, 
    "start": 0, 
    "end": embedding_dim, 
    "intervention": torch.zeros((1,embedding_dim))}

For the `X_same_different` input, the network computes the following values at our intervention site, without any intervention:

In [22]:
model.retrieve_activations(X_same_different, zeroing_get_coord, None)

tensor([[0.0344, 0.0000, 1.0244, 0.5173, 0.0000, 0.4466, 0.7505, 0.0000, 0.0725,
         0.3432, 0.0611, 0.2861, 0.5741, 0.7445, 0.6889, 0.0000]],
       device='cuda:0', grad_fn=<SliceBackward0>)

And here are the values computed with the intervention:

In [23]:
model.retrieve_activations(X_same_different, zeroing_get_coord, zeroing_intervention)

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4466, 0.7505, 0.0000, 0.0725,
         0.3432, 0.0611, 0.2861, 0.5741, 0.7445, 0.6889, 0.0000]],
       device='cuda:0', grad_fn=<SliceBackward0>)

We can also see how the intervention affects outputs. To that, we specify the final layer (the two logits) as the coordinate:

In [24]:
zeroing_output_coord = {
    "layer": 3, 
    "start": 0, 
    "end": 2}

Here are the outputs without an intervention:

In [25]:
model.retrieve_activations(X_same_different, zeroing_output_coord, sets=None)

tensor([[ 6.3725, -7.8416]], device='cuda:0', grad_fn=<SliceBackward0>)

And with the intervention we specified above:

In [26]:
model.retrieve_activations(X_same_different, zeroing_output_coord, zeroing_intervention)

tensor([[ 8.2483, -9.8024]], device='cuda:0', grad_fn=<SliceBackward0>)

### An interchange intervention

We're now ready to do a full intervention. The only change from the above is that, instead of simply zeroing out some neurons, we'll replace them with the corresponding values determined by a distinct input.

We'll again target the first `embedding_dim` units in the first hidden layer:

In [27]:
ii_coord = {"layer": 1, "start": 0, "end": embedding_dim}

For our **source** input, we'll use `X_different_same`. The first step is to get the activations for this input at our coordinate:

In [28]:
intervention_get = model.retrieve_activations(X_different_same, ii_coord, None)

intervention_get

tensor([[0.0000, 0.0000, 0.5353, 0.4133]], device='cuda:0',
       grad_fn=<SliceBackward0>)

Then we define the intervention using these values:

In [29]:
ii_set = {
    "layer": 1, 
    "start": 0, 
    "end": embedding_dim, 
    "intervention": intervention_get}

We now turn to our __base__ input, which will be `X_same_different`. With no intervention, this has the following values at our intervention site:

In [30]:
model.retrieve_activations(X_same_different, ii_coord, None)

tensor([[0.0344, 0.0000, 1.0244, 0.5173]], device='cuda:0',
       grad_fn=<SliceBackward0>)

And then we can verify that the intervention works as we intended it to; these values should be the same as `intervention_get` above:

In [31]:
model.retrieve_activations(X_same_different, ii_coord, ii_set)

tensor([[0.0000, 0.0000, 0.5353, 0.4133]], device='cuda:0',
       grad_fn=<SliceBackward0>)

Finally, we can see what the intervention does to the network's predictions. We specify the coordinates of the output logits:

In [32]:
ii_output_coord = {"layer": 3, "start": 0, "end": 2}

With no intervention, the input `X_same_different` delivers:

In [33]:
model.retrieve_activations(X_same_different, ii_output_coord, None)

tensor([[ 6.3725, -7.8416]], device='cuda:0', grad_fn=<SliceBackward0>)

With the intervention, that same input delivers:

In [34]:
model.retrieve_activations(X_same_different, ii_output_coord, ii_set)

tensor([[ 7.9475, -9.5349]], device='cuda:0', grad_fn=<SliceBackward0>)

If our target coordinates for the intervention were a modular encoding of the value for the first two inputs, then this intervention would have change the network's prediction from `0` to `1`, since we would have effectively created a **(different, different)** input. It's unlikely that this happened, suggesting that our hypothesis about where this information is encoded is false. A full-fledged causal abstraction analysis will allow us to assess this more comprehensively.

## Causal abstraction

To recap:

1. We defined a **high-level causal model** (a tree-structured algorithm) that solves the hierarchical equality task.

1. We trained a **low-level fully connected neural network** that seeks to solv the hierarchical equality task.

1. We peformed illustrative interventions on both these networks to begin to get a feel for whether the high-level model is an abstraction of the lower-level neural one.

The formal theory of **causal abstraction** describes the conditions that must hold for the high-level tree structured algorithm to be a **simplified and faithful description** of the neural network. 

In essence: an high-level model is a causal abstraction of a neural network if and only if for all base and source inputs, the algorithm and network provides the same output, for some alignment between these two models.

Below, we define an alignment between the neural network and the algorithm and a function to compute the **interchange intervention accuracy** (II accuracy) for a high-level variable: the percentage of aligned interchange interventions that the network and algorithm produce the same output on. When the II accuracy is 100%, the causal abstraction relation holds between the network and a simplified version of the algorithm where only one high-level variable exists.

The first step is to specify an alignment:

In [35]:
alignment = {
    "V1": {"layer": 1, "start": 0, "end": embedding_dim}, 
    "V2": {"layer": 1, "start": embedding_dim, "end": embedding_dim*2}}

In essence, this reflects a hypothesis that we will find the equality label for the first two inputs in the first four neurons in layer 1, and that we'll find the equality label for the second two inputs in the next four neurons in layer 1. This is of course just one of a great many hypotheses we could state. 

The function `interchange_intervention` packages up the multi-step process we walked through above:

In [36]:
def interchange_intervention(model, base, source, get_coord, output_coord):
    intervention = model.retrieve_activations(source, get_coord, None)
    get_coord["intervention"] = intervention
    return model.retrieve_activations(base, output_coord, get_coord)

In [37]:
output_coord = {"layer": 3, "start": 0, "end": 2}

Example: 

In [38]:
interchange_intervention(
    model, 
    base=X_same_different, 
    source=X_different_same, 
    get_coord=ii_coord, 
    output_coord=output_coord)

tensor([[ 7.9475, -9.5349]], device='cuda:0', grad_fn=<SliceBackward0>)

So that we can run out high-level model on our vector examples, we define a helper function to parse them into their component inputs:

In [39]:
def convert_input(tensor, embedding_dim):
    return [tuple(tensor[0, embedding_dim*k:embedding_dim*(k+1)].flatten().tolist()) 
            for k in range(4)]

Illustration:

In [40]:
compute_A(convert_input(X_same_different, embedding_dim), {})['output']

False

In [41]:
compute_A(convert_input(X_different_same, embedding_dim), {})['output']

False

In [42]:
compute_interchange_A(
    convert_input(X_different_same, embedding_dim),
    convert_input(X_same_different, embedding_dim),
    variable="V1")['output']    

True

The function `compute_ii_accuracy` puts these pieces together in the context of a full evaluation on a set of examples:

In [43]:
def compute_ii_accuracy(X_assess, model, variable, output_coord):
    labels = []
    predictions = []
    for base, source in itertools.product(X_assess, repeat=2):
        base = base.unsqueeze(0)
        source = source.unsqueeze(0)
        algorithm_output = compute_interchange_A(
            convert_input(base, embedding_dim), 
            convert_input(source, embedding_dim), 
            variable)
        labels.append(int(algorithm_output["output"]))
        network_output = interchange_intervention(
            model, 
            base,
            source,
            alignment[variable],
            output_coord)
        pred = network_output.argmax(axis=1)
        predictions.append(int(pred))
    return labels, predictions

First, let's assess the hypothesis that **V1** is encoded at our chosen site:

In [44]:
print(classification_report(*compute_ii_accuracy(X_test[:100], model, "V1", output_coord)))

              precision    recall  f1-score   support

           0       0.52      0.61      0.56      5060
           1       0.51      0.42      0.46      4940

    accuracy                           0.51     10000
   macro avg       0.51      0.51      0.51     10000
weighted avg       0.51      0.51      0.51     10000



And then the corresponding assessment for **V2**:

In [45]:
print(classification_report(*compute_ii_accuracy(X_test[:100], model, "V2", output_coord)))

              precision    recall  f1-score   support

           0       0.61      0.47      0.53      5060
           1       0.56      0.69      0.62      4940

    accuracy                           0.58     10000
   macro avg       0.59      0.58      0.58     10000
weighted avg       0.59      0.58      0.58     10000



We have low accuracy for both **V1** and **V2**, meaning that under this alignment the neural network does not compute either variable. In other words, we have no evidence that this network computes simple equality relations to solve this hierarchical equality task. The goal of interchange intervention training is to change this. We turn to that method next.

## Interchange Intervention Training (IIT)

Original IIT [Geiger\*, Wu\*, Lu\*, Rozner, Kreiss, Icard, Goodman, and Potts (2021)](https://arxiv.org/abs/2112.00826)

IIT for model distillation [ Wu\*,Geiger\*, Rozner, Kreiss, Lu, Icard, Goodman, and Potts (2021)](https://arxiv.org/abs/2112.02505)

Interchange intervention training is a method for training a neural network to conform to the causal structure of a high-level algorithm. Conceptually, it is a direct extension of the causal abstraction analysis we just performed, except instead of **evaluating** whether the neural network and algorithm produce the same outputs under aligned interchange interventions, we are now **training** the neural network to produce the output of the algorithm under aligned interchange interventions.


In [46]:
V1 = 0
V2 = 1
data_size = 10000
both = 2
id_to_coords = {V1:{1: [{"layer":1, "start":0, "end":embedding_dim}]}, \
    V2: {1: [{"layer":1, "start":embedding_dim, "end":embedding_dim*2}]}, \
    both: {1: [{"layer":1, "start":0, "end":embedding_dim},{"layer":1, "start":embedding_dim, "end":embedding_dim*2}]}}

X_base_train, X_sources_train, y_base_train, y_IIT_train, interventions = iit.get_IIT_equality_dataset("V1", embedding_dim ,data_size)

iit_model = TorchDeepNeuralClassifierIIT(
    hidden_dim=embedding_dim*4, 
    hidden_activation=torch.nn.ReLU(), 
    num_layers=3,
    id_to_coords=id_to_coords)

_ = iit_model.fit(X_base_train, X_sources_train, y_base_train, y_IIT_train,interventions)

Stopping after epoch 694. Training loss did not improve more than tol=1e-05. Final error is 0.0011611922091105953.

In [47]:
iit_model.layers

[ActivationLayer(
   (linear): Linear(in_features=16, out_features=16, bias=True)
   (activation): ReLU()
 ),
 ActivationLayer(
   (linear): Linear(in_features=16, out_features=16, bias=True)
   (activation): ReLU()
 ),
 ActivationLayer(
   (linear): Linear(in_features=16, out_features=16, bias=True)
   (activation): ReLU()
 ),
 Linear(in_features=16, out_features=2, bias=True)]

In [48]:
X_base_test, X_sources_test, y_base_test, y_IIT_test, interventions = iit.get_IIT_equality_dataset("V1", embedding_dim,data_size)

IIT_preds, base_preds = iit_model.model(iit_model.prep_input(X_base_test, X_sources_test, interventions))
IIT_preds = np.array(IIT_preds.argmax(axis=1).cpu())
base_preds = np.array(base_preds.argmax(axis=1).cpu())
print(classification_report(y_base_test, base_preds))
print(classification_report(y_IIT_test, IIT_preds))


X_base_test, X_sources_test, y_base_test, y_IIT_test, interventions = iit.get_IIT_equality_dataset("V2", embedding_dim,data_size)
IIT_preds, base_preds = iit_model.model(iit_model.prep_input(X_base_test, X_sources_test, interventions))
IIT_preds = np.array(IIT_preds.argmax(axis=1).cpu())
base_preds = np.array(base_preds.argmax(axis=1).cpu())
print(classification_report(y_IIT_test, IIT_preds))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5000
           1       1.00      1.00      1.00      5000

    accuracy                           1.00     10000
   macro avg       1.00      1.00      1.00     10000
weighted avg       1.00      1.00      1.00     10000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5000
           1       1.00      1.00      1.00      5000

    accuracy                           1.00     10000
   macro avg       1.00      1.00      1.00     10000
weighted avg       1.00      1.00      1.00     10000

              precision    recall  f1-score   support

           0       0.53      0.50      0.51      5000
           1       0.53      0.56      0.54      5000

    accuracy                           0.53     10000
   macro avg       0.53      0.53      0.53     10000
weighted avg       0.53      0.53      0.53     10000



Observe that we now have perfect interchange intervention accuracy **V1** meaning that under this alignment the neural network computes whether the first pair of inputs are equal. However, we still have low interchange intervention accuracy for **V2**, meaning that under this alignment the neural network doesn't compute whether the second pair of inputs are equal.

This is expected, because we only trained the network to compute **V1**.

We can train the network to compute both **V1** and **V2**.

In [None]:
model = TorchDeepNeuralClassifierIIT(hidden_dim=embedding_dim*4, hidden_activation=torch.nn.ReLU(), num_layers=3, id_to_coords=id_to_coords)


v1data = iit.get_IIT_equality_dataset("V1", embedding_dim, data_size)
v2data = iit.get_IIT_equality_dataset("V2", embedding_dim, data_size)
X_base_train = torch.cat([v1data[0],v2data[0]], dim=0)
X_sources_train = [ torch.cat([v1data[1][i],v2data[1][i]], dim=0) for i in range(len(v1data[1]))] 
y_base_train = torch.cat([v1data[2],v2data[2]])
y_IIT_train = torch.cat([v1data[3],v2data[3]])
interventions = torch.cat([v1data[4],v2data[4]])

_ = model.fit(X_base_train, X_sources_train, y_base_train, y_IIT_train, interventions)




Finished epoch 305 of 1000; error is 0.8625252051278949

In [None]:
X_base_test, X_sources_test, y_base_test, y_IIT_test, interventions = iit.get_IIT_equality_dataset("V1", embedding_dim,data_size)

IIT_preds, base_preds = model.model(model.prep_input(X_base_test, X_sources_test, interventions))
IIT_preds = np.array(IIT_preds.argmax(axis=1).cpu())
base_preds = np.array(base_preds.argmax(axis=1).cpu())
print(classification_report(y_base_test, base_preds))
print(classification_report(y_IIT_test, IIT_preds))


X_base_test, X_sources_test, y_base_test, y_IIT_test, interventions = iit.get_IIT_equality_dataset("V2", embedding_dim,data_size)
IIT_preds, base_preds = model.model(model.prep_input(X_base_test, X_sources_test, interventions))
IIT_preds = np.array(IIT_preds.argmax(axis=1).cpu())
base_preds = np.array(base_preds.argmax(axis=1).cpu())
print(classification_report(y_IIT_test, IIT_preds))


## Multisource IIT

We can also extend IIT to a setting where a base input has several source inputs. Consider an intervention to the high-level algorithm that fixes both intermediate variables. We can perform an interchange intervention on the neural network where the neurons aligned with the left intermediate variable have one source input and the neurons aligned with the right intermediate variable have a second source input.

In [None]:

def compute_multisource_interchange_A(base,source,source2):
    return compute_A(base, {"V1":compute_A(source, {})["V1"], "V2":compute_A(source2, {})["V2"]})

def multisource_interchange_intervention(model, base, sources, coords, output_coord):
    source_activations = model.retrieve_activations(sources[0], coords[1][0],None)
    source_activations2 = model.retrieve_activations(sources[1], coords[1][1],None)
    coords = copy.deepcopy(coords)
    coords[1][0]["intervention"] = source_activations
    coords[1][1]["intervention"] = source_activations2
    return model.retrieve_activations(base, output_coord, coords)

def compute_multisource_IIT_accuracy(model, coords):
    labels = []
    predictions = []
    for base in itertools.product([pentagon, triangle, square], repeat=4):
        for source in itertools.product([pentagon, triangle, square], repeat=4):
            for source2 in itertools.product([pentagon, triangle, square], repeat=4):
                basetensor = torch.cat([torch.tensor([base[k]]) for k in range(4)], 1)
                sourcetensor = torch.cat([torch.tensor([source[k]]) for k in range(4)],1)
                sourcetensor2 = torch.cat([torch.tensor([source2[k]]) for k in range(4)],1)
                algorithm_output = compute_multisource_interchange_A(convert_input(basetensor, embedding_dim), convert_input(sourcetensor, embedding_dim),convert_input(sourcetensor2, embedding_dim))
                if algorithm_output["output"]:   
                    labels.append(TRUE_LABEL)
                else:
                    labels.append(FALSE_LABEL)
                get_coord = {"layer":3, "start":0, "end":2}
                network_output = multisource_interchange_intervention(model, basetensor, [sourcetensor,sourcetensor2], coords, get_coord).argmax(axis=1)
                predictions.append(int(network_output))
    return labels, predictions

In [None]:
v1data = iit.get_IIT_equality_dataset("V1", embedding_dim ,data_size)
v2data = iit.get_IIT_equality_dataset("V2", embedding_dim ,data_size)
bothdata = iit.get_IIT_equality_dataset_both(embedding_dim ,data_size)
X_base_train = torch.cat([v1data[0],v2data[0], bothdata[0]], dim=0)
X_sources_train = [ torch.cat([v1data[1][0],v2data[1][0], bothdata[1][i]], dim=0) for i in range(len(bothdata[1]))] 
y_base_train = torch.cat([v1data[2],v2data[2],bothdata[2]])
y_IIT_train = torch.cat([v1data[3],v2data[3], bothdata[3]])
interventions = torch.cat([v1data[4],v2data[4], bothdata[4]])

model = TorchDeepNeuralClassifierIIT(hidden_dim=embedding_dim*4, hidden_activation=torch.nn.ReLU(), num_layers=3, id_to_coords=id_to_coords)

_ = model.fit(X_base_train, X_sources_train, y_base_train, y_IIT_train, interventions)
