# Analyze NetBERT improvements over BERT predictions
The goal is to extract the right predictions of BERT-base on the dev set, and pass only these subset to NetBERT to see if it performs at least as well as BERT-base. Then, extract the wrong predictions of BERT-base and see where NetBERT improves, which specific cases, which classes in particluar, which type of sentences (badly written, not clear?)

In [22]:
import os
import json

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

## 1. Prepare eval datasets from bert-base predictions

In [25]:
def create_eval_dataset(infile, outfile, class_mappings):
    """
    """
    # Load right predictions from BERT-base
    df = pd.read_csv(os.path.join(dirpath, infile), index_col=0)

    # Create columns with classes.
    df['Class'] = df.apply(lambda row: class_mappings[str(row.Class_id)], axis=1)
    df['Prediction'] = df.apply(lambda row: class_mappings[str(row.Prediction_id)], axis=1)

    # Drop useless columns.
    to_drop = ['Class_id', 'Prediction_id', 'Prediction']
    df.drop(to_drop, axis=1, inplace=True)
    
    # Save dataset for evaluation with NetBERT.
    df.to_csv(os.path.join(dirpath, outfile))
    return df


# Load class-class_id mapping.
dirpath = '/raid/antoloui/Master-thesis/Code/Extrinsic_evaluation/Classification/output/bert_base_cased/'
with open(os.path.join(dirpath, 'map_classes.json')) as f:
    class_mappings = json.load(f)

# Create eval dataset from BERT-base right predictions.
df_bert_right = create_eval_dataset(infile='preds_right.csv', outfile='eval_right_preds.csv', class_mappings=class_mappings)

# Create eval dataset from BERT-base wrong predictions.
df_bert_wrong = create_eval_dataset(infile='preds_wrong.csv', outfile='eval_wrong_preds.csv', class_mappings=class_mappings)

# Create full test dataset.
df_bert = pd.concat([df_bert_right,df_bert_wrong], ignore_index=True)
df_bert.to_csv(os.path.join(dirpath, 'eval_preds.csv'))

## 2. Analyze predictions of NetBERT

### 2.1. On queries correclty classified by BERT

### 2.2. On queries wrongly classified by BERT

## NetBERT predictions

In [None]:
dirpath = '/raid/antoloui/Master-thesis/Code/Extrinsic_evaluation/Classification/output/netbert-1880000/'

### On BERT right predictions only

In [None]:
# Load right predictions from BERT-base
df_bert_right = pd.read_csv(os.path.join(dirpath, 'preds_right.csv'), index_col=0)

# Create columns with classes.
df_bert_right['Class'] = df_bert_right.apply(lambda row: class_mappings[str(row.Class_id)], axis=1)
df_bert_right['Prediction'] = df_bert_right.apply(lambda row: class_mappings[str(row.Prediction_id)], axis=1)

# Save dataset for evaluation with NetBERT.
to_drop = ['Class_id', 'Prediction_id', 'Prediction']
df_bert_right.drop(to_drop, axis=1, inplace=True)
df_bert_right.to_csv(os.path.join(dirpath, 'eval_right_preds.csv'))
df_bert_right

In [21]:
def plot_confusion_matrix(cm, classes):
    """
    This function prints and plots the confusion matrix.
    """
    cm = np.array(cm)
    df_cm = pd.DataFrame(cm, index=classes, columns=classes)
    
    plt.figure(figsize = (10,7))
    ax = sn.heatmap(df_cm, annot=True)
    
    ax.set_xticklabels(ax.get_xticklabels(), fontsize=8, horizontalalignment='right', rotation=45) 
    ax.set_yticklabels(ax.get_yticklabels(), fontsize=8)
    
    plt.title('Confusion matrix', fontsize=18)
    plt.ylabel('True labels', fontsize=12)
    plt.xlabel('Predicted labels', fontsize=12)
    plt.tight_layout()
    
    #plt.savefig(outdir+"confusion_matrix.png")
    #plt.close()
    plt.show()
    return


# Load results.
with open(os.path.join(dirpath, 'scores_bert_right_preds.json')) as f:
    result = json.load(f)
    
# Accuracy
print("Accuracy: {}".format(result['Accuracy']))

# Confusion matrix
plot_confusion_matrix(result['conf_matrix'], )

{'Accuracy': 0.9631979695431472,
 'Precision': 0.9616735588794019,
 'Recall': 0.9619604922685113,
 'F1 score': 0.961536700021114,
 'MCC': 0.9540423108263494,
 'conf_matrix': [[0.9403973509933775,
   0.033112582781456956,
   0.026490066225165563,
   0.0,
   0.0],
  [0.021739130434782608,
   0.9565217391304348,
   0.0,
   0.014492753623188406,
   0.007246376811594203],
  [0.0, 0.046052631578947366, 0.9473684210526315, 0.0, 0.006578947368421052],
  [0.0, 0.017142857142857144, 0.0, 0.9771428571428571, 0.005714285714285714],
  [0.0, 0.0, 0.011627906976744186, 0.0, 0.9883720930232558]]}