In [1]:
import pandas as pd
import seaborn as sn
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import classification_report,accuracy_score, matthews_corrcoef, confusion_matrix
pd.set_option('display.max_colwidth', None)

In [1]:
from transformers import RobertaTokenizer

In [5]:
tok = RobertaTokenizer.from_pretrained("roberta-base-whole-word-masking")

HTTPError: 404 Client Error: Not Found for url: https://huggingface.co/api/models/roberta-base-whole-word-masking

# Evaluate on Glue Diagnostic, "Ax"

This scripts compares predictions on the Glue Diagnostic set against the gold labels. The script expects that the predictions are on the format: index\tpredictions\n

That is:

1 contradiction

2 entailment


In [12]:
def get_dataframe(path: str) -> pd.DataFrame:
    return pd.read_csv(path, sep="\t")

In [13]:
def merge_df(predictions: pd.DataFrame, gold: pd.DataFrame = gold_df):
    final = gold.copy()
    final["prediction"] = predictions["prediction"]
    return final

In [19]:
def compare_full(df_gold: pd.DataFrame, predictions_path: str):
    predictions = get_dataframe(predictions_path)
    final = merge_df(predictions, df_gold)
    true_labels = list(final["Label"])
    pred_labels = list(final["prediction"])
    print(f"MCC score: {matthews_corrcoef(true_labels, pred_labels)}")
    print(classification_report(true_labels, pred_labels, target_names=["contradiction", "neutral", "entailment"]))

In [25]:
def compare_common_sense(df_gold: pd.DataFrame, predictions_path: str):
    predictions = get_dataframe(predictions_path)
    final = merge_df(predictions, df_gold)
    
    knowledge = final[~final['Knowledge'].isnull()]
    true_labels = list(knowledge["Label"])
    pred_labels = list(knowledge["prediction"])
    
    knowledge_common = knowledge[knowledge.Knowledge.str.contains('Common',case=False)]
    true_labels = list(knowledge_common["Label"])
    pred_labels = list(knowledge_common["prediction"])

    print(f"MCC score: {matthews_corrcoef(true_labels, pred_labels)}")
    print(classification_report(true_labels, pred_labels, target_names=["contradiction", "neutral", "entailment"]))

In [26]:
def compare_wold_knowledge(df_gold: pd.DataFrame, predictions_path: str):
    predictions = get_dataframe(predictions_path)
    final = merge_df(predictions, df_gold)
    
    knowledge = final[~final['Knowledge'].isnull()]
    true_labels = list(knowledge["Label"])
    pred_labels = list(knowledge["prediction"])
    
    knowledge_world = knowledge[knowledge.Knowledge.str.contains('World',case=False)]
    true_labels = list(knowledge_world["Label"])
    pred_labels = list(knowledge_world["prediction"])

    print(f"MCC score: {matthews_corrcoef(true_labels, pred_labels)}")
    print(classification_report(true_labels, pred_labels, target_names=["contradiction", "neutral", "entailment"]))

In [15]:
gold = get_dataframe("../data/diagnostic_gold.tsv")

### Full diagnostic

Bert base uncased, 3e5, 16 batch size, 3 epochs:

In [20]:
compare_full(gold, "../diagnostic_results/bert_base_uncased_3e5_16_3/results_ax.txt")

MCC score: 0.3715884275271643
               precision    recall  f1-score   support

contradiction       0.56      0.52      0.54       258
      neutral       0.63      0.75      0.69       460
   entailment       0.56      0.45      0.50       386

     accuracy                           0.59      1104
    macro avg       0.58      0.58      0.58      1104
 weighted avg       0.59      0.59      0.59      1104



Roberta base, 3e5, 16 batch size, 3 epochs:

In [29]:
compare_full(gold, "../diagnostic_results/roberta_base_3e5_16_3/results_ax.txt")

MCC score: 0.41188796925134874
               precision    recall  f1-score   support

contradiction       0.59      0.55      0.57       258
      neutral       0.64      0.78      0.71       460
   entailment       0.60      0.47      0.53       386

     accuracy                           0.62      1104
    macro avg       0.61      0.60      0.60      1104
 weighted avg       0.62      0.62      0.61      1104



### Common sense

Bert base uncased, 3e5, 16 batch size, 3 epochs:

In [24]:
compare_common_sense(gold, "../diagnostic_results/bert_base_uncased_3e5_16_3/results_ax.txt")

MCC score: 0.32741091196998096
               precision    recall  f1-score   support

contradiction       0.68      0.52      0.59        58
      neutral       0.60      0.50      0.54        56
   entailment       0.39      0.64      0.48        36

     accuracy                           0.54       150
    macro avg       0.56      0.55      0.54       150
 weighted avg       0.58      0.54      0.55       150



Roberta base, 3e5, 16 batch size, 3 epochs:

In [30]:
compare_common_sense(gold, "../diagnostic_results/roberta_base_3e5_16_3/results_ax.txt")

MCC score: 0.3708507737093856
               precision    recall  f1-score   support

contradiction       0.65      0.48      0.55        58
      neutral       0.62      0.59      0.61        56
   entailment       0.46      0.69      0.56        36

     accuracy                           0.57       150
    macro avg       0.58      0.59      0.57       150
 weighted avg       0.60      0.57      0.57       150



### World knowledge


Bert base uncased, 3e5, 16 batch size, 3 epochs:

In [27]:
compare_wold_knowledge(gold, "../diagnostic_results/bert_base_uncased_3e5_16_3/results_ax.txt")

MCC score: 0.13369507292599506
               precision    recall  f1-score   support

contradiction       0.23      0.19      0.21        32
      neutral       0.66      0.52      0.58        63
   entailment       0.33      0.49      0.39        39

     accuracy                           0.43       134
    macro avg       0.41      0.40      0.39       134
 weighted avg       0.46      0.43      0.44       134



Roberta base, 3e5, 16 batch size, 3 epochs:

In [31]:
compare_wold_knowledge(gold, "../diagnostic_results/roberta_base_3e5_16_3/results_ax.txt")

MCC score: 0.14983036699702876
               precision    recall  f1-score   support

contradiction       0.24      0.16      0.19        32
      neutral       0.66      0.56      0.60        63
   entailment       0.33      0.51      0.40        39

     accuracy                           0.45       134
    macro avg       0.41      0.41      0.40       134
 weighted avg       0.46      0.45      0.45       134

