In [1]:
import dataset_nli
import torch
from transformers import BertTokenizer, BertModel
import random
import copy
import itertools
import numpy as np
import utils

from sklearn.metrics import classification_report
from LIM_bert import LIMBERTClassifier
from ii_benchmark import IIBenchmarkMoNli

utils.fix_random_seeds()

In [2]:
device = "cuda:0"
data_size = 10000
test_data_size = 1000
seed = 42

In [21]:
def get_IIT_nli_dataset_factual_pairs(
    data_size,
    tokenizer_name,
    split="train",
):
    bert_tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
    
    def encoding(X):
        if X[0][-1] != ".":
            input = [". ".join(X)]
        else:
            input = [" ".join(X)]
        data = bert_tokenizer.batch_encode_plus(
                input,
                max_length=128,
                add_special_tokens=True,
                padding='max_length',
                truncation=True,
                return_attention_mask=True)
        indices = torch.tensor(data['input_ids'])
        mask = torch.tensor(data['attention_mask'])
        return (indices, mask)
    
    dataset = dataset_nli.IIT_MoNLIDataset(
        embed_func=encoding,
        suffix=split,
        size=data_size)
    
    X_base, y_base = dataset.create_factual_pairs()
    y_base = torch.tensor(y_base)
    return X_base, y_base

def get_IIT_nli_dataset_neghyp_V2(
    data_size,
    tokenizer_name,
    split="train",
):
    bert_tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
    
    def encoding(X):
        if X[0][-1] != ".":
            input = [". ".join(X)]
        else:
            input = [" ".join(X)]
        data = bert_tokenizer.batch_encode_plus(
                input,
                max_length=128,
                add_special_tokens=True,
                padding='max_length',
                truncation=True,
                return_attention_mask=True)
        indices = torch.tensor(data['input_ids'])
        mask = torch.tensor(data['attention_mask'])
        return (indices, mask)
    
    dataset = dataset_nli.IIT_MoNLIDataset(
        embed_func=encoding,
        suffix=split,
        size=data_size)
    
    X_base, y_base, X_sources,  y_IIT, interventions = dataset.create_neghyp_V2()
    y_base = torch.tensor(y_base)
    y_IIT = torch.tensor(y_IIT)
    interventions = torch.tensor(interventions)
    return X_base, y_base, X_sources,  y_IIT, interventions

factual model training

In [35]:
benchmark = IIBenchmarkMoNli(
        variable_names=['LEX'],
        data_parameters={
            'train_size': data_size, 'test_size': test_data_size
        },
        model_parameters={
            'weights_name': 'ishan/bert-base-uncased-mnli',
            'max_length': 128,
            'n_classes': 2,
            'hidden_dim': 768,
            'target_layers' : [10],
            'target_dims':{
                "start" : 0,
                "end" : 786,
            },
            'debug':False, 
            'device': device
        },
        training_parameters={
            'warm_start': False, 'max_iter': 3, 'batch_size': 32, 'n_iter_no_change': 10000, 
            'shuffle_train': True, 'eta': 0.00002, 'device': device
        },
        seed=seed
)
LIM_bert = benchmark.create_model()
LIM_trainer = benchmark.create_classifier(LIM_bert)

Some weights of the model checkpoint at ishan/bert-base-uncased-mnli were not used when initializing BertModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [36]:
utils.fix_random_seeds(seed=42)
train_datasetIIT = get_IIT_nli_dataset_neghyp_V2(
    data_size=10000, 
    split="train",
    tokenizer_name=benchmark.model_parameters["weights_name"]
)
X_base_train, y_base_train = train_datasetIIT[0:2]
iit_data = tuple(train_datasetIIT[2:])

In [37]:
test_datasetIIT = get_IIT_nli_dataset_neghyp_V2(
    data_size=1000, 
    split="test",
    tokenizer_name=benchmark.model_parameters["weights_name"]
)
X_base_test, y_base_test = test_datasetIIT[0:2]
iit_data_test = tuple(test_datasetIIT[2:])

In [38]:
_ = LIM_trainer.fit(
    X_base_train, 
    y_base_train
)

Finished epoch 3 of 3; error is 1.2753275868890341

In [39]:
torch.cuda.empty_cache()
preds = LIM_trainer.predict(X_base_test)
print(classification_report(y_base_test, preds.cpu()))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       500
           1       1.00      1.00      1.00       500

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000



In [40]:
PATH = f"./saved_models_nli/basemodel-last-bert-{seed}.bin"
torch.save(LIM_bert.state_dict(), PATH)

d-iit training

In [41]:
benchmark = IIBenchmarkMoNli(
        variable_names=['LEX'],
        data_parameters={
            'train_size': data_size, 'test_size': test_data_size
        },
        model_parameters={
            'weights_name': 'ishan/bert-base-uncased-mnli',
            'max_length': 128,
            'n_classes': 2,
            'hidden_dim': 768,
            'target_layers' : [10],
            'target_dims':{
                "start" : 0,
                "end" : 786,
            },
            'debug':False, 
            'device': device
        },
        training_parameters={
            'warm_start': False, 'max_iter': 10, 'batch_size': 32, 'n_iter_no_change': 10000, 
            'shuffle_train': True, 'eta': 0.002, 'device': device
        },
        seed=seed
)
LIM_bert = benchmark.create_model()
ORACLE_PATH = f"./saved_models_nli/basemodel-last-bert-{seed}.bin"
LIM_bert.load_state_dict(torch.load(ORACLE_PATH))
LIM_trainer = benchmark.create_classifier(LIM_bert)

Some weights of the model checkpoint at ishan/bert-base-uncased-mnli were not used when initializing BertModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [42]:
LIM_trainer.model.set_analysis_mode(True)
torch.cuda.empty_cache()
preds = LIM_trainer.predict(X_base_test)
print(classification_report(y_base_test, preds.cpu()))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       500
           1       1.00      1.00      1.00       500

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000



In [None]:
intervention_ids_to_coords = {
    1:[{"layer":10, "start":0, "end":64}]
}
_ = LIM_trainer.fit(
    X_base_train, 
    y_base_train, 
    iit_data=iit_data,
    intervention_ids_to_coords=intervention_ids_to_coords)

In [45]:
sources_test, y_IIT_test, intervention_ids_test = iit_data_test

In [48]:
IIT_preds_test = LIM_trainer.iit_predict(
    X_base_test, sources_test, 
    intervention_ids_test, 
    intervention_ids_to_coords, device="cpu"
)

In [49]:
print(classification_report(y_IIT_test, IIT_preds_test.cpu()))

              precision    recall  f1-score   support

           0       0.95      0.94      0.94       500
           1       0.94      0.95      0.94       500

    accuracy                           0.94      1000
   macro avg       0.94      0.94      0.94      1000
weighted avg       0.94      0.94      0.94      1000

