In [1]:
import dataset_nli
import torch
from transformers import BertTokenizer, BertModel
import random, os
import copy
import itertools
import numpy as np
import utils

from sklearn.metrics import classification_report
from LIM_bert import LIMBERTClassifier
from ii_benchmark import IIBenchmarkMoNli

utils.fix_random_seeds()

In [74]:
def get_eval_from_train_nli(iit_nli_dataset, n=1000, control=False):
    base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = iit_nli_dataset
    
    if control:
        if len(sources_test) == 1:
            
            
            indices = torch.randperm(len(base_test[0])) # within each bucket, we randomize!
            indices = indices[:n]
            
            return (
                (
                    [base_test[0][ind] for ind in indices], 
                    [base_test[1][ind] for ind in indices], 
                ), 
                y_base_test[indices], 
                [
                    (
                        [sources_test[0][0][ind] for ind in indices],
                        [sources_test[0][1][ind] for ind in indices]
                    )
                ], 
                y_IIT_test[indices], 
                intervention_ids_test[indices], 
            )
        else:
            
            assert n % 3 == 0
            sub_n = n // 3
            sub_dataset_size = len(base_test[0]) // 3
            indices = torch.randperm(sub_dataset_size) # within each bucket, we randomize!
            indices = indices[:sub_n]
            indices2 = indices + sub_dataset_size
            indices3 = indices2 + sub_dataset_size
            indices = torch.cat([indices, indices2, indices3])
            
            return (
                (
                    [base_test[0][ind] for ind in indices], 
                    [base_test[1][ind] for ind in indices], 
                ), 
                y_base_test[indices], 
                [
                    (
                        [sources_test[0][0][ind] for ind in indices],
                        [sources_test[0][1][ind] for ind in indices]
                    ),
                    (
                        [sources_test[1][0][ind] for ind in indices],
                        [sources_test[1][1][ind] for ind in indices]
                    )
                ], 
                y_IIT_test[indices], 
                intervention_ids_test[indices], 
            )
    else:
        pass

def get_IIT_nli_dataset_factual_pairs(
    data_size,
    tokenizer_name,
    split="train",
):
    bert_tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
    
    def encoding(X):
        if X[0][-1] != ".":
            input = [". ".join(X)]
        else:
            input = [" ".join(X)]
        data = bert_tokenizer.batch_encode_plus(
                input,
                max_length=128,
                add_special_tokens=True,
                padding='max_length',
                truncation=True,
                return_attention_mask=True)
        indices = torch.tensor(data['input_ids'])
        mask = torch.tensor(data['attention_mask'])
        return (indices, mask)
    
    dataset = dataset_nli.IIT_MoNLIDataset(
        embed_func=encoding,
        suffix=split,
        size=data_size)
    
    X_base, y_base = dataset.create_factual_pairs()
    y_base = torch.tensor(y_base)
    return X_base, y_base

def get_IIT_nli_dataset_neghyp_V1(
    data_size,
    tokenizer_name,
    split="train",
):
    bert_tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
    
    def encoding(X):
        if X[0][-1] != ".":
            input = [". ".join(X)]
        else:
            input = [" ".join(X)]
        data = bert_tokenizer.batch_encode_plus(
                input,
                max_length=128,
                add_special_tokens=True,
                padding='max_length',
                truncation=True,
                return_attention_mask=True)
        indices = torch.tensor(data['input_ids'])
        mask = torch.tensor(data['attention_mask'])
        return (indices, mask)
    
    dataset = dataset_nli.IIT_MoNLIDataset(
        embed_func=encoding,
        suffix=split,
        size=data_size)
    
    X_base, y_base, X_sources,  y_IIT, interventions = dataset.create_neghyp_V1()
    y_base = torch.tensor(y_base)
    y_IIT = torch.tensor(y_IIT)
    interventions = torch.tensor(interventions)
    return X_base, y_base, X_sources,  y_IIT, interventions

def get_IIT_nli_dataset_neghyp_V2(
    data_size,
    tokenizer_name,
    split="train",
):
    bert_tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
    
    def encoding(X):
        if X[0][-1] != ".":
            input = [". ".join(X)]
        else:
            input = [" ".join(X)]
        data = bert_tokenizer.batch_encode_plus(
                input,
                max_length=128,
                add_special_tokens=True,
                padding='max_length',
                truncation=True,
                return_attention_mask=True)
        indices = torch.tensor(data['input_ids'])
        mask = torch.tensor(data['attention_mask'])
        return (indices, mask)
    
    dataset = dataset_nli.IIT_MoNLIDataset(
        embed_func=encoding,
        suffix=split,
        size=data_size)
    
    X_base, y_base, X_sources,  y_IIT, interventions = dataset.create_neghyp_V2()
    y_base = torch.tensor(y_base)
    y_IIT = torch.tensor(y_IIT)
    interventions = torch.tensor(interventions)
    return X_base, y_base, X_sources, y_IIT, interventions

def get_IIT_nli_dataset_neghyp_V1_V2(
    data_size,
    tokenizer_name,
    split="train",
):
    bert_tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
    
    def encoding(X):
        if X[0][-1] != ".":
            input = [". ".join(X)]
        else:
            input = [" ".join(X)]
        data = bert_tokenizer.batch_encode_plus(
                input,
                max_length=128,
                add_special_tokens=True,
                padding='max_length',
                truncation=True,
                return_attention_mask=True)
        indices = torch.tensor(data['input_ids'])
        mask = torch.tensor(data['attention_mask'])
        return (indices, mask)
    
    dataset = dataset_nli.IIT_MoNLIDataset(
        embed_func=encoding,
        suffix=split,
        size=data_size)
    
    X_base, y_base, X_sources,  y_IIT, interventions = dataset.create_neghyp_V1_V2()
    y_base = torch.tensor(y_base)
    y_IIT = torch.tensor(y_IIT)
    interventions = torch.tensor(interventions)
    return X_base, y_base, X_sources, y_IIT, interventions

def get_IIT_nli_dataset_neghyp(
    data_size,
    tokenizer_name,
    split="train",
):
    assert data_size % 3 == 0
    sub_data_size = data_size // 3
    V1_dataset = \
        get_IIT_nli_dataset_neghyp_V1(sub_data_size, tokenizer_name, split)
    V2_dataset = \
        get_IIT_nli_dataset_neghyp_V2(sub_data_size, tokenizer_name, split)
    both_dataset = \
        get_IIT_nli_dataset_neghyp_V1_V2(sub_data_size, tokenizer_name, split)
    
    X_base = (V1_dataset[0][0] + V2_dataset[0][0] + both_dataset[0][0],
     V1_dataset[0][1] + V2_dataset[0][1] + both_dataset[0][1])
    y_base = torch.cat((V1_dataset[1],
                        V2_dataset[1],
                        both_dataset[1]))
    
    X_sources = [(V1_dataset[2][0][0] + V2_dataset[2][0][0] + both_dataset[2][0][0],
    V1_dataset[2][0][1] + V2_dataset[2][0][1] + both_dataset[2][0][1]),
    (V1_dataset[2][0][0] + V2_dataset[2][0][0] + both_dataset[2][1][0],
    V1_dataset[2][0][1] + V2_dataset[2][0][1] + both_dataset[2][1][1])]
    
    y_IIT = torch.cat((V1_dataset[3],
                        V2_dataset[3],
                        both_dataset[3]))
    interventions = torch.cat((V1_dataset[4],
                        V2_dataset[4],
                        both_dataset[4]))
    
    return X_base, y_base, X_sources, y_IIT, interventions

### Train Factual Models

In [50]:
device = "cuda:0"
data_size = 10000
test_data_size = 1000
num_layers = 12
hidden_dim = 768

In [42]:
for seed in {42, 66, 77}:
    utils.fix_random_seeds(seed=seed)
    print(f"training factual model for seed={seed}")
    PATH = f"./saved_models_nli/basemodel-last-{num_layers}-{hidden_dim}-{seed}.bin"
    if os.path.isfile(PATH):
        print(f"Found trained model thus skip: {PATH}")
        continue
    benchmark = IIBenchmarkMoNli(
            variable_names=['LEX'],
            data_parameters={
                'train_size': data_size, 'test_size': test_data_size
            },
            model_parameters={
                'weights_name': 'ishan/bert-base-uncased-mnli',
                'max_length': 128,
                'n_classes': 2,
                'hidden_dim': 768,
                'target_layers' : [],
                'target_dims':{
                    "start" : 0,
                    "end" : 786,
                },
                'debug':False, 
                'device': device
            },
            training_parameters={
                'warm_start': False, 'max_iter': 5, 'batch_size': 32, 'n_iter_no_change': 10000, 
                'shuffle_train': True, 'eta': 0.00002, 'device': device, 'seed' : seed,
            },
            seed=seed
    )
    LIM_bert = benchmark.create_model()
    LIM_trainer = benchmark.create_classifier(LIM_bert)
    
    X_base_train, y_base_train = get_IIT_nli_dataset_factual_pairs(
        data_size=10000, 
        split="train",
        tokenizer_name=benchmark.model_parameters["weights_name"]
    )
    
    X_base_test, y_base_test = get_IIT_nli_dataset_factual_pairs(
        data_size=1000, 
        split="test",
        tokenizer_name=benchmark.model_parameters["weights_name"]
    )
    
    _ = LIM_trainer.fit(
        X_base_train, 
        y_base_train,
        save_checkpoint_per_epoch_overwrite=True,
        save_checkpoint_prefix=f"./saved_models_nli/basemodel"
    )
    
    torch.cuda.empty_cache()
    preds = LIM_trainer.predict(X_base_test)
    print(classification_report(y_base_test, preds.cpu()))

    torch.save(LIM_bert.state_dict(), PATH)

training factual model for seed=42
Found trained model thus skip: ./saved_models_nli/basemodel-last-12-768-42.bin
training factual model for seed=66
Found trained model thus skip: ./saved_models_nli/basemodel-last-12-768-66.bin
training factual model for seed=77
Found trained model thus skip: ./saved_models_nli/basemodel-last-12-768-77.bin


### Train d-IIT Oracle (0, 1) Models

In [86]:
device = "cuda:0"
data_size = 12000
test_data_size = 1440
num_layers = 12
hidden_dim = 768

In [55]:
oracle_results = []
for seed in {42}: # {42, 66, 77}
    utils.fix_random_seeds(seed=seed)
    train_datasetIIT = get_IIT_nli_dataset_neghyp(
        data_size=data_size, 
        split="train",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    )
    X_base_train, y_base_train = train_datasetIIT[0:2]
    iit_data = tuple(train_datasetIIT[2:])

    base_test_train, y_base_test_train, sources_test_train, y_IIT_test_train, intervention_ids_test_train = \
        get_eval_from_train_nli(
            train_datasetIIT, test_data_size, control=True
        )
    
    base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = get_IIT_nli_dataset_neghyp(
        data_size=test_data_size, 
        split="test",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    )    
    for hidden_dim_per_concept in {256}: # {64, 128, 256}
        for iit_layer in [10]: # [6, 8, 10]
            intervention_ids_to_coords = {
                0:[{"layer":iit_layer, "start":0, "end":hidden_dim_per_concept}],
                1:[{"layer":iit_layer, "start":hidden_dim_per_concept, "end":2*hidden_dim_per_concept}],
                2:[{"layer":iit_layer, "start":0, "end":hidden_dim_per_concept},
                   {"layer":iit_layer, "start":hidden_dim_per_concept, "end":2*hidden_dim_per_concept}],
            }
            for i in [5]: # 1, 2, 3, 4, 5
                torch.cuda.empty_cache()
                benchmark = IIBenchmarkMoNli(
                        variable_names=['LEX'],
                        data_parameters={
                            'train_size': data_size, 'test_size': test_data_size
                        },
                        model_parameters={
                            'weights_name': 'ishan/bert-base-uncased-mnli',
                            'max_length': 128,
                            'n_classes': 2,
                            'hidden_dim': 768,
                            'target_layers' : [iit_layer],
                            'target_dims':{
                                "start" : 0,
                                "end" : 786,
                            },
                            'debug':False, 
                            'device': device
                        },
                        training_parameters={
                            'warm_start': False, 'max_iter': 10, 'batch_size': 32, 'n_iter_no_change': 10000, 
                            'shuffle_train': False, 'eta': 0.004, 'device': device
                        },
                        seed=seed
                )
                LIM_bert = benchmark.create_model()
                new_state_dict = {}
                ORACLE_PATH = f"./saved_models_nli/basemodel-{i}-{num_layers}-{hidden_dim}-{seed}.bin"
                for k, v in torch.load(ORACLE_PATH).items():
                    if "analysis_model" not in k:
                        new_state_dict[k] = v
                    else:
                        if int(k.split(".")[2]) <= iit_layer:
                            new_state_dict[k] = v
                        else:
                            new_layer_number = int(k.split(".")[2]) + 2
                            k_list = k.split(".")
                            k_list[2] = str(new_layer_number)
                            new_k = ".".join(k_list)
                            new_state_dict[new_k] = v
                LIM_bert.load_state_dict(new_state_dict, strict=False)
                LIM_trainer = benchmark.create_classifier(LIM_bert)
                LIM_trainer.model.set_analysis_mode(True)
                
                _ = LIM_trainer.fit(
                    X_base_train, 
                    y_base_train, 
                    iit_data=iit_data,
                    intervention_ids_to_coords=intervention_ids_to_coords)
                
                # train data eval
                base_preds_train = LIM_trainer.predict(
                    base_test_train
                )
                IIT_preds_train = LIM_trainer.iit_predict(
                    base_test_train, sources_test_train, 
                    intervention_ids_test_train, 
                    intervention_ids_to_coords
                )
                r1_train = classification_report(y_base_test_train, base_preds_train.cpu(), output_dict=True)
                r2_train = classification_report(y_IIT_test_train, IIT_preds_train.cpu(), output_dict=True)

                # test data eval
                base_preds = LIM_trainer.predict(
                    base_test
                )
                IIT_preds = LIM_trainer.iit_predict(
                    base_test, sources_test, 
                    intervention_ids_test, 
                    intervention_ids_to_coords
                )
                r1 = classification_report(y_base_test, base_preds.cpu(), output_dict=True)
                r2 = classification_report(y_IIT_test, IIT_preds.cpu(), output_dict=True)
                
                iit_layer_out = iit_layer + 1
                oracle_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "Factual Train", r1_train["weighted avg"]["f1-score"]]
                )
                oracle_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "d-IIT Train", r2_train["weighted avg"]["f1-score"]]
                )
                oracle_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "Factual Test", r1["weighted avg"]["f1-score"]]
                )
                oracle_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "d-IIT Test", r2["weighted avg"]["f1-score"]]
                )

Some weights of the model checkpoint at ishan/bert-base-uncased-mnli were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Finished epoch 5 of 5; error is 215.27628363668926

### Train d-IIT (1, ) Models

In [None]:
control_1_results = []

for seed in {42, 66, 77}: # {42, 66, 77}
    utils.fix_random_seeds(seed=seed)
    train_datasetIIT = get_IIT_nli_dataset_neghyp_V2(
        data_size=10000, 
        split="train",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    )
    X_base_train, y_base_train = train_datasetIIT[0:2]
    iit_data = tuple(train_datasetIIT[2:])

    base_test_train, y_base_test_train, sources_test_train, y_IIT_test_train, intervention_ids_test_train = \
        get_eval_from_train_nli(
            train_datasetIIT, 1000, control=True
        )
    
    base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = get_IIT_nli_dataset_neghyp_V2(
        data_size=1000, 
        split="test",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    )    
    for hidden_dim_per_concept in {32, 64, 128}: # {32, 64, 128}
        for iit_layer in [6, 8, 10]: # [6, 8, 10]
            intervention_ids_to_coords = {
                1:[{"layer":iit_layer, "start":0, "end":hidden_dim_per_concept}]
            }
            for i in [1, 2, 3, 4, 5]: # 1, 2, 3, 4, 5
                torch.cuda.empty_cache()
                benchmark = IIBenchmarkMoNli(
                        variable_names=['LEX'],
                        data_parameters={
                            'train_size': data_size, 'test_size': test_data_size
                        },
                        model_parameters={
                            'weights_name': 'ishan/bert-base-uncased-mnli',
                            'max_length': 128,
                            'n_classes': 2,
                            'hidden_dim': 768,
                            'target_layers' : [iit_layer],
                            'target_dims':{
                                "start" : 0,
                                "end" : 786,
                            },
                            'debug':False, 
                            'device': device
                        },
                        training_parameters={
                            'warm_start': False, 'max_iter': 5, 'batch_size': 32, 'n_iter_no_change': 10000, 
                            'shuffle_train': True, 'eta': 0.002, 'device': device
                        },
                        seed=seed
                )
                LIM_bert = benchmark.create_model()
                new_state_dict = {}
                ORACLE_PATH = f"./saved_models_nli/basemodel-last-{num_layers}-{hidden_dim}-{seed}.bin"
                for k, v in torch.load(ORACLE_PATH).items():
                    if "analysis_model" not in k:
                        new_state_dict[k] = v
                    else:
                        if int(k.split(".")[2]) <= iit_layer:
                            new_state_dict[k] = v
                        else:
                            new_layer_number = int(k.split(".")[2]) + 2
                            k_list = k.split(".")
                            k_list[2] = str(new_layer_number)
                            new_k = ".".join(k_list)
                            new_state_dict[new_k] = v
                LIM_bert.load_state_dict(new_state_dict, strict=False)
                LIM_trainer = benchmark.create_classifier(LIM_bert)
                LIM_trainer.model.set_analysis_mode(True)
                
                _ = LIM_trainer.fit(
                    X_base_train, 
                    y_base_train, 
                    iit_data=iit_data,
                    intervention_ids_to_coords=intervention_ids_to_coords)
                
                # train data eval
                base_preds_train = LIM_trainer.predict(
                    base_test_train
                )
                IIT_preds_train = LIM_trainer.iit_predict(
                    base_test_train, sources_test_train, 
                    intervention_ids_test_train, 
                    intervention_ids_to_coords
                )
                r1_train = classification_report(y_base_test_train, base_preds_train.cpu(), output_dict=True)
                r2_train = classification_report(y_IIT_test_train, IIT_preds_train.cpu(), output_dict=True)

                # test data eval
                base_preds = LIM_trainer.predict(
                    base_test
                )
                IIT_preds = LIM_trainer.iit_predict(
                    base_test, sources_test, 
                    intervention_ids_test, 
                    intervention_ids_to_coords
                )
                r1 = classification_report(y_base_test, base_preds.cpu(), output_dict=True)
                r2 = classification_report(y_IIT_test, IIT_preds.cpu(), output_dict=True)
                
                iit_layer_out = iit_layer + 1
                control_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "Factual Train", r1_train["weighted avg"]["f1-score"]]
                )
                control_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "d-IIT Train", r2_train["weighted avg"]["f1-score"]]
                )
                control_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "Factual Test", r1["weighted avg"]["f1-score"]]
                )
                control_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "d-IIT Test", r2["weighted avg"]["f1-score"]]
                )

Some weights of the model checkpoint at ishan/bert-base-uncased-mnli were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Finished epoch 5 of 5; error is 441.78408563137054Some weights of the model checkpoint at ishan/bert-base-uncased-mnli were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification mod

Finished epoch 5 of 5; error is 375.02098411321647Some weights of the model checkpoint at ishan/bert-base-uncased-mnli were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Finished epoch 5 of 5; error is 8.8817834388464695Some weights of the model checkpoint at ishan/bert-base-uncased-mnli were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g

Finished epoch 5 of 5; error is 59.111437168903655Some weights of the model checkpoint at ishan/bert-base-uncased-mnli were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Finished epoch 5 of 5; error is 8.8516769353009348Some weights of the model checkpoint at ishan/bert-base-uncased-mnli were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g