In [1]:
from random import randrange
import torch, gc
import random
import copy
import itertools
import numpy as np
import utils
from trainer import LIMTrainer
import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import classification_report
from LIM_deep_neural_classifier import LIMDeepNeuralClassifier
import dataset_equality

gc.collect()
torch.cuda.empty_cache()
utils.fix_random_seeds()

def ab_c(a, b, c):
    return (a+b)*c

def get_IIT_arithmetic_dataset_factuals(
    variable_range, seed,
    data_size,
    combiner_func
):
    random.seed(seed)
    base = [[random.randint(
        variable_range[0], variable_range[1]
    ) for _ in range(3)] for _ in range(data_size)]
    base_y = []
    for b in base:
        b_y = combiner_func(*b)
        base_y += [b_y]
        
    return torch.tensor(base, dtype=torch.long), torch.tensor(base_y, dtype=torch.long)

def get_IIT_arithmetic_dataset_factual_pairs(
    variable_range, seed,
    data_size,
    combiner_func
):
    random.seed(seed)
    base = [[random.randint(
        variable_range[0], variable_range[1]
    ) for _ in range(3)] for _ in range(data_size)]
    base_y = []
    for b in base:
        b_y = combiner_func(*b)
        base_y += [b_y]
        
    source = [[random.randint(
        variable_range[0], variable_range[1]
    ) for _ in range(3)] for _ in range(data_size)]
    source_y = []
    for s in source:
        s_y = combiner_func(*s)
        source_y += [s_y]
        
    return torch.tensor(base, dtype=torch.long), \
        torch.tensor(base_y, dtype=torch.long), \
        torch.tensor(source, dtype=torch.long), \
        torch.tensor(source_y, dtype=torch.long)

def get_IIT_arithmetic_dataset_sum_first_V1(
    variable_range, seed,
    data_size,
    combiner_func
):
    base, base_y, source, source_y = get_IIT_arithmetic_dataset_factual_pairs(
        variable_range, seed,
        data_size,
        combiner_func
    )
    
    source_sum = []
    for s in source:
        s_sum = s[0]+s[1]
        source_sum += [s_sum]
    
    counterfactual_y = []
    for i in range(len(base)):
        c_y = source_sum[i]*base[i][-1]
        counterfactual_y += [c_y]
        
    return torch.tensor(base, dtype=torch.long), \
        torch.tensor(base_y, dtype=torch.long), \
        torch.tensor(source, dtype=torch.long), \
        torch.tensor(counterfactual_y, dtype=torch.long), \
        torch.tensor([0 for _ in range(len(base))], dtype=torch.long)

def get_IIT_arithmetic_dataset_sum_first_V2(
    variable_range, seed,
    data_size,
    combiner_func
):
    base, base_y, source, source_y = get_IIT_arithmetic_dataset_factual_pairs(
        variable_range, seed,
        data_size,
        combiner_func
    )
    
    counterfactual_y = []
    for i in range(len(base)):
        c_y = (base[i][0]+base[i][1])*source[i][-1]
        counterfactual_y += [c_y]
        
    return torch.tensor(base, dtype=torch.long), \
        torch.tensor(base_y, dtype=torch.long), \
        torch.tensor(source, dtype=torch.long), \
        torch.tensor(counterfactual_y, dtype=torch.long), \
        torch.tensor([1 for _ in range(len(base))], dtype=torch.long)

def get_IIT_arithmetic_dataset_prod_first_V1(
    variable_range, seed,
    data_size,
    combiner_func
):
    base, base_y, source, source_y = get_IIT_arithmetic_dataset_factual_pairs(
        variable_range, seed,
        data_size,
        combiner_func
    )
    
    counterfactual_y = []
    for i in range(len(base)):
        c_y = (base[i][1]*base[i][-1])+(source[i][0]*source[i][-1])
        counterfactual_y += [c_y]
        
    return torch.tensor(base, dtype=torch.long), \
        torch.tensor(base_y, dtype=torch.long), \
        torch.tensor(source, dtype=torch.long), \
        torch.tensor(counterfactual_y, dtype=torch.long), \
        torch.tensor([0 for _ in range(len(base))], dtype=torch.long)

def get_IIT_arithmetic_dataset_prod_first_V2(
    variable_range, seed,
    data_size,
    combiner_func
):
    base, base_y, source, source_y = get_IIT_arithmetic_dataset_factual_pairs(
        variable_range, seed,
        data_size,
        combiner_func
    )
    
    counterfactual_y = []
    for i in range(len(base)):
        c_y = (base[i][0]*base[i][-1])+(source[i][1]*source[i][-1])
        counterfactual_y += [c_y]
        
    return torch.tensor(base, dtype=torch.long), \
        torch.tensor(base_y, dtype=torch.long), \
        torch.tensor(source, dtype=torch.long), \
        torch.tensor(counterfactual_y, dtype=torch.long), \
        torch.tensor([1 for _ in range(len(base))], dtype=torch.long)

In [2]:
def get_IIT_arithmetic_dataset_control_V1(
    variable_range, seed,
    data_size,
    combiner_func
):
    base, base_y, source, source_y = get_IIT_arithmetic_dataset_factual_pairs(
        variable_range, seed,
        data_size,
        combiner_func
    )
    
    counterfactual_y = []
    for i in range(len(base)):
        c_y = combiner_func(source[i][0], base[i][1], base[i][-1])
        counterfactual_y += [c_y]
        
    return torch.tensor(base, dtype=torch.long), \
        torch.tensor(base_y, dtype=torch.long), \
        torch.tensor(source, dtype=torch.long), \
        torch.tensor(counterfactual_y, dtype=torch.long), \
        torch.tensor([2 for _ in range(len(base))], dtype=torch.long)

def get_IIT_arithmetic_dataset_control_V2(
    variable_range, seed,
    data_size,
    combiner_func
):
    base, base_y, source, source_y = get_IIT_arithmetic_dataset_factual_pairs(
        variable_range, seed,
        data_size,
        combiner_func
    )
    
    counterfactual_y = []
    for i in range(len(base)):
        c_y = combiner_func(base[i][0], source[i][1], base[i][-1])
        counterfactual_y += [c_y]
        
    return torch.tensor(base, dtype=torch.long), \
        torch.tensor(base_y, dtype=torch.long), \
        torch.tensor(source, dtype=torch.long), \
        torch.tensor(counterfactual_y, dtype=torch.long), \
        torch.tensor([2 for _ in range(len(base))], dtype=torch.long)

In [3]:
variable_range = [1, 9]
X_base_train, y_base_train = get_IIT_arithmetic_dataset_factuals(
    variable_range, 42, 640000, ab_c
)
min_y = (variable_range[0]+variable_range[0])*variable_range[0]
max_y = (variable_range[1]+variable_range[1])*variable_range[1]
classes = sorted(set([i for i in range(min_y, max_y+1)]))
n_classes = (variable_range[1]+variable_range[1])*variable_range[1]
class2index = dict(zip(classes, range(n_classes)))

In [5]:
embedding_dim = 9
device = "cpu"
max_iter = 10

LIM = LIMDeepNeuralClassifier(
    hidden_dim=embedding_dim*3, 
    hidden_activation=torch.nn.ReLU(), 
    num_layers=2,
    input_dim=embedding_dim*3,
    n_classes=n_classes,
    device=device,
    vocab_size=(variable_range[1]-variable_range[0]+1),
    embed_dim=embedding_dim,
)

In [6]:
LIM_trainer = LIMTrainer(
    LIM,
    warm_start=True,
    max_iter=max_iter,
    batch_size=6400,
    n_iter_no_change=10000,
    shuffle_train=False,
    eta=0.001,
    save_checkpoint_per_epoch=False,
    input_as_ids=True,
    device=device,
    class2index=class2index,
)

In [7]:
_ = LIM_trainer.fit(
    X_base_train, 
    y_base_train, 
    iit_data=None,
    intervention_ids_to_coords=None
)

Finished epoch 10 of 10; error is 7.665561750531197

In [8]:
base_test, y_base_test = get_IIT_arithmetic_dataset_factuals(
    variable_range, 42, 1000, ab_c
)

In [9]:
base_preds = LIM_trainer.predict(base_test, device="cpu")

In [10]:
print(classification_report(y_base_test, base_preds))

              precision    recall  f1-score   support

           2       1.00      1.00      1.00         1
           3       1.00      1.00      1.00         1
           4       1.00      1.00      1.00         4
           5       1.00      1.00      1.00         5
           6       1.00      1.00      1.00         9
           7       1.00      1.00      1.00         5
           8       1.00      1.00      1.00        10
           9       1.00      1.00      1.00         9
          10       1.00      1.00      1.00        19
          11       1.00      1.00      1.00         6
          12       1.00      1.00      1.00        23
          13       1.00      1.00      1.00         4
          14       1.00      1.00      1.00        18
          15       1.00      1.00      1.00        12
          16       1.00      1.00      1.00        11
          17       1.00      1.00      1.00         5
          18       1.00      1.00      1.00        28
          20       1.00    

In [11]:
PATH = f"./saved_models/basemodel.bin"
torch.save(LIM_trainer.model.state_dict(), PATH)

d-IIT

In [12]:
X_base_train, y_base_train, X_sources_train, y_IIT_train, intervention_ids_train = get_IIT_arithmetic_dataset_sum_first_V1(
    variable_range, 42, 1280000, ab_c
)



In [13]:
layer = 0
id_to_coords = {
    0: [{"layer": layer, "start": 0, "end": 2*embedding_dim}],
}
iit_data = ([X_sources_train], y_IIT_train, intervention_ids_train)
PATH = f"./saved_models/basemodel.bin"
LIM_trainer.model.load_state_dict(torch.load(PATH))
LIM_trainer.model.set_analysis_mode(True)

In [14]:
_ = LIM_trainer.fit(
    X_base_train, 
    y_base_train, 
    iit_data=iit_data,
    intervention_ids_to_coords=id_to_coords)

Finished epoch 10 of 10; error is 327.2347227334976

In [32]:
X_base_test, y_base_test, X_sources_test, y_IIT_test, intervention_ids_test = get_IIT_arithmetic_dataset_sum_first_V1(
    variable_range, 42, 1000, ab_c
)



In [33]:
IIT_preds = LIM_trainer.iit_predict(
    X_base_test, [X_sources_test], 
    intervention_ids_test, 
    id_to_coords, device="cpu"
)

In [34]:
print(classification_report(y_IIT_test, IIT_preds))

              precision    recall  f1-score   support

           2       0.00      0.00      0.00         1
           3       0.00      0.00      0.00         0
           4       0.50      0.50      0.50         6
           5       0.00      0.00      0.00         2
           6       0.70      1.00      0.82         7
           7       0.67      0.50      0.57         4
           8       1.00      0.53      0.69        19
           9       0.63      0.77      0.69        22
          10       0.50      0.24      0.32        21
          11       0.30      0.23      0.26        13
          12       0.36      0.45      0.40        22
          13       0.20      0.25      0.22         4
          14       0.20      0.08      0.12        12
          15       0.64      0.50      0.56        18
          16       0.21      0.19      0.20        16
          17       1.00      0.50      0.67         2
          18       0.81      0.54      0.65        24
          20       0.27    