In [None]:
from random import randrange
import torch, gc
import random
import copy
import itertools
import numpy as np
import utils
from trainer import LIMTrainer
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from sklearn.metrics import classification_report
from LIM_deep_neural_classifier import LIMDeepNeuralClassifier
import dataset_equality
import os.path

gc.collect()
torch.cuda.empty_cache()
utils.fix_random_seeds()

import seaborn as sns

def ab_c(a, b, c):
    return (a+b)*c

def get_IIT_arithmetic_dataset_factuals(
    variable_range, seed,
    data_size,
    combiner_func
):
    random.seed(seed)
    base = [[random.randint(
        variable_range[0], variable_range[1]
    ) for _ in range(3)] for _ in range(data_size)]
    base_y = []
    for b in base:
        b_y = combiner_func(*b)
        base_y += [b_y]
        
    return torch.tensor(base, dtype=torch.long), torch.tensor(base_y, dtype=torch.long)

def get_IIT_arithmetic_dataset_factual_pairs(
    variable_range, seed,
    data_size,
    combiner_func
):
    random.seed(seed)
    base = [[random.randint(
        variable_range[0], variable_range[1]
    ) for _ in range(3)] for _ in range(data_size)]
    base_y = []
    for b in base:
        b_y = combiner_func(*b)
        base_y += [b_y]
        
    source = [[random.randint(
        variable_range[0], variable_range[1]
    ) for _ in range(3)] for _ in range(data_size)]
    source_y = []
    for s in source:
        s_y = combiner_func(*s)
        source_y += [s_y]
        
    return base, \
        base_y, \
        source, \
        source_y

def get_IIT_arithmetic_dataset_sum_first_V1(
    variable_range, seed,
    data_size,
    combiner_func
):
    base, base_y, source, source_y = get_IIT_arithmetic_dataset_factual_pairs(
        variable_range, seed,
        data_size,
        combiner_func
    )
    
    source_sum = []
    for s in source:
        s_sum = s[0]+s[1]
        source_sum += [s_sum]
    
    counterfactual_y = []
    for i in range(len(base)):
        c_y = source_sum[i]*base[i][-1]
        counterfactual_y += [c_y]
        
    return torch.tensor(base, dtype=torch.long), \
        torch.tensor(base_y, dtype=torch.long), \
        [torch.tensor(source, dtype=torch.long), torch.tensor(source, dtype=torch.long)], \
        torch.tensor(counterfactual_y, dtype=torch.long), \
        torch.tensor([0 for _ in range(len(base))], dtype=torch.long)

def get_IIT_arithmetic_dataset_sum_first_V2(
    variable_range, seed,
    data_size,
    combiner_func
):
    base, base_y, source, source_y = get_IIT_arithmetic_dataset_factual_pairs(
        variable_range, seed,
        data_size,
        combiner_func
    )
    
    counterfactual_y = []
    for i in range(len(base)):
        c_y = (base[i][0]+base[i][1])*source[i][-1]
        counterfactual_y += [c_y]
        
    return torch.tensor(base, dtype=torch.long), \
        torch.tensor(base_y, dtype=torch.long), \
        [torch.tensor(source, dtype=torch.long), torch.tensor(source, dtype=torch.long)], \
        torch.tensor(counterfactual_y, dtype=torch.long), \
        torch.tensor([1 for _ in range(len(base))], dtype=torch.long)

def get_IIT_arithmetic_dataset_sum_first_V1_V2(
    variable_range, seed,
    data_size,
    combiner_func
):
    base_1, base_y_1, source_1, source_y_1 = get_IIT_arithmetic_dataset_factual_pairs(
        variable_range, seed,
        data_size,
        combiner_func
    )
    base_2, base_y_2, source_2, source_y_2 = get_IIT_arithmetic_dataset_factual_pairs(
        variable_range, seed+1,
        data_size,
        combiner_func
    )
    
    source_sum_1 = []
    for s in source_1:
        s_sum = s[0]+s[1]
        source_sum_1 += [s_sum]
    
    counterfactual_y = []
    for i in range(len(source_2)):
        c_y = source_sum_1[i]*source_2[i][-1]
        counterfactual_y += [c_y]
    
    return torch.tensor(base_1, dtype=torch.long), \
        torch.tensor(base_y_1, dtype=torch.long), \
        [torch.tensor(source_1, dtype=torch.long), torch.tensor(source_2, dtype=torch.long)], \
        torch.tensor(counterfactual_y, dtype=torch.long), \
        torch.tensor([2 for _ in range(len(base_1))], dtype=torch.long)

def get_IIT_arithmetic_dataset_sum_first(
    variable_range, seed,
    data_size,
    combiner_func
):
    V1_dataset = get_IIT_arithmetic_dataset_sum_first_V1(
        variable_range, seed,
        data_size,
        combiner_func
    )
    V2_dataset = get_IIT_arithmetic_dataset_sum_first_V2(
        variable_range, seed,
        data_size,
        combiner_func
    )
    both_dataset = get_IIT_arithmetic_dataset_sum_first_V1_V2(
        variable_range, seed,
        data_size,
        combiner_func
    )
    combined_dataset = [torch.cat((V1_dataset[0],
                                    V2_dataset[0],
                                    both_dataset[0])),
                       torch.cat((V1_dataset[1],
                                    V2_dataset[1],
                                    both_dataset[1])),
                       [torch.cat((V1_dataset[2][0],
                                    V2_dataset[2][0],
                                    both_dataset[2][0])),
                       torch.cat((V1_dataset[2][0],
                                    V2_dataset[2][0],
                                    both_dataset[2][1]))],
                       torch.cat((V1_dataset[3],
                                    V2_dataset[3],
                                    both_dataset[3])),
                       torch.cat((V1_dataset[4],
                                    V2_dataset[4],
                                    both_dataset[4]))]
    return combined_dataset

def get_IIT_arithmetic_dataset_sum_first_control_V1(
    variable_range, seed,
    data_size,
    combiner_func
):
    base, base_y, source, source_y = get_IIT_arithmetic_dataset_factual_pairs(
        variable_range, seed,
        data_size,
        combiner_func
    )
    
    source_sum = []
    for s in source:
        s_sum = s[0]+s[1]
        source_sum += [s_sum]
    
    counterfactual_y = []
    for i in range(len(base)):
        c_y = source_sum[i]*base[i][-1]
        counterfactual_y += [c_y]
        
    return torch.tensor(base, dtype=torch.long), \
        torch.tensor(base_y, dtype=torch.long), \
        [torch.tensor(source, dtype=torch.long)], \
        torch.tensor(counterfactual_y, dtype=torch.long), \
        torch.tensor([3 for _ in range(len(base))], dtype=torch.long)

def get_IIT_arithmetic_dataset_control_V1(
    variable_range, seed,
    data_size,
    combiner_func
):
    base, base_y, source, source_y = get_IIT_arithmetic_dataset_factual_pairs(
        variable_range, seed,
        data_size,
        combiner_func
    )
    
    counterfactual_y = []
    for i in range(len(base)):
        c_y = combiner_func(source[i][0], base[i][1], base[i][-1])
        counterfactual_y += [c_y]
        
    return torch.tensor(base, dtype=torch.long), \
        torch.tensor(base_y, dtype=torch.long), \
        [torch.tensor(source, dtype=torch.long)], \
        torch.tensor(counterfactual_y, dtype=torch.long), \
        torch.tensor([3 for _ in range(len(base))], dtype=torch.long)

def get_IIT_arithmetic_dataset_prod_first_V1(
    variable_range, seed,
    data_size,
    combiner_func
):
    base, base_y, source, source_y = get_IIT_arithmetic_dataset_factual_pairs(
        variable_range, seed,
        data_size,
        combiner_func
    )
    
    source_prod = []
    for s in source:
        s_prod = s[0]*s[2]
        source_prod += [s_prod]
    
    counterfactual_y = []
    for i in range(len(base)):
        c_y = source_prod[i] + (base[i][1]*base[i][2])
        counterfactual_y += [c_y]
        
    return torch.tensor(base, dtype=torch.long), \
        torch.tensor(base_y, dtype=torch.long), \
        [torch.tensor(source, dtype=torch.long), torch.tensor(source, dtype=torch.long)], \
        torch.tensor(counterfactual_y, dtype=torch.long), \
        torch.tensor([0 for _ in range(len(base))], dtype=torch.long)

def get_IIT_arithmetic_dataset_prod_first_V2(
    variable_range, seed,
    data_size,
    combiner_func
):
    base, base_y, source, source_y = get_IIT_arithmetic_dataset_factual_pairs(
        variable_range, seed,
        data_size,
        combiner_func
    )
    
    source_prod = []
    for s in source:
        s_prod = s[1]*s[2]
        source_prod += [s_prod]
    
    counterfactual_y = []
    for i in range(len(base)):
        c_y = source_prod[i] + (base[i][0]*base[i][2])
        counterfactual_y += [c_y]
        
    return torch.tensor(base, dtype=torch.long), \
        torch.tensor(base_y, dtype=torch.long), \
        [torch.tensor(source, dtype=torch.long), torch.tensor(source, dtype=torch.long)], \
        torch.tensor(counterfactual_y, dtype=torch.long), \
        torch.tensor([1 for _ in range(len(base))], dtype=torch.long)

def get_IIT_arithmetic_dataset_prod_first_V1_V2(
    variable_range, seed,
    data_size,
    combiner_func
):
    base_1, base_y_1, source_1, source_y_1 = get_IIT_arithmetic_dataset_factual_pairs(
        variable_range, seed,
        data_size,
        combiner_func
    )
    base_2, base_y_2, source_2, source_y_2 = get_IIT_arithmetic_dataset_factual_pairs(
        variable_range, seed+1,
        data_size,
        combiner_func
    )
    
    source_prod_1 = []
    for s in source_1:
        s_prod = s[0]*s[2]
        source_prod_1 += [s_prod]
    
    counterfactual_y = []
    for i in range(len(source_2)):
        c_y = source_prod_1[i] + (source_2[i][1]*source_2[i][2])
        counterfactual_y += [c_y]
    
    return torch.tensor(base_1, dtype=torch.long), \
        torch.tensor(base_y_1, dtype=torch.long), \
        [torch.tensor(source_1, dtype=torch.long), torch.tensor(source_2, dtype=torch.long)], \
        torch.tensor(counterfactual_y, dtype=torch.long), \
        torch.tensor([2 for _ in range(len(base_1))], dtype=torch.long)

def get_IIT_arithmetic_dataset_prod_first(
    variable_range, seed,
    data_size,
    combiner_func
):
    V1_dataset = get_IIT_arithmetic_dataset_prod_first_V1(
        variable_range, seed,
        data_size,
        combiner_func
    )
    V2_dataset = get_IIT_arithmetic_dataset_prod_first_V2(
        variable_range, seed,
        data_size,
        combiner_func
    )
    both_dataset = get_IIT_arithmetic_dataset_prod_first_V1_V2(
        variable_range, seed,
        data_size,
        combiner_func
    )
    combined_dataset = [torch.cat((V1_dataset[0],
                                    V2_dataset[0],
                                    both_dataset[0])),
                       torch.cat((V1_dataset[1],
                                    V2_dataset[1],
                                    both_dataset[1])),
                       [torch.cat((V1_dataset[2][0],
                                    V2_dataset[2][0],
                                    both_dataset[2][0])),
                       torch.cat((V1_dataset[2][0],
                                    V2_dataset[2][0],
                                    both_dataset[2][1]))],
                       torch.cat((V1_dataset[3],
                                    V2_dataset[3],
                                    both_dataset[3])),
                       torch.cat((V1_dataset[4],
                                    V2_dataset[4],
                                    both_dataset[4]))]
    return combined_dataset

### Train Factual Models

In [None]:
# Fixed variables
device = "cpu"
data_size = 6400
num_layers = 3
max_iter = 500
variable_range = [1, 6]

embedding_dim = variable_range[1]-variable_range[0]+1
input_dim = embedding_dim * 3
min_y = (variable_range[0]+variable_range[0])*variable_range[0]
max_y = (variable_range[1]+variable_range[1])*variable_range[1]
classes = sorted(set([i for i in range(min_y, max_y+1)]))
n_classes = (variable_range[1]+variable_range[1])*variable_range[1]
class2index = dict(zip(classes, range(n_classes)))

In [None]:
for seed in {42, 77, 88}:
    for hidden_dim in {18, 36}: # {18, 36}
        utils.fix_random_seeds(seed=seed)
        print(f"training factual model for seed={seed}")
        X_base_train, y_base_train = get_IIT_arithmetic_dataset_factuals(
            variable_range, seed, data_size, ab_c
        )
        LIM = LIMDeepNeuralClassifier(
            hidden_dim=hidden_dim, 
            hidden_activation=torch.nn.ReLU(), 
            num_layers=num_layers,
            input_dim=input_dim,
            n_classes=n_classes,
            device=device,
            vocab_size=variable_range[1]-variable_range[0]+1,
            embed_dim=embedding_dim,
        )
        
        LIM_trainer = LIMTrainer(
            LIM,
            warm_start=True,
            max_iter=max_iter,
            batch_size=6400,
            n_iter_no_change=10000,
            shuffle_train=False,
            eta=0.01,
            input_as_ids=True,
            device=device,
            class2index=class2index,
            save_checkpoint_per_epoch=True,
            seed=seed
        )
        
        _ = LIM_trainer.fit(
            X_base_train, 
            y_base_train, 
            iit_data=None,
            intervention_ids_to_coords=None
        )
        
        PATH = f"./saved_models_arithmetic/basemodel-last-{num_layers}-{hidden_dim}-{seed}.bin"
        torch.save(LIM_trainer.model.state_dict(), PATH)

### IIT Dataset Generation

### Train Oracle Models

In [None]:
dataset_fun = get_IIT_arithmetic_dataset_sum_first
iit_data_size = 64000
iit_max_iter = 50
model_c = 1
total_model_c = 135
for seed in {42, 77, 88}: # {42, 77, 88}
    utils.fix_random_seeds(seed=seed)
    train_datasetIIT = dataset_fun(
        variable_range, seed, iit_data_size, ab_c
    )
    X_base_train, y_base_train = train_datasetIIT[0:2]
    iit_data = tuple(train_datasetIIT[2:])
    for hidden_dim in {18, 36}: # {18, 36}
        for hidden_dim_per_concept in {1, 2, 6}: # {1, 2, 6}
            for iit_layer in [0, 1, 2]: # [0, 1, 2]
                scale_factor = hidden_dim/9
                small_scale_factor = hidden_dim/18
                id_to_coords = {
                    0: [{"layer": iit_layer, "start": 0, "end": int(scale_factor*hidden_dim_per_concept)}],
                    1: [{"layer": iit_layer, "start":  int(scale_factor*hidden_dim_per_concept), "end": int((scale_factor+small_scale_factor)*hidden_dim_per_concept)}],
                    2: [{"layer": iit_layer, "start": 0, "end": int(scale_factor*hidden_dim_per_concept)}, 
                        {"layer": iit_layer, "start":  int(scale_factor*hidden_dim_per_concept), "end": int((scale_factor+small_scale_factor)*hidden_dim_per_concept)}],
                }
                print("id_to_coords: ", id_to_coords)
                for i in [100, 200, 300, 400, 500]:
                    
                    iit_layer_out = iit_layer + 1
                    PATH = f"./saved_models_arithmetic/iit-oraclemodel-epoch{i}-"\
                           f"{iit_layer_out}-{hidden_dim}-{hidden_dim_per_concept}-"\
                           f"{seed}.bin"
                    if os.path.isfile(PATH):
                        print(f"Found trained model thus skip: {PATH}")
                        continue
                    
                    print(f"Training {model_c}/{total_model_c} model aligned with oracle hlm with params:")
                    print(f"seed={seed}")
                    print(f"hidden_dim_per_concept={hidden_dim_per_concept}")
                    print(f"iit_layer={iit_layer}")
                    print(f"epoch={i}")
                    LIM = LIMDeepNeuralClassifier(
                        hidden_dim=hidden_dim, 
                        hidden_activation=torch.nn.ReLU(), 
                        num_layers=num_layers,
                        input_dim=input_dim,
                        n_classes=n_classes,
                        device=device,
                        vocab_size=variable_range[1]-variable_range[0]+1,
                        embed_dim=embedding_dim,
                    )
                    LIM_trainer = LIMTrainer(
                        LIM,
                        warm_start=True,
                        max_iter=iit_max_iter,
                        batch_size=6400,
                        n_iter_no_change=10000,
                        shuffle_train=False,
                        eta=0.01,
                        input_as_ids=True,
                        device=device,
                        class2index=class2index,
                        save_checkpoint_per_epoch=False,
                        seed=seed
                    )
        
                    ORACLE_PATH = f"./saved_models_arithmetic/basemodel-{i}-{num_layers}-{hidden_dim}-{seed}.bin"
                    LIM_trainer.model.load_state_dict(torch.load(ORACLE_PATH))
                    LIM_trainer.model.set_analysis_mode(True)

                    _ = LIM_trainer.fit(
                        X_base_train, 
                        y_base_train, 
                        iit_data=iit_data,
                        intervention_ids_to_coords=id_to_coords)

                    torch.save(LIM_trainer.model.state_dict(), PATH)

                    model_c += 1

### Eval Oracle Models

In [None]:
oracle_results = []
dataset_fun = get_IIT_arithmetic_dataset_sum_first
iit_data_size = 64000
iit_max_iter = 50
model_c = 1
total_model_c = 135
for seed in {42, 77, 88}: # {42, 77, 88}
    utils.fix_random_seeds(seed=seed)
    train_datasetIIT = dataset_fun(
        variable_range, seed, iit_data_size, ab_c
    )
    base_test_train, y_base_test_train, sources_test_train, y_IIT_test_train, intervention_ids_test_train = \
        utils.get_eval_from_train(
            train_datasetIIT, 6400
        )
    base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = \
        dataset_fun(
            variable_range, seed+1, 6400, ab_c
        )
    for hidden_dim in {18, 36}: # {18, 36}
        for hidden_dim_per_concept in {1, 2, 6}: # {1, 2, 6}
            for iit_layer in [0, 1, 2]: # [0, 1, 2]
                scale_factor = hidden_dim/9
                small_scale_factor = hidden_dim/18
                id_to_coords = {
                    0: [{"layer": iit_layer, "start": 0, "end": int(scale_factor*hidden_dim_per_concept)}],
                    1: [{"layer": iit_layer, "start":  int(scale_factor*hidden_dim_per_concept), "end": int((scale_factor+small_scale_factor)*hidden_dim_per_concept)}],
                    2: [{"layer": iit_layer, "start": 0, "end": int(scale_factor*hidden_dim_per_concept)}, 
                        {"layer": iit_layer, "start":  int(scale_factor*hidden_dim_per_concept), "end": int((scale_factor+small_scale_factor)*hidden_dim_per_concept)}],
                }
                for i in [100, 200, 300, 400, 500]:
                    print(f"Evaluating {model_c}/270 model aligned with oracle hlm.")
                    LIM = LIMDeepNeuralClassifier(
                        hidden_dim=hidden_dim, 
                        hidden_activation=torch.nn.ReLU(), 
                        num_layers=num_layers,
                        input_dim=input_dim,
                        n_classes=n_classes,
                        device=device,
                        vocab_size=variable_range[1]-variable_range[0]+1,
                        embed_dim=embedding_dim,
                    )
                    LIM_trainer = LIMTrainer(
                        LIM,
                        warm_start=True,
                        max_iter=iit_max_iter,
                        batch_size=6400,
                        n_iter_no_change=10000,
                        shuffle_train=False,
                        eta=0.01,
                        input_as_ids=True,
                        device=device,
                        class2index=class2index,
                        save_checkpoint_per_epoch=False,
                        seed=seed
                    )

                    iit_layer_out = iit_layer + 1
                    PATH = f"./saved_models_arithmetic/iit-oraclemodel-epoch{i}-"\
                           f"{iit_layer_out}-{hidden_dim}-{hidden_dim_per_concept}-"\
                           f"{seed}.bin"
                    LIM_trainer.model.load_state_dict(torch.load(PATH))
                    LIM_trainer.model.set_analysis_mode(True)
                    
                    # train data eval
                    base_preds_train = LIM_trainer.predict(
                        base_test_train, device="cpu"
                    )
                    IIT_preds_train = LIM_trainer.iit_predict(
                        base_test_train, sources_test_train, 
                        intervention_ids_test_train, 
                        id_to_coords, device="cpu"
                    )
                    r1_train = classification_report(y_base_test_train, base_preds_train, output_dict=True)
                    r2_train = classification_report(y_IIT_test_train, IIT_preds_train, output_dict=True)

                    # test data eval
                    base_preds = LIM_trainer.predict(
                        base_test, device="cpu"
                    )
                    IIT_preds = LIM_trainer.iit_predict(
                        base_test, sources_test, 
                        intervention_ids_test, 
                        id_to_coords, device="cpu"
                    )
                    r1 = classification_report(y_base_test, base_preds, output_dict=True)
                    r2 = classification_report(y_IIT_test, IIT_preds, output_dict=True)

                    oracle_results.append(
                        [
                            seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                            "Factual Train", r1_train["weighted avg"]["f1-score"]]
                    )
                    oracle_results.append(
                        [
                            seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                            "d-IIT Train", r2_train["weighted avg"]["f1-score"]]
                    )
                    oracle_results.append(
                        [
                            seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                            "Factual Test", r1["weighted avg"]["f1-score"]]
                    )
                    oracle_results.append(
                        [
                            seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                            "d-IIT Test", r2["weighted avg"]["f1-score"]]
                    )
                    model_c += 1
oracle_df = pd.DataFrame(
    oracle_results,
    columns =['seed', 'hidden_dim', 'hidden_dim_per_concept', 'iit_layer', 'epoch', 
              'type', 'f1-score']
)

### Train (0, 1) control model with d-IIT

In [None]:
dataset_fun = get_IIT_arithmetic_dataset_sum_first_control_V1
iit_data_size = 64000
iit_max_iter = 50
model_c = 1
total_model_c = 270
for seed in {42, 77, 88}: # {42, 77, 88}
    utils.fix_random_seeds(seed=seed)
    train_datasetIIT = dataset_fun(
        variable_range, seed, iit_data_size, ab_c
    )
    X_base_train, y_base_train = train_datasetIIT[0:2]
    iit_data = tuple(train_datasetIIT[2:])
    for hidden_dim in {18, 36}: # {18, 36}
        for hidden_dim_per_concept in {1, 2, 6}: # {1, 2, 6}
            for iit_layer in [0, 1, 2]: # [0, 1, 2]
                scale_factor = hidden_dim/9
                small_scale_factor = hidden_dim/18
                control = 3
                id_to_coords = {
                    control: [
                        {
                            "layer": iit_layer, 
                            "start": 0, 
                            "end": int(scale_factor*hidden_dim_per_concept)
                        }
                    ],
                }
                print("id_to_coords: ", id_to_coords)
                for i in [100, 200, 300, 400, 500]:
                    
                    iit_layer_out = iit_layer + 1
                    PATH = f"./saved_models_arithmetic/iit-controlmodel-0-epoch{i}-"\
                           f"{iit_layer_out}-{hidden_dim}-{hidden_dim_per_concept}-"\
                           f"{seed}.bin"
                    if os.path.isfile(PATH):
                        print(f"Found trained model thus skip: {PATH}")
                        continue
                    
                    print(f"Training {model_c}/{total_model_c} model aligned with oracle hlm with params:")
                    print(f"seed={seed}")
                    print(f"hidden_dim_per_concept={hidden_dim_per_concept}")
                    print(f"iit_layer={iit_layer}")
                    print(f"epoch={i}")
                    LIM = LIMDeepNeuralClassifier(
                        hidden_dim=hidden_dim, 
                        hidden_activation=torch.nn.ReLU(), 
                        num_layers=num_layers,
                        input_dim=input_dim,
                        n_classes=n_classes,
                        device=device,
                        vocab_size=variable_range[1]-variable_range[0]+1,
                        embed_dim=embedding_dim,
                    )
                    LIM_trainer = LIMTrainer(
                        LIM,
                        warm_start=True,
                        max_iter=iit_max_iter,
                        batch_size=6400,
                        n_iter_no_change=10000,
                        shuffle_train=False,
                        eta=0.01,
                        input_as_ids=True,
                        device=device,
                        class2index=class2index,
                        save_checkpoint_per_epoch=False,
                        seed=seed
                    )
        
                    ORACLE_PATH = f"./saved_models_arithmetic/basemodel-{i}-{num_layers}-{hidden_dim}-{seed}.bin"
                    LIM_trainer.model.load_state_dict(torch.load(ORACLE_PATH))
                    LIM_trainer.model.set_analysis_mode(True)

                    _ = LIM_trainer.fit(
                        X_base_train, 
                        y_base_train, 
                        iit_data=iit_data,
                        intervention_ids_to_coords=id_to_coords)

                    torch.save(LIM_trainer.model.state_dict(), PATH)

                    model_c += 1

### Eval (0, 1) control model with d-IIT

In [None]:
control_0_results = []
dataset_fun = get_IIT_arithmetic_dataset_sum_first_control_V1
iit_data_size = 64000
iit_max_iter = 50
model_c = 1
total_model_c = 135
for seed in {42, 77, 88}: # {42, 77, 88}
    utils.fix_random_seeds(seed=seed)
    train_datasetIIT = dataset_fun(
        variable_range, seed, iit_data_size, ab_c
    )
    base_test_train, y_base_test_train, sources_test_train, y_IIT_test_train, intervention_ids_test_train = \
        utils.get_eval_from_train(
            train_datasetIIT, 6400, control=True
        )
    base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = \
        dataset_fun(
            variable_range, seed+1, 6400, ab_c
        )
    for hidden_dim in {18, 36}: # {18, 36}
        for hidden_dim_per_concept in {1, 2, 6}: # {1, 2, 6}
            for iit_layer in [0, 1, 2]: # [0, 1, 2]
                scale_factor = hidden_dim/9
                small_scale_factor = hidden_dim/18
                control = 3
                id_to_coords = {
                    control: [
                        {
                            "layer": iit_layer, 
                            "start": 0, 
                            "end": int(scale_factor*hidden_dim_per_concept)
                        }
                    ],
                }
                for i in [100, 200, 300, 400, 500]:
                    print(f"Evaluating {model_c}/270 model aligned with oracle hlm.")
                    LIM = LIMDeepNeuralClassifier(
                        hidden_dim=hidden_dim, 
                        hidden_activation=torch.nn.ReLU(), 
                        num_layers=num_layers,
                        input_dim=input_dim,
                        n_classes=n_classes,
                        device=device,
                        vocab_size=variable_range[1]-variable_range[0]+1,
                        embed_dim=embedding_dim,
                    )
                    LIM_trainer = LIMTrainer(
                        LIM,
                        warm_start=True,
                        max_iter=iit_max_iter,
                        batch_size=6400,
                        n_iter_no_change=10000,
                        shuffle_train=False,
                        eta=0.01,
                        input_as_ids=True,
                        device=device,
                        class2index=class2index,
                        save_checkpoint_per_epoch=False,
                        seed=seed
                    )

                    iit_layer_out = iit_layer + 1
                    PATH = f"./saved_models_arithmetic/iit-controlmodel-0-epoch{i}-"\
                           f"{iit_layer_out}-{hidden_dim}-{hidden_dim_per_concept}-"\
                           f"{seed}.bin"
                    LIM_trainer.model.load_state_dict(torch.load(PATH))
                    LIM_trainer.model.set_analysis_mode(True)
                    
                    # train data eval
                    base_preds_train = LIM_trainer.predict(
                        base_test_train, device="cpu"
                    )
                    IIT_preds_train = LIM_trainer.iit_predict(
                        base_test_train, sources_test_train, 
                        intervention_ids_test_train, 
                        id_to_coords, device="cpu"
                    )
                    r1_train = classification_report(y_base_test_train, base_preds_train, output_dict=True)
                    r2_train = classification_report(y_IIT_test_train, IIT_preds_train, output_dict=True)

                    # test data eval
                    base_preds = LIM_trainer.predict(
                        base_test, device="cpu"
                    )
                    IIT_preds = LIM_trainer.iit_predict(
                        base_test, sources_test, 
                        intervention_ids_test, 
                        id_to_coords, device="cpu"
                    )
                    r1 = classification_report(y_base_test, base_preds, output_dict=True)
                    r2 = classification_report(y_IIT_test, IIT_preds, output_dict=True)

                    control_0_results.append(
                        [
                            seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                            "Factual Train", r1_train["weighted avg"]["f1-score"]]
                    )
                    control_0_results.append(
                        [
                            seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                            "d-IIT Train", r2_train["weighted avg"]["f1-score"]]
                    )
                    control_0_results.append(
                        [
                            seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                            "Factual Test", r1["weighted avg"]["f1-score"]]
                    )
                    control_0_results.append(
                        [
                            seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                            "d-IIT Test", r2["weighted avg"]["f1-score"]]
                    )
                    model_c += 1
control_0_df = pd.DataFrame(
    control_0_results,
    columns =['seed', 'hidden_dim', 'hidden_dim_per_concept', 'iit_layer', 'epoch', 
              'type', 'f1-score']
)

### Train (0, ) control model with d-IIT

In [None]:
dataset_fun = get_IIT_arithmetic_dataset_control_V1
iit_data_size = 64000
iit_max_iter = 50
model_c = 1
total_model_c = 270
for seed in {42, 77, 88}: # {42, 77, 88}
    utils.fix_random_seeds(seed=seed)
    train_datasetIIT = dataset_fun(
        variable_range, seed, iit_data_size, ab_c
    )
    X_base_train, y_base_train = train_datasetIIT[0:2]
    iit_data = tuple(train_datasetIIT[2:])
    for hidden_dim in {18, 36}: # {18, 36}
        for hidden_dim_per_concept in {1, 2, 6}: # {1, 2, 6}
            for iit_layer in [0, 1, 2]: # [0, 1, 2]
                scale_factor = hidden_dim/9
                small_scale_factor = hidden_dim/18
                control = 3
                id_to_coords = {
                    control: [
                        {
                            "layer": iit_layer, 
                            "start": 0, 
                            "end": int(scale_factor*hidden_dim_per_concept*0.5)
                        }
                    ],
                }
                print("id_to_coords: ", id_to_coords)
                for i in [100, 200, 300, 400, 500]:
                    
                    iit_layer_out = iit_layer + 1
                    PATH = f"./saved_models_arithmetic/iit-controlmodel-1-epoch{i}-"\
                           f"{iit_layer_out}-{hidden_dim}-{hidden_dim_per_concept}-"\
                           f"{seed}.bin"
                    if os.path.isfile(PATH):
                        print(f"Found trained model thus skip: {PATH}")
                        continue
                    
                    print(f"Training {model_c}/{total_model_c} model aligned with oracle hlm with params:")
                    print(f"seed={seed}")
                    print(f"hidden_dim_per_concept={hidden_dim_per_concept}")
                    print(f"iit_layer={iit_layer}")
                    print(f"epoch={i}")
                    LIM = LIMDeepNeuralClassifier(
                        hidden_dim=hidden_dim, 
                        hidden_activation=torch.nn.ReLU(), 
                        num_layers=num_layers,
                        input_dim=input_dim,
                        n_classes=n_classes,
                        device=device,
                        vocab_size=variable_range[1]-variable_range[0]+1,
                        embed_dim=embedding_dim,
                    )
                    LIM_trainer = LIMTrainer(
                        LIM,
                        warm_start=True,
                        max_iter=iit_max_iter,
                        batch_size=6400,
                        n_iter_no_change=10000,
                        shuffle_train=False,
                        eta=0.01,
                        input_as_ids=True,
                        device=device,
                        class2index=class2index,
                        save_checkpoint_per_epoch=False,
                        seed=seed
                    )
        
                    ORACLE_PATH = f"./saved_models_arithmetic/basemodel-{i}-{num_layers}-{hidden_dim}-{seed}.bin"
                    LIM_trainer.model.load_state_dict(torch.load(ORACLE_PATH))
                    LIM_trainer.model.set_analysis_mode(True)

                    _ = LIM_trainer.fit(
                        X_base_train, 
                        y_base_train, 
                        iit_data=iit_data,
                        intervention_ids_to_coords=id_to_coords)

                    torch.save(LIM_trainer.model.state_dict(), PATH)

                    model_c += 1

### Eval (0, ) control model with d-IIT

In [None]:
control_1_results = []
dataset_fun = get_IIT_arithmetic_dataset_control_V1
iit_data_size = 64000
iit_max_iter = 50
model_c = 1
total_model_c = 135
for seed in {42, 77, 88}: # {42, 77, 88}
    utils.fix_random_seeds(seed=seed)
    train_datasetIIT = dataset_fun(
        variable_range, seed, iit_data_size, ab_c
    )
    base_test_train, y_base_test_train, sources_test_train, y_IIT_test_train, intervention_ids_test_train = \
        utils.get_eval_from_train(
            train_datasetIIT, 6400, control=True
        )
    base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = \
        dataset_fun(
            variable_range, seed+1, 6400, ab_c
        )
    for hidden_dim in {18, 36}: # {18, 36}
        for hidden_dim_per_concept in {1, 2, 6}: # {1, 2, 6}
            for iit_layer in [0, 1, 2]: # [0, 1, 2]
                scale_factor = hidden_dim/9
                small_scale_factor = hidden_dim/18
                control = 3
                id_to_coords = {
                    control: [
                        {
                            "layer": iit_layer, 
                            "start": 0, 
                            "end": int(scale_factor*hidden_dim_per_concept*0.5)
                        }
                    ],
                }
                for i in [100, 200, 300, 400, 500]:
                    print(f"Evaluating {model_c}/270 model aligned with oracle hlm.")
                    LIM = LIMDeepNeuralClassifier(
                        hidden_dim=hidden_dim, 
                        hidden_activation=torch.nn.ReLU(), 
                        num_layers=num_layers,
                        input_dim=input_dim,
                        n_classes=n_classes,
                        device=device,
                        vocab_size=variable_range[1]-variable_range[0]+1,
                        embed_dim=embedding_dim,
                    )
                    LIM_trainer = LIMTrainer(
                        LIM,
                        warm_start=True,
                        max_iter=iit_max_iter,
                        batch_size=6400,
                        n_iter_no_change=10000,
                        shuffle_train=False,
                        eta=0.01,
                        input_as_ids=True,
                        device=device,
                        class2index=class2index,
                        save_checkpoint_per_epoch=False,
                        seed=seed
                    )

                    iit_layer_out = iit_layer + 1
                    PATH = f"./saved_models_arithmetic/iit-controlmodel-1-epoch{i}-"\
                           f"{iit_layer_out}-{hidden_dim}-{hidden_dim_per_concept}-"\
                           f"{seed}.bin"
                    LIM_trainer.model.load_state_dict(torch.load(PATH))
                    LIM_trainer.model.set_analysis_mode(True)
                    
                    # train data eval
                    base_preds_train = LIM_trainer.predict(
                        base_test_train, device="cpu"
                    )
                    IIT_preds_train = LIM_trainer.iit_predict(
                        base_test_train, sources_test_train, 
                        intervention_ids_test_train, 
                        id_to_coords, device="cpu"
                    )
                    r1_train = classification_report(y_base_test_train, base_preds_train, output_dict=True)
                    r2_train = classification_report(y_IIT_test_train, IIT_preds_train, output_dict=True)

                    # test data eval
                    base_preds = LIM_trainer.predict(
                        base_test, device="cpu"
                    )
                    IIT_preds = LIM_trainer.iit_predict(
                        base_test, sources_test, 
                        intervention_ids_test, 
                        id_to_coords, device="cpu"
                    )
                    r1 = classification_report(y_base_test, base_preds, output_dict=True)
                    r2 = classification_report(y_IIT_test, IIT_preds, output_dict=True)

                    control_1_results.append(
                        [
                            seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                            "Factual Train", r1_train["weighted avg"]["f1-score"]]
                    )
                    control_1_results.append(
                        [
                            seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                            "d-IIT Train", r2_train["weighted avg"]["f1-score"]]
                    )
                    control_1_results.append(
                        [
                            seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                            "Factual Test", r1["weighted avg"]["f1-score"]]
                    )
                    control_1_results.append(
                        [
                            seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                            "d-IIT Test", r2["weighted avg"]["f1-score"]]
                    )
                    model_c += 1
control_1_df = pd.DataFrame(
    control_1_results,
    columns =['seed', 'hidden_dim', 'hidden_dim_per_concept', 'iit_layer', 'epoch', 
              'type', 'f1-score']
)

### Train distributed rule oracle model

In [None]:
dataset_fun = get_IIT_arithmetic_dataset_prod_first
iit_data_size = 64000
iit_max_iter = 50
model_c = 1
total_model_c = 135
for seed in {42, 77, 88}: # {42, 77, 88}
    utils.fix_random_seeds(seed=seed)
    train_datasetIIT = dataset_fun(
        variable_range, seed, iit_data_size, ab_c
    )
    X_base_train, y_base_train = train_datasetIIT[0:2]
    iit_data = tuple(train_datasetIIT[2:])
    for hidden_dim in {18, 36}: # {18, 36}
        for hidden_dim_per_concept in {1, 2, 4.5}: # {1, 2, 6}
            for iit_layer in [0, 1, 2]: # [0, 1, 2]
                scale_factor = hidden_dim/9
                id_to_coords = {
                    0: [{"layer": iit_layer, "start": 0, "end": int(scale_factor*hidden_dim_per_concept)}],
                    1: [{"layer": iit_layer, "start":  int(scale_factor*hidden_dim_per_concept), "end": int((scale_factor*2)*hidden_dim_per_concept)}],
                    2: [{"layer": iit_layer, "start": 0, "end": int(scale_factor*hidden_dim_per_concept)}, 
                        {"layer": iit_layer, "start":  int(scale_factor*hidden_dim_per_concept), "end": int((scale_factor*2)*hidden_dim_per_concept)}],
                }
                print("id_to_coords: ", id_to_coords)
                for i in [100, 200, 300, 400, 500]:
                    
                    iit_layer_out = iit_layer + 1
                    PATH = f"./saved_models_arithmetic/iit-prod-oraclemodel-epoch{i}-"\
                           f"{iit_layer_out}-{hidden_dim}-{hidden_dim_per_concept}-"\
                           f"{seed}.bin"
                    if os.path.isfile(PATH):
                        print(f"Found trained model thus skip: {PATH}")
                        continue
                    
                    print(f"Training {model_c}/{total_model_c} model aligned with oracle hlm with params:")
                    print(f"seed={seed}")
                    print(f"hidden_dim_per_concept={hidden_dim_per_concept}")
                    print(f"iit_layer={iit_layer}")
                    print(f"epoch={i}")
                    LIM = LIMDeepNeuralClassifier(
                        hidden_dim=hidden_dim, 
                        hidden_activation=torch.nn.ReLU(), 
                        num_layers=num_layers,
                        input_dim=input_dim,
                        n_classes=n_classes,
                        device=device,
                        vocab_size=variable_range[1]-variable_range[0]+1,
                        embed_dim=embedding_dim,
                    )
                    LIM_trainer = LIMTrainer(
                        LIM,
                        warm_start=True,
                        max_iter=iit_max_iter,
                        batch_size=6400,
                        n_iter_no_change=10000,
                        shuffle_train=False,
                        eta=0.01,
                        input_as_ids=True,
                        device=device,
                        class2index=class2index,
                        save_checkpoint_per_epoch=False,
                        seed=seed
                    )
        
                    ORACLE_PATH = f"./saved_models_arithmetic/basemodel-{i}-{num_layers}-{hidden_dim}-{seed}.bin"
                    LIM_trainer.model.load_state_dict(torch.load(ORACLE_PATH))
                    LIM_trainer.model.set_analysis_mode(True)

                    _ = LIM_trainer.fit(
                        X_base_train, 
                        y_base_train, 
                        iit_data=iit_data,
                        intervention_ids_to_coords=id_to_coords)

                    torch.save(LIM_trainer.model.state_dict(), PATH)

                    model_c += 1

### Eval distributed rule oracle model

In [None]:
prod_oracle_results = []
dataset_fun = get_IIT_arithmetic_dataset_prod_first
iit_data_size = 64000
iit_max_iter = 50
model_c = 1
total_model_c = 135
for seed in {42, 77, 88}: # {42, 77, 88}
    utils.fix_random_seeds(seed=seed)
    train_datasetIIT = dataset_fun(
        variable_range, seed, iit_data_size, ab_c
    )
    base_test_train, y_base_test_train, sources_test_train, y_IIT_test_train, intervention_ids_test_train = \
        utils.get_eval_from_train(
            train_datasetIIT, 6400
        )
    base_test, y_base_test, sources_test, y_IIT_test, intervention_ids_test = \
        dataset_fun(
            variable_range, seed+1, 6400, ab_c
        )
    for hidden_dim in {18, 36}: # {18, 36}
        for hidden_dim_per_concept in {1, 2, 4.5}: # {1, 2, 6}
            for iit_layer in [0, 1, 2]: # [0, 1, 2]
                scale_factor = hidden_dim/9
                id_to_coords = {
                    0: [{"layer": iit_layer, "start": 0, "end": int(scale_factor*hidden_dim_per_concept)}],
                    1: [{"layer": iit_layer, "start":  int(scale_factor*hidden_dim_per_concept), "end": int((scale_factor*2)*hidden_dim_per_concept)}],
                    2: [{"layer": iit_layer, "start": 0, "end": int(scale_factor*hidden_dim_per_concept)}, 
                        {"layer": iit_layer, "start":  int(scale_factor*hidden_dim_per_concept), "end": int((scale_factor*2)*hidden_dim_per_concept)}],
                }
                for i in [100, 200, 300, 400, 500]:
                    print(f"Evaluating {model_c}/270 model aligned with oracle hlm.")
                    LIM = LIMDeepNeuralClassifier(
                        hidden_dim=hidden_dim, 
                        hidden_activation=torch.nn.ReLU(), 
                        num_layers=num_layers,
                        input_dim=input_dim,
                        n_classes=n_classes,
                        device=device,
                        vocab_size=variable_range[1]-variable_range[0]+1,
                        embed_dim=embedding_dim,
                    )
                    LIM_trainer = LIMTrainer(
                        LIM,
                        warm_start=True,
                        max_iter=iit_max_iter,
                        batch_size=6400,
                        n_iter_no_change=10000,
                        shuffle_train=False,
                        eta=0.01,
                        input_as_ids=True,
                        device=device,
                        class2index=class2index,
                        save_checkpoint_per_epoch=False,
                        seed=seed
                    )

                    iit_layer_out = iit_layer + 1
                    PATH = f"./saved_models_arithmetic/iit-prod-oraclemodel-epoch{i}-"\
                           f"{iit_layer_out}-{hidden_dim}-{hidden_dim_per_concept}-"\
                           f"{seed}.bin"
                    LIM_trainer.model.load_state_dict(torch.load(PATH))
                    LIM_trainer.model.set_analysis_mode(True)
                    
                    # train data eval
                    base_preds_train = LIM_trainer.predict(
                        base_test_train, device="cpu"
                    )
                    IIT_preds_train = LIM_trainer.iit_predict(
                        base_test_train, sources_test_train, 
                        intervention_ids_test_train, 
                        id_to_coords, device="cpu"
                    )
                    r1_train = classification_report(y_base_test_train, base_preds_train, output_dict=True)
                    r2_train = classification_report(y_IIT_test_train, IIT_preds_train, output_dict=True)

                    # test data eval
                    base_preds = LIM_trainer.predict(
                        base_test, device="cpu"
                    )
                    IIT_preds = LIM_trainer.iit_predict(
                        base_test, sources_test, 
                        intervention_ids_test, 
                        id_to_coords, device="cpu"
                    )
                    r1 = classification_report(y_base_test, base_preds, output_dict=True)
                    r2 = classification_report(y_IIT_test, IIT_preds, output_dict=True)

                    prod_oracle_results.append(
                        [
                            seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                            "Factual Train", r1_train["weighted avg"]["f1-score"]]
                    )
                    prod_oracle_results.append(
                        [
                            seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                            "d-IIT Train", r2_train["weighted avg"]["f1-score"]]
                    )
                    prod_oracle_results.append(
                        [
                            seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                            "Factual Test", r1["weighted avg"]["f1-score"]]
                    )
                    prod_oracle_results.append(
                        [
                            seed, hidden_dim, hidden_dim_per_concept, iit_layer_out, i, 
                            "d-IIT Test", r2["weighted avg"]["f1-score"]]
                    )
                    model_c += 1
prod_oracle_df = pd.DataFrame(
    prod_oracle_results,
    columns =['seed', 'hidden_dim', 'hidden_dim_per_concept', 'iit_layer', 'epoch', 
              'type', 'f1-score']
)

Some plot code

In [None]:
sns.lineplot(
    data=control_1_df[
        (control_1_df["hidden_dim"]==36)&
        (control_1_df["hidden_dim_per_concept"]==6)&
        (control_1_df["iit_layer"]==1)
    ],
    x="epoch", y="f1-score", hue="type", style="type",
    dashes=False, markers=['o', 's', '^', 'D'], markersize=12, legend=True,
    alpha=0.8
)

In [None]:
sns.lineplot(
    data=control_0_df[
        (control_0_df["hidden_dim"]==36)&
        (control_0_df["hidden_dim_per_concept"]==6)&
        (control_0_df["iit_layer"]==1)
    ],
    x="epoch", y="f1-score", hue="type", style="type",
    dashes=False, markers=['o', 's', '^', 'D'], markersize=12, legend=True,
    alpha=0.8
)

In [None]:
sns.lineplot(
    data=oracle_df[
        (oracle_df["hidden_dim"]==36)&
        (oracle_df["hidden_dim_per_concept"]==6)&
        (oracle_df["iit_layer"]==1)
    ],
    x="epoch", y="f1-score", hue="type", style="type",
    dashes=False, markers=['o', 's', '^', 'D'], markersize=12, legend=True,
    alpha=0.8
)

In [None]:
sns.lineplot(
    data=prod_oracle_df[
        (prod_oracle_df["hidden_dim"]==36)&
        (prod_oracle_df["hidden_dim_per_concept"]==4.5)&
        (prod_oracle_df["iit_layer"]==1)
    ],
    x="epoch", y="f1-score", hue="type", style="type",
    dashes=False, markers=['o', 's', '^', 'D'], markersize=12, legend=True,
    alpha=0.8
)