In [1]:
import torch
import random
import copy
import itertools
import numpy as np
import utils
from trainer import LIMTrainer

from sklearn.metrics import classification_report
from LIM_deep_neural_classifier import LIMDeepNeuralClassifier
import dataset_equality

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
utils.fix_random_seeds()
debug = True

def print_weights(LIM, n, verbose=False):
    for i, layer in enumerate(LIM.labeled_layers):
        print(f"\n\n\n\n -----      Layer {i}      -----")
        for name, param in layer["model"].named_parameters():
            print(name,param.data.shape, param.requires_grad)
            if verbose:
                print(param.data)
        if "disentangle" in layer:
            w = layer["disentangle"].weight
            print(w.requires_grad)
            if verbose:
                print(w.data)
            print(torch.dist(w.T @ w.data, torch.eye(16).to("cuda")))

# Training a Feed-Forward Neural Network with no IIT

In [3]:
embedding_dim = 4
V1 = 0
V2 = 1
both = 2
control = 3

id_to_coords = {
    V1: [{"layer": 1, "start": 0, "end": 2*embedding_dim}],
    V2: [{"layer": 1, "start":  2*embedding_dim, "end": 4*embedding_dim}],
    both:[{"layer": 1, "start": 0, "end": 2*embedding_dim},{"layer": 0, "start":  2*embedding_dim, "end": 4*embedding_dim}],
    control:[{"layer": 0, "start": 0, "end": 2*embedding_dim}]
}

In [4]:
if debug:
    data_size = 64
else:
    data_size = 640000

data_size = 640

iit_equality_dataset = \
    dataset_equality.get_IIT_equality_dataset_all(embedding_dim, 
                                                    data_size)

X_base_train, y_base_train = iit_equality_dataset[0:2]
iit_data = tuple(iit_equality_dataset[2:])

In [None]:
LIM = LIMDeepNeuralClassifier(
    hidden_dim=embedding_dim*4, 
    hidden_activation=torch.nn.ReLU(), 
    num_layers=2,
    input_dim=embedding_dim*4,
    n_classes=2,
    device="cuda:0"
)

In [8]:
LIM.device

device(type='cuda', index=0)

In [18]:
LIM_trainer = LIMTrainer(
    LIM,
    warm_start=True,
    max_iter=10,
    batch_size=64,
    n_iter_no_change=10000,
    shuffle_train=False,
    eta=0.001)

In [19]:
print_weights(LIM,embedding_dim*4)





 -----      Layer 0      -----
linear.weight torch.Size([16, 16]) True
linear.bias torch.Size([16]) True
False
tensor(1.0292e-06, device='cuda:0')




 -----      Layer 1      -----
linear.weight torch.Size([16, 16]) True
linear.bias torch.Size([16]) True
False
tensor(1.3388e-06, device='cuda:0')




 -----      Layer 2      -----
weight torch.Size([2, 16]) True
bias torch.Size([2]) True


In [31]:
_ = LIM_trainer.fit(
    X_base_train, 
    y_base_train, 
    iit_data=None,
    intervention_ids_to_coords=id_to_coords)

Finished epoch 10 of 10; error is 31.468547663436084

In [32]:
base_preds = LIM_trainer.predict(X_base_train, device="cpu")
print(classification_report(y_base_train, base_preds))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    959839
           1       1.00      1.00      1.00    960161

    accuracy                           1.00   1920000
   macro avg       1.00      1.00      1.00   1920000
weighted avg       1.00      1.00      1.00   1920000



In [33]:
IIT_preds = LIM_trainer.iit_predict(X_base_train, iit_data[0], iit_data[2] , id_to_coords, device="cpu")
print(classification_report(iit_data[1], IIT_preds))

              precision    recall  f1-score   support

           0       0.54      0.41      0.46    960316
           1       0.52      0.65      0.58    959684

    accuracy                           0.53   1920000
   macro avg       0.53      0.53      0.52   1920000
weighted avg       0.53      0.53      0.52   1920000



In [34]:
datasetIIT = dataset_equality.get_IIT_equality_dataset_both(embedding_dim, 1000)

base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = datasetIIT

base_preds = LIM_trainer.predict(base_test, device="cpu")

IIT_preds = LIM_trainer.iit_predict(base_test, sources_test, intervention_ids_test, id_to_coords, device="cpu")

In [35]:
print(classification_report(y_base_test, base_preds))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       518
           1       1.00      1.00      1.00       482

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000



In [36]:
print(classification_report(y_IIT_test, IIT_preds))

              precision    recall  f1-score   support

           0       0.49      0.41      0.44       496
           1       0.50      0.58      0.53       504

    accuracy                           0.49      1000
   macro avg       0.49      0.49      0.49      1000
weighted avg       0.49      0.49      0.49      1000



# Set Analysis Mode to True and Verify Nothing Changes

In [37]:
LIM_trainer.model.set_analysis_mode(True)
print_weights(LIM,embedding_dim*4)





 -----      Layer 0      -----
linear.weight torch.Size([16, 16]) False
linear.bias torch.Size([16]) False
True
tensor(1.2184e-06, device='cuda:0', grad_fn=<DistBackward0>)




 -----      Layer 1      -----
linear.weight torch.Size([16, 16]) False
linear.bias torch.Size([16]) False
True
tensor(1.3366e-06, device='cuda:0', grad_fn=<DistBackward0>)




 -----      Layer 2      -----
weight torch.Size([2, 16]) False
bias torch.Size([2]) False


In [38]:
base_preds = LIM_trainer.predict(base_test, device="cpu")

IIT_preds = LIM_trainer.iit_predict(base_test, sources_test, intervention_ids_test, id_to_coords, device="cpu")
print(classification_report(y_base_test, base_preds))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       518
           1       1.00      1.00      1.00       482

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000



In [39]:
print(classification_report(y_IIT_test, IIT_preds))

              precision    recall  f1-score   support

           0       0.50      0.40      0.44       496
           1       0.51      0.61      0.55       504

    accuracy                           0.50      1000
   macro avg       0.50      0.50      0.50      1000
weighted avg       0.50      0.50      0.50      1000



In [40]:
PATH = "basemodel"
print(LIM_trainer.model.state_dict().keys())
torch.save(LIM_trainer.model.state_dict(), PATH)
LIM_trainer.model.load_state_dict(torch.load(PATH))


odict_keys(['model_layers.0.linear.weight', 'model_layers.0.linear.bias', 'model_layers.1.linear.weight', 'model_layers.1.linear.bias', 'model_layers.2.weight', 'model_layers.2.bias', 'analysis_model.0.linear.weight', 'analysis_model.0.linear.bias', 'analysis_model.1.parametrizations.weight.original', 'analysis_model.1.parametrizations.weight.0.base', 'analysis_model.2.lin_layer.parametrizations.weight.original', 'analysis_model.2.lin_layer.parametrizations.weight.0.base', 'analysis_model.3.linear.weight', 'analysis_model.3.linear.bias', 'analysis_model.4.parametrizations.weight.original', 'analysis_model.4.parametrizations.weight.0.base', 'analysis_model.5.lin_layer.parametrizations.weight.original', 'analysis_model.5.lin_layer.parametrizations.weight.0.base', 'analysis_model.6.weight', 'analysis_model.6.bias', 'normal_model.0.linear.weight', 'normal_model.0.linear.bias', 'normal_model.1.linear.weight', 'normal_model.1.linear.bias', 'normal_model.2.weight', 'normal_model.2.bias'])


<All keys matched successfully>

# Train a Basis Agnostic Alignment

In [41]:
id_to_coords = {
    V1: [{"layer": 1, "start": 0, "end": 2*embedding_dim}],
    V2: [{"layer": 1, "start":  2*embedding_dim, "end": 4*embedding_dim}],
    both:[{"layer": 1, "start": 0, "end": 2*embedding_dim},{"layer": 0, "start":  2*embedding_dim, "end": 4*embedding_dim}],
    control:[{"layer": 1, "start": 0, "end": 2*embedding_dim}]
}
_ = LIM_trainer.fit(
    X_base_train, 
    y_base_train, 
    iit_data=iit_data,
    intervention_ids_to_coords=id_to_coords)

Finished epoch 10 of 10; error is 268353.20465135574

In [42]:
print_weights(LIM,embedding_dim*4)





 -----      Layer 0      -----
linear.weight torch.Size([16, 16]) False
linear.bias torch.Size([16]) False
True
tensor(2.4018e-06, device='cuda:0', grad_fn=<DistBackward0>)




 -----      Layer 1      -----
linear.weight torch.Size([16, 16]) False
linear.bias torch.Size([16]) False
True
tensor(2.9071e-06, device='cuda:0', grad_fn=<DistBackward0>)




 -----      Layer 2      -----
weight torch.Size([2, 16]) False
bias torch.Size([2]) False


In [43]:
base_preds = LIM_trainer.predict(X_base_train, device="cpu")
print(classification_report(y_base_train, base_preds))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    959839
           1       1.00      1.00      1.00    960161

    accuracy                           1.00   1920000
   macro avg       1.00      1.00      1.00   1920000
weighted avg       1.00      1.00      1.00   1920000



In [44]:
IIT_preds = LIM_trainer.iit_predict(X_base_train, iit_data[0], iit_data[2] , id_to_coords, device="cpu")
print(classification_report(iit_data[1], IIT_preds))


              precision    recall  f1-score   support

           0       0.49      0.40      0.44    960316
           1       0.50      0.59      0.54    959684

    accuracy                           0.49   1920000
   macro avg       0.49      0.49      0.49   1920000
weighted avg       0.49      0.49      0.49   1920000



In [45]:
base_preds = LIM_trainer.predict(base_test, device="cpu")

IIT_preds = LIM_trainer.iit_predict(base_test, sources_test, intervention_ids_test, id_to_coords, device="cpu")
print(classification_report(y_base_test, base_preds))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       518
           1       1.00      1.00      1.00       482

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000



In [46]:
print(classification_report(y_IIT_test, IIT_preds))

              precision    recall  f1-score   support

           0       0.49      0.42      0.45       496
           1       0.50      0.58      0.54       504

    accuracy                           0.50      1000
   macro avg       0.50      0.50      0.50      1000
weighted avg       0.50      0.50      0.50      1000



# Systematic Evaluation with Control Experiments


In [24]:


def enum_keys():
    for left in [(), 0,1, (0,1)]:
        for right in [(), 0,1, (0,1)]:
            if (left,right) not in [((),()), 
                                    ((),(0,1)), 
                                    ((0,1),())]:
                yield {"left":left,
                       "right":right}


if not debug:
    data_size = 1280000
    for k in [0,1]:
        for key in enum_keys():
            
            print(f"\n\n\n\nControl Experiment \n Layer: {k}\n Key:{key}\n\n")
            train_datasetIIT = \
                dataset_equality.get_IIT_equality_dataset_control(key, 
                                                            embedding_dim,   
                                                            data_size)
            X_base_train, y_base_train = train_datasetIIT[0:2]
            iit_data = tuple(train_datasetIIT[2:])
            test_datasetIIT = \
                dataset_equality.get_IIT_equality_dataset_control(key, 
                                                            embedding_dim, 
                                                            1000)
            base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = test_datasetIIT


            if key["left"] in [0,1] and key["right"] in [0,1]:
                key_size = 2
            elif key["left"] == (0,1) and key["right"] == (0,1):
                key_size = 4
            elif key["left"] == () or key["right"] == ():
                key_size = 1
            else:
                key_size = 3
            id_to_coords = {
                control:[{"layer": k, "start": 0, "end": key_size*embedding_dim}]
            }
            
            LIM_trainer.model.load_state_dict(torch.load(PATH))

            _ = LIM_trainer.fit(
                X_base_train, 
                y_base_train, 
                iit_data=iit_data,
                intervention_ids_to_coords=id_to_coords)
            base_preds = LIM_trainer.predict(X_base_train, device="cpu")
            print(classification_report(y_base_train, base_preds))
            IIT_preds = LIM_trainer.iit_predict(X_base_train, iit_data[0], iit_data[2] , id_to_coords, device="cpu")
            print(classification_report(iit_data[1], IIT_preds))
            base_preds = LIM_trainer.predict(base_test, device="cpu")
            IIT_preds = LIM_trainer.iit_predict(base_test, sources_test, intervention_ids_test, id_to_coords, device="cpu")
            print(classification_report(y_base_test, base_preds))
            print(classification_report(y_IIT_test, IIT_preds))

            





Control Experiment 
 Layer: 0
 Key:{'left': (), 'right': 0}




Finished epoch 10 of 10; error is 220117.40471696854

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    639766
           1       1.00      1.00      1.00    640234

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.49      0.48      0.48    639869
           1       0.49      0.50      0.49    640131

    accuracy                           0.49   1280000
   macro avg       0.49      0.49      0.49   1280000
weighted avg       0.49      0.49      0.49   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       497
           1       1.00      1.00      1.00       503

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 220116.79436731339

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    639515
           1       1.00      1.00      1.00    640485

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.50      0.45      0.47    639862
           1       0.50      0.54      0.52    640138

    accuracy                           0.50   1280000
   macro avg       0.50      0.50      0.50   1280000
weighted avg       0.50      0.50      0.50   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       510
           1       1.00      1.00      1.00       490

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 207670.32550239563

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    640243
           1       1.00      1.00      1.00    639757

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.49      0.52      0.51    638913
           1       0.50      0.47      0.48    641087

    accuracy                           0.49   1280000
   macro avg       0.49      0.49      0.49   1280000
weighted avg       0.49      0.49      0.49   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       516
           1       1.00      1.00      1.00       484

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 192920.76829314232

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    640460
           1       1.00      1.00      1.00    639540

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.51      0.47      0.49    640507
           1       0.51      0.54      0.53    639493

    accuracy                           0.51   1280000
   macro avg       0.51      0.51      0.51   1280000
weighted avg       0.51      0.51      0.51   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       483
           1       1.00      1.00      1.00       517

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 192843.94300985336

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    640286
           1       1.00      1.00      1.00    639714

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.52      0.50      0.51    641138
           1       0.51      0.53      0.52    638862

    accuracy                           0.52   1280000
   macro avg       0.52      0.52      0.52   1280000
weighted avg       0.52      0.52      0.52   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       508
           1       1.00      1.00      1.00       492

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 214561.59795832634

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    641004
           1       1.00      1.00      1.00    638996

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.48      0.47      0.47    639762
           1       0.48      0.48      0.48    640238

    accuracy                           0.48   1280000
   macro avg       0.48      0.48      0.48   1280000
weighted avg       0.48      0.48      0.48   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       505
           1       1.00      1.00      1.00       495

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 206817.31916427612

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    640407
           1       1.00      1.00      1.00    639593

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.49      0.49      0.49    639149
           1       0.49      0.49      0.49    640851

    accuracy                           0.49   1280000
   macro avg       0.49      0.49      0.49   1280000
weighted avg       0.49      0.49      0.49   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       517
           1       1.00      1.00      1.00       483

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 192324.12076485157

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    639468
           1       1.00      1.00      1.00    640532

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.51      0.50      0.51    639018
           1       0.51      0.53      0.52    640982

    accuracy                           0.51   1280000
   macro avg       0.51      0.51      0.51   1280000
weighted avg       0.51      0.51      0.51   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       489
           1       1.00      1.00      1.00       511

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 192921.82252907753

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    639960
           1       1.00      1.00      1.00    640040

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.52      0.49      0.50    640266
           1       0.51      0.54      0.53    639734

    accuracy                           0.52   1280000
   macro avg       0.52      0.52      0.51   1280000
weighted avg       0.52      0.52      0.51   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       484
           1       1.00      1.00      1.00       516

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 214622.39066910744

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    639213
           1       1.00      1.00      1.00    640787

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.47      0.47      0.47    639547
           1       0.48      0.48      0.48    640453

    accuracy                           0.48   1280000
   macro avg       0.48      0.48      0.48   1280000
weighted avg       0.48      0.48      0.48   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       535
           1       1.00      1.00      1.00       465

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 206268.4818727978

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    640070
           1       1.00      1.00      1.00    639930

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.48      0.47      0.48    638957
           1       0.48      0.50      0.49    641043

    accuracy                           0.48   1280000
   macro avg       0.48      0.48      0.48   1280000
weighted avg       0.48      0.48      0.48   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       489
           1       1.00      1.00      1.00       511

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 217133.8563835621

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    639670
           1       1.00      1.00      1.00    640330

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.47      0.45      0.46    640052
           1       0.48      0.50      0.49    639948

    accuracy                           0.47   1280000
   macro avg       0.47      0.47      0.47   1280000
weighted avg       0.47      0.47      0.47   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       510
           1       1.00      1.00      1.00       490

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 31.213749848928273

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    640592
           1       1.00      1.00      1.00    639408

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    641588
           1       1.00      1.00      1.00    638412

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       497
           1       1.00      1.00      1.00       503

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 195597.37672793865

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    640024
           1       1.00      1.00      1.00    639976

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.51      0.40      0.45    640546
           1       0.50      0.61      0.55    639454

    accuracy                           0.50   1280000
   macro avg       0.50      0.50      0.50   1280000
weighted avg       0.50      0.50      0.50   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       501
           1       1.00      1.00      1.00       499

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 195872.3507528305

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    639415
           1       1.00      1.00      1.00    640585

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.51      0.40      0.45    639974
           1       0.51      0.62      0.56    640026

    accuracy                           0.51   1280000
   macro avg       0.51      0.51      0.50   1280000
weighted avg       0.51      0.51      0.50   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       499
           1       1.00      1.00      1.00       501

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 195723.79802203178

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    640540
           1       1.00      1.00      1.00    639460

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.52      0.37      0.43    639718
           1       0.51      0.65      0.57    640282

    accuracy                           0.51   1280000
   macro avg       0.51      0.51      0.50   1280000
weighted avg       0.51      0.51      0.50   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       519
           1       1.00      1.00      1.00       481

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 175666.71449959278

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    640235
           1       1.00      1.00      1.00    639765

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.58      0.52      0.55    640232
           1       0.57      0.63      0.60    639768

    accuracy                           0.57   1280000
   macro avg       0.58      0.58      0.57   1280000
weighted avg       0.58      0.57      0.57   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       500
           1       1.00      1.00      1.00       500

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 175798.73180627823

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    640789
           1       1.00      1.00      1.00    639211

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.59      0.50      0.54    640437
           1       0.57      0.66      0.61    639563

    accuracy                           0.58   1280000
   macro avg       0.58      0.58      0.57   1280000
weighted avg       0.58      0.58      0.57   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       487
           1       1.00      1.00      1.00       513

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 222602.39616918564

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    639300
           1       1.00      1.00      1.00    640700

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.44      0.38      0.41    640709
           1       0.45      0.51      0.48    639291

    accuracy                           0.44   1280000
   macro avg       0.44      0.44      0.44   1280000
weighted avg       0.44      0.44      0.44   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       482
           1       1.00      1.00      1.00       518

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 195894.61722552776

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    639367
           1       1.00      1.00      1.00    640633

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.51      0.40      0.45    639965
           1       0.51      0.62      0.56    640035

    accuracy                           0.51   1280000
   macro avg       0.51      0.51      0.50   1280000
weighted avg       0.51      0.51      0.50   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       500
           1       1.00      1.00      1.00       500

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 175971.83953773975

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    639937
           1       1.00      1.00      1.00    640063

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.61      0.46      0.52    640321
           1       0.57      0.70      0.63    639679

    accuracy                           0.58   1280000
   macro avg       0.59      0.58      0.58   1280000
weighted avg       0.59      0.58      0.58   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       542
           1       1.00      1.00      1.00       458

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 175637.39994168282

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    640060
           1       1.00      1.00      1.00    639940

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.61      0.45      0.52    639712
           1       0.57      0.71      0.63    640288

    accuracy                           0.58   1280000
   macro avg       0.59      0.58      0.58   1280000
weighted avg       0.59      0.58      0.58   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       508
           1       1.00      1.00      1.00       492

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 222719.59947657585

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    640981
           1       1.00      1.00      1.00    639019

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.43      0.39      0.41    639974
           1       0.44      0.48      0.46    640026

    accuracy                           0.44   1280000
   macro avg       0.44      0.44      0.44   1280000
weighted avg       0.44      0.44      0.44   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       478
           1       1.00      1.00      1.00       522

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 222120.2862625122

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    639833
           1       1.00      1.00      1.00    640167

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.44      0.39      0.41    640353
           1       0.45      0.49      0.47    639647

    accuracy                           0.44   1280000
   macro avg       0.44      0.44      0.44   1280000
weighted avg       0.44      0.44      0.44   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       511
           1       1.00      1.00      1.00       489

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 223148.11044883728

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    639345
           1       1.00      1.00      1.00    640655

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       0.42      0.38      0.40    639261
           1       0.44      0.48      0.46    640739

    accuracy                           0.43   1280000
   macro avg       0.43      0.43      0.43   1280000
weighted avg       0.43      0.43      0.43   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       484
           1       1.00      1.00      1.00       516

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

Finished epoch 10 of 10; error is 31.056622040126967

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    638799
           1       1.00      1.00      1.00    641201

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    639823
           1       1.00      1.00      1.00    640177

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       509
           1       1.00      1.00      1.00       491

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000

              preci

# Training Bert with No IIT

In [1]:
from LIM_bert import LIMBERTClassifier
from trainer import BERTLIMTrainer
from transformers import BertModel, BertTokenizer

import torch
import random
import dataset_equality
from sklearn.metrics import classification_report
debug = True

In [2]:
weights_name = "bert-base-uncased"
bert_tokenizer = BertTokenizer.from_pretrained(weights_name)
n_classes = 2

In [3]:
vocab = bert_tokenizer.get_vocab()

tokens = list(vocab.keys())

random.shuffle(tokens)

cutoff = int(0.9 * len(tokens))



train_ids = bert_tokenizer.convert_tokens_to_ids(tokens[:cutoff])

test_ids = bert_tokenizer.convert_tokens_to_ids(tokens[cutoff:])



In [4]:
if debug:
    train_size = 12
else:
    train_size = 20000
    
if debug:
    test_size = 12
else:
    test_size = 5000


iit_train_dataset = dataset_equality.get_IIT_equality_dataset_all(768, 
                                                             train_size,
                                                             token_ids=train_ids)

iit_test_dataset = dataset_equality.get_IIT_equality_dataset_all(768, 
                                                             test_size, 
                                                             token_ids=test_ids)
X_base_train = (iit_train_dataset[0], iit_train_dataset[1])
y_base_train = iit_train_dataset[2]
X_sources_train = (iit_train_dataset[3], iit_train_dataset[4])
y_IIT_train = iit_train_dataset[5]
intervention_ids_train = iit_train_dataset[6]

X_base_test = (iit_test_dataset[0], iit_test_dataset[1])
y_base_test  = iit_test_dataset[2]
X_sources_test  = (iit_test_dataset[3], iit_test_dataset[4])
y_IIT_test  = iit_test_dataset[5]
intervention_ids_test  = iit_test_dataset[6]

In [5]:
bert = BertModel.from_pretrained(weights_name)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
max_length = 4
embeddin_dim = 768
LIM = LIMBERTClassifier(n_classes,
                        bert,
                        max_length=max_length,
                        debug=debug,
                        use_wrapper=True)

In [7]:
LIM_trainer = BERTLIMTrainer(
    LIM,
    warm_start=True,
    max_iter=3,
    batch_size=16,
    n_iter_no_change=10000,
    shuffle_train=True,
    eta=0.00001)


In [8]:
_ = LIM_trainer.fit(
    X_base_train, 
    y_base_train, 
    iit_data=(X_sources_train,
             y_IIT_train,
             intervention_ids_train),
    intervention_ids_to_coords=None)

TypeError: 'NoneType' object is not subscriptable

In [None]:
for input in X_base_train[0]:
    input.cpu()
for mask in X_base_train[1]:
    mask.cpu()

preds = LIM_trainer.predict(X_base_train,device="cpu")
print(classification_report(y_base_train, preds))

In [None]:
preds = LIM_trainer.predict(X_base_test, device="cpu")
print(classification_report(y_base_test, preds.cpu()))

In [None]:
V1 = 0
V2 = 1
both = 2
control = 3
embedding_dim = 768
target_layer = 0
id_to_coords = {
    V1: [{"layer": target_layer, "start": 0, "end": 2*embedding_dim}],
    V2: [{"layer": target_layer, "start":  2*embedding_dim, "end": 4*embedding_dim}],
    both:[{"layer": target_layer, "start": 0, "end": 2*embedding_dim},
          {"layer": target_layer, "start":  2*embedding_dim, "end": 4*embedding_dim}],
    control:[{"layer": target_layer, "start": 0, "end": 2*embedding_dim}]
}
for inputs in X_sources_train[0]:
    for input in inputs:
        input.cpu()
for masks in X_sources_train[0]:
    for mask in masks:
        mask.cpu()

IIT_preds = LIM_trainer.iit_predict(X_base_train,
                                    X_sources_train,
                                    intervention_ids_train,
                                    id_to_coords, 
                                    device="cpu")
print(classification_report(y_IIT_train, IIT_preds))

# Set Analysis Mode to True and Verify Nothing Changes

In [None]:
LIM_trainer.model.unfreeze_disentangling_parameters(layer_num=target_layer)
LIM_trainer.model.freeze_model_parameters()

In [None]:
for input in X_base_test[0]:
    input.cpu()
for mask in X_base_test[1]:
    mask.cpu()

base_preds = LIM_trainer.predict(X_base_test, device="cpu")
print(classification_report(y_base_test, base_preds.cpu()))

In [None]:
IIT_preds = LIM_trainer.iit_predict(X_base_test,
                                    X_sources_test,
                                    intervention_ids_test,
                                    id_to_coords,
                                    device="cpu")
print(classification_report(y_IIT_test, IIT_preds.cpu()))

# Train a Basis Agnostic Alignment

In [None]:
LIM_trainer.model.bert.embeddings.dropout.p = 0.0

for l in LIM_trainer.model.bert.encoder.layers:
    l.layer.attention.self.dropout.p = 0.0
    l.layer.attention.output.dropout.p = 0.0
    l.layer.output.dropout.p = 0.0

for g in LIM_trainer.optimizer.param_groups:
    g['lr'] = 0.001

In [None]:


LIM_trainer.max_iter = 50
_ = LIM_trainer.fit(
    X_base_train, 
    y_base_train, 
    iit_data=(X_sources_train,
             y_IIT_train,
             intervention_ids_train),
    intervention_ids_to_coords=id_to_coords)

In [None]:
base_preds = LIM_trainer.predict(X_base_train, device="cpu")
print(classification_report(y_base_train, base_preds.cpu()))

In [None]:
IIT_preds = LIM_trainer.iit_predict(X_base_train,
                                    X_sources_train,
                                    intervention_ids_train,
                                    id_to_coords,
                                    device="cpu")


In [None]:
print(classification_report(y_IIT_train, IIT_preds.cpu()))

In [None]:
base_preds = LIM_trainer.predict(X_base_test, device="cpu")
print(classification_report(y_base_test, base_preds.cpu()))

In [None]:
IIT_preds = LIM_trainer.iit_predict(X_base_test,
                                    X_sources_test,
                                    intervention_ids_test,
                                    id_to_coords,
                                    device="cpu")
print(classification_report(y_IIT_test, IIT_preds.cpu()))