In [1]:
import torch
import random
import copy
import itertools
import numpy as np
from sklearn.metrics import classification_report
from iit import get_equality_dataset, get_IIT_equality_dataset, get_IIT_equality_dataset_both
from torch_deep_neural_classifier_iit import TorchDeepNeuralClassifierIIT, TorchDeepNeuralClassifier
from torch_rnn_classifier import TorchRNNClassifier
random.seed(42)

In [2]:
torch.__version__

'1.9.0'

# Table of Contents
1. [Hierarchical Equality Dataset](#Hierarchical-Equality-Dataset)
2. [High-Level Tree-Structured Algorithm](#The-High-Level-Tree-Structured-Algorithm)
3. [A Fully-Connected Feed-Forward Neural Network](#A-Fully-Connected-Feed-Forward-Neural-Network)
4. [Causal Abstraction](#Causal-Abstraction)
5. [Interchange Intervention Training (IIT)](#Interchange-Intervention-Training-(IIT))

## Hierarchical Equality Dataset  
[Geiger, Carstensen, Frank, and Potts (2020)](https://arxiv.org/abs/2006.07968)

We will use a hierarchical equality task to present IIT. We define the hierarchical equality task as follows: The input is two pairs of objects and the output is **true** if both pairs contain the same object or if both pairs contain different objects and **false** otherwise. For example, AABB and ABCD are both labeled **true** while ABCC and BBCD are both labeled **false**. 

## The High-Level Tree-Structured Algorithm

Let $\mathcal{A}$ be the simple tree structured algorithm that solves this task by applying a simple equality relation three times: Compute whether the first two inputs are equal, compute whether the second two inputs are equal, then compute whether
the truth-valued outputs of these first two computations are equal. We visually define $\mathcal{A}$ below and then define a python function that computes $\mathcal{A}$, possibly under an intervention that sets $V_1$ and/or $V_2$ to fixed values.

<img src="fig/IIT/PremackFunctions.png" width="500"/>
<img src="fig/IIT/PremackGraph.png" width="500"/>

In [3]:
def compute_A(input, intervention):
    graph = dict()
    for i, object in enumerate(input):
        graph["input" + str(i+1)] = object
    if "V1" in intervention:
        graph["V1"] = intervention["V1"]
    else:
        graph["V1"] = graph["input1"] == graph["input2"]
    if "V2" in intervention:
        graph["V2"] = intervention["V2"]
    else:
        graph["V2"] = graph["input3"] == graph["input4"]
    graph["output"] = graph["V1"] == graph["V2"]
    return graph

### The algorithm with no intervention

First, observe the behavior of the algorithm whhen we provide the input **(pentagon,pentagon, triangle, square)** with no intervention. We show this visually and by using our **compute_A** function.

<img src="fig/IIT/PremackNoIntervention.png" width="500"/>

In [4]:
compute_A(("pentagon", "pentagon", "triangle", "square"), {})

{'input1': 'pentagon',
 'input2': 'pentagon',
 'input3': 'triangle',
 'input4': 'square',
 'V1': True,
 'V2': False,
 'output': False}

### The algorithm with an intervention

Observe the behavior of the algorithm whhen we provide the input **(square,pentagon, triangle, triangle)** with an intervention setting **V1** to **False**. We show this visually and by using our **compute_A** function.

<img src="fig/IIT/PremackIntervention.png" width="500"/>

In [5]:
compute_A(("square", "pentagon", "triangle", "triangle"), {"V1":True})

{'input1': 'square',
 'input2': 'pentagon',
 'input3': 'triangle',
 'input4': 'triangle',
 'V1': True,
 'V2': True,
 'output': True}

### The algorithm with an interchange intervention

Finaally, observe the behavior of the algorithm when we provide the base input **(square,pentagon, triangle, triangle)** with an intervention setting **V1** to be the value it would be for the source input **(pentagon,pentagon, triangle, square)**. We show this visually and by using our **compute_A** function.

<img src="fig/IIT/algorithmII.png" width="800"/>

In [6]:
# seems like this interchange is working the other way around? (swapping V1 from right to left)

def compute_interchange_A(base,source, variable):
    return compute_A(base, {variable:compute_A(source, {})[variable]})
    
compute_interchange_A(("pentagon", "pentagon", "triangle", "square"), ("square", "pentagon", "triangle", "triangle"), "V1")

{'input1': 'pentagon',
 'input2': 'pentagon',
 'input3': 'triangle',
 'input4': 'square',
 'V1': False,
 'V2': False,
 'output': True}

# A Fully-Connected Feed-Forward Neural Network

We will train a three layer feed-forward neural network on this task where each object has a random vector assigned to it and the objects in training are disjoint from the objects seen in testing.

<img src="fig/IIT/Network.png" width="800"/>

In [7]:
class InterventionableTorchDeepNeuralClassifier(TorchDeepNeuralClassifier):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def make_hook(self, gets, sets, layer):
        # a hook is some lambda function that manipulates the internal state of the model?
        def hook(model, input, output):
            layer_gets, layer_sets = [], []
            if gets is not None and layer in gets:
                layer_gets = gets[layer]
            if sets is not None and layer in sets:
                layer_sets = sets[layer]
            # get commands retrieve values from a computed hidden layer
            for get in layer_gets:
                self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}'] = output[:,get["start"]: get["end"] ]
            # set commands fix values of a hidden layer
            # do these changes propagate?
            for set in layer_sets:
                output[:,set["start"]: set["end"]] = set["intervention"]
        return hook

    def _gets_sets(self,gets=None, sets = None):
        handlers = []
        for layer in range(len(self.layers)):
            hook = self.make_hook(gets,sets, layer)
            # what does it mean to register a forward hook? is a hook similar to a replacement/intervention
            # on a hidden layer?
            both_handler = self.layers[layer].register_forward_hook(hook)
            handlers.append(both_handler)
        return handlers

    # overwrites/fetches hidden layer for neural network?
    def retrieve_activations(self, input, get, sets):
        input = input.type(torch.FloatTensor).to(self.device)
        self.activation = dict()
        handlers = self._gets_sets({get["layer"]:[get]}, sets)
        logits = self.model(input) # what do we use the logits for? are we running our model through all layers?
        for handler in handlers:
            handler.remove()
        return self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}']

In [8]:

TRUE_LABEL = 1
FALSE_LABEL = 0

data_size = 1024 * 10
embedding_dim = 4
X_train, X_test, y_train, y_test, test_dataset = get_equality_dataset(embedding_dim,10000)

model = InterventionableTorchDeepNeuralClassifier(hidden_dim=4*embedding_dim, hidden_activation=torch.nn.ReLU(), num_layers=3)
_ = model.fit(X_train,y_train)


Stopping after epoch 649. Training loss did not improve more than tol=1e-05. Final error is 0.00046398123049584683.

Observe that this neural network achieves near perfect performance on its test set.

In [9]:
# double-checking: top example seems to be {A, A}, {B, C}? 
# (this can change depending on the order of the cells being run)
print(X_train[0], y_train[0])
preds = model.predict(X_train)
print("Train Results")
print(classification_report(y_train, preds))
preds = model.predict(X_test)

# i see near-perfect performance on the training set, but the performance on the test set doesn't look all that good...
# given the comment above, is this expected?
print("\n\n\nTest Results")
print(classification_report(y_test, preds))

tensor([ 0.4365, -0.0289, -0.2255,  0.3735,  0.4365, -0.0289, -0.2255,  0.3735,
         0.3634,  0.1236,  0.0243, -0.1352,  0.4411, -0.2673, -0.4682, -0.3980],
       dtype=torch.float64) 0
Train Results
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5000
           1       1.00      1.00      1.00      5000

    accuracy                           1.00     10000
   macro avg       1.00      1.00      1.00     10000
weighted avg       1.00      1.00      1.00     10000




Test Results
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5000
           1       1.00      1.00      1.00      5000

    accuracy                           1.00     10000
   macro avg       1.00      1.00      1.00     10000
weighted avg       1.00      1.00      1.00     10000



### The network with no intervention

First, observe the behavior of the network when we provide the input **(pentagon,pentagon, triangle, square)** with no intervention. We assign each shape a random vector.

In [10]:
pentagon = [random.uniform(-0.5,0.5) for _ in range(embedding_dim)]
triangle = [random.uniform(-0.5,0.5) for _ in range(embedding_dim)]
square = [random.uniform(-0.5,0.5) for _ in range(embedding_dim)]

# how can I tell the final prediction from the neural activations?

print("Input:",[[*pentagon,*pentagon,*triangle,*square]])
for k in range(len(model.layers)):
    # build a get request for which layer and which values we want to see
    get_coord = {"layer":k, "start":0, "end":embedding_dim*4}
    print(f"\nLayer {k}:", model.layers[k])
    # prints actual hidden values (activations) of the neural network layer
    print("\nNeural Activations:", model.retrieve_activations(torch.tensor([[*pentagon,*pentagon,*triangle,*square]]), get_coord, None))


Input: [[-0.027239528153731096, -0.1708941925356774, 0.10471252264638642, -0.03486029578210181, -0.027239528153731096, -0.1708941925356774, 0.10471252264638642, -0.03486029578210181, 0.42314000042542865, -0.44822583072577604, 0.3215207354592715, 0.35969762210932066, 0.23637433807549235, 0.4279018218631253, -0.0706330407827378, 0.3186659056305904]]

Layer 0: ActivationLayer(
  (linear): Linear(in_features=16, out_features=16, bias=True)
  (activation): ReLU()
)

Neural Activations: tensor([[1.3238e-01, 1.0843e-01, 0.0000e+00, 1.7661e-01, 5.6381e-01, 0.0000e+00,
         4.7403e-01, 4.4808e-03, 3.3821e-01, 2.9191e-02, 2.3711e-05, 0.0000e+00,
         0.0000e+00, 1.0986e-03, 0.0000e+00, 5.7847e-01]],
       grad_fn=<SliceBackward>)

Layer 1: ActivationLayer(
  (linear): Linear(in_features=16, out_features=16, bias=True)
  (activation): ReLU()
)

Neural Activations: tensor([[0.0000, 0.6655, 0.3471, 0.5689, 1.0628, 0.6814, 0.0835, 0.4683, 0.0000,
         0.4746, 0.4940, 0.0000, 0.0000, 0.6

### The network with an intervention

Now, observe the behavior of the network when we provide the input **(pentagon,pentagon, triangle, square)** with an intervention that zeros out five neurons after the first hidden layer.

In [11]:
# command to actually change the value of the first layer
# does this change the computations of the following layers? Oh! Is this why we re-compute the model (from start
# to finish) each time we call retreive_activations()?
set_coord = {"layer":1, "start":0, "end":embedding_dim, "intervention": torch.tensor([[0 for _ in range(embedding_dim)]])}

print("Input:",[[*pentagon,*pentagon,*triangle,*square]])
for k in range(len(model.layers)):
    get_coord = {"layer":k, "start":0, "end":embedding_dim*4}
    print(f"\nLayer {k}:", model.layers[k])
    # we pass in the set request every call, but this will really only apply to layer 1
    print("\nNeural Activations:", model.retrieve_activations(torch.tensor([[*pentagon,*pentagon,*triangle,*square]]), get_coord, set_coord))

# where should I see the 0's pop up?
# this intervention didn't change the output of the model (but I imagine this is expected?)
    

Input: [[-0.027239528153731096, -0.1708941925356774, 0.10471252264638642, -0.03486029578210181, -0.027239528153731096, -0.1708941925356774, 0.10471252264638642, -0.03486029578210181, 0.42314000042542865, -0.44822583072577604, 0.3215207354592715, 0.35969762210932066, 0.23637433807549235, 0.4279018218631253, -0.0706330407827378, 0.3186659056305904]]

Layer 0: ActivationLayer(
  (linear): Linear(in_features=16, out_features=16, bias=True)
  (activation): ReLU()
)

Neural Activations: tensor([[1.3238e-01, 1.0843e-01, 0.0000e+00, 1.7661e-01, 5.6381e-01, 0.0000e+00,
         4.7403e-01, 4.4808e-03, 3.3821e-01, 2.9191e-02, 2.3711e-05, 0.0000e+00,
         0.0000e+00, 1.0986e-03, 0.0000e+00, 5.7847e-01]],
       grad_fn=<SliceBackward>)

Layer 1: ActivationLayer(
  (linear): Linear(in_features=16, out_features=16, bias=True)
  (activation): ReLU()
)

Neural Activations: tensor([[0.0000, 0.6655, 0.3471, 0.5689, 1.0628, 0.6814, 0.0835, 0.4683, 0.0000,
         0.4746, 0.4940, 0.0000, 0.0000, 0.6

### The network with an interchange intervention

Finally, observe the behavior of the network when we provide the input **(square, pentagon, triangle, triangle)** with an intervention that sets five neurons after the first hidden layer to the values they achieve for the source input **(pentagon,pentagon, triangle, square)**.
<img src="fig/IIT/networkII.png" width="500"/>

In [12]:
# this time our set command depends upon a separate input (in this case, {P, P}, {T, S})
intervention = model.retrieve_activations(torch.tensor([[*pentagon,*pentagon,*triangle,*square]]), {"layer":1, "start":0, "end":embedding_dim},None)

set_coord = {"layer":1, "start":0, "end":embedding_dim, "intervention": intervention}

print("Input:",[[*pentagon,*pentagon,*triangle,*square]])
for k in range(len(model.layers)):
    get_coord = {"layer":k, "start":0, "end":embedding_dim*4}
    print(f"\nLayer {k}:", model.layers[k])
    print("\nNeural Activations:", model.retrieve_activations(torch.tensor([[*square,*pentagon,*triangle,*triangle]]), get_coord, set_coord))

# how can I tell whether this changed the final output? (how is the final output, T/F, calculated?)


Input: [[-0.027239528153731096, -0.1708941925356774, 0.10471252264638642, -0.03486029578210181, -0.027239528153731096, -0.1708941925356774, 0.10471252264638642, -0.03486029578210181, 0.42314000042542865, -0.44822583072577604, 0.3215207354592715, 0.35969762210932066, 0.23637433807549235, 0.4279018218631253, -0.0706330407827378, 0.3186659056305904]]

Layer 0: ActivationLayer(
  (linear): Linear(in_features=16, out_features=16, bias=True)
  (activation): ReLU()
)

Neural Activations: tensor([[0.5472, 0.0215, 0.0000, 0.1866, 0.5763, 0.0000, 0.0000, 0.0000, 0.0000,
         0.2921, 0.4541, 0.2680, 0.0000, 0.6818, 0.0050, 0.0000]],
       grad_fn=<SliceBackward>)

Layer 1: ActivationLayer(
  (linear): Linear(in_features=16, out_features=16, bias=True)
  (activation): ReLU()
)

Neural Activations: tensor([[1.3066, 0.6663, 0.0000, 0.8273, 0.3432, 0.7402, 0.7459, 0.0000, 1.0878,
         0.0000, 0.0000, 0.0000, 0.5465, 0.0489, 0.3966, 0.3415]],
       grad_fn=<SliceBackward>)

Layer 2: Activati

# Causal Abstraction

We defined a **high-level tree structured agorithm** that solves the hierarchical equality task.

We trained a **low-level fully connected neural network** that solves the hierarchical equality task.

A formal theory of **causal abstraction** describes the conditions that must hold for the high-level tree structured algorithm to be a **simplified and faithful description** of the neural network: 

**An algorithm is a causal abstraction of a neural network if and only if for all base and source inputs, the algorithm and network provides the same output under an aligned interchange intervention.**

Below, we define an alignment between the neural network and the algorithm and a function to compute the **interchange intervention training accuracy** for a high-level variable, which is the percentage of aligned interchange interventions that the network and algorithm produce the same output on. When the IIT accuracy is 100%, the causal abstraction relation holds between the network and a simplified version of the algorithm where only one high-level variable exists.

<img src="fig/IIT/alignment.png" width="500"/>

We compute the IIT accuracy on our toy domain where each entity is either a pentagon, square, or triangle.

In [13]:
# does our alignment correspond exactly to the picture above? seems to me that the dimensions are slightly different?

# we say that the first layer, which has a size of embedding_dim * 2 (in this case, 8),
# aligns with V1 on the left and V2 on the right? (which, visually, makes sense)
alignment = {"V1": {"layer":1, "start":0, "end":embedding_dim}, "V2": {"layer":1, "start":embedding_dim, "end":embedding_dim*2}}

# performs interchange intervention on a model, given base and source inputs, and the locations
# of where we extract the output and where we want to intervene
def interchange_intervention(model, base, source, int_coord, output_coord):
    # first run model on source input (start to finish), and retreive activation at the coordinate of intervention
    intervention = model.retrieve_activations(source, int_coord[1][0],None)
    # define our "set" query to be at the location of the intervention, and with the value from the base computation
    int_coord[1][0]["intervention"] = intervention
    # run the model (start to finish) on the base input with our intervention from the source input in place
    return model.retrieve_activations(base, output_coord, int_coord)

def convert_input(tensor, embedding_dim):
    return [tuple(tensor[0,embedding_dim*k:embedding_dim*(k+1)].flatten().tolist()) for k in range(4)]

def compute_IIT_accuracy(variable, model):
    labels = []
    predictions = []
    # generate all possible permutations of P, T, S in 4 locations
    # and iterate as pairs
    for base in itertools.product([pentagon, triangle, square], repeat=4):
        for source in itertools.product([pentagon, triangle, square], repeat=4):
            basetensor = torch.cat([torch.tensor([base[k]]) for k in range(4)], 1)
            sourcetensor = torch.cat([torch.tensor([source[k]]) for k in range(4)],1)
            algorithm_output = compute_interchange_A(convert_input(basetensor, embedding_dim), convert_input(sourcetensor, embedding_dim), variable)
            if algorithm_output["output"]:   
                labels.append(TRUE_LABEL)
            else:
                labels.append(FALSE_LABEL)
            # our output comes from both activation values in the 3rd layer (how do we evaluate this as a T/F?)
            output_coord = {"layer":3, "start":0, "end":2}
            network_output = interchange_intervention(model, basetensor, sourcetensor,{1:[copy.deepcopy(alignment[variable])]}, output_coord).argmax(axis=1)
            predictions.append(int(network_output))
    return labels, predictions

Observe that we have low IIT accuracy for both **V1** and **V2**, meaning that under this alignment the neural network does not compute either variable. We have no evidence that this network computes simple equality relations to solve this hierarchical equality task.

In [14]:
# in IIT, do we care both about the original accuracy and the IIT accuracy, or the IIT accuracy alone?
# is it fair to interpret IIT accuracy as the extent to which our model is a causal abstraction of the causal
# graph we using to compute IIT accuracy?

print(classification_report(*compute_IIT_accuracy("V1", model)))
print(classification_report(*compute_IIT_accuracy("V2", model)))

              precision    recall  f1-score   support

           0       0.48      0.29      0.36      2916
           1       0.57      0.75      0.65      3645

    accuracy                           0.55      6561
   macro avg       0.53      0.52      0.50      6561
weighted avg       0.53      0.55      0.52      6561

              precision    recall  f1-score   support

           0       0.53      0.28      0.37      2916
           1       0.58      0.80      0.67      3645

    accuracy                           0.57      6561
   macro avg       0.55      0.54      0.52      6561
weighted avg       0.56      0.57      0.54      6561



# Interchange Intervention Training (IIT)

Original IIT [Geiger\*, Wu\*, Lu\*, Rozner, Kreiss, Icard, Goodman, and Potts (2021)](https://arxiv.org/abs/2112.00826)

IIT for model distillation [ Wu\*,Geiger\*, Rozner, Kreiss, Lu, Icard, Goodman, and Potts (2021)](https://arxiv.org/abs/2112.02505)

Interchange intervention training is a method for training a neural network to conform to the causal structure of a high-level algorithm. Conceptually, it is a direct extension of the causal abstraction analysis we just performed, except instead of **evaluating** whether the neural network and algorithm produce the same outputs under aligned interchange interventions, we are now **training** the neural network to produce the output of the algorithm under aligned interchange interventions.


In [15]:
V1 = 0
V2 = 1
both = 2
# similar to our alignment in the IIT accuracy section?
# aligning V1 to left side of layer 1, and V2 to the right side
# we are defining both as a list with two values -- why not encode it as a single range from 0  to dim * 2?
id_to_coords = {V1:{1: [{"layer":1, "start":0, "end":embedding_dim}]}, \
    V2: {1: [{"layer":1, "start":embedding_dim, "end":embedding_dim*2}]}, \
    both: {1: [{"layer":1, "start":0, "end":embedding_dim},{"layer":1, "start":embedding_dim, "end":embedding_dim*2}]}}

# gives back an IIT dataset based off of the Premack dataset, coming up with 
# all possible permutations of same/different shape pairs and same/different base-source pairs?
X_base_train, X_sources_train, y_base_train, y_IIT_train, interventions = get_IIT_equality_dataset("V1", embedding_dim ,data_size)

# this is a different model from the one we defined in the previous cell, but with a similar idea?
model = TorchDeepNeuralClassifierIIT(hidden_dim=embedding_dim*4, hidden_activation=torch.nn.ReLU(), num_layers=3, id_to_coords=id_to_coords)
# model.fit() function internally calls on model.create_dataset(), which creates dataset in a way that pairs off
# source and base inputs?
_ = model.fit(X_base_train, X_sources_train, y_base_train, y_IIT_train, interventions)

# this is a runtime error I've also encountered in antra (with no change to the original code)
# could this be due to mismatching pytorch versions??

Stopping after epoch 652. Training loss did not improve more than tol=1e-05. Final error is 0.0011470201934571378.

In [25]:
X_base_test, X_sources_test, y_base_test, y_IIT_test, interventions = get_IIT_equality_dataset("V1", embedding_dim,data_size)

IIT_preds, base_preds = model.model(model.prep_input(X_base_test, X_sources_test, interventions))
IIT_preds = np.array(IIT_preds.argmax(axis=1).cpu())
base_preds1 = np.array(base_preds.argmax(axis=1).cpu())
print(classification_report(y_base_test, base_preds1))
print(classification_report(y_IIT_test, IIT_preds))


X_base_test, X_sources_test, y_base_test, y_IIT_test, interventions = get_IIT_equality_dataset("V2", embedding_dim,data_size)
IIT_preds, base_preds = model.model(model.prep_input(X_base_test, X_sources_test, interventions))
IIT_preds = np.array(IIT_preds.argmax(axis=1).cpu())
base_preds2 = np.array(base_preds.argmax(axis=1).cpu())
print(classification_report(y_IIT_test, IIT_preds))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5120
           1       1.00      1.00      1.00      5120

    accuracy                           1.00     10240
   macro avg       1.00      1.00      1.00     10240
weighted avg       1.00      1.00      1.00     10240

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5120
           1       1.00      1.00      1.00      5120

    accuracy                           1.00     10240
   macro avg       1.00      1.00      1.00     10240
weighted avg       1.00      1.00      1.00     10240

              precision    recall  f1-score   support

           0       0.59      0.74      0.66      5120
           1       0.65      0.49      0.56      5120

    accuracy                           0.62     10240
   macro avg       0.62      0.62      0.61     10240
weighted avg       0.62      0.62      0.61     10240



Observe that we now have perfect IIT accuracy **V1** meaning that under this alignment the neural network computes whether the first pair of inputs are equal. However, we still have low IIT accuracy for **V2**, meaning that under this alignment the neural network doesn't compute whether the second pair of inputs are equal.

This is expected, because we only trained the network to compute **V1**.

We can train the network to compute both **V1** and **V2**.

In [26]:
model = TorchDeepNeuralClassifierIIT(hidden_dim=embedding_dim*4, hidden_activation=torch.nn.ReLU(), num_layers=3, id_to_coords=id_to_coords)


v1data = get_IIT_equality_dataset("V1", embedding_dim, data_size)
v2data = get_IIT_equality_dataset("V2", embedding_dim, data_size)
X_base_train = torch.cat([v1data[0],v2data[0]], dim=0)
X_sources_train = [ torch.cat([v1data[1][i],v2data[1][i]], dim=0) for i in range(len(v1data[1]))] 
y_base_train = torch.cat([v1data[2],v2data[2]])
y_IIT_train = torch.cat([v1data[3],v2data[3]])
interventions = torch.cat([v1data[4],v2data[4]])

_ = model.fit(X_base_train, X_sources_train, y_base_train, y_IIT_train, interventions)

Stopping after epoch 532. Training loss did not improve more than tol=1e-05. Final error is 0.03824407953652553.

In [27]:
X_base_test, X_sources_test, y_base_test, y_IIT_test, interventions = get_IIT_equality_dataset("V1", embedding_dim,data_size)

# training it once for causally abstracting V1
IIT_preds, base_preds = model.model(model.prep_input(X_base_test, X_sources_test, interventions))
IIT_preds = np.array(IIT_preds.argmax(axis=1).cpu())
base_preds = np.array(base_preds.argmax(axis=1).cpu())
print(classification_report(y_base_test, base_preds))
print(classification_report(y_IIT_test, IIT_preds))

# training it again for causally abstracting V2
X_base_test, X_sources_test, y_base_test, y_IIT_test, interventions = get_IIT_equality_dataset("V2", embedding_dim,data_size)
IIT_preds, base_preds = model.model(model.prep_input(X_base_test, X_sources_test, interventions))
IIT_preds = np.array(IIT_preds.argmax(axis=1).cpu())
base_preds = np.array(base_preds.argmax(axis=1).cpu())
print(classification_report(y_IIT_test, IIT_preds))

# seems that training one after the other doesn't "undo" previous learning - will this always be the case?
# does order matter, or is there anything we should watch out for here?


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5120
           1       1.00      1.00      1.00      5120

    accuracy                           1.00     10240
   macro avg       1.00      1.00      1.00     10240
weighted avg       1.00      1.00      1.00     10240

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5120
           1       1.00      1.00      1.00      5120

    accuracy                           1.00     10240
   macro avg       1.00      1.00      1.00     10240
weighted avg       1.00      1.00      1.00     10240

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5120
           1       1.00      1.00      1.00      5120

    accuracy                           1.00     10240
   macro avg       1.00      1.00      1.00     10240
weighted avg       1.00      1.00      1.00     10240



In [28]:
print(classification_report(*compute_IIT_accuracy("V1", model.model)))
print(classification_report(*compute_IIT_accuracy("V2", model.model)))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2916
           1       1.00      1.00      1.00      3645

    accuracy                           1.00      6561
   macro avg       1.00      1.00      1.00      6561
weighted avg       1.00      1.00      1.00      6561

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2916
           1       1.00      1.00      1.00      3645

    accuracy                           1.00      6561
   macro avg       1.00      1.00      1.00      6561
weighted avg       1.00      1.00      1.00      6561



# Multisource IIT

We can also extend IIT to a setting where a base input has several source inputs. Consider an intervention to the high-level algorithm that fixes both intermediate variables. We can perform an interchange intervention on the neural network where the neurons aligned with the left intermediate variable have one source input and the neurons aligned with the right intermediate variable have a second source input.

In [19]:
# this time, intervening on V1 and V2 at the same time? but from two different sources (so we can still generate all
# possible pairs)
def compute_multisource_interchange_A(base,source,source2):
    return compute_A(base, {"V1":compute_A(source, {})["V1"], "V2":compute_A(source2, {})["V2"]})

def multisource_interchange_intervention(model, base, sources, coords, output_coord):
    source_activations = model.retrieve_activations(sources[0], coords[1][0],None)
    source_activations2 = model.retrieve_activations(sources[1], coords[1][1],None)
    coords = copy.deepcopy(coords)
    coords[1][0]["intervention"] = source_activations
    coords[1][1]["intervention"] = source_activations2
    return model.retrieve_activations(base, output_coord, coords)

def compute_multisource_IIT_accuracy(model, coords):
    labels = []
    predictions = []
    # iterate through all possible bases, and then through all possible V1 and V2 intervention values
    # seems a bit inefficient -- how does this approach compare to many single sources? where might this approach be better?
    for base in itertools.product([pentagon, triangle, square], repeat=4):
        for source in itertools.product([pentagon, triangle, square], repeat=4):
            for source2 in itertools.product([pentagon, triangle, square], repeat=4):
                basetensor = torch.cat([torch.tensor([base[k]]) for k in range(4)], 1)
                sourcetensor = torch.cat([torch.tensor([source[k]]) for k in range(4)],1)
                sourcetensor2 = torch.cat([torch.tensor([source2[k]]) for k in range(4)],1)
                algorithm_output = compute_multisource_interchange_A(convert_input(basetensor, embedding_dim), convert_input(sourcetensor, embedding_dim),convert_input(sourcetensor2, embedding_dim))
                if algorithm_output["output"]:   
                    labels.append(TRUE_LABEL)
                else:
                    labels.append(FALSE_LABEL)
                get_coord = {"layer":3, "start":0, "end":2}
                network_output = multisource_interchange_intervention(model, basetensor, [sourcetensor,sourcetensor2], coords, get_coord).argmax(axis=1)
                predictions.append(int(network_output))
    return labels, predictions

In [20]:
# seems that model already does well for multisource IIT accuracy - will this generally be the case?
# how would you recommend choosing which nodes from the graph to train our model against?
sets = {1: [{"layer":1, "start":0, "end":embedding_dim},{"layer":1, "start":embedding_dim, "end":embedding_dim*2}]}
print(classification_report(*compute_multisource_IIT_accuracy(model.model, sets)))

              precision    recall  f1-score   support

           0       0.99      1.00      0.99    236196
           1       1.00      0.99      1.00    295245

    accuracy                           0.99    531441
   macro avg       0.99      1.00      0.99    531441
weighted avg       0.99      0.99      0.99    531441



In [21]:
v1data = get_IIT_equality_dataset("V1", embedding_dim ,data_size)
v2data = get_IIT_equality_dataset("V2", embedding_dim ,data_size)
bothdata = get_IIT_equality_dataset_both(embedding_dim ,data_size)
# are we training the model to become a causal abstraction for V1, V2, and both all at once?
# could we have done the same when considering V1 and V2?
X_base_train = torch.cat([v1data[0],v2data[0], bothdata[0]], dim=0)
X_sources_train = [ torch.cat([v1data[1][0],v2data[1][0], bothdata[1][i]], dim=0) for i in range(len(bothdata[1]))] 
y_base_train = torch.cat([v1data[2],v2data[2],bothdata[2]])
y_IIT_train = torch.cat([v1data[3],v2data[3], bothdata[3]])
interventions = torch.cat([v1data[4],v2data[4], bothdata[4]])

model = TorchDeepNeuralClassifierIIT(hidden_dim=embedding_dim*4, hidden_activation=torch.nn.ReLU(), num_layers=3, id_to_coords=id_to_coords)

_ = model.fit(X_base_train, X_sources_train, y_base_train, y_IIT_train, interventions)


Stopping after epoch 610. Training loss did not improve more than tol=1e-05. Final error is 0.14826767449267209.

In [22]:
print(classification_report(*compute_multisource_IIT_accuracy(model.model, sets)))
print(classification_report(*compute_IIT_accuracy("V1", model.model)))
print(classification_report(*compute_IIT_accuracy("V2", model.model)))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    236196
           1       1.00      1.00      1.00    295245

    accuracy                           1.00    531441
   macro avg       1.00      1.00      1.00    531441
weighted avg       1.00      1.00      1.00    531441

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2916
           1       1.00      1.00      1.00      3645

    accuracy                           1.00      6561
   macro avg       1.00      1.00      1.00      6561
weighted avg       1.00      1.00      1.00      6561

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2916
           1       1.00      1.00      1.00      3645

    accuracy                           1.00      6561
   macro avg       1.00      1.00      1.00      6561
weighted avg       1.00      1.00      1.00      6561

