In [1]:
import torch
import random
import copy
import itertools
import numpy as np
import utils
from trainer import BERTLIMTrainer

from sklearn.metrics import classification_report
from LIM_bert import LIMBERTClassifier
import dataset_nli

from transformers import BertModel, BertTokenizer
utils.fix_random_seeds()

In [2]:
weights_name = "bert-base-uncased"
bert_tokenizer = BertTokenizer.from_pretrained(weights_name)
max_length = 20
n_classes = 2

In [3]:
def encoding(X):
    data = bert_tokenizer.batch_encode_plus(
            [" ".join(X)],
            max_length=max_length,
            add_special_tokens=True,
            padding='max_length',
            truncation=True,
            return_attention_mask=True)
    indices = torch.tensor(data['input_ids'])
    mask = torch.tensor(data['attention_mask'])
    return (indices, mask)

X_nmonli_train, y_nmonli_train = dataset_nli.get_NMoNLI_dataset(encoding, "train")

X_nmonli_test, y_nmonli_test = dataset_nli.get_NMoNLI_dataset(encoding, "test")

X_pmonli, y_pmonli = dataset_nli.get_PMoNLI_dataset(encoding)

In [4]:
bert = BertModel.from_pretrained(weights_name)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
LIM = LIMBERTClassifier(n_classes, bert, max_length, debug=True)

In [6]:
LIM_trainer = BERTLIMTrainer(
    LIM,
    warm_start=True,
    max_iter=10,
    batch_size=16,
    n_iter_no_change=10000,
    shuffle_train=True,
    eta=0.0001)

In [7]:


X_monli_train = ((*X_nmonli_train[0],*X_pmonli[0]), (*X_nmonli_train[1],*X_pmonli[1]))
y_monli_train = torch.cat([y_nmonli_train, y_pmonli])

_ = LIM_trainer.fit(
    X_monli_train, 
    y_monli_train, 
    iit_data=None,
    intervention_ids_to_coords=None)

Finished epoch 2 of 2; error is 108.46288371086128

In [8]:
preds = LIM_trainer.predict(X_nmonli_train, device="cpu")
print(classification_report(y_nmonli_train, preds))

              precision    recall  f1-score   support

           0       0.50      1.00      0.67       501
           1       0.00      0.00      0.00       501

    accuracy                           0.50      1002
   macro avg       0.25      0.50      0.33      1002
weighted avg       0.25      0.50      0.33      1002



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [9]:
preds = LIM_trainer.predict(X_pmonli, device="cpu")
print(classification_report(y_pmonli, preds))

              precision    recall  f1-score   support

           0       0.50      1.00      0.67       738
           1       0.00      0.00      0.00       738

    accuracy                           0.50      1476
   macro avg       0.25      0.50      0.33      1476
weighted avg       0.25      0.50      0.33      1476



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [10]:
preds = LIM_trainer.predict(X_nmonli_test, device="cpu")
print(classification_report(y_nmonli_test, preds))

              precision    recall  f1-score   support

           0       0.50      1.00      0.67       100
           1       0.00      0.00      0.00       100

    accuracy                           0.50       200
   macro avg       0.25      0.50      0.33       200
weighted avg       0.25      0.50      0.33       200



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


# IIT

In [11]:
iit_MoNLI_dataset = dataset_nli.get_IIT_MoNLI_dataset(encoding,"train", 10000)

X_base_train, y_base_train = iit_MoNLI_dataset[0:2]
iit_data_train = tuple(iit_MoNLI_dataset[2:])

In [12]:
LEXVAR = 0

id_to_coords = {
    LEXVAR: [{"layer": 4, "start": 0, "end": 512}]
    }

_ = LIM_trainer.fit(
    X_base_train, 
    y_base_train, 
    iit_data=iit_data_train,
    intervention_ids_to_coords=id_to_coords)

torch.Size([10008, 1, 20]) torch.Size([10008, 1, 20]) torch.Size([10008]) torch.Size([10008, 1, 20]) torch.Size([10008, 1, 20]) torch.Size([10008]) torch.Size([10008])


Finished epoch 2 of 2; error is 870.7685129642487

In [None]:
base_preds = LIM_trainer.iit_predict(
                        X_base_train,
                        iit_data_train[0], 
                        iit_data_train[2], 
                        id_to_coords,
                        device="cpu")
print(classification_report(iit_data_train[1], base_preds))

In [None]:
iit_MoNLI_dataset = dataset_nli.get_IIT_MoNLI_dataset(encoding,"test", 10000)

X_base_test, y_base_test = iit_MoNLI_dataset[0:2]
iit_data_test = tuple(iit_MoNLI_dataset[2:])


In [None]:
base_preds = LIM_trainer.iit_predict(
                        X_base_test,
                        iit_data_test[0], 
                        iit_data_test[2], 
                        id_to_coords,
                        device="cpu")
print(classification_report(iit_data_test[1], base_preds))