In [None]:
import dataset_nli
import torch
from transformers import BertTokenizer, BertModel
import random, os
import copy
import itertools
import numpy as np
import utils
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from sklearn.metrics import classification_report
from LIM_bert import LIMBERTClassifier
from ii_benchmark import IIBenchmarkMoNli

utils.fix_random_seeds()

In [None]:
def get_eval_from_train_nli(iit_nli_dataset, n=1000, control=False):
    base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = iit_nli_dataset
    
    if control:
        if len(sources_test) == 1:
            
            
            indices = torch.randperm(len(base_test[0])) # within each bucket, we randomize!
            indices = indices[:n]
            
            return (
                (
                    [base_test[0][ind] for ind in indices], 
                    [base_test[1][ind] for ind in indices], 
                ), 
                y_base_test[indices], 
                [
                    (
                        [sources_test[0][0][ind] for ind in indices],
                        [sources_test[0][1][ind] for ind in indices]
                    )
                ], 
                y_IIT_test[indices], 
                intervention_ids_test[indices], 
            )
        else:
            
            assert n % 3 == 0
            sub_n = n // 3
            sub_dataset_size = len(base_test[0]) // 3
            indices = torch.randperm(sub_dataset_size) # within each bucket, we randomize!
            indices = indices[:sub_n]
            indices2 = indices + sub_dataset_size
            indices3 = indices2 + sub_dataset_size
            indices = torch.cat([indices, indices2, indices3])
            
            return (
                (
                    [base_test[0][ind] for ind in indices], 
                    [base_test[1][ind] for ind in indices], 
                ), 
                y_base_test[indices], 
                [
                    (
                        [sources_test[0][0][ind] for ind in indices],
                        [sources_test[0][1][ind] for ind in indices]
                    ),
                    (
                        [sources_test[1][0][ind] for ind in indices],
                        [sources_test[1][1][ind] for ind in indices]
                    )
                ], 
                y_IIT_test[indices], 
                intervention_ids_test[indices], 
            )
    else:
        pass

def get_IIT_nli_dataset_factual_pairs(
    data_size,
    tokenizer_name,
    split="train",
):
    bert_tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
    
    def encoding(X):
        if X[0][-1] != ".":
            input = [". ".join(X)]
        else:
            input = [" ".join(X)]
        data = bert_tokenizer.batch_encode_plus(
                input,
                max_length=128,
                add_special_tokens=True,
                padding='max_length',
                truncation=True,
                return_attention_mask=True)
        indices = torch.tensor(data['input_ids'])
        mask = torch.tensor(data['attention_mask'])
        return (indices, mask)
    
    dataset = dataset_nli.IIT_MoNLIDataset(
        embed_func=encoding,
        suffix=split,
        size=data_size)
    
    X_base, y_base = dataset.create_factual_pairs()
    y_base = torch.tensor(y_base)
    return X_base, y_base

def get_IIT_nli_dataset_neghyp_V1(
    data_size,
    tokenizer_name,
    split="train",
):
    bert_tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
    
    def encoding(X):
        if X[0][-1] != ".":
            input = [". ".join(X)]
        else:
            input = [" ".join(X)]
        data = bert_tokenizer.batch_encode_plus(
                input,
                max_length=128,
                add_special_tokens=True,
                padding='max_length',
                truncation=True,
                return_attention_mask=True)
        indices = torch.tensor(data['input_ids'])
        mask = torch.tensor(data['attention_mask'])
        return (indices, mask)
    
    dataset = dataset_nli.IIT_MoNLIDataset(
        embed_func=encoding,
        suffix=split,
        size=data_size)
    
    X_base, y_base, X_sources,  y_IIT, interventions = dataset.create_neghyp_V1()
    y_base = torch.tensor(y_base)
    y_IIT = torch.tensor(y_IIT)
    interventions = torch.tensor(interventions)
    return X_base, y_base, X_sources,  y_IIT, interventions

def get_IIT_nli_dataset_neghyp_V2(
    data_size,
    tokenizer_name,
    split="train",
):
    bert_tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
    
    def encoding(X):
        if X[0][-1] != ".":
            input = [". ".join(X)]
        else:
            input = [" ".join(X)]
        data = bert_tokenizer.batch_encode_plus(
                input,
                max_length=128,
                add_special_tokens=True,
                padding='max_length',
                truncation=True,
                return_attention_mask=True)
        indices = torch.tensor(data['input_ids'])
        mask = torch.tensor(data['attention_mask'])
        return (indices, mask)
    
    dataset = dataset_nli.IIT_MoNLIDataset(
        embed_func=encoding,
        suffix=split,
        size=data_size)
    
    X_base, y_base, X_sources,  y_IIT, interventions = dataset.create_neghyp_V2()
    y_base = torch.tensor(y_base)
    y_IIT = torch.tensor(y_IIT)
    interventions = torch.tensor(interventions)
    return X_base, y_base, X_sources, y_IIT, interventions

def get_IIT_nli_dataset_neghyp_V1_V2(
    data_size,
    tokenizer_name,
    split="train",
):
    bert_tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
    
    def encoding(X):
        if X[0][-1] != ".":
            input = [". ".join(X)]
        else:
            input = [" ".join(X)]
        data = bert_tokenizer.batch_encode_plus(
                input,
                max_length=128,
                add_special_tokens=True,
                padding='max_length',
                truncation=True,
                return_attention_mask=True)
        indices = torch.tensor(data['input_ids'])
        mask = torch.tensor(data['attention_mask'])
        return (indices, mask)
    
    dataset = dataset_nli.IIT_MoNLIDataset(
        embed_func=encoding,
        suffix=split,
        size=data_size)
    
    X_base, y_base, X_sources,  y_IIT, interventions = dataset.create_neghyp_V1_V2()
    y_base = torch.tensor(y_base)
    y_IIT = torch.tensor(y_IIT)
    interventions = torch.tensor(interventions)
    return X_base, y_base, X_sources, y_IIT, interventions

def get_IIT_nli_dataset_neghyp(
    data_size,
    tokenizer_name,
    split="train",
):
    assert data_size % 3 == 0
    sub_data_size = data_size // 3
    V1_dataset = \
        get_IIT_nli_dataset_neghyp_V1(sub_data_size, tokenizer_name, split)
    V2_dataset = \
        get_IIT_nli_dataset_neghyp_V2(sub_data_size, tokenizer_name, split)
    both_dataset = \
        get_IIT_nli_dataset_neghyp_V1_V2(sub_data_size, tokenizer_name, split)
    
    X_base = (V1_dataset[0][0] + V2_dataset[0][0] + both_dataset[0][0],
     V1_dataset[0][1] + V2_dataset[0][1] + both_dataset[0][1])
    y_base = torch.cat((V1_dataset[1],
                        V2_dataset[1],
                        both_dataset[1]))
    
    X_sources = [(V1_dataset[2][0][0] + V2_dataset[2][0][0] + both_dataset[2][0][0],
    V1_dataset[2][0][1] + V2_dataset[2][0][1] + both_dataset[2][0][1]),
    (V1_dataset[2][0][0] + V2_dataset[2][0][0] + both_dataset[2][1][0],
    V1_dataset[2][0][1] + V2_dataset[2][0][1] + both_dataset[2][1][1])]
    
    y_IIT = torch.cat((V1_dataset[3],
                        V2_dataset[3],
                        both_dataset[3]))
    interventions = torch.cat((V1_dataset[4],
                        V2_dataset[4],
                        both_dataset[4]))
    
    return X_base, y_base, X_sources, y_IIT, interventions

def get_IIT_nli_dataset_tokenidentity_V1(
    data_size,
    tokenizer_name,
    split="train",
):
    bert_tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
    
    def encoding(X):
        if X[0][-1] != ".":
            input = [". ".join(X)]
        else:
            input = [" ".join(X)]
        data = bert_tokenizer.batch_encode_plus(
                input,
                max_length=128,
                add_special_tokens=True,
                padding='max_length',
                truncation=True,
                return_attention_mask=True)
        indices = torch.tensor(data['input_ids'])
        mask = torch.tensor(data['attention_mask'])
        return (indices, mask)
    
    dataset = dataset_nli.IIT_MoNLIDataset(
        embed_func=encoding,
        suffix=split,
        size=data_size)
    
    X_base, y_base, X_sources,  y_IIT, interventions = dataset.create_tokenidentity_V1()
    y_base = torch.tensor(y_base)
    y_IIT = torch.tensor(y_IIT)
    interventions = torch.tensor(interventions)
    return X_base, y_base, X_sources,  y_IIT, interventions

### Train Factual Models

In [None]:
device = "cuda:0"
data_size = 10000
test_data_size = 1000
num_layers = 12
hidden_dim = 768

In [None]:
for seed in {42, 66, 77}:
    utils.fix_random_seeds(seed=seed)
    print(f"training factual model for seed={seed}")
    PATH = f"./saved_models_nli/basemodel-last-{num_layers}-{hidden_dim}-{seed}.bin"
    if os.path.isfile(PATH):
        print(f"Found trained model thus skip: {PATH}")
        continue
    benchmark = IIBenchmarkMoNli(
            variable_names=['LEX'],
            data_parameters={
                'train_size': data_size, 'test_size': test_data_size
            },
            model_parameters={
                'weights_name': 'ishan/bert-base-uncased-mnli',
                'max_length': 128,
                'n_classes': 2,
                'hidden_dim': 768,
                'target_layers' : [],
                'target_dims':{
                    "start" : 0,
                    "end" : 786,
                },
                'debug':False, 
                'device': device
            },
            training_parameters={
                'warm_start': False, 'max_iter': 5, 'batch_size': 32, 'n_iter_no_change': 10000, 
                'shuffle_train': True, 'eta': 0.00002, 'device': device, 'seed' : seed,
            },
            seed=seed
    )
    LIM_bert = benchmark.create_model()
    LIM_trainer = benchmark.create_classifier(LIM_bert)
    
    X_base_train, y_base_train = get_IIT_nli_dataset_factual_pairs(
        data_size=10000, 
        split="train",
        tokenizer_name=benchmark.model_parameters["weights_name"]
    )
    
    X_base_test, y_base_test = get_IIT_nli_dataset_factual_pairs(
        data_size=1000, 
        split="test",
        tokenizer_name=benchmark.model_parameters["weights_name"]
    )
    
    _ = LIM_trainer.fit(
        X_base_train, 
        y_base_train,
        save_checkpoint_per_epoch_overwrite=True,
        save_checkpoint_prefix=f"./saved_models_nli/basemodel"
    )
    
    torch.cuda.empty_cache()
    preds = LIM_trainer.predict(X_base_test)
    print(classification_report(y_base_test, preds.cpu()))

    torch.save(LIM_bert.state_dict(), PATH)

### Train d-IIT Oracle (0, 1) Models

In [None]:
device = "cuda:0"
data_size = 24000
test_data_size = 1920
num_layers = 12
hidden_dim = 768

In [None]:
oracle_results = []
for seed in {77}: # {42, 66, 77}
    utils.fix_random_seeds(seed=seed)
    train_datasetIIT = get_IIT_nli_dataset_neghyp(
        data_size=data_size, 
        split="train",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    )
    X_base_train, y_base_train = train_datasetIIT[0:2]
    iit_data = tuple(train_datasetIIT[2:])

    base_test_train, y_base_test_train, sources_test_train, y_IIT_test_train, intervention_ids_test_train = \
        get_eval_from_train_nli(
            train_datasetIIT, test_data_size, control=True
        )
    
    base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = get_IIT_nli_dataset_neghyp(
        data_size=test_data_size, 
        split="test",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    )    
    for hidden_dim_per_concept in {256}: # {64, 128, 256}
        for iit_layer in [6, 8, 10]: # [6, 8, 10]
            intervention_ids_to_coords = {
                0:[{"layer":iit_layer, "start":0, "end":hidden_dim_per_concept}],
                1:[{"layer":iit_layer, "start":hidden_dim_per_concept, "end":2*hidden_dim_per_concept}],
                2:[{"layer":iit_layer, "start":0, "end":hidden_dim_per_concept},
                   {"layer":iit_layer, "start":hidden_dim_per_concept, "end":2*hidden_dim_per_concept}],
            }
            for i in [5]: # 1, 2, 3, 4, 5
                torch.cuda.empty_cache()
                benchmark = IIBenchmarkMoNli(
                        variable_names=['LEX'],
                        data_parameters={
                            'train_size': data_size, 'test_size': test_data_size
                        },
                        model_parameters={
                            'weights_name': 'ishan/bert-base-uncased-mnli',
                            'max_length': 128,
                            'n_classes': 2,
                            'hidden_dim': 768,
                            'target_layers' : [iit_layer],
                            'target_dims':{
                                "start" : 0,
                                "end" : 786,
                            },
                            'debug':False, 
                            'device': device,
                            'static_search': False,
                            'nested_disentangle_inplace': False
                        },
                        training_parameters={
                            'warm_start': False, 'max_iter': 5, 'batch_size': 64, 'n_iter_no_change': 10000, 
                            'shuffle_train': False, 'eta': 0.002, 'device': device
                        },
                        seed=seed
                )
                LIM_bert = benchmark.create_model()
                new_state_dict = {}
                ORACLE_PATH = f"./saved_models_nli/basemodel-{i}-{num_layers}-{hidden_dim}-{seed}.bin"
                for k, v in torch.load(ORACLE_PATH).items():
                    if "analysis_model" not in k:
                        new_state_dict[k] = v
                    else:
                        if int(k.split(".")[2]) <= iit_layer:
                            new_state_dict[k] = v
                        else:
                            new_layer_number = int(k.split(".")[2]) + 2
                            k_list = k.split(".")
                            k_list[2] = str(new_layer_number)
                            new_k = ".".join(k_list)
                            new_state_dict[new_k] = v
                LIM_bert.load_state_dict(new_state_dict, strict=False)
                LIM_trainer = benchmark.create_classifier(LIM_bert)
                LIM_trainer.model.set_analysis_mode(True)
                
                _ = LIM_trainer.fit(
                    X_base_train, 
                    y_base_train, 
                    iit_data=iit_data,
                    intervention_ids_to_coords=intervention_ids_to_coords)
                
                # train data eval
                base_preds_train = LIM_trainer.predict(
                    base_test_train
                )
                IIT_preds_train = LIM_trainer.iit_predict(
                    base_test_train, sources_test_train, 
                    intervention_ids_test_train, 
                    intervention_ids_to_coords
                )
                r1_train = classification_report(y_base_test_train, base_preds_train.cpu(), output_dict=True)
                r2_train = classification_report(y_IIT_test_train, IIT_preds_train.cpu(), output_dict=True)

                # test data eval
                base_preds = LIM_trainer.predict(
                    base_test
                )
                IIT_preds = LIM_trainer.iit_predict(
                    base_test, sources_test, 
                    intervention_ids_test, 
                    intervention_ids_to_coords
                )
                r1 = classification_report(y_base_test, base_preds.cpu(), output_dict=True)
                r2 = classification_report(y_IIT_test, IIT_preds.cpu(), output_dict=True)
                
                iit_layer_out = iit_layer + 1
                torch.save(
                    LIM_trainer.model.analysis_model.layers[iit_layer_out].weight, 
                    f"./saved_models_nli/oracle-rotation_matrix-{iit_layer_out}-{hidden_dim}-{hidden_dim_per_concept}-{seed}.bin"
                )
                
#                 
#                 oracle_results.append(
#                     [
#                         seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
#                         "Factual Train", r1_train["weighted avg"]["f1-score"]]
#                 )
#                 oracle_results.append(
#                     [
#                         seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
#                         "d-IIT Train", r2_train["weighted avg"]["f1-score"]]
#                 )
#                 oracle_results.append(
#                     [
#                         seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
#                         "Factual Test", r1["weighted avg"]["f1-score"]]
#                 )
#                 oracle_results.append(
#                     [
#                         seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
#                         "d-IIT Test", r2["weighted avg"]["f1-score"]]
#                 )

In [None]:
import numpy as np
from numpy import linalg as LA

def eigenvalues_to_angles(eigenvalues):
    angles = []
    for eig in eigenvalues:
        angle = np.arctan2(np.imag(eig), np.real(eig))
        angles.append(angle)
    return angles

def to_degree_angles(angles):
    degree_angles = set()
    for angle in angles:
        angle = np.degrees(angle)
        degree_angles.add(abs(angle))
    return degree_angles

R = torch.load("./saved_models_nli/rotation_matrix.bin")
w, v = LA.eig(R[:512, :512].detach().numpy())

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "DejaVu Serif"
font = {'family' : 'DejaVu Serif',
        'size'   : 12}
plt.rc('font', **font)
params = {'mathtext.default': 'regular' }          
plt.rcParams.update(params)

with plt.rc_context({
    'axes.edgecolor':'black', 'xtick.color':'black', 
    'ytick.color':'black', 'figure.facecolor':'white'
}):

    fig = plt.figure(figsize=(5, 1.2))

    # Create the distribution plot
    ax = sns.histplot(
        to_degree_angles(eigenvalues_to_angles(w)), 
        legend=False,
    )

    # Add a title and labels
    # ax.set_xlabel("Basis Vector Rotation Degree(s)", fontsize=14)
    ax.set_ylabel("Frequency", fontsize=14)

    ax.spines["top"].set_linewidth(2)
    ax.spines["bottom"].set_linewidth(2)
    ax.spines["left"].set_linewidth(2)
    ax.spines["right"].set_linewidth(2)
    ax.spines["top"].set_linewidth(2)
    ax.spines["bottom"].set_linewidth(2)
    ax.spines["left"].set_linewidth(2)
    ax.spines["right"].set_linewidth(2)
    ax.xaxis.grid(color='grey', linestyle='-.', linewidth=1, alpha=0.5)
    ax.yaxis.grid(color='grey', linestyle='-.', linewidth=1, alpha=0.5)
    
    ax.set_facecolor("white")
    
    plt.legend(labels=['MoNLI'],loc="lower right")
    # Show the plot
    plt.show()

In [None]:
oracle_df_more = pd.DataFrame(
    oracle_results,
    columns =['seed', 'hidden_dim', 'hidden_dim_per_concept', 'iit_layer', 'epoch', 
              'type', 'f1-score']
)

In [None]:
oracle_df = pd.read_csv("oracle_df.csv")
# oracle_df.to_csv("oracle_df.csv")

In [None]:
oracle_df = pd.concat([oracle_df, oracle_df_more], ignore_index=True, sort=False)

In [None]:
oracle_df[
    (oracle_df["iit_layer"] == 9)&
    (oracle_df["hidden_dim_per_concept"] == 256)&
    (oracle_df["epoch"] == 5)
]

In [None]:
sns.lineplot(
    data=oracle_df[
        (oracle_df["hidden_dim"]==768)&
        (oracle_df["hidden_dim_per_concept"]==256)&
        (oracle_df["iit_layer"]==9)
    ],
    x="epoch", y="f1-score", hue="type", style="type",
    dashes=False, markers=['o', 's', '^', 'D'], markersize=12, legend=True,
    alpha=0.8
)

### Train d-IIT (1, ) Models

In [None]:
control_1_results = []

In [None]:
for seed in {77}: # {42, 66, 77}
    utils.fix_random_seeds(seed=seed)
    train_datasetIIT = get_IIT_nli_dataset_neghyp_V2(
        data_size=10000, 
        split="train",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    )
    X_base_train, y_base_train = train_datasetIIT[0:2]
    iit_data = tuple(train_datasetIIT[2:])

    base_test_train, y_base_test_train, sources_test_train, y_IIT_test_train, intervention_ids_test_train = \
        get_eval_from_train_nli(
            train_datasetIIT, 1000, control=True
        )
    
    base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = get_IIT_nli_dataset_neghyp_V2(
        data_size=1000, 
        split="test",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    )    
    for hidden_dim_per_concept in {256}: # {32, 64, 128}
        for iit_layer in [6, 8, 10]: # [6, 8, 10]
            intervention_ids_to_coords = {
                1:[{"layer":iit_layer, "start":0, "end":hidden_dim_per_concept}]
            }
            for i in [5]: # 1, 2, 3, 4, 5
                torch.cuda.empty_cache()
                benchmark = IIBenchmarkMoNli(
                        variable_names=['LEX'],
                        data_parameters={
                            'train_size': data_size, 'test_size': test_data_size
                        },
                        model_parameters={
                            'weights_name': 'ishan/bert-base-uncased-mnli',
                            'max_length': 128,
                            'n_classes': 2,
                            'hidden_dim': 768,
                            'target_layers' : [iit_layer],
                            'target_dims':{
                                "start" : 0,
                                "end" : 786,
                            },
                            'debug':False, 
                            'device': device,
                            'static_search': False,
                            'nested_disentangle_inplace': False
                        },
                        training_parameters={
                            'warm_start': False, 'max_iter': 5, 'batch_size': 64, 'n_iter_no_change': 10000, 
                            'shuffle_train': True, 'eta': 0.002, 'device': device
                        },
                        seed=seed
                )
                LIM_bert = benchmark.create_model()
                new_state_dict = {}
                ORACLE_PATH = f"./saved_models_nli/basemodel-last-{num_layers}-{hidden_dim}-{seed}.bin"
                for k, v in torch.load(ORACLE_PATH).items():
                    if "analysis_model" not in k:
                        new_state_dict[k] = v
                    else:
                        if int(k.split(".")[2]) <= iit_layer:
                            new_state_dict[k] = v
                        else:
                            new_layer_number = int(k.split(".")[2]) + 2
                            k_list = k.split(".")
                            k_list[2] = str(new_layer_number)
                            new_k = ".".join(k_list)
                            new_state_dict[new_k] = v
                LIM_bert.load_state_dict(new_state_dict, strict=False)
                LIM_trainer = benchmark.create_classifier(LIM_bert)
                LIM_trainer.model.set_analysis_mode(True)
                
                _ = LIM_trainer.fit(
                    X_base_train, 
                    y_base_train, 
                    iit_data=iit_data,
                    intervention_ids_to_coords=intervention_ids_to_coords)
                
                # train data eval
                base_preds_train = LIM_trainer.predict(
                    base_test_train
                )
                IIT_preds_train = LIM_trainer.iit_predict(
                    base_test_train, sources_test_train, 
                    intervention_ids_test_train, 
                    intervention_ids_to_coords
                )
                r1_train = classification_report(y_base_test_train, base_preds_train.cpu(), output_dict=True)
                r2_train = classification_report(y_IIT_test_train, IIT_preds_train.cpu(), output_dict=True)

                # test data eval
                base_preds = LIM_trainer.predict(
                    base_test
                )
                IIT_preds = LIM_trainer.iit_predict(
                    base_test, sources_test, 
                    intervention_ids_test, 
                    intervention_ids_to_coords
                )
                r1 = classification_report(y_base_test, base_preds.cpu(), output_dict=True)
                r2 = classification_report(y_IIT_test, IIT_preds.cpu(), output_dict=True)
                
                iit_layer_out = iit_layer + 1
                torch.save(
                    LIM_trainer.model.analysis_model.layers[iit_layer_out].weight, 
                    f"./saved_models_nli/control_1-rotation_matrix-{iit_layer_out}-{hidden_dim}-{hidden_dim_per_concept}-{seed}.bin"
                )
                control_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "Factual Train", r1_train["weighted avg"]["f1-score"]]
                )
                control_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "d-IIT Train", r2_train["weighted avg"]["f1-score"]]
                )
                control_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "Factual Test", r1["weighted avg"]["f1-score"]]
                )
                control_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "d-IIT Test", r2["weighted avg"]["f1-score"]]
                )

In [None]:
control_1_results

In [None]:
control_1_df = pd.read_csv("control_1_df.csv")

In [None]:
sns.lineplot(
    data=control_1_df[
        (control_1_df["hidden_dim"]==768)&
        (control_1_df["hidden_dim_per_concept"]==256)&
        (control_1_df["iit_layer"]==11)
    ],
    x="epoch", y="f1-score", hue="type", style="type",
    dashes=False, markers=['o', 's', '^', 'D'], markersize=12, legend=True,
    alpha=0.8
)

### Train d-IIT token identity (1, ) Models

In [None]:
device = "cuda:0"
data_size = 10000
test_data_size = 1000
num_layers = 12
hidden_dim = 768

control_token_1_results = []
for seed in {77}: # {42, 66, 77}
    utils.fix_random_seeds(seed=seed)
    train_datasetIIT = get_IIT_nli_dataset_tokenidentity_V1(
        data_size=10000, 
        split="train",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    )
    X_base_train, y_base_train = train_datasetIIT[0:2]
    iit_data = tuple(train_datasetIIT[2:])

    base_test_train, y_base_test_train, sources_test_train, y_IIT_test_train, intervention_ids_test_train = \
        get_eval_from_train_nli(
            train_datasetIIT, 1000, control=True
        )
    
    base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = get_IIT_nli_dataset_tokenidentity_V1(
        data_size=1000, 
        split="test",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    )    
    for hidden_dim_per_concept in {256}: # {64, 128, 256}
        for iit_layer in [6, 8, 10]: # [6, 8, 10]
            intervention_ids_to_coords = {
                1:[{"layer":iit_layer, "start":0, "end":hidden_dim_per_concept}]
            }
            for i in [5]: # 1, 2, 3, 4, 5
                torch.cuda.empty_cache()
                benchmark = IIBenchmarkMoNli(
                        variable_names=['LEX'],
                        data_parameters={
                            'train_size': data_size, 'test_size': test_data_size
                        },
                        model_parameters={
                            'weights_name': 'ishan/bert-base-uncased-mnli',
                            'max_length': 128,
                            'n_classes': 2,
                            'hidden_dim': 768,
                            'target_layers' : [iit_layer],
                            'target_dims':{
                                "start" : 0,
                                "end" : 786,
                            },
                            'debug':False, 
                            'device': device,
                            'static_search': False,
                            'nested_disentangle_inplace': False
                        },
                        training_parameters={
                            'warm_start': False, 'max_iter': 5, 'batch_size': 64, 'n_iter_no_change': 10000, 
                            'shuffle_train': True, 'eta': 0.002, 'device': device
                        },
                        seed=seed
                )
                LIM_bert = benchmark.create_model()
                new_state_dict = {}
                ORACLE_PATH = f"./saved_models_nli/basemodel-last-{num_layers}-{hidden_dim}-{seed}.bin"
                for k, v in torch.load(ORACLE_PATH).items():
                    if "analysis_model" not in k:
                        new_state_dict[k] = v
                    else:
                        if int(k.split(".")[2]) <= iit_layer:
                            new_state_dict[k] = v
                        else:
                            new_layer_number = int(k.split(".")[2]) + 2
                            k_list = k.split(".")
                            k_list[2] = str(new_layer_number)
                            new_k = ".".join(k_list)
                            new_state_dict[new_k] = v
                LIM_bert.load_state_dict(new_state_dict, strict=False)
                LIM_trainer = benchmark.create_classifier(LIM_bert)
                LIM_trainer.model.set_analysis_mode(True)
                
                _ = LIM_trainer.fit(
                    X_base_train, 
                    y_base_train, 
                    iit_data=iit_data,
                    intervention_ids_to_coords=intervention_ids_to_coords)
                
                # train data eval
                base_preds_train = LIM_trainer.predict(
                    base_test_train
                )
                IIT_preds_train = LIM_trainer.iit_predict(
                    base_test_train, sources_test_train, 
                    intervention_ids_test_train, 
                    intervention_ids_to_coords
                )
                r1_train = classification_report(y_base_test_train, base_preds_train.cpu(), output_dict=True)
                r2_train = classification_report(y_IIT_test_train, IIT_preds_train.cpu(), output_dict=True)

                # test data eval
                base_preds = LIM_trainer.predict(
                    base_test
                )
                IIT_preds = LIM_trainer.iit_predict(
                    base_test, sources_test, 
                    intervention_ids_test, 
                    intervention_ids_to_coords
                )
                r1 = classification_report(y_base_test, base_preds.cpu(), output_dict=True)
                r2 = classification_report(y_IIT_test, IIT_preds.cpu(), output_dict=True)
                
                iit_layer_out = iit_layer + 1
                torch.save(
                    LIM_trainer.model.analysis_model.layers[iit_layer_out].weight, 
                    f"./saved_models_nli/control_token_1-rotation_matrix-{iit_layer_out}-{hidden_dim}-{hidden_dim_per_concept}-{seed}.bin"
                )
                control_token_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "Factual Train", r1_train["weighted avg"]["f1-score"]]
                )
                control_token_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "d-IIT Train", r2_train["weighted avg"]["f1-score"]]
                )
                control_token_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "Factual Test", r1["weighted avg"]["f1-score"]]
                )
                control_token_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "d-IIT Test", r2["weighted avg"]["f1-score"]]
                )

In [None]:
control_token_1_results

In [None]:
control_token_1_df = pd.DataFrame(
    control_token_1_results,
    columns =['seed', 'hidden_dim', 'hidden_dim_per_concept', 'iit_layer', 'epoch', 
              'type', 'f1-score']
)

In [None]:
control_token_1_df.to_csv("control_token_1_df.csv")

In [None]:
control_token_1_df = pd.read_csv("control_token_1_df.csv")

In [None]:
sns.lineplot(
    data=control_token_1_df[
        (control_token_1_df["hidden_dim"]==768)&
        (control_token_1_df["hidden_dim_per_concept"]==256)&
        (control_token_1_df["iit_layer"]==11)
    ],
    x="epoch", y="f1-score", hue="type", style="type",
    dashes=False, markers=['o', 's', '^', 'D'], markersize=12, legend=True,
    alpha=0.8
)

In [None]:
def generate_main_result_table(
    oracle_df,
    control_1_df,
    control_token_1_df,
    epoch = 5,
    hidden_dim = 768,
    reduce=max,
    eval_setting="d-IIT Test",
    round_to=2
):
    rows = []
    for hidden_dim_per_concept in [64, 128, 256]:
        row_scores = []
        for df in [oracle_df, control_1_df, control_token_1_df]:
            selected_scores = []
            for iit_layer in [7, 9, 11]:
                selected_score = reduce(df[
                    (df["hidden_dim_per_concept"]==hidden_dim_per_concept)&
                    (df["iit_layer"]==iit_layer)&
                    (df["hidden_dim"]==hidden_dim)&
                    (df["epoch"]==epoch)&
                    (df["type"]==eval_setting)
                ]["f1-score"].tolist())
                selected_scores += ["%.2f" % round(selected_score, round_to)]
            row_scores.extend(selected_scores)
        rows += [row_scores]
    df = pd.DataFrame(
        rows, 
        columns=[
            "L7;SCM1", "L9;SCM1", "L11;SCM1", 
            "L7;SCM2", "L9;SCM2", "L11;SCM2", 
            "L7;SCM3", "L9;SCM3", "L11;SCM3"
        ]
    )
    return df

In [None]:
generate_main_result_table(
    oracle_df,
    control_1_df,
    control_token_1_df
).to_latex()

In [None]:
hidden_dims = [768]
SCMs = ["oracle", "control_1_df", "control_token_1_df"]
sns.set_palette("colorblind")

for hidden_dim in hidden_dims:
    for SCM in SCMs:

        if SCM == "oracle":
            df = oracle_df
        elif SCM == "control_1_df":
            df = control_1_df
        elif SCM == "control_token_1_df":
            df = control_token_1_df

        plt.rcParams["font.family"] = "DejaVu Serif"
        font = {'family' : 'DejaVu Serif',
                'size'   : 20}
        plt.rc('font', **font)
        fig, axes = plt.subplots(3, 3, figsize=(14,10))

        hidden_dim_per_concepts = [64, 128, 256]
        iit_layers = [7,9,11]
        ylabels = [f"k=64", f"k=128", f"k=256"]
        for i in range(3):
            for j in range(3):
                hidden_dim_per_concept = hidden_dim_per_concepts[i]
                iit_layer = iit_layers[j]
                df_toplot = df[(df["hidden_dim_per_concept"]==hidden_dim_per_concept)&
                   (df["iit_layer"]==iit_layer)&
                   (df["hidden_dim"]==hidden_dim)&
                   ((df["type"]=="Factual Train")|
                    (df["type"]=="d-IIT Train")
                    |(df["type"]=="Factual Test")
                    |(df["type"]=="d-IIT Test"))
                ]
                sns.lineplot(
                    ax=axes[i,j],
                    data=df_toplot,
                    x="epoch", y="f1-score", hue="type", style="type",
                    dashes=False, markers=['o', 's', '^', 'D'], markersize=12, legend=False,
                    alpha=0.8
                )
                axes[i,j].set_ylim(0.5, 1.1)
                if i == 2:
                    axes[i,j].set(xlabel="Epoch" if j == 1 else None, xticks=[1,2,3,4,5], xticklabels=[1,2,3,4,5])
                else:
                    axes[i,j].set(xlabel=None, xticks=[1,2,3,4,5], xticklabels=[])
                if j == 0:
                    axes[i,j].set(ylabel=ylabels[i])
                else:
                    axes[i,j].set(ylabel=None, yticklabels=[])

        axes[0,0].set_title("L7")
        axes[0,1].set_title("L9")
        axes[0,2].set_title("L11")

        plt.legend(loc='lower right', labels=['Task Acc. (Train)', 'Int. Acc. (Train)', 
                                               'Task Acc. (Test)', 'Int. Acc. (Test)'], fontsize=14)

        # plt.show()
        plt.savefig(f"./fig/MoNLI-{hidden_dim}-{SCM}.png",dpi=200, bbox_inches='tight')
        

In [None]:
device = "cuda:0"
data_size = 10000
test_data_size = 1000
num_layers = 12
hidden_dim = 768

control_token_1_control_results = []
for seed in {66, 77}: # {42, 66, 77}
    utils.fix_random_seeds(seed=seed)
    train_datasetIIT = get_IIT_nli_dataset_tokenidentity_V1(
        data_size=data_size, 
        split="train",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    )
    X_base_train, y_base_train = train_datasetIIT[0:2]
    iit_data = tuple(train_datasetIIT[2:])

    base_test_train, y_base_test_train, sources_test_train, y_IIT_test_train, intervention_ids_test_train = \
        get_eval_from_train_nli(
            train_datasetIIT, test_data_size, control=True
        )
    
    base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = get_IIT_nli_dataset_tokenidentity_V1(
        data_size=test_data_size, 
        split="test",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    )    
    for hidden_dim_per_concept in {256}: # {64, 128, 256}
        for start_index in {0, 64, 128, 256, 512}:
            for iit_layer in [0,1,2,3,4,5,6,7,8,9,10,11]: # [6, 8, 10]
                print(f"searching with layer = {iit_layer}, start_index = {start_index}")
                intervention_ids_to_coords = {
                    0:[{"layer":iit_layer, "start":0+start_index, "end":hidden_dim_per_concept+start_index}],
                    1:[{"layer":iit_layer, "start":hidden_dim_per_concept+start_index, "end":2*hidden_dim_per_concept+start_index}],
                    2:[{"layer":iit_layer, "start":0+start_index, "end":hidden_dim_per_concept+start_index},
                       {"layer":iit_layer, "start":hidden_dim_per_concept+start_index, "end":2*hidden_dim_per_concept+start_index}],
                }
                for i in [5]: # 1, 2, 3, 4, 5
                    torch.cuda.empty_cache()
                    benchmark = IIBenchmarkMoNli(
                            variable_names=['LEX'],
                            data_parameters={
                                'train_size': data_size, 'test_size': test_data_size
                            },
                            model_parameters={
                                'weights_name': 'ishan/bert-base-uncased-mnli',
                                'max_length': 128,
                                'n_classes': 2,
                                'hidden_dim': 768,
                                'target_layers' : [iit_layer],
                                'target_dims':{
                                    "start" : 0,
                                    "end" : 786,
                                },
                                'debug':False, 
                                'device': device,
                                'static_search': True
                            },
                            training_parameters={
                                'warm_start': False, 'max_iter': 5, 'batch_size': 64, 'n_iter_no_change': 10000, 
                                'shuffle_train': False, 'eta': 0.002, 'device': device
                            },
                            seed=seed
                    )
                    LIM_bert = benchmark.create_model()
                    new_state_dict = {}
                    ORACLE_PATH = f"./saved_models_nli/basemodel-last-{num_layers}-{hidden_dim}-{seed}.bin"
                    for k, v in torch.load(ORACLE_PATH).items():
                        if "analysis_model" not in k:
                            new_state_dict[k] = v
                        else:
                            if int(k.split(".")[2]) <= iit_layer:
                                new_state_dict[k] = v
                            else:
                                new_layer_number = int(k.split(".")[2]) + 2
                                k_list = k.split(".")
                                k_list[2] = str(new_layer_number)
                                new_k = ".".join(k_list)
                                new_state_dict[new_k] = v
                    LIM_bert.load_state_dict(new_state_dict, strict=False)
                    LIM_trainer = benchmark.create_classifier(LIM_bert)
                    LIM_trainer.model.set_analysis_mode(True)

                    # train data eval
                    base_preds_train = LIM_trainer.predict(
                        base_test_train
                    )
                    IIT_preds_train = LIM_trainer.iit_predict(
                        base_test_train, sources_test_train, 
                        intervention_ids_test_train, 
                        intervention_ids_to_coords
                    )
                    r1_train = classification_report(y_base_test_train, base_preds_train.cpu(), output_dict=True)
                    r2_train = classification_report(y_IIT_test_train, IIT_preds_train.cpu(), output_dict=True)

                    # test data eval
                    base_preds = LIM_trainer.predict(
                        base_test
                    )
                    IIT_preds = LIM_trainer.iit_predict(
                        base_test, sources_test, 
                        intervention_ids_test, 
                        intervention_ids_to_coords
                    )
                    r1 = classification_report(y_base_test, base_preds.cpu(), output_dict=True)
                    r2 = classification_report(y_IIT_test, IIT_preds.cpu(), output_dict=True)

                    iit_layer_out = iit_layer + 1
                    control_token_1_control_results.append(
                        [
                            seed, hidden_dim, start_index, hidden_dim_per_concept, iit_layer_out, i, 
                            "Factual Train", r1_train["weighted avg"]["f1-score"]]
                    )
                    control_token_1_control_results.append(
                        [
                            seed, hidden_dim, start_index, hidden_dim_per_concept, iit_layer_out, i, 
                            "d-IIT Train", r2_train["weighted avg"]["f1-score"]]
                    )
                    control_token_1_control_results.append(
                        [
                            seed, hidden_dim, start_index, hidden_dim_per_concept, iit_layer_out, i, 
                            "Factual Test", r1["weighted avg"]["f1-score"]]
                    )
                    control_token_1_control_results.append(
                        [
                            seed, hidden_dim, start_index, hidden_dim_per_concept, iit_layer_out, i, 
                            "d-IIT Test", r2["weighted avg"]["f1-score"]]
                    )

In [None]:
control_token_1_control_df = pd.DataFrame(
    control_token_1_control_results,
    columns =['seed', 'hidden_dim', 'start_index', 'hidden_dim_per_concept', 'iit_layer', 'epoch', 
              'type', 'f1-score']
)

In [None]:
max(control_token_1_control_df[
    (control_token_1_control_df["type"]=="d-IIT Train")&
    (control_token_1_control_df["iit_layer"]==7)
]["f1-score"])

In [None]:
max(control_token_1_control_df[
    (control_token_1_control_df["type"]=="d-IIT Train")&
    (control_token_1_control_df["iit_layer"]==9)
]["f1-score"])

In [None]:
max(control_token_1_control_df[
    (control_token_1_control_df["type"]=="d-IIT Train")&
    (control_token_1_control_df["iit_layer"]==11)
]["f1-score"])

### Double Alignments

In [None]:
device = "cuda:0"
data_size = 10000
test_data_size = 1000
num_layers = 12
hidden_dim = 768

control_token_1_double_results = []

In [None]:
for seed in {42, 66, 77}: # {42, 66, 77}
    utils.fix_random_seeds(seed=seed)
    train_datasetIIT = get_IIT_nli_dataset_tokenidentity_V1(
        data_size=10000, 
        split="train",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    )
    X_base_train, y_base_train = train_datasetIIT[0:2]
    iit_data = tuple(train_datasetIIT[2:])

    base_test_train, y_base_test_train, sources_test_train, y_IIT_test_train, intervention_ids_test_train = \
        get_eval_from_train_nli(
            train_datasetIIT, 1000, control=True
        )
    
    base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = get_IIT_nli_dataset_tokenidentity_V1(
        data_size=1000, 
        split="test",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    )    
    for hidden_dim_per_concept in {32,}: # {64, 128, 256}
        for iit_layer in [8]: # [6, 8, 10]
            intervention_ids_to_coords = {
                1:[{"layer":iit_layer+1, "start":0, "end":hidden_dim_per_concept}]
            }
            for i in [5]: # 1, 2, 3, 4, 5
                torch.cuda.empty_cache()
                benchmark = IIBenchmarkMoNli(
                        variable_names=['LEX'],
                        data_parameters={
                            'train_size': data_size, 'test_size': test_data_size
                        },
                        model_parameters={
                            'weights_name': 'ishan/bert-base-uncased-mnli',
                            'max_length': 128,
                            'n_classes': 2,
                            'hidden_dim': 768,
                            'target_layers' : [iit_layer],
                            'target_dims':{
                                "start" : 0,
                                "end" : 786,
                            },
                            'debug':False, 
                            'device': device,
                            'static_search': False,
                            'nested_disentangle_inplace': True
                        },
                        training_parameters={
                            'warm_start': False, 'max_iter': 5, 'batch_size': 64, 'n_iter_no_change': 10000, 
                            'shuffle_train': True, 'eta': 0.002, 'device': device
                        },
                        seed=seed
                )
                LIM_bert = benchmark.create_model()
                new_state_dict = {}
                iit_layer = iit_layer + 1
                new_state_dict = {}
                ORACLE_PATH = f"./saved_models_nli/first_stage.bin"
                for k, v in torch.load(ORACLE_PATH).items():
                    if "analysis_model" not in k:
                        new_state_dict[k] = v
                    else:
                        if int(k.split(".")[2]) <= iit_layer:
                            new_state_dict[k] = v
                        else:
                            new_layer_number = int(k.split(".")[2]) + 2
                            k_list = k.split(".")
                            k_list[2] = str(new_layer_number)
                            new_k = ".".join(k_list)
                            new_state_dict[new_k] = v
                LIM_bert.load_state_dict(new_state_dict, strict=False)
                LIM_trainer = benchmark.create_classifier(LIM_bert)
                LIM_trainer.model.set_analysis_mode(True, layers=[iit_layer])

                _ = LIM_trainer.fit(
                    X_base_train, 
                    y_base_train, 
                    iit_data=iit_data,
                    intervention_ids_to_coords=intervention_ids_to_coords)
                
                # train data eval
                base_preds_train = LIM_trainer.predict(
                    base_test_train
                )
                IIT_preds_train = LIM_trainer.iit_predict(
                    base_test_train, sources_test_train, 
                    intervention_ids_test_train, 
                    intervention_ids_to_coords
                )
                r1_train = classification_report(y_base_test_train, base_preds_train.cpu(), output_dict=True)
                r2_train = classification_report(y_IIT_test_train, IIT_preds_train.cpu(), output_dict=True)

                # test data eval
                base_preds = LIM_trainer.predict(
                    base_test
                )
                IIT_preds = LIM_trainer.iit_predict(
                    base_test, sources_test, 
                    intervention_ids_test, 
                    intervention_ids_to_coords
                )
                r1 = classification_report(y_base_test, base_preds.cpu(), output_dict=True)
                r2 = classification_report(y_IIT_test, IIT_preds.cpu(), output_dict=True)
                
                iit_layer_out = iit_layer
                control_token_1_double_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "Factual Train", r1_train["weighted avg"]["f1-score"]]
                )
                control_token_1_double_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "d-IIT Train", r2_train["weighted avg"]["f1-score"]]
                )
                control_token_1_double_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "Factual Test", r1["weighted avg"]["f1-score"]]
                )
                control_token_1_double_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "d-IIT Test", r2["weighted avg"]["f1-score"]]
                )

In [None]:
control_token_1_double_df = pd.DataFrame(
    control_token_1_double_results,
    columns =['seed', 'hidden_dim', 'hidden_dim_per_concept', 'iit_layer', 'epoch', 
              'type', 'f1-score']
)

In [None]:
max(control_token_1_double_df[
    (control_token_1_double_df["type"] == "d-IIT Train")&
    (control_token_1_double_df["hidden_dim_per_concept"] == 32)
]["f1-score"])

In [None]:
control_token_1_double_df

### Evaluate closest localist

In [None]:
device = "cuda:0"
data_size = 24000
test_data_size = 1920
num_layers = 12
hidden_dim = 768
oracle_results = []
for seed in {77}: # {42, 66, 77}
    utils.fix_random_seeds(seed=seed)   
    train_datasetIIT = get_IIT_nli_dataset_neghyp(
        data_size=data_size, 
        split="train",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    )
    X_base_train, y_base_train = train_datasetIIT[0:2]
    iit_data = tuple(train_datasetIIT[2:])

    base_test_train, y_base_test_train, sources_test_train, y_IIT_test_train, intervention_ids_test_train = \
        get_eval_from_train_nli(
            train_datasetIIT, test_data_size, control=True
        )

    base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = get_IIT_nli_dataset_neghyp(
        data_size=test_data_size, 
        split="test",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    ) 
    for hidden_dim_per_concept in {256}: # {64, 128, 256}
        for iit_layer in [6, 8, 10]: # [6, 8, 10]
            intervention_ids_to_coords = {
                0:[{"layer":iit_layer, "start":0, "end":hidden_dim_per_concept}],
                1:[{"layer":iit_layer, "start":hidden_dim_per_concept, "end":2*hidden_dim_per_concept}],
                2:[{"layer":iit_layer, "start":0, "end":hidden_dim_per_concept},
                   {"layer":iit_layer, "start":hidden_dim_per_concept, "end":2*hidden_dim_per_concept}],
            }
            for i in [5]: # 1, 2, 3, 4, 5
                torch.cuda.empty_cache()
                benchmark = IIBenchmarkMoNli(
                        variable_names=['LEX'],
                        data_parameters={
                            'train_size': data_size, 'test_size': test_data_size
                        },
                        model_parameters={
                            'weights_name': 'ishan/bert-base-uncased-mnli',
                            'max_length': 128,
                            'n_classes': 2,
                            'hidden_dim': 768,
                            'target_layers' : [iit_layer],
                            'target_dims':{
                                "start" : 0,
                                "end" : 786,
                            },
                            'debug':False, 
                            'device': device,
                            'static_search': False,
                            'nested_disentangle_inplace': False
                        },
                        training_parameters={
                            'warm_start': False, 'max_iter': 5, 'batch_size': 64, 'n_iter_no_change': 10000, 
                            'shuffle_train': False, 'eta': 0.002, 'device': device
                        },
                        seed=seed
                )
                LIM_bert = benchmark.create_model()
                new_state_dict = {}
                ORACLE_PATH = f"./saved_models_nli/basemodel-{i}-{num_layers}-{hidden_dim}-{seed}.bin"
                for k, v in torch.load(ORACLE_PATH).items():
                    if "analysis_model" not in k:
                        new_state_dict[k] = v
                    else:
                        if int(k.split(".")[2]) <= iit_layer:
                            new_state_dict[k] = v
                        else:
                            new_layer_number = int(k.split(".")[2]) + 2
                            k_list = k.split(".")
                            k_list[2] = str(new_layer_number)
                            new_k = ".".join(k_list)
                            new_state_dict[new_k] = v
                LIM_bert.load_state_dict(torch.load(ORACLE_PATH), strict=False)
                LIM_trainer = benchmark.create_classifier(LIM_bert)
                LIM_trainer.model.set_analysis_mode(True)
                iit_layer_out = iit_layer + 1
                
                # load rotation matrix
                R = torch.load(
                    f"./saved_models_nli/oracle-rotation_matrix-{iit_layer_out}-{hidden_dim}-{hidden_dim_per_concept}-{seed}.bin"
                )
                abs_R = torch.abs(R)
                sign = -1 * torch.ones_like(R) * (R<0) + torch.ones_like(R) * (R>0)
                snapped_R = torch.zeros_like(R)

                local_max_pos = []
                for i in range(R.shape[0]):
                    local_max_pos += [(abs_R==torch.max(abs_R)).nonzero()]
                    select_row = local_max_pos[-1][0,0].tolist()
                    select_col = local_max_pos[-1][0,1].tolist()
                    abs_R[select_row, :] = 0.
                    abs_R[:, select_col] = 0.
                for pos in local_max_pos:
                    snapped_R[pos[0][0], pos[0][1]] = 1.

                snapped_R *= sign
                snapped_R_T = torch.transpose(snapped_R, 0, 1)
                snapped_R = snapped_R.to(LIM_trainer.model.analysis_model.layers[iit_layer_out].weight.device)
                LIM_trainer.model.analysis_model.layers[iit_layer_out].weight = snapped_R
                
                # train data eval
                base_preds_train = LIM_trainer.predict(
                    base_test_train
                )
                IIT_preds_train = LIM_trainer.iit_predict(
                    base_test_train, sources_test_train, 
                    intervention_ids_test_train, 
                    intervention_ids_to_coords
                )
                r1_train = classification_report(y_base_test_train, base_preds_train.cpu(), output_dict=True)
                r2_train = classification_report(y_IIT_test_train, IIT_preds_train.cpu(), output_dict=True)

                # test data eval
                base_preds = LIM_trainer.predict(
                    base_test
                )
                IIT_preds = LIM_trainer.iit_predict(
                    base_test, sources_test, 
                    intervention_ids_test, 
                    intervention_ids_to_coords
                )
                r1 = classification_report(y_base_test, base_preds.cpu(), output_dict=True)
                r2 = classification_report(y_IIT_test, IIT_preds.cpu(), output_dict=True)
                
                
                oracle_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "Factual Train", r1_train["weighted avg"]["f1-score"]]
                )
                oracle_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "d-IIT Train", r2_train["weighted avg"]["f1-score"]]
                )
                oracle_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "Factual Test", r1["weighted avg"]["f1-score"]]
                )
                oracle_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "d-IIT Test", r2["weighted avg"]["f1-score"]]
                )
                
oracle_df = pd.DataFrame(
    oracle_results,
    columns =['seed', 'hidden_dim', 'hidden_dim_per_concept', 'iit_layer', 'epoch', 
              'type', 'f1-score']
)

In [None]:
oracle_df[
    oracle_df["type"]=="d-IIT Train"
]

In [None]:
control_1_results = []
for seed in {77}: # {42, 66, 77}
    utils.fix_random_seeds(seed=seed)
    train_datasetIIT = get_IIT_nli_dataset_neghyp_V2(
        data_size=10000, 
        split="train",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    )
    X_base_train, y_base_train = train_datasetIIT[0:2]
    iit_data = tuple(train_datasetIIT[2:])

    base_test_train, y_base_test_train, sources_test_train, y_IIT_test_train, intervention_ids_test_train = \
        get_eval_from_train_nli(
            train_datasetIIT, 1000, control=True
        )
    
    base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = get_IIT_nli_dataset_neghyp_V2(
        data_size=1000, 
        split="test",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    )    
    for hidden_dim_per_concept in {256}: # {32, 64, 128}
        for iit_layer in [6, 8, 10]: # [6, 8, 10]
            intervention_ids_to_coords = {
                1:[{"layer":iit_layer, "start":0, "end":hidden_dim_per_concept}]
            }
            for i in [5]: # 1, 2, 3, 4, 5
                torch.cuda.empty_cache()
                benchmark = IIBenchmarkMoNli(
                        variable_names=['LEX'],
                        data_parameters={
                            'train_size': data_size, 'test_size': test_data_size
                        },
                        model_parameters={
                            'weights_name': 'ishan/bert-base-uncased-mnli',
                            'max_length': 128,
                            'n_classes': 2,
                            'hidden_dim': 768,
                            'target_layers' : [iit_layer],
                            'target_dims':{
                                "start" : 0,
                                "end" : 786,
                            },
                            'debug':False, 
                            'device': device,
                            'static_search': False,
                            'nested_disentangle_inplace': False
                        },
                        training_parameters={
                            'warm_start': False, 'max_iter': 5, 'batch_size': 64, 'n_iter_no_change': 10000, 
                            'shuffle_train': True, 'eta': 0.002, 'device': device
                        },
                        seed=seed
                )
                LIM_bert = benchmark.create_model()
                new_state_dict = {}
                ORACLE_PATH = f"./saved_models_nli/basemodel-last-{num_layers}-{hidden_dim}-{seed}.bin"
                for k, v in torch.load(ORACLE_PATH).items():
                    if "analysis_model" not in k:
                        new_state_dict[k] = v
                    else:
                        if int(k.split(".")[2]) <= iit_layer:
                            new_state_dict[k] = v
                        else:
                            new_layer_number = int(k.split(".")[2]) + 2
                            k_list = k.split(".")
                            k_list[2] = str(new_layer_number)
                            new_k = ".".join(k_list)
                            new_state_dict[new_k] = v
                LIM_bert.load_state_dict(new_state_dict, strict=False)
                LIM_trainer = benchmark.create_classifier(LIM_bert)
                LIM_trainer.model.set_analysis_mode(True)
                
                iit_layer_out = iit_layer + 1
                
                # load rotation matrix
                R = torch.load(
                    f"./saved_models_nli/control_1-rotation_matrix-{iit_layer_out}-{hidden_dim}-{hidden_dim_per_concept}-{seed}.bin"
                )
                abs_R = torch.abs(R)
                sign = -1 * torch.ones_like(R) * (R<0) + torch.ones_like(R) * (R>0)
                snapped_R = torch.zeros_like(R)

                local_max_pos = []
                for i in range(R.shape[0]):
                    local_max_pos += [(abs_R==torch.max(abs_R)).nonzero()]
                    select_row = local_max_pos[-1][0,0].tolist()
                    select_col = local_max_pos[-1][0,1].tolist()
                    abs_R[select_row, :] = 0.
                    abs_R[:, select_col] = 0.
                for pos in local_max_pos:
                    snapped_R[pos[0][0], pos[0][1]] = 1.

                snapped_R *= sign
                snapped_R_T = torch.transpose(snapped_R, 0, 1)
                snapped_R = snapped_R.to(LIM_trainer.model.analysis_model.layers[iit_layer_out].weight.device)
                LIM_trainer.model.analysis_model.layers[iit_layer_out].weight = snapped_R
                
                # train data eval
                base_preds_train = LIM_trainer.predict(
                    base_test_train
                )
                IIT_preds_train = LIM_trainer.iit_predict(
                    base_test_train, sources_test_train, 
                    intervention_ids_test_train, 
                    intervention_ids_to_coords
                )
                r1_train = classification_report(y_base_test_train, base_preds_train.cpu(), output_dict=True)
                r2_train = classification_report(y_IIT_test_train, IIT_preds_train.cpu(), output_dict=True)

                # test data eval
                base_preds = LIM_trainer.predict(
                    base_test
                )
                IIT_preds = LIM_trainer.iit_predict(
                    base_test, sources_test, 
                    intervention_ids_test, 
                    intervention_ids_to_coords
                )
                r1 = classification_report(y_base_test, base_preds.cpu(), output_dict=True)
                r2 = classification_report(y_IIT_test, IIT_preds.cpu(), output_dict=True)
                

                control_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "Factual Train", r1_train["weighted avg"]["f1-score"]]
                )
                control_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "d-IIT Train", r2_train["weighted avg"]["f1-score"]]
                )
                control_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "Factual Test", r1["weighted avg"]["f1-score"]]
                )
                control_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "d-IIT Test", r2["weighted avg"]["f1-score"]]
                )

In [None]:
control_1_df = pd.DataFrame(
    control_1_results,
    columns =['seed', 'hidden_dim', 'hidden_dim_per_concept', 'iit_layer', 'epoch', 
              'type', 'f1-score']
)
control_1_df[
    control_1_df["type"]=="d-IIT Train"
]

In [None]:
device = "cuda:0"
data_size = 10000
test_data_size = 1000
num_layers = 12
hidden_dim = 768

control_token_1_results = []
for seed in {77}: # {42, 66, 77}
    utils.fix_random_seeds(seed=seed)
    train_datasetIIT = get_IIT_nli_dataset_tokenidentity_V1(
        data_size=10000, 
        split="train",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    )
    X_base_train, y_base_train = train_datasetIIT[0:2]
    iit_data = tuple(train_datasetIIT[2:])

    base_test_train, y_base_test_train, sources_test_train, y_IIT_test_train, intervention_ids_test_train = \
        get_eval_from_train_nli(
            train_datasetIIT, 1000, control=True
        )
    
    base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = get_IIT_nli_dataset_tokenidentity_V1(
        data_size=1000, 
        split="test",
        tokenizer_name='ishan/bert-base-uncased-mnli'
    )    
    for hidden_dim_per_concept in {256}: # {64, 128, 256}
        for iit_layer in [6, 8, 10]: # [6, 8, 10]
            intervention_ids_to_coords = {
                1:[{"layer":iit_layer, "start":0, "end":hidden_dim_per_concept}]
            }
            for i in [5]: # 1, 2, 3, 4, 5
                torch.cuda.empty_cache()
                benchmark = IIBenchmarkMoNli(
                        variable_names=['LEX'],
                        data_parameters={
                            'train_size': data_size, 'test_size': test_data_size
                        },
                        model_parameters={
                            'weights_name': 'ishan/bert-base-uncased-mnli',
                            'max_length': 128,
                            'n_classes': 2,
                            'hidden_dim': 768,
                            'target_layers' : [iit_layer],
                            'target_dims':{
                                "start" : 0,
                                "end" : 786,
                            },
                            'debug':False, 
                            'device': device,
                            'static_search': False,
                            'nested_disentangle_inplace': False
                        },
                        training_parameters={
                            'warm_start': False, 'max_iter': 5, 'batch_size': 64, 'n_iter_no_change': 10000, 
                            'shuffle_train': True, 'eta': 0.002, 'device': device
                        },
                        seed=seed
                )
                LIM_bert = benchmark.create_model()
                new_state_dict = {}
                ORACLE_PATH = f"./saved_models_nli/basemodel-last-{num_layers}-{hidden_dim}-{seed}.bin"
                for k, v in torch.load(ORACLE_PATH).items():
                    if "analysis_model" not in k:
                        new_state_dict[k] = v
                    else:
                        if int(k.split(".")[2]) <= iit_layer:
                            new_state_dict[k] = v
                        else:
                            new_layer_number = int(k.split(".")[2]) + 2
                            k_list = k.split(".")
                            k_list[2] = str(new_layer_number)
                            new_k = ".".join(k_list)
                            new_state_dict[new_k] = v
                LIM_bert.load_state_dict(new_state_dict, strict=False)
                LIM_trainer = benchmark.create_classifier(LIM_bert)
                LIM_trainer.model.set_analysis_mode(True)
                
                iit_layer_out = iit_layer + 1
                
                # load rotation matrix
                R = torch.load(
                    f"./saved_models_nli/control_token_1-rotation_matrix-{iit_layer_out}-{hidden_dim}-{hidden_dim_per_concept}-{seed}.bin"
                )
                abs_R = torch.abs(R)
                sign = -1 * torch.ones_like(R) * (R<0) + torch.ones_like(R) * (R>0)
                snapped_R = torch.zeros_like(R)

                local_max_pos = []
                for i in range(R.shape[0]):
                    local_max_pos += [(abs_R==torch.max(abs_R)).nonzero()]
                    select_row = local_max_pos[-1][0,0].tolist()
                    select_col = local_max_pos[-1][0,1].tolist()
                    abs_R[select_row, :] = 0.
                    abs_R[:, select_col] = 0.
                for pos in local_max_pos:
                    snapped_R[pos[0][0], pos[0][1]] = 1.

                snapped_R *= sign
                snapped_R_T = torch.transpose(snapped_R, 0, 1)
                snapped_R = snapped_R.to(LIM_trainer.model.analysis_model.layers[iit_layer_out].weight.device)
                LIM_trainer.model.analysis_model.layers[iit_layer_out].weight = snapped_R
                
                # train data eval
                base_preds_train = LIM_trainer.predict(
                    base_test_train
                )
                IIT_preds_train = LIM_trainer.iit_predict(
                    base_test_train, sources_test_train, 
                    intervention_ids_test_train, 
                    intervention_ids_to_coords
                )
                r1_train = classification_report(y_base_test_train, base_preds_train.cpu(), output_dict=True)
                r2_train = classification_report(y_IIT_test_train, IIT_preds_train.cpu(), output_dict=True)

                # test data eval
                base_preds = LIM_trainer.predict(
                    base_test
                )
                IIT_preds = LIM_trainer.iit_predict(
                    base_test, sources_test, 
                    intervention_ids_test, 
                    intervention_ids_to_coords
                )
                r1 = classification_report(y_base_test, base_preds.cpu(), output_dict=True)
                r2 = classification_report(y_IIT_test, IIT_preds.cpu(), output_dict=True)
                
                

                control_token_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "Factual Train", r1_train["weighted avg"]["f1-score"]]
                )
                control_token_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "d-IIT Train", r2_train["weighted avg"]["f1-score"]]
                )
                control_token_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "Factual Test", r1["weighted avg"]["f1-score"]]
                )
                control_token_1_results.append(
                    [
                        seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                        "d-IIT Test", r2["weighted avg"]["f1-score"]]
                )

In [None]:
control_token_1_df = pd.DataFrame(
    control_token_1_results,
    columns =['seed', 'hidden_dim', 'hidden_dim_per_concept', 'iit_layer', 'epoch', 
              'type', 'f1-score']
)
control_token_1_df[
    control_token_1_df["type"]=="d-IIT Train"
]