In [1]:
import torch
import random
import copy
import itertools
import numpy as np
import utils
from trainer import LIMTrainer

from sklearn.metrics import classification_report
from LIM_deep_neural_classifier import LIMDeepNeuralClassifier
import dataset_equality

In [2]:
utils.fix_random_seeds()

def print_weights(LIM, n, verbose=False):
    for i, layer in enumerate(LIM.labeled_layers):
        print(f"\n\n\n\n -----      Layer {i}      -----")
        for name, param in layer["model"].named_parameters():
            print(name,param.data.shape, param.requires_grad)
            if verbose:
                print(param.data)
        if "disentangle" in layer:
            w = layer["disentangle"].weight
            print(w.requires_grad)
            if verbose:
                print(w.data)
            print(torch.dist(w.T @ w.data, torch.eye(16).to("cuda")))

# Training a Model with no IIT

In [3]:
embedding_dim = 4
V1 = 0
V2 = 1
both = 2
control13 = 3

id_to_coords = {
    V1: [{"layer": 0, "start": 0, "end": 2*embedding_dim}],
    V2: [{"layer": 0, "start":  2*embedding_dim, "end": 4*embedding_dim}],
    both:[{"layer": 0, "start": 0, "end": 2*embedding_dim},{"layer": 0, "start":  2*embedding_dim, "end": 4*embedding_dim}],
    control13:[{"layer": 0, "start": 0, "end": 2*embedding_dim}]
}

In [4]:
control = False
AB = True
if control:
    data_size = 1280000
    iit_equality_dataset = dataset_equality.get_IIT_equality_dataset_control13(embedding_dim, data_size)
elif AB:
    data_size = 640000
    iit_equality_datasetA = \
        dataset_equality.get_IIT_equality_dataset("V1", embedding_dim, data_size)
    iit_equality_datasetB = \
        dataset_equality.get_IIT_equality_dataset("V2", embedding_dim, data_size)

    iit_equality_dataset = (torch.cat((iit_equality_datasetA[0], iit_equality_datasetB[0])),
                       torch.cat((iit_equality_datasetA[1], iit_equality_datasetB[1])), 
                       [torch.cat((iit_equality_datasetA[2][0], iit_equality_datasetB[2][0]))], 
                       torch.cat((iit_equality_datasetA[3], iit_equality_datasetB[3])), 
                       torch.cat((iit_equality_datasetA[4], iit_equality_datasetB[4])), 
                       )
else:
    data_size = 640000
    iit_equality_dataset = \
        dataset_equality.get_IIT_equality_dataset("V1", embedding_dim, data_size)

X_base_train, y_base_train = iit_equality_dataset[0:2]
iit_data = tuple(iit_equality_dataset[2:])



In [5]:
LIM = LIMDeepNeuralClassifier(
    hidden_dim=embedding_dim*4, 
    hidden_activation=torch.nn.ReLU(), 
    num_layers=2,
    input_dim=embedding_dim*4,
    n_classes=2
    )

In [6]:
LIM_trainer = LIMTrainer(
    LIM,
    warm_start=True,
    max_iter=10,
    batch_size=64,
    n_iter_no_change=10000,
    shuffle_train=False,
    eta=0.001)

In [7]:
print_weights(LIM,embedding_dim*4)





 -----      Layer 0      -----
linear.weight torch.Size([16, 16]) True
linear.bias torch.Size([16]) True
False
tensor(1.0361e-06, device='cuda:0')




 -----      Layer 1      -----
linear.weight torch.Size([16, 16]) True
linear.bias torch.Size([16]) True
False
tensor(1.3227e-06, device='cuda:0')




 -----      Layer 2      -----
weight torch.Size([2, 16]) True
bias torch.Size([2]) True


In [8]:
_ = LIM_trainer.fit(
    X_base_train, 
    y_base_train, 
    iit_data=None,
    intervention_ids_to_coords=id_to_coords)

Finished epoch 10 of 10; error is 15.062122988900228

In [9]:
base_preds = LIM_trainer.predict(X_base_train, device="cpu")
print(classification_report(y_base_train, base_preds))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    640000
           1       1.00      1.00      1.00    640000

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000



In [10]:
IIT_preds = LIM_trainer.iit_predict(X_base_train, iit_data[0], iit_data[2] , id_to_coords, device="cpu")
print(classification_report(iit_data[1], IIT_preds))

              precision    recall  f1-score   support

           0       0.80      0.87      0.84    640000
           1       0.86      0.78      0.82    640000

    accuracy                           0.83   1280000
   macro avg       0.83      0.83      0.83   1280000
weighted avg       0.83      0.83      0.83   1280000



In [11]:
datasetIIT = dataset_equality.get_IIT_equality_dataset_both(embedding_dim, 1000)

base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = datasetIIT

base_preds = LIM_trainer.predict(base_test, device="cpu")

IIT_preds = LIM_trainer.iit_predict(base_test, sources_test, intervention_ids_test, id_to_coords, device="cpu")

In [12]:
print(classification_report(y_base_test, base_preds))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       500
           1       1.00      1.00      1.00       500

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000



In [13]:
print(classification_report(y_IIT_test, IIT_preds))

              precision    recall  f1-score   support

           0       0.80      0.87      0.83       505
           1       0.85      0.78      0.81       495

    accuracy                           0.82      1000
   macro avg       0.83      0.82      0.82      1000
weighted avg       0.83      0.82      0.82      1000



# Set Analysis Mode to True and Verify Nothing Changes

In [14]:
LIM_trainer.model.set_analysis_mode(True)
print_weights(LIM,embedding_dim*4)





 -----      Layer 0      -----
linear.weight torch.Size([16, 16]) False
linear.bias torch.Size([16]) False
True
tensor(1.0361e-06, device='cuda:0', grad_fn=<DistBackward0>)




 -----      Layer 1      -----
linear.weight torch.Size([16, 16]) False
linear.bias torch.Size([16]) False
True
tensor(1.3227e-06, device='cuda:0', grad_fn=<DistBackward0>)




 -----      Layer 2      -----
weight torch.Size([2, 16]) False
bias torch.Size([2]) False


In [15]:
base_preds = LIM_trainer.predict(base_test, device="cpu")

IIT_preds = LIM_trainer.iit_predict(base_test, sources_test, intervention_ids_test, id_to_coords, device="cpu")
print(classification_report(y_base_test, base_preds))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       500
           1       1.00      1.00      1.00       500

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000



In [16]:
print(classification_report(y_IIT_test, IIT_preds))

              precision    recall  f1-score   support

           0       0.64      0.44      0.52       505
           1       0.57      0.75      0.64       495

    accuracy                           0.59      1000
   macro avg       0.60      0.59      0.58      1000
weighted avg       0.60      0.59      0.58      1000



# Train a Basis Agnostic Alignment

In [17]:
_ = LIM_trainer.fit(
    X_base_train, 
    y_base_train, 
    iit_data=iit_data,
    intervention_ids_to_coords=id_to_coords)

Finished epoch 10 of 10; error is 397.85884997464746

In [18]:
print_weights(LIM,embedding_dim*4)





 -----      Layer 0      -----
linear.weight torch.Size([16, 16]) False
linear.bias torch.Size([16]) False
True
tensor(2.1224e-06, device='cuda:0', grad_fn=<DistBackward0>)




 -----      Layer 1      -----
linear.weight torch.Size([16, 16]) False
linear.bias torch.Size([16]) False
True
tensor(3.7412e-06, device='cuda:0', grad_fn=<DistBackward0>)




 -----      Layer 2      -----
weight torch.Size([2, 16]) False
bias torch.Size([2]) False


In [19]:
base_preds = LIM_trainer.predict(X_base_train, device="cpu")
print(classification_report(y_base_train, base_preds))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    640000
           1       1.00      1.00      1.00    640000

    accuracy                           1.00   1280000
   macro avg       1.00      1.00      1.00   1280000
weighted avg       1.00      1.00      1.00   1280000



In [20]:
IIT_preds = LIM_trainer.iit_predict(X_base_train, iit_data[0], iit_data[2] , id_to_coords, device="cpu")
print(classification_report(iit_data[1], IIT_preds))


              precision    recall  f1-score   support

           0       0.99      1.00      0.99    640000
           1       1.00      0.99      0.99    640000

    accuracy                           0.99   1280000
   macro avg       0.99      0.99      0.99   1280000
weighted avg       0.99      0.99      0.99   1280000



In [21]:
base_preds = LIM_trainer.predict(base_test, device="cpu")

IIT_preds = LIM_trainer.iit_predict(base_test, sources_test, intervention_ids_test, id_to_coords, device="cpu")
print(classification_report(y_base_test, base_preds))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       500
           1       1.00      1.00      1.00       500

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000



In [22]:
print(classification_report(y_IIT_test, IIT_preds))

              precision    recall  f1-score   support

           0       0.98      0.99      0.99       505
           1       0.99      0.98      0.99       495

    accuracy                           0.99      1000
   macro avg       0.99      0.99      0.99      1000
weighted avg       0.99      0.99      0.99      1000

