# **IIT**: Induce structure in a neural network

In this notebook, we set up IIT (Interchange Intervention Training) over the internal states of our trained neural network in order to induce the model to localize high level variables (circularity, color, and/or area).

It goes through:
1. **Single source IIT**: localizing many variables in one representation
2. **Multi-source IIT**: disentangling variables by localizing different variables in different representations

In [1]:
# !pip install -r requirements.txt

In [18]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [29]:
import random
import numpy as np
import torch

random.seed(0)
np.random.seed(0)
_ = torch.manual_seed(0)
_ = torch.cuda.manual_seed(0) # only if using GPU

## 1. Single source IIT: localize many variables in a single representation

We use IIT to train a neural network to mediate the causal effect of one or more high level variables (circularity, color, and/or area) in a single representation.

In [36]:
# toggle variables to select which variables to localize & toggle intervention_size to set the # of neurons assigned to the variables
variables = [0, 1]
intervention_size = 64
n_train = 10000
n_test = 1000

Create counterfactual dataset

In [37]:
from data_utils import create_dataset
from counterfactual_data_utils import create_single_source_counterfactual_dataset

# first, create base dataset
images, labels = create_dataset(n_train)
images = images.reshape((-1, 1, 28, 28))
coefficients = np.array([0.4, 0.4, 0.4])

# create single source counterfactual dataset
X_base, X_sources, y_base, y_sources, y_counterfactual = create_single_source_counterfactual_dataset(
    variables, images, labels, coefficients, size=n_train
)

Load intervenable model with `pyvene`: set up a single intervention over the 1st 64 neurons of the output of the 1st convolutional layer

In [38]:
import pyvene as pv
from model_utils import PyTorchCNN
from das_utils import CNNConfig, CustomLowRankRotatedSpaceIntervention

# load base model
model = PyTorchCNN()
model.load_state_dict(torch.load('pytorch_models/amir_cnn_model.pth'))

model.config = CNNConfig(
    hidden_size=28*28*16 # batch x 28 x 28 x 16
)

intervention_size = 1

# create a single intervention on the first 64 neurons of the first convolutional layer
representations = [{
    "component": "conv1.output",
    "low_rank_dimension": intervention_size
}]

pv_config = pv.IntervenableConfig(
    representations=representations,
    intervention_types=CustomLowRankRotatedSpaceIntervention
)
pv_model = pv.IntervenableModel(pv_config, model)
pv_model.set_device('cuda')

In [39]:
from das_utils import iit_train

iit_train(pv_model, X_base.to('cuda'), X_sources.to('cuda'), y_counterfactual.to('cuda'), lr=0.0005, num_epochs=1, batch_size=256, subspaces=None)

Training (Epoch 1): 100%|██████████| 40/40 [01:12<00:00,  1.81s/it, loss=0.168]


Evaluate interchange intervention accuracy on a new evaluation dataset

In [40]:
from data_utils import create_dataset
from counterfactual_data_utils import create_single_source_counterfactual_dataset
from das_utils import iit_evaluate

# first, create base dataset
images, labels = create_dataset(n_test)
images = images.reshape((-1, 1, 28, 28))
coefficients = np.array([0.4, 0.4, 0.4])

# create single source counterfactual dataset
X_base, X_sources, y_base, y_sources, y_counterfactual = create_single_source_counterfactual_dataset(
    variables, images, labels, coefficients, size=n_test
)

# evaluate the accuracy of the model on the counterfactual dataset
iit_evaluate(pv_model, X_base.to('cuda'), X_sources.to('cuda'), y_counterfactual.to('cuda'), subspaces=None)

Evaluating: 100%|██████████| 4/4 [00:00<00:00, 238.53it/s]


0.8260000348091125

In [41]:
from model_utils import evaluate
from data_utils import create_dataset

# create our own evaluation dataset
images, labels = create_dataset(n_test)
coefficients = np.array([0.4, 0.4, 0.4])

# batch, channel, height, width
X = torch.tensor(images.reshape((-1, 1, 28, 28))).float()
y = torch.tensor(np.matmul(labels, coefficients) > 0.6).float()

# evaluate model
accuracy = evaluate(pv_model.model, X.to('cuda'), y.to('cuda'))
print(f'Accuracy: {accuracy:.4f}')

Evaluating: 100%|██████████| 4/4 [00:00<?, ?it/s]

Accuracy: 0.8330





## 2. Multi-source IIT: localize different variables in different representations

We use IIT to update multiple representations to mediate the causal effect of one or more high level variables (circularity, color, and/or area), where each separate representation corresponds to a separate variable.

In [21]:
# toggle variables to select which variables to localize & toggle intervention_size to set the # of neurons assigned to each variable
variables = [0, 1]
intervention_size = 64
n_train = 10000
n_test = 1000

Create counterfactual data

In [22]:
from data_utils import create_dataset
from counterfactual_data_utils import create_multi_source_counterfactual_dataset

# first, create base dataset
images, labels = create_dataset(n_train)
images = images.reshape((-1, 1, 28, 28))
coefficients = np.array([0.4, 0.4, 0.4])

# create multi-source counterfactual dataset
X_base, X_sources, y_base, y_sources, y_counterfactual = create_multi_source_counterfactual_dataset(
    variables, images, labels, coefficients, size=n_train
)

Load intervenable model with `pyvene`: set up a separate intervention for each variable, but link them to use the same rotation matrix (so they can index different subspaces of the rotated neurons).

In [23]:
import pyvene as pv
from model_utils import PyTorchCNN
from das_utils import CNNConfig

# load base model
model = PyTorchCNN()
model.load_state_dict(torch.load('pytorch_models/amir_cnn_model.pth'))

model.config = CNNConfig(
    hidden_size=28 # batch x 28 x 28 x 16
)

intervention_size = 2

# create a single intervention on the first 64 neurons of the first convolutional layer
representations = [
    {
        "component": "conv1.output",
        "subspace_partition": [[0, intervention_size], [intervention_size, intervention_size * 2], [intervention_size * 2, model.config.hidden_size]],
        "intervention_link_key": 0 # link interventions to use the same rotation matrix
    },
    {
        "component": "conv1.output",
        "subspace_partition": [[intervention_size, intervention_size * 2], [0, intervention_size], [intervention_size * 2, model.config.hidden_size]], 
        "intervention_link_key": 0
    }
]

pv_config = pv.IntervenableConfig(
    representations=representations,
    intervention_types=pv.VanillaIntervention
)
pv_model = pv.IntervenableModel(pv_config, model)

In [24]:
from das_utils import iit_train

iit_train(pv_model, X_base, X_sources, y_counterfactual, lr=0.0005, num_epochs=5, batch_size=256)

Training (Epoch 1): 100%|██████████| 40/40 [00:04<00:00,  8.11it/s, loss=0.652]
Training (Epoch 2): 100%|██████████| 40/40 [00:04<00:00,  8.14it/s, loss=0.62] 
Training (Epoch 3): 100%|██████████| 40/40 [00:04<00:00,  8.08it/s, loss=0.617]
Training (Epoch 4): 100%|██████████| 40/40 [00:04<00:00,  8.17it/s, loss=0.627]
Training (Epoch 5): 100%|██████████| 40/40 [00:04<00:00,  8.27it/s, loss=0.623]


Evaluate interchange intervention accuracy on a new evaluation dataset

In [29]:
from data_utils import create_dataset
from counterfactual_data_utils import create_single_source_counterfactual_dataset
from das_utils import iit_evaluate

# first, create base dataset
images, labels = create_dataset(n_test)
images = images.reshape((-1, 1, 28, 28))
coefficients = np.array([0.4, 0.4, 0.4])

# create single source counterfactual dataset
X_base, X_sources, y_base, y_sources, y_counterfactual = create_single_source_counterfactual_dataset(
    variables, images, labels, coefficients, size=n_test
)

# evaluate the accuracy of the model on the counterfactual dataset
iit_evaluate(pv_model, X_base, X_sources, y_counterfactual)

Evaluating: 100%|██████████| 4/4 [00:00<00:00, 23.73it/s]


0.5920000076293945