In [1]:
import torch
import random
import copy
import itertools
import numpy as np
import utils
from trainer import LIMTrainer

from sklearn.metrics import classification_report
from LIM_deep_neural_classifier import LIMDeepNeuralClassifier
import dataset_equality

In [2]:
utils.fix_random_seeds()
debug = True

def print_weights(LIM, n, verbose=False):
    for i, layer in enumerate(LIM.labeled_layers):
        print(f"\n\n\n\n -----      Layer {i}      -----")
        for name, param in layer["model"].named_parameters():
            print(name,param.data.shape, param.requires_grad)
            if verbose:
                print(param.data)
        if "disentangle" in layer:
            w = layer["disentangle"].weight
            print(w.requires_grad)
            if verbose:
                print(w.data)
            print(torch.dist(w.T @ w.data, torch.eye(16).to("cuda")))

# Training a Feed-Forward Neural Network with no IIT

In [3]:
embedding_dim = 4
V1 = 0
V2 = 1
both = 2
control13 = 3

id_to_coords = {
    V1: [{"layer": 0, "start": 0, "end": 2*embedding_dim}],
    V2: [{"layer": 0, "start":  2*embedding_dim, "end": 4*embedding_dim}],
    both:[{"layer": 0, "start": 0, "end": 2*embedding_dim},{"layer": 0, "start":  2*embedding_dim, "end": 4*embedding_dim}],
    control13:[{"layer": 0, "start": 0, "end": 2*embedding_dim}]
}

In [4]:
control = False
if control:
    if debug:
        data_size = 128
    else:
        data_size = 1280000
    iit_equality_dataset = \
        dataset_equality.get_IIT_equality_dataset_control13(embedding_dim, 
                                                            data_size)
else:
    if debug:
        data_size = 64
    else:
        data_size = 640000
    iit_equality_dataset = \
        dataset_equality.get_IIT_equality_dataset_all(embedding_dim, 
                                                      data_size)

X_base_train, y_base_train = iit_equality_dataset[0:2]
iit_data = tuple(iit_equality_dataset[2:])

In [5]:
LIM = LIMDeepNeuralClassifier(
    hidden_dim=embedding_dim*4, 
    hidden_activation=torch.nn.ReLU(), 
    num_layers=2,
    input_dim=embedding_dim*4,
    n_classes=2
    )

In [6]:
LIM_trainer = LIMTrainer(
    LIM,
    warm_start=True,
    max_iter=10,
    batch_size=64,
    n_iter_no_change=10000,
    shuffle_train=False,
    eta=0.001)

In [7]:
print_weights(LIM,embedding_dim*4)





 -----      Layer 0      -----
linear.weight torch.Size([16, 16]) True
linear.bias torch.Size([16]) True
False
tensor(1.1129e-06, device='cuda:0')




 -----      Layer 1      -----
linear.weight torch.Size([16, 16]) True
linear.bias torch.Size([16]) True
False
tensor(1.5595e-06, device='cuda:0')




 -----      Layer 2      -----
weight torch.Size([2, 16]) True
bias torch.Size([2]) True


In [8]:
_ = LIM_trainer.fit(
    X_base_train, 
    y_base_train, 
    iit_data=None,
    intervention_ids_to_coords=id_to_coords)

Finished epoch 10 of 10; error is 2.0710567235946655

In [9]:
base_preds = LIM_trainer.predict(X_base_train, device="cpu")
print(classification_report(y_base_train, base_preds))

              precision    recall  f1-score   support

           0       0.77      0.11      0.19        94
           1       0.53      0.97      0.69        98

    accuracy                           0.55       192
   macro avg       0.65      0.54      0.44       192
weighted avg       0.65      0.55      0.44       192



In [10]:
IIT_preds = LIM_trainer.iit_predict(X_base_train, iit_data[0], iit_data[2] , id_to_coords, device="cpu")
print(classification_report(iit_data[1], IIT_preds))

              precision    recall  f1-score   support

           0       0.38      0.03      0.06        94
           1       0.51      0.95      0.66        98

    accuracy                           0.50       192
   macro avg       0.44      0.49      0.36       192
weighted avg       0.44      0.50      0.37       192



In [11]:
datasetIIT = dataset_equality.get_IIT_equality_dataset_both(embedding_dim, 1000)

base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = datasetIIT

base_preds = LIM_trainer.predict(base_test, device="cpu")

IIT_preds = LIM_trainer.iit_predict(base_test, sources_test, intervention_ids_test, id_to_coords, device="cpu")

In [12]:
print(classification_report(y_base_test, base_preds))

              precision    recall  f1-score   support

           0       0.63      0.10      0.17       490
           1       0.52      0.95      0.67       510

    accuracy                           0.53      1000
   macro avg       0.57      0.52      0.42      1000
weighted avg       0.57      0.53      0.42      1000



In [13]:
print(classification_report(y_IIT_test, IIT_preds))

              precision    recall  f1-score   support

           0       0.46      0.05      0.10       508
           1       0.49      0.93      0.64       492

    accuracy                           0.49      1000
   macro avg       0.47      0.49      0.37      1000
weighted avg       0.47      0.49      0.36      1000



# Set Analysis Mode to True and Verify Nothing Changes

In [14]:
LIM_trainer.model.set_analysis_mode(True)
print_weights(LIM,embedding_dim*4)





 -----      Layer 0      -----
linear.weight torch.Size([16, 16]) False
linear.bias torch.Size([16]) False
True
tensor(1.1129e-06, device='cuda:0', grad_fn=<DistBackward0>)




 -----      Layer 1      -----
linear.weight torch.Size([16, 16]) False
linear.bias torch.Size([16]) False
True
tensor(1.5595e-06, device='cuda:0', grad_fn=<DistBackward0>)




 -----      Layer 2      -----
weight torch.Size([2, 16]) False
bias torch.Size([2]) False


In [15]:
base_preds = LIM_trainer.predict(base_test, device="cpu")

IIT_preds = LIM_trainer.iit_predict(base_test, sources_test, intervention_ids_test, id_to_coords, device="cpu")
print(classification_report(y_base_test, base_preds))

              precision    recall  f1-score   support

           0       0.63      0.10      0.17       490
           1       0.52      0.95      0.67       510

    accuracy                           0.53      1000
   macro avg       0.57      0.52      0.42      1000
weighted avg       0.57      0.53      0.42      1000



In [16]:
print(classification_report(y_IIT_test, IIT_preds))

              precision    recall  f1-score   support

           0       0.49      0.07      0.13       508
           1       0.49      0.92      0.64       492

    accuracy                           0.49      1000
   macro avg       0.49      0.50      0.38      1000
weighted avg       0.49      0.49      0.38      1000



# Train a Basis Agnostic Alignment

In [17]:
_ = LIM_trainer.fit(
    X_base_train, 
    y_base_train, 
    iit_data=iit_data,
    intervention_ids_to_coords=id_to_coords)

Finished epoch 10 of 10; error is 4.149481773376465

In [18]:
print_weights(LIM,embedding_dim*4)





 -----      Layer 0      -----
linear.weight torch.Size([16, 16]) False
linear.bias torch.Size([16]) False
True
tensor(1.2363e-06, device='cuda:0', grad_fn=<DistBackward0>)




 -----      Layer 1      -----
linear.weight torch.Size([16, 16]) False
linear.bias torch.Size([16]) False
True
tensor(1.6265e-06, device='cuda:0', grad_fn=<DistBackward0>)




 -----      Layer 2      -----
weight torch.Size([2, 16]) False
bias torch.Size([2]) False


In [19]:
base_preds = LIM_trainer.predict(X_base_train, device="cpu")
print(classification_report(y_base_train, base_preds))

              precision    recall  f1-score   support

           0       0.77      0.11      0.19        94
           1       0.53      0.97      0.69        98

    accuracy                           0.55       192
   macro avg       0.65      0.54      0.44       192
weighted avg       0.65      0.55      0.44       192



In [20]:
IIT_preds = LIM_trainer.iit_predict(X_base_train, iit_data[0], iit_data[2] , id_to_coords, device="cpu")
print(classification_report(iit_data[1], IIT_preds))


              precision    recall  f1-score   support

           0       0.59      0.11      0.18        94
           1       0.52      0.93      0.67        98

    accuracy                           0.53       192
   macro avg       0.55      0.52      0.42       192
weighted avg       0.55      0.53      0.43       192



In [21]:
base_preds = LIM_trainer.predict(base_test, device="cpu")

IIT_preds = LIM_trainer.iit_predict(base_test, sources_test, intervention_ids_test, id_to_coords, device="cpu")
print(classification_report(y_base_test, base_preds))

              precision    recall  f1-score   support

           0       0.63      0.10      0.17       490
           1       0.52      0.95      0.67       510

    accuracy                           0.53      1000
   macro avg       0.57      0.52      0.42      1000
weighted avg       0.57      0.53      0.42      1000



In [22]:
print(classification_report(y_IIT_test, IIT_preds))

              precision    recall  f1-score   support

           0       0.49      0.07      0.13       508
           1       0.49      0.92      0.64       492

    accuracy                           0.49      1000
   macro avg       0.49      0.50      0.38      1000
weighted avg       0.49      0.49      0.38      1000



# Training Bert with No IIT

In [23]:
from LIM_bert import LIMBERTClassifier
from trainer import BERTLIMTrainer
from transformers import BertModel, BertTokenizer

import torch
import random
import dataset_equality
from sklearn.metrics import classification_report
debug = False

In [24]:
weights_name = "bert-base-uncased"
bert_tokenizer = BertTokenizer.from_pretrained(weights_name)
n_classes = 2

In [25]:
vocab = bert_tokenizer.get_vocab()

tokens = list(vocab.keys())

random.shuffle(tokens)

cutoff = int(0.9 * len(tokens))



train_ids = bert_tokenizer.convert_tokens_to_ids(tokens[:cutoff])

test_ids = bert_tokenizer.convert_tokens_to_ids(tokens[cutoff:])



In [26]:
if debug:
    train_size = 12
else:
    train_size = 20000
    
if debug:
    test_size = 12
else:
    test_size = 5000


iit_train_dataset = dataset_equality.get_IIT_equality_dataset_all(768, 
                                                             train_size,
                                                             token_ids=train_ids)

iit_test_dataset = dataset_equality.get_IIT_equality_dataset_all(768, 
                                                             test_size, 
                                                             token_ids=test_ids)
X_base_train = (iit_train_dataset[0], iit_train_dataset[1])
y_base_train = iit_train_dataset[2]
X_sources_train = (iit_train_dataset[3], iit_train_dataset[4])
y_IIT_train = iit_train_dataset[5]
intervention_ids_train = iit_train_dataset[6]

X_base_test = (iit_test_dataset[0], iit_test_dataset[1])
y_base_test  = iit_test_dataset[2]
X_sources_test  = (iit_test_dataset[3], iit_test_dataset[4])
y_IIT_test  = iit_test_dataset[5]
intervention_ids_test  = iit_test_dataset[6]

In [27]:
bert = BertModel.from_pretrained(weights_name)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [28]:
max_length = 4
embeddin_dim = 768
LIM = LIMBERTClassifier(n_classes,
                        bert,
                        max_length=max_length,
                        debug=debug,
                        use_wrapper=True)

In [29]:
LIM_trainer = BERTLIMTrainer(
    LIM,
    warm_start=True,
    max_iter=3,
    batch_size=16,
    n_iter_no_change=10000,
    shuffle_train=True,
    eta=0.00001)

In [30]:
_ = LIM_trainer.fit(
    X_base_train, 
    y_base_train, 
    iit_data=None,
    intervention_ids_to_coords=None)

Finished epoch 3 of 3; error is 106.71799992746674

In [31]:
for input in X_base_train[0]:
    input.cpu()
for mask in X_base_train[1]:
    mask.cpu()

preds = LIM_trainer.predict(X_base_train,device="cpu")
print(classification_report(y_base_train, preds))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00     30034
           1       1.00      1.00      1.00     29966

    accuracy                           1.00     60000
   macro avg       1.00      1.00      1.00     60000
weighted avg       1.00      1.00      1.00     60000



In [32]:
preds = LIM_trainer.predict(X_base_test, device="cpu")
print(classification_report(y_base_test, preds.cpu()))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      7433
           1       1.00      1.00      1.00      7567

    accuracy                           1.00     15000
   macro avg       1.00      1.00      1.00     15000
weighted avg       1.00      1.00      1.00     15000



In [None]:
V1 = 0
V2 = 1
both = 2
control13 = 3
embedding_dim = 768
target_layer = 10
id_to_coords = {
    V1: [{"layer": target_layer, "start": 0, "end": 2*embedding_dim}],
    V2: [{"layer": target_layer, "start":  2*embedding_dim, "end": 4*embedding_dim}],
    both:[{"layer": target_layer, "start": 0, "end": 2*embedding_dim},
          {"layer": target_layer, "start":  2*embedding_dim, "end": 4*embedding_dim}],
    control13:[{"layer": target_layer, "start": 0, "end": 2*embedding_dim}]
}
for inputs in X_sources_train[0]:
    for input in inputs:
        input.cpu()
for masks in X_sources_train[0]:
    for mask in masks:
        mask.cpu()

IIT_preds = LIM_trainer.iit_predict(X_base_train,
                                    X_sources_train,
                                    intervention_ids_train,
                                    id_to_coords, 
                                    device="cpu")
print(classification_report(y_IIT_train, IIT_preds))

# Set Analysis Mode to True and Verify Nothing Changes

In [None]:
LIM_trainer.model.unfreeze_disentangling_parameters(layer_num=target_layer)
LIM_trainer.model.freeze_model_parameters()

In [None]:
for input in X_base_test[0]:
    input.cpu()
for mask in X_base_test[1]:
    mask.cpu()

base_preds = LIM_trainer.predict(X_base_test, device="cpu")
print(classification_report(y_base_test, base_preds.cpu()))

In [None]:
IIT_preds = LIM_trainer.iit_predict(X_base_test,
                                    X_sources_test,
                                    intervention_ids_test,
                                    id_to_coords,
                                    device="cpu")
print(classification_report(y_IIT_test, IIT_preds.cpu()))

# Train a Basis Agnostic Alignment

In [None]:
LIM_trainer.max_iter = 4
_ = LIM_trainer.fit(
    X_base_train, 
    y_base_train, 
    iit_data=(X_sources_train,
             y_IIT_train,
             intervention_ids_train),
    intervention_ids_to_coords=id_to_coords)

In [None]:
base_preds = LIM_trainer.predict(X_base_train, device="cpu")
print(classification_report(y_base_train, base_preds.cpu()))

In [None]:
IIT_preds = LIM_trainer.iit_predict(X_base_train,
                                    X_sources_train,
                                    intervention_ids_train,
                                    id_to_coords,
                                    device="cpu")


In [None]:
print(classification_report(y_IIT_train, IIT_preds.cpu()))

In [None]:
base_preds = LIM_trainer.predict(X_base_test, device="cpu")
print(classification_report(y_base_test, base_preds.cpu()))

In [None]:
IIT_preds = LIM_trainer.iit_predict(X_base_test,
                                    X_sources_test,
                                    intervention_ids_test,
                                    id_to_coords,
                                    device="cpu")
print(classification_report(y_IIT_test, IIT_preds.cpu()))