# **DAS**: Uncover structure in a neural network

In this notebook, we set up DAS (Distributed Alignment Search) over the internal states of our trained neural network in order to localize high level variables (circularity, color, and/or area).

It goes through:
1. **Single source DAS**: localizing many variables in one representation
2. **Multi-source DAS**: disentangling variables by localizing different variables in different representations

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import random
import numpy as np
import torch

random.seed(42)
np.random.seed(42)
_ = torch.manual_seed(42)
# _ = torch.cuda.manual_seed(42) # only if using GPU

### 1. Single source DAS: localize many variables in a single representation

We use DAS to find a representation that mediates the causal effect of one or more high level variables (circularity, color, and/or area).

In [3]:
# toggle variables to select which variables to localize & toggle intervention_size to set the # of neurons assigned to the variables
variables = [0, 1]
intervention_size = 64
n_train = 10000
n_test = 1000

Create counterfactual dataset

In [4]:
from data_utils import create_dataset
from counterfactual_data_utils import create_single_source_counterfactual_dataset

# first, create base dataset
images, labels = create_dataset(n_train)
images = images.reshape((-1, 1, 28, 28))
coefficients = np.array([0.4, 0.4, 0.4])

# create single source counterfactual dataset
X_base, X_sources, y_base, y_sources, y_counterfactual = create_single_source_counterfactual_dataset(
    variables, images, labels, coefficients, size=n_train
)

Load intervenable model with `pyvene`: set up a single intervention over the 1st 64 neurons of the output of the 1st convolutional layer

In [5]:
import pyvene as pv
from model_utils import PyTorchCNN
from das_utils import CNNConfig

# load base model
model = PyTorchCNN()
model.load_state_dict(torch.load('pytorch_models/amir_cnn_model.pth'))

model.config = CNNConfig(
    hidden_size=448 # batch x 28 x (28 x 16) -> batch x 28 x 448
)

intervention_size = 64

# create a single intervention on the first 64 neurons of the first convolutional layer
representations = [{
    "component": "conv1.output",
    "subspace_partition": [[0, intervention_size], [intervention_size, model.config.hidden_size]]
}]

pv_config = pv.IntervenableConfig(
    representations=representations,
    intervention_types=pv.RotatedSpaceIntervention
)
pv_model = pv.IntervenableModel(pv_config, model)

In [6]:
from das_utils import das_train

das_train(pv_model, X_base, X_sources, y_counterfactual, lr=0.0005, num_epochs=3, batch_size=256)

Training (Epoch 1): 100%|██████████| 40/40 [00:16<00:00,  2.40it/s, loss=0.27] 
Training (Epoch 2): 100%|██████████| 40/40 [00:16<00:00,  2.38it/s, loss=0.225]
Training (Epoch 3): 100%|██████████| 40/40 [00:16<00:00,  2.42it/s, loss=0.21] 


Evaluate interchange intervention accuracy on a new evaluation dataset

In [7]:
from data_utils import create_dataset
from counterfactual_data_utils import create_single_source_counterfactual_dataset
from das_utils import das_evaluate

# first, create base dataset
images, labels = create_dataset(n_test)
images = images.reshape((-1, 1, 28, 28))
coefficients = np.array([0.4, 0.4, 0.4])

# create single source counterfactual dataset
X_base, X_sources, y_base, y_sources, y_counterfactual = create_single_source_counterfactual_dataset(
    variables, images, labels, coefficients, size=n_test
)

# evaluate the accuracy of the model on the counterfactual dataset
das_evaluate(pv_model, X_base, X_sources, y_counterfactual)

Evaluating: 100%|██████████| 4/4 [00:00<00:00,  7.01it/s]


0.859000027179718

## Multi-source DAS: localize different variables in different representations

We use DAS to find multiple representations that mediate the causal effect of one or more high level variables (circularity, color, and/or area), where each separate representation corresponds to a separate variable.

In [8]:
# toggle variables to select which variables to localize & toggle intervention_size to set the # of neurons assigned to each variable
variables = [0, 1]
intervention_size = 64
n_train = 10000
n_test = 1000

Create counterfactual data

In [9]:
from data_utils import create_dataset
from counterfactual_data_utils import create_multi_source_counterfactual_dataset

# first, create base dataset
images, labels = create_dataset(n_train)
images = images.reshape((-1, 1, 28, 28))
coefficients = np.array([0.4, 0.4, 0.4])

# create multi-source counterfactual dataset
X_base, X_sources, y_base, y_sources, y_counterfactual = create_multi_source_counterfactual_dataset(
    variables, images, labels, coefficients, size=n_train
)

Load intervenable model with `pyvene`: set up a separate intervention for each variable, but link them to use the same rotation matrix (so they can index different subspaces of the rotated neurons).

In [10]:
import pyvene as pv
from model_utils import PyTorchCNN
from das_utils import CNNConfig

# load base model
model = PyTorchCNN()
model.load_state_dict(torch.load('pytorch_models/amir_cnn_model.pth'))

model.config = CNNConfig(
    hidden_size=448 # batch x 28 x (28 x 16) -> batch x 28 x 448
)

intervention_size = 64

# create a single intervention on the first 64 neurons of the first convolutional layer
representations = [
    {
        "component": "conv1.output",
        "subspace_partition": [[0, intervention_size], [intervention_size, intervention_size * 2], [intervention_size * 2, model.config.hidden_size]],
        "intervention_link_key": 0 # link interventions to use the same rotation matrix
    },
    {
        "component": "conv1.output",
        "subspace_partition": [[intervention_size, intervention_size * 2], [0, intervention_size], [intervention_size * 2, model.config.hidden_size]], 
        "intervention_link_key": 0
    }
]

pv_config = pv.IntervenableConfig(
    representations=representations,
    intervention_types=pv.RotatedSpaceIntervention
)
pv_model = pv.IntervenableModel(pv_config, model)

In [11]:
from das_utils import das_train

das_train(pv_model, X_base, X_sources, y_counterfactual, lr=0.0005, num_epochs=2, batch_size=256)

  return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
Training (Epoch 1): 100%|██████████| 40/40 [00:31<00:00,  1.27it/s, loss=0.191]
Training (Epoch 2): 100%|██████████| 40/40 [00:32<00:00,  1.23it/s, loss=0.18] 


Evaluate interchange intervention accuracy on a new evaluation dataset

In [12]:
from data_utils import create_dataset
from counterfactual_data_utils import create_single_source_counterfactual_dataset
from das_utils import das_evaluate

# first, create base dataset
images, labels = create_dataset(n_test)
images = images.reshape((-1, 1, 28, 28))
coefficients = np.array([0.4, 0.4, 0.4])

# create single source counterfactual dataset
X_base, X_sources, y_base, y_sources, y_counterfactual = create_single_source_counterfactual_dataset(
    variables, images, labels, coefficients, size=n_test
)

# evaluate the accuracy of the model on the counterfactual dataset
das_evaluate(pv_model, X_base, X_sources, y_counterfactual)

Evaluating: 100%|██████████| 4/4 [00:01<00:00,  3.91it/s]


0.8180000185966492